增加后台管理,支持docker部署

This commit is contained in:
ageer
2024-05-17 02:00:31 +08:00
parent ef7434ed04
commit 7fe89a931b
59 changed files with 3911 additions and 1501 deletions

View File

@@ -1,36 +1,22 @@
package com.xmzs.midjourney.controller;
import cn.hutool.json.JSONUtil;
import com.xmzs.common.chat.constant.OpenAIConst;
import com.xmzs.common.core.domain.model.LoginUser;
import com.xmzs.common.core.exception.base.BaseException;
import com.xmzs.common.satoken.utils.LoginHelper;
import com.xmzs.midjourney.domain.InsightFace;
import com.xmzs.system.domain.bo.ChatMessageBo;
import com.xmzs.system.service.IChatMessageService;
import com.xmzs.system.service.IChatService;
import com.xmzs.system.service.ISseService;
import com.xmzs.midjourney.domain.MjPriceConfig;
import com.xmzs.midjourney.util.MjOkHttpUtil;
import com.xmzs.system.service.IChatCostService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okio.Buffer;
import okio.BufferedSink;
import okio.GzipSink;
import okio.Okio;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.apache.commons.lang3.math.NumberUtils;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
@Api(tags = "任务查询")
@RestController
@@ -39,62 +25,22 @@ import java.util.concurrent.TimeUnit;
@Slf4j
public class FaceController {
@Value("${chat.apiKey}")
private String apiKey;
@Value("${chat.apiHost}")
private String apiHost;
private final IChatCostService chatCostService;
@Autowired
private IChatService chatService;
private final MjOkHttpUtil mjOkHttpUtil;
@Autowired
private ISseService sseService;
private final MjPriceConfig priceConfig;
@ApiOperation(value = "换脸")
@PostMapping("/insight-face/swap")
public String insightFace(@RequestBody InsightFace insightFace) {
// 查询是否是付费用户
sseService.checkUserGrade();
// 扣除接口费用
chatService.mjTaskDeduct("换脸", OpenAIConst.MJ_COST_TYPE2);
OkHttpClient client = new OkHttpClient.Builder()
.connectTimeout(30, TimeUnit.SECONDS) // 连接超时时间
.writeTimeout(30, TimeUnit.SECONDS) // 写入超时时间
.readTimeout(30, TimeUnit.SECONDS) // 读取超时时间
.build();
// 创建一个Request对象来配置你的请求
// 扣除接口费用并且保存消息记录
chatCostService.taskDeduct("mj","换脸", NumberUtils.toDouble(priceConfig.getFaceSwapping(), 0.3));
// 创建请求体这里使用JSON作为媒体类型
String jsonStr = JSONUtil.toJsonStr(insightFace);
MediaType JSON = MediaType.get("application/json; charset=utf-8");
okhttp3.RequestBody body = okhttp3.RequestBody.create(jsonStr, JSON);
Buffer buffer = new Buffer();
GzipSink gzipSink = new GzipSink(buffer);
BufferedSink gzipBufferedSink = Okio.buffer(gzipSink);
try {
body.writeTo(gzipBufferedSink);
gzipBufferedSink.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
// 创建POST请求
Request request = new Request.Builder()
.header("mj-api-secret", apiKey)
.header("Content-Encoding", "gzip")
.url(apiHost + "mj/insight-face/swap") // 替换为你的URL
.post(body)
.build();
try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);
if (response.body() != null) {
return response.body().string();
}
} catch (IOException e) {
log.error("换脸失败! {}", e.getMessage());
}
return null;
String insightFaceJson = JSONUtil.toJsonStr(insightFace);
String url = "mj/insight-face/swap";
Request request = mjOkHttpUtil.createPostRequest(url, insightFaceJson);
return mjOkHttpUtil.executeRequest(request);
}
}

View File

@@ -1,61 +1,24 @@
package com.xmzs.midjourney.controller;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.json.JSONUtil;
import com.xmzs.common.chat.constant.OpenAIConst;
import com.xmzs.common.core.domain.model.LoginUser;
import com.xmzs.common.core.exception.base.BaseException;
import com.xmzs.common.satoken.utils.LoginHelper;
import com.xmzs.midjourney.Constants;
import com.xmzs.midjourney.ProxyProperties;
import com.xmzs.midjourney.ReturnCode;
import com.xmzs.midjourney.domain.MjPriceConfig;
import com.xmzs.midjourney.dto.*;
import com.xmzs.midjourney.enums.TaskAction;
import com.xmzs.midjourney.enums.TaskStatus;
import com.xmzs.midjourney.enums.TranslateWay;
import com.xmzs.midjourney.exception.BannedPromptException;
import com.xmzs.midjourney.result.SubmitResultVO;
import com.xmzs.midjourney.service.TaskService;
import com.xmzs.midjourney.service.TaskStoreService;
import com.xmzs.midjourney.service.TranslateService;
import com.xmzs.midjourney.support.Task;
import com.xmzs.midjourney.support.TaskCondition;
import com.xmzs.midjourney.util.BannedPromptUtils;
import com.xmzs.midjourney.util.ConvertUtils;
import com.xmzs.midjourney.util.MimeTypeUtils;
import com.xmzs.midjourney.util.SnowFlake;
import com.xmzs.midjourney.util.TaskChangeParams;
import com.xmzs.system.domain.bo.ChatMessageBo;
import com.xmzs.system.service.IChatMessageService;
import com.xmzs.system.service.IChatService;
import com.xmzs.system.service.ISseService;
import eu.maxschuster.dataurl.DataUrl;
import eu.maxschuster.dataurl.DataUrlSerializer;
import eu.maxschuster.dataurl.IDataUrlSerializer;
import com.xmzs.midjourney.enums.ActionType;
import com.xmzs.midjourney.util.*;
import com.xmzs.system.service.IChatCostService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import okhttp3.Request;
import org.apache.commons.lang3.math.NumberUtils;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.IOException;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import okhttp3.*;
import java.util.Optional;
@Api(tags = "任务提交")
@RestController
@@ -63,285 +26,114 @@ import okhttp3.*;
@RequiredArgsConstructor
@Slf4j
public class SubmitController {
private final TranslateService translateService;
private final ProxyProperties properties;
private final TaskService taskService;
private final TaskStoreService taskStoreService;
@Value("${chat.apiKey}")
private String apiKey;
@Value("${chat.apiHost}")
private String apiHost;
@Autowired
private IChatService chatService;
@Autowired
private IChatMessageService chatMessageService;
@Autowired
private ISseService sseService;
@ApiOperation(value = "提交Imagine任务")
@PostMapping("/imagine")
public SubmitResultVO imagine(@RequestBody SubmitImagineDTO imagineDTO) {
String prompt = imagineDTO.getPrompt();
if (CharSequenceUtil.isBlank(prompt)) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "prompt不能为空");
}
prompt = prompt.trim();
Task task = newTask(imagineDTO);
task.setAction(TaskAction.IMAGINE);
task.setPrompt(prompt);
String promptEn = translatePrompt(prompt);
try {
BannedPromptUtils.checkBanned(promptEn);
} catch (BannedPromptException e) {
return SubmitResultVO.fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词")
.setProperty("promptEn", promptEn).setProperty("bannedWord", e.getMessage());
}
List<String> base64Array = Optional.ofNullable(imagineDTO.getBase64Array()).orElse(new ArrayList<>());
if (CharSequenceUtil.isNotBlank(imagineDTO.getBase64())) {
base64Array.add(imagineDTO.getBase64());
}
List<DataUrl> dataUrls;
try {
dataUrls = ConvertUtils.convertBase64Array(base64Array);
} catch (MalformedURLException e) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
}
task.setPromptEn(promptEn);
task.setDescription("/imagine " + prompt);
return this.taskService.submitImagine(task, dataUrls);
private final MjPriceConfig priceConfig;
private final IChatCostService chatCostService;
private final MjOkHttpUtil mjOkHttpUtil;
@ApiOperation(value = "绘图变化")
@PostMapping("/change")
public String change(@RequestBody SubmitChangeDTO changeDTO) {
String jsonStr = JSONUtil.toJsonStr(changeDTO);
String url = "mj/submit/change";
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
return mjOkHttpUtil.executeRequest(request);
}
@ApiOperation(value = "执行动作")
@PostMapping("/action")
public String action(@RequestBody SubmitActionDTO changeDTO) {
ActionType actionType = ActionType.fromCustomId(getAction(changeDTO.getCustomId()));
Optional.ofNullable(actionType).ifPresentOrElse(
type -> {
switch (type) {
case UP_SAMPLE:
chatCostService.taskDeduct("mj","放大", NumberUtils.toDouble(priceConfig.getUpsample(), 0.1));
break;
case IN_PAINT:
// 局部重绘已经扣费,不执行任何操作
break;
default:
chatCostService.taskDeduct("mj","变化", NumberUtils.toDouble(priceConfig.getChange(), 0.3));
break;
}
},
() -> chatCostService.taskDeduct("mj","变化", NumberUtils.toDouble(priceConfig.getChange(), 0.3))
);
String jsonStr = JSONUtil.toJsonStr(changeDTO);
String url = "mj/submit/action";
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
return mjOkHttpUtil.executeRequest(request);
}
@ApiOperation(value = "绘图变化-simple")
@PostMapping("/simple-change")
public SubmitResultVO simpleChange(@RequestBody SubmitSimpleChangeDTO simpleChangeDTO) {
TaskChangeParams changeParams = ConvertUtils.convertChangeParams(simpleChangeDTO.getContent());
if (changeParams == null) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "content参数错误");
}
SubmitChangeDTO changeDTO = new SubmitChangeDTO();
changeDTO.setAction(changeParams.getAction());
changeDTO.setTaskId(changeParams.getId());
changeDTO.setIndex(changeParams.getIndex());
changeDTO.setState(simpleChangeDTO.getState());
changeDTO.setNotifyHook(simpleChangeDTO.getNotifyHook());
return change(changeDTO);
public String simpleChange(@RequestBody SubmitSimpleChangeDTO simpleChangeDTO) {
String jsonStr = JSONUtil.toJsonStr(simpleChangeDTO);
String url = "mj/submit/simple-change";
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
return mjOkHttpUtil.executeRequest(request);
}
@ApiOperation(value = "绘图变化")
@PostMapping("/change")
public SubmitResultVO change(@RequestBody SubmitChangeDTO changeDTO) {
if (CharSequenceUtil.isBlank(changeDTO.getTaskId())) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "taskId不能为空");
}
if (!Set.of(TaskAction.UPSCALE, TaskAction.VARIATION, TaskAction.REROLL).contains(changeDTO.getAction())) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "action参数错误");
}
String description = "/up " + changeDTO.getTaskId();
if (TaskAction.REROLL.equals(changeDTO.getAction())) {
description += " R";
} else {
description += " " + changeDTO.getAction().name().charAt(0) + changeDTO.getIndex();
}
if (TaskAction.UPSCALE.equals(changeDTO.getAction())) {
TaskCondition condition = new TaskCondition().setDescription(description);
Task existTask = this.taskStoreService.findOne(condition);
if (existTask != null) {
return SubmitResultVO.of(ReturnCode.EXISTED, "任务已存在", existTask.getId())
.setProperty("status", existTask.getStatus())
.setProperty("imageUrl", existTask.getImageUrl());
}
}
Task targetTask = this.taskStoreService.get(changeDTO.getTaskId());
if (targetTask == null) {
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "关联任务不存在或已失效");
}
if (!TaskStatus.SUCCESS.equals(targetTask.getStatus())) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务状态错误");
}
if (!Set.of(TaskAction.IMAGINE, TaskAction.VARIATION, TaskAction.REROLL, TaskAction.BLEND).contains(targetTask.getAction())) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务不允许执行变化");
}
Task task = newTask(changeDTO);
task.setAction(changeDTO.getAction());
task.setPrompt(targetTask.getPrompt());
task.setPromptEn(targetTask.getPromptEn());
task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, targetTask.getProperty(Constants.TASK_PROPERTY_FINAL_PROMPT));
task.setProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_MESSAGE_ID));
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID));
task.setDescription(description);
int messageFlags = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_FLAGS);
String messageId = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_ID);
String messageHash = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_HASH);
if (TaskAction.UPSCALE.equals(changeDTO.getAction())) {
return this.taskService.submitUpscale(task, messageId, messageHash, changeDTO.getIndex(), messageFlags);
} else if (TaskAction.VARIATION.equals(changeDTO.getAction())) {
return this.taskService.submitVariation(task, messageId, messageHash, changeDTO.getIndex(), messageFlags);
} else {
return this.taskService.submitReroll(task, messageId, messageHash, messageFlags);
}
@ApiOperation(value = "提交图生图、混图任务")
@PostMapping("/blend")
public String blend(@RequestBody SubmitBlendDTO blendDTO) {
chatCostService.taskDeduct("mj","图生图", NumberUtils.toDouble(priceConfig.getBlend(), 0.3));
String jsonStr = JSONUtil.toJsonStr(blendDTO);
String url = "mj/submit/blend";
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
return mjOkHttpUtil.executeRequest(request);
}
@ApiOperation(value = "绘图变化")
@PostMapping("/action")
public String action(@RequestBody SubmitActionDTO changeDTO) {
// 查询是否是付费用户
sseService.checkUserGrade();
// 扣除接口费用
if ("upsample".equals(getAction(changeDTO.getCustomId()))) {
mjTaskDeduct("放大", OpenAIConst.MJ_COST_TYPE2);
} else {
// Inpaint: 局部重绘
// reroll 重绘
// upsample 放大
// zoom 变焦
// upscale 高清放大
// variation 变化
if (!"Inpaint".equals(getAction(changeDTO.getCustomId()))) {
mjTaskDeduct("变化", OpenAIConst.MJ_COST_TYPE1);
}
}
OkHttpClient client = new OkHttpClient.Builder()
.connectTimeout(30, TimeUnit.SECONDS) // 连接超时时间
.writeTimeout(30, TimeUnit.SECONDS) // 写入超时时间
.readTimeout(30, TimeUnit.SECONDS) // 读取超时时间
.build();
@ApiOperation(value = "提交图生文任务")
@PostMapping("/describe")
public String describe(@RequestBody SubmitDescribeDTO describeDTO) {
chatCostService.taskDeduct("mj","图生文", NumberUtils.toDouble(priceConfig.getDescribe(), 0.1));
String jsonStr = JSONUtil.toJsonStr(describeDTO);
String url = "mj/submit/describe";
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
return mjOkHttpUtil.executeRequest(request);
}
String jsonStr = JSONUtil.toJsonStr(changeDTO);
@ApiOperation(value = "提交文生图任务")
@PostMapping("/imagine")
public String imagine(@RequestBody SubmitImagineDTO imagineDTO) {
chatCostService.taskDeduct("mj",imagineDTO.getPrompt(), NumberUtils.toDouble(priceConfig.getImagine(), 0.3));
String jsonStr = JSONUtil.toJsonStr(imagineDTO);
String url = "mj/submit/imagine";
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
return mjOkHttpUtil.executeRequest(request);
}
MediaType mediaType = MediaType.parse("application/json");
okhttp3.RequestBody body = okhttp3.RequestBody.create(jsonStr, mediaType);
Request request = new Request.Builder()
.url(apiHost + "mj/submit/action")
.method("POST", body)
.header("mj-api-secret", apiKey) // 设置Authorization header
.build();
try {
Response response = client.newCall(request).execute();
return response.body().string();
} catch (IOException e) {
log.error("绘图变化失败:{}", e.getMessage());
}
return null;
@ApiOperation(value = "提交局部重绘任务")
@PostMapping("/modal")
public String modal(@RequestBody SubmitModalDTO submitModalDTO) {
chatCostService.taskDeduct("mj","局部重绘", NumberUtils.toDouble(priceConfig.getInpaint(), 0.1));
String jsonStr = JSONUtil.toJsonStr(submitModalDTO);
String url = "mj/submit/modal";
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
return mjOkHttpUtil.executeRequest(request);
}
@ApiOperation(value = "提交提示词分析任务")
@PostMapping("/shorten")
public String shorten(@RequestBody SubmitShortenDTO submitShortenDTO) {
chatCostService.taskDeduct("mj","提示词分析", NumberUtils.toDouble(priceConfig.getShorten(), 0.1));
String jsonStr = JSONUtil.toJsonStr(submitShortenDTO);
String url = "mj/submit/shorten";
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
return mjOkHttpUtil.executeRequest(request);
}
public String getAction(String customId) {
// 检查 customId 是否为空
if(customId == null || customId.isEmpty()) {
if (customId == null || customId.isEmpty()) {
return null;
}
// 使用 "::" 分割字符串
String[] parts = customId.split("::");
// "MJ", "Inpaint", "1", "4fca7c14-181c-4...", "SOLO"
if(customId.endsWith("SOLO")) {
return parts[1];
}
// 返回 "upsample" 值,假设它总是在第三个位置
return parts[2];
}
public void mjTaskDeduct(String prompt, double cost) {
//扣除费用
chatService.deductUserBalance(getUserId(), cost);
// 保存消息记录
ChatMessageBo chatMessageBo = new ChatMessageBo();
chatMessageBo.setUserId(getUserId());
chatMessageBo.setModelName("mj");
chatMessageBo.setContent(prompt);
chatMessageBo.setDeductCost(cost);
chatMessageBo.setTotalTokens(0);
chatMessageService.insertByBo(chatMessageBo);
}
/**
* 获取用户Id
*
* @return
*/
public Long getUserId() {
LoginUser loginUser = LoginHelper.getLoginUser();
if (loginUser == null) {
throw new BaseException("用户未登录!");
}
return loginUser.getUserId();
}
@ApiOperation(value = "提交Describe任务")
@PostMapping("/describe")
public SubmitResultVO describe(@RequestBody SubmitDescribeDTO describeDTO) {
if (CharSequenceUtil.isBlank(describeDTO.getBase64())) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64不能为空");
}
IDataUrlSerializer serializer = new DataUrlSerializer();
DataUrl dataUrl;
try {
dataUrl = serializer.unserialize(describeDTO.getBase64());
} catch (MalformedURLException e) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
}
Task task = newTask(describeDTO);
task.setAction(TaskAction.DESCRIBE);
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
task.setDescription("/describe " + taskFileName);
return this.taskService.submitDescribe(task, dataUrl);
}
@ApiOperation(value = "提交Blend任务")
@PostMapping("/blend")
public SubmitResultVO blend(@RequestBody SubmitBlendDTO blendDTO) {
List<String> base64Array = blendDTO.getBase64Array();
if (base64Array == null || base64Array.size() < 2 || base64Array.size() > 5) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64List参数错误");
}
if (blendDTO.getDimensions() == null) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "dimensions参数错误");
}
IDataUrlSerializer serializer = new DataUrlSerializer();
List<DataUrl> dataUrlList = new ArrayList<>();
try {
for (String base64 : base64Array) {
DataUrl dataUrl = serializer.unserialize(base64);
dataUrlList.add(dataUrl);
}
} catch (MalformedURLException e) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
}
Task task = newTask(blendDTO);
task.setAction(TaskAction.BLEND);
task.setDescription("/blend " + task.getId() + " " + dataUrlList.size());
return this.taskService.submitBlend(task, dataUrlList, blendDTO.getDimensions());
}
private Task newTask(BaseSubmitDTO base) {
Task task = new Task();
task.setId(System.currentTimeMillis() + RandomUtil.randomNumbers(3));
task.setSubmitTime(System.currentTimeMillis());
task.setState(base.getState());
String notifyHook = CharSequenceUtil.isBlank(base.getNotifyHook()) ? this.properties.getNotifyHook() : base.getNotifyHook();
task.setProperty(Constants.TASK_PROPERTY_NOTIFY_HOOK, notifyHook);
task.setProperty(Constants.TASK_PROPERTY_NONCE, SnowFlake.INSTANCE.nextId());
return task;
}
private String translatePrompt(String prompt) {
if (TranslateWay.NULL.equals(this.properties.getTranslateWay()) || CharSequenceUtil.isBlank(prompt)) {
return prompt;
}
List<String> imageUrls = new ArrayList<>();
Matcher imageMatcher = Pattern.compile("https?://[a-z0-9-_:@&?=+,.!/~*'%$]+\\x20+", Pattern.CASE_INSENSITIVE).matcher(prompt);
while (imageMatcher.find()) {
imageUrls.add(imageMatcher.group(0));
}
String paramStr = "";
Matcher paramMatcher = Pattern.compile("\\x20+-{1,2}[a-z]+.*$", Pattern.CASE_INSENSITIVE).matcher(prompt);
if (paramMatcher.find()) {
paramStr = paramMatcher.group(0);
}
String imageStr = CharSequenceUtil.join("", imageUrls);
String text = prompt.substring(imageStr.length(), prompt.length() - paramStr.length());
if (CharSequenceUtil.isNotBlank(text)) {
text = this.translateService.translateToEnglish(text).trim();
}
return imageStr + text + paramStr;
return customId.endsWith("SOLO") ? parts[1] : parts[2];
}
}

View File

@@ -1,32 +1,15 @@
package com.xmzs.midjourney.controller;
import cn.hutool.core.comparator.CompareUtil;
import cn.hutool.json.JSONUtil;
import com.xmzs.midjourney.dto.SubmitImagineDTO;
import com.xmzs.midjourney.util.MjOkHttpUtil;
import com.xmzs.midjourney.dto.TaskConditionDTO;
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
import com.xmzs.midjourney.result.SubmitResultVO;
import com.xmzs.midjourney.service.TaskStoreService;
import com.xmzs.midjourney.support.Task;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import io.swagger.annotations.ApiParam;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.IOException;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import okhttp3.Request;
import org.springframework.web.bind.annotation.*;
@Api(tags = "任务查询")
@RestController
@@ -34,58 +17,32 @@ import java.util.Objects;
@RequiredArgsConstructor
@Slf4j
public class TaskController {
private final TaskStoreService taskStoreService;
private final DiscordLoadBalancer discordLoadBalancer;
@Value("${chat.apiKey}")
private String apiKey;
@Value("${chat.apiHost}")
private String apiHost;
private final MjOkHttpUtil mjOkHttpUtil;
@ApiOperation(value = "指定ID获取任务")
@GetMapping("/{id}/fetch")
public String fetch(@ApiParam(value = "任务ID") @PathVariable String id) {
OkHttpClient client = new OkHttpClient();
// 创建一个Request对象来配置你的请求
Request request = new Request.Builder()
.header("mj-api-secret", apiKey) // 设置Authorization header
.url(apiHost+"mj/task/" + id + "/fetch")
.build();
try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);
if (response.body() != null) {
return response.body().string();
}
} catch (IOException e) {
log.error("任务:{}查询失败:{}",id,e.getMessage());
}
return null;
String url = "mj/task/" + id + "/fetch";
Request request = mjOkHttpUtil.createGetRequest(url);
return mjOkHttpUtil.executeRequest(request);
}
@ApiOperation(value = "查询任务队列")
@GetMapping("/queue")
public List<Task> queue() {
return this.discordLoadBalancer.getQueueTaskIds().stream()
.map(this.taskStoreService::get).filter(Objects::nonNull)
.sorted(Comparator.comparing(Task::getSubmitTime))
.toList();
}
@ApiOperation(value = "查询所有任务")
@GetMapping("/list")
public List<Task> list() {
return this.taskStoreService.list().stream()
.sorted((t1, t2) -> CompareUtil.compare(t2.getSubmitTime(), t1.getSubmitTime()))
.toList();
}
@ApiOperation(value = "根据ID列表查询任务")
@PostMapping("/list-by-condition")
public List<Task> listByIds(@RequestBody TaskConditionDTO conditionDTO) {
if (conditionDTO.getIds() == null) {
return Collections.emptyList();
}
return conditionDTO.getIds().stream().map(this.taskStoreService::get).filter(Objects::nonNull).toList();
public String listByIds(@RequestBody TaskConditionDTO conditionDTO) {
String url = "mj/task/list-by-condition";
String conditionJson = JSONUtil.toJsonStr(conditionDTO);
Request request = mjOkHttpUtil.createPostRequest(url,conditionJson);
return mjOkHttpUtil.executeRequest(request);
}
@ApiOperation(value = "获取任务图片的seed")
@GetMapping("/{id}/image-seed")
public String getSeed(@ApiParam(value = "任务ID") @PathVariable String id) {
String url = "mj/task/" + id + "/image-seed";
Request request = mjOkHttpUtil.createGetRequest(url);
return mjOkHttpUtil.executeRequest(request);
}
}

View File

@@ -0,0 +1,56 @@
package com.xmzs.midjourney.domain;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
/**
* 绘画费用信息
*
* @author Admin
*/
@Data
@Component
@ConfigurationProperties(prefix = "mj")
public class MjPriceConfig {
/**
* 放大图像
*/
private String upsample;
/**
* 变化
*/
private String change;
/**
* 图生图
*/
private String blend;
/**
* 图生文
*/
private String describe;
/**
* 文生图
*/
private String imagine;
/**
* 局部重绘
*/
private String inpaint;
/**
* 提示词分析
*/
private String shorten;
/**
* 换脸
*/
private String faceSwapping;
}

View File

@@ -1,10 +1,7 @@
package com.xmzs.midjourney.dto;
import com.xmzs.midjourney.enums.TaskAction;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data

View File

@@ -0,0 +1,19 @@
package com.xmzs.midjourney.dto;
import io.swagger.annotations.ApiModel;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = true)
@ApiModel("局部重绘提交参数")
public class SubmitModalDTO extends BaseSubmitDTO{
private String maskBase64;
private String taskId;
private String prompt;
}

View File

@@ -0,0 +1,17 @@
package com.xmzs.midjourney.dto;
import io.swagger.annotations.ApiModel;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = true)
@ApiModel("prompt分析提交参数")
public class SubmitShortenDTO extends BaseSubmitDTO{
private String botType;
private String prompt;
}

View File

@@ -0,0 +1,32 @@
package com.xmzs.midjourney.enums;
import lombok.Getter;
/**
* @author WangLe
*/
@Getter
public enum ActionType {
IN_PAINT("Inpaint"), // 局部重绘操作
RE_ROLL("reroll"), // 重绘操作
UP_SAMPLE("upsample"), // 放大操作
ZOOM("zoom"), // 变焦操作
UPSCALE("upscale"), // 高清放大操作
VARIATION("variation"); // 变化操作
private final String action;
ActionType(String action) {
this.action = action;
}
public static ActionType fromCustomId(String customId) {
for (ActionType type : values()) {
if (type.getAction().equals(customId)) {
return type;
}
}
return null;
}
}

View File

@@ -1,6 +1,9 @@
package com.xmzs.midjourney.enums;
import lombok.Getter;
@Getter
public enum BlendDimensions {
PORTRAIT("2:3"),
@@ -15,7 +18,4 @@ public enum BlendDimensions {
this.value = value;
}
public String getValue() {
return this.value;
}
}

View File

@@ -0,0 +1,65 @@
package com.xmzs.midjourney.util;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.TimeUnit;
/**
* @author WangLe
*/
@RequiredArgsConstructor
@Component
@Slf4j
public class MjOkHttpUtil {
@Value("${chat.apiKey}")
private List<String> apiKey;
@Value("${chat.apiHost}")
private String apiHost;
private static final String API_SECRET_HEADER = "mj-api-secret";
private final OkHttpClient client = new OkHttpClient.Builder()
.connectTimeout(300, TimeUnit.SECONDS)
.writeTimeout(300, TimeUnit.SECONDS)
.readTimeout(300, TimeUnit.SECONDS)
.build();
public String executeRequest(Request request) {
try (Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) {
throw new IOException("Unexpected code " + response);
}
return response.body() != null ? response.body().string() : null;
} catch (IOException e) {
// 这里应根据实际情况使用适当的日志记录方式
log.error("请求失败: {}",e.getMessage());
return null;
}
}
public Request createPostRequest(String url, String json) {
MediaType JSON = MediaType.get("application/json; charset=utf-8");
RequestBody body = RequestBody.create(json, JSON);
return new Request.Builder()
.url(apiHost + url)
.post(body)
.header(API_SECRET_HEADER, apiKey.get(0))
.build();
}
public Request createGetRequest(String url) {
return new Request.Builder()
.url(apiHost + url)
.header(API_SECRET_HEADER, apiKey.get(0))
.build();
}
}