mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-13 11:53:48 +00:00
init v1.0.0
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
package com.xmzs.midjourney;
|
||||
|
||||
import lombok.experimental.UtilityClass;
|
||||
|
||||
@UtilityClass
|
||||
public final class Constants {
|
||||
// 任务扩展属性 start
|
||||
public static final String TASK_PROPERTY_NOTIFY_HOOK = "notifyHook";
|
||||
public static final String TASK_PROPERTY_FINAL_PROMPT = "finalPrompt";
|
||||
public static final String TASK_PROPERTY_MESSAGE_ID = "messageId";
|
||||
public static final String TASK_PROPERTY_MESSAGE_HASH = "messageHash";
|
||||
public static final String TASK_PROPERTY_PROGRESS_MESSAGE_ID = "progressMessageId";
|
||||
public static final String TASK_PROPERTY_FLAGS = "flags";
|
||||
public static final String TASK_PROPERTY_NONCE = "nonce";
|
||||
public static final String TASK_PROPERTY_DISCORD_INSTANCE_ID = "discordInstanceId";
|
||||
// 任务扩展属性 end
|
||||
|
||||
public static final String API_SECRET_HEADER_NAME = "mj-api-secret";
|
||||
public static final String DEFAULT_DISCORD_USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36";
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.xmzs.midjourney;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.context.annotation.Import;
|
||||
import org.springframework.scheduling.annotation.EnableScheduling;
|
||||
import spring.config.BeanConfig;
|
||||
import spring.config.WebMvcConfig;
|
||||
|
||||
@EnableScheduling
|
||||
@SpringBootApplication
|
||||
@Import({BeanConfig.class, WebMvcConfig.class})
|
||||
public class ProxyApplication {
|
||||
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(ProxyApplication.class, args);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
package com.xmzs.midjourney;
|
||||
|
||||
import com.xmzs.midjourney.enums.TranslateWay;
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Component
|
||||
@ConfigurationProperties(prefix = "mj")
|
||||
public class ProxyProperties {
|
||||
/**
|
||||
* task存储配置.
|
||||
*/
|
||||
private final TaskStore taskStore = new TaskStore();
|
||||
/**
|
||||
* discord账号选择规则.
|
||||
*/
|
||||
private String accountChooseRule = "BestWaitIdleRule";
|
||||
/**
|
||||
* discord单账号配置.
|
||||
*/
|
||||
private final DiscordAccountConfig discord = new DiscordAccountConfig();
|
||||
/**
|
||||
* discord账号池配置.
|
||||
*/
|
||||
private final List<DiscordAccountConfig> accounts = new ArrayList<>();
|
||||
/**
|
||||
* 代理配置.
|
||||
*/
|
||||
private final ProxyConfig proxy = new ProxyConfig();
|
||||
/**
|
||||
* 反代配置.
|
||||
*/
|
||||
private final NgDiscordConfig ngDiscord = new NgDiscordConfig();
|
||||
/**
|
||||
* 百度翻译配置.
|
||||
*/
|
||||
private final BaiduTranslateConfig baiduTranslate = new BaiduTranslateConfig();
|
||||
/**
|
||||
* openai配置.
|
||||
*/
|
||||
private final OpenaiConfig openai = new OpenaiConfig();
|
||||
/**
|
||||
* 中文prompt翻译方式.
|
||||
*/
|
||||
private TranslateWay translateWay = TranslateWay.NULL;
|
||||
/**
|
||||
* 接口密钥,为空不启用鉴权;调用接口时需要加请求头 mj-api-secret.
|
||||
*/
|
||||
private String apiSecret;
|
||||
/**
|
||||
* 任务状态变更回调地址.
|
||||
*/
|
||||
private String notifyHook;
|
||||
/**
|
||||
* 通知回调线程池大小.
|
||||
*/
|
||||
private int notifyPoolSize = 10;
|
||||
|
||||
@Data
|
||||
public static class DiscordAccountConfig {
|
||||
/**
|
||||
* 服务器ID.
|
||||
*/
|
||||
private String guildId;
|
||||
/**
|
||||
* 频道ID.
|
||||
*/
|
||||
private String channelId;
|
||||
/**
|
||||
* 用户Token.
|
||||
*/
|
||||
private String userToken;
|
||||
/**
|
||||
* 用户UserAgent.
|
||||
*/
|
||||
private String userAgent = Constants.DEFAULT_DISCORD_USER_AGENT;
|
||||
/**
|
||||
* 是否可用.
|
||||
*/
|
||||
private boolean enable = true;
|
||||
/**
|
||||
* 并发数.
|
||||
*/
|
||||
private int coreSize = 3;
|
||||
/**
|
||||
* 等待队列长度.
|
||||
*/
|
||||
private int queueSize = 10;
|
||||
/**
|
||||
* 任务超时时间(分钟).
|
||||
*/
|
||||
private int timeoutMinutes = 5;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class BaiduTranslateConfig {
|
||||
/**
|
||||
* 百度翻译的APP_ID.
|
||||
*/
|
||||
private String appid;
|
||||
/**
|
||||
* 百度翻译的密钥.
|
||||
*/
|
||||
private String appSecret;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class OpenaiConfig {
|
||||
/**
|
||||
* 自定义gpt的api-url.
|
||||
*/
|
||||
private String gptApiUrl;
|
||||
/**
|
||||
* gpt的api-key.
|
||||
*/
|
||||
private String gptApiKey;
|
||||
/**
|
||||
* 超时时间.
|
||||
*/
|
||||
private Duration timeout = Duration.ofSeconds(30);
|
||||
/**
|
||||
* 使用的模型.
|
||||
*/
|
||||
private String model = "gpt-3.5-turbo";
|
||||
/**
|
||||
* 返回结果的最大分词数.
|
||||
*/
|
||||
private int maxTokens = 2048;
|
||||
/**
|
||||
* 相似度,取值 0-2.
|
||||
*/
|
||||
private double temperature = 0;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class TaskStore {
|
||||
/**
|
||||
* 任务过期时间,默认30天.
|
||||
*/
|
||||
private Duration timeout = Duration.ofDays(30);
|
||||
/**
|
||||
* 任务存储方式: redis(默认)、in_memory.
|
||||
*/
|
||||
private Type type = Type.IN_MEMORY;
|
||||
|
||||
public enum Type {
|
||||
/**
|
||||
* redis.
|
||||
*/
|
||||
REDIS,
|
||||
/**
|
||||
* in_memory.
|
||||
*/
|
||||
IN_MEMORY
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class ProxyConfig {
|
||||
/**
|
||||
* 代理host.
|
||||
*/
|
||||
private String host;
|
||||
/**
|
||||
* 代理端口.
|
||||
*/
|
||||
private Integer port;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class NgDiscordConfig {
|
||||
/**
|
||||
* https://discord.com 反代.
|
||||
*/
|
||||
private String server;
|
||||
/**
|
||||
* https://cdn.discordapp.com 反代.
|
||||
*/
|
||||
private String cdn;
|
||||
/**
|
||||
* wss://gateway.discord.gg 反代.
|
||||
*/
|
||||
private String wss;
|
||||
/**
|
||||
* https://discord-attachments-uploads-prd.storage.googleapis.com 反代.
|
||||
*/
|
||||
private String uploadServer;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class TaskQueueConfig {
|
||||
/**
|
||||
* 并发数.
|
||||
*/
|
||||
private int coreSize = 3;
|
||||
/**
|
||||
* 等待队列长度.
|
||||
*/
|
||||
private int queueSize = 10;
|
||||
/**
|
||||
* 任务超时时间(分钟).
|
||||
*/
|
||||
private int timeoutMinutes = 5;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package com.xmzs.midjourney;
|
||||
|
||||
import lombok.experimental.UtilityClass;
|
||||
|
||||
@UtilityClass
|
||||
public final class ReturnCode {
|
||||
/**
|
||||
* 成功.
|
||||
*/
|
||||
public static final int SUCCESS = 1;
|
||||
/**
|
||||
* 数据未找到.
|
||||
*/
|
||||
public static final int NOT_FOUND = 3;
|
||||
/**
|
||||
* 校验错误.
|
||||
*/
|
||||
public static final int VALIDATION_ERROR = 4;
|
||||
/**
|
||||
* 系统异常.
|
||||
*/
|
||||
public static final int FAILURE = 9;
|
||||
|
||||
/**
|
||||
* 已存在.
|
||||
*/
|
||||
public static final int EXISTED = 21;
|
||||
/**
|
||||
* 排队中.
|
||||
*/
|
||||
public static final int IN_QUEUE = 22;
|
||||
/**
|
||||
* 队列已满.
|
||||
*/
|
||||
public static final int QUEUE_REJECTED = 23;
|
||||
/**
|
||||
* prompt包含敏感词.
|
||||
*/
|
||||
public static final int BANNED_PROMPT = 24;
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.xmzs.midjourney.controller;
|
||||
|
||||
import com.xmzs.midjourney.domain.DiscordAccount;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
|
||||
import io.swagger.annotations.Api;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import io.swagger.annotations.ApiParam;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Api(tags = "账号查询")
|
||||
@RestController
|
||||
@RequestMapping("/mj/account")
|
||||
@RequiredArgsConstructor
|
||||
public class AccountController {
|
||||
private final DiscordLoadBalancer loadBalancer;
|
||||
|
||||
@ApiOperation(value = "指定ID获取账号")
|
||||
@GetMapping("/{id}/fetch")
|
||||
public DiscordAccount fetch(@ApiParam(value = "账号ID") @PathVariable String id) {
|
||||
DiscordInstance instance = this.loadBalancer.getDiscordInstance(id);
|
||||
return instance == null ? null : instance.account();
|
||||
}
|
||||
|
||||
@ApiOperation(value = "查询所有账号")
|
||||
@GetMapping("/list")
|
||||
public List<DiscordAccount> list() {
|
||||
return this.loadBalancer.getAllInstances().stream().map(DiscordInstance::account).toList();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,240 @@
|
||||
package com.xmzs.midjourney.controller;
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import cn.hutool.core.util.RandomUtil;
|
||||
import com.xmzs.midjourney.Constants;
|
||||
import com.xmzs.midjourney.ProxyProperties;
|
||||
import com.xmzs.midjourney.ReturnCode;
|
||||
import com.xmzs.midjourney.dto.BaseSubmitDTO;
|
||||
import com.xmzs.midjourney.dto.SubmitBlendDTO;
|
||||
import com.xmzs.midjourney.dto.SubmitChangeDTO;
|
||||
import com.xmzs.midjourney.dto.SubmitDescribeDTO;
|
||||
import com.xmzs.midjourney.dto.SubmitImagineDTO;
|
||||
import com.xmzs.midjourney.dto.SubmitSimpleChangeDTO;
|
||||
import com.xmzs.midjourney.enums.TaskAction;
|
||||
import com.xmzs.midjourney.enums.TaskStatus;
|
||||
import com.xmzs.midjourney.enums.TranslateWay;
|
||||
import com.xmzs.midjourney.exception.BannedPromptException;
|
||||
import com.xmzs.midjourney.result.SubmitResultVO;
|
||||
import com.xmzs.midjourney.service.TaskService;
|
||||
import com.xmzs.midjourney.service.TaskStoreService;
|
||||
import com.xmzs.midjourney.service.TranslateService;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import com.xmzs.midjourney.support.TaskCondition;
|
||||
import com.xmzs.midjourney.util.BannedPromptUtils;
|
||||
import com.xmzs.midjourney.util.ConvertUtils;
|
||||
import com.xmzs.midjourney.util.MimeTypeUtils;
|
||||
import com.xmzs.midjourney.util.SnowFlake;
|
||||
import com.xmzs.midjourney.util.TaskChangeParams;
|
||||
import eu.maxschuster.dataurl.DataUrl;
|
||||
import eu.maxschuster.dataurl.DataUrlSerializer;
|
||||
import eu.maxschuster.dataurl.IDataUrlSerializer;
|
||||
import io.swagger.annotations.Api;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.net.MalformedURLException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
@Api(tags = "任务提交")
|
||||
@RestController
|
||||
@RequestMapping("/mj/submit")
|
||||
@RequiredArgsConstructor
|
||||
public class SubmitController {
|
||||
private final TranslateService translateService;
|
||||
private final TaskStoreService taskStoreService;
|
||||
private final ProxyProperties properties;
|
||||
private final TaskService taskService;
|
||||
|
||||
@ApiOperation(value = "提交Imagine任务")
|
||||
@PostMapping("/imagine")
|
||||
public SubmitResultVO imagine(@RequestBody SubmitImagineDTO imagineDTO) {
|
||||
String prompt = imagineDTO.getPrompt();
|
||||
if (CharSequenceUtil.isBlank(prompt)) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "prompt不能为空");
|
||||
}
|
||||
prompt = prompt.trim();
|
||||
Task task = newTask(imagineDTO);
|
||||
task.setAction(TaskAction.IMAGINE);
|
||||
task.setPrompt(prompt);
|
||||
String promptEn = translatePrompt(prompt);
|
||||
try {
|
||||
BannedPromptUtils.checkBanned(promptEn);
|
||||
} catch (BannedPromptException e) {
|
||||
return SubmitResultVO.fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词")
|
||||
.setProperty("promptEn", promptEn).setProperty("bannedWord", e.getMessage());
|
||||
}
|
||||
List<String> base64Array = Optional.ofNullable(imagineDTO.getBase64Array()).orElse(new ArrayList<>());
|
||||
if (CharSequenceUtil.isNotBlank(imagineDTO.getBase64())) {
|
||||
base64Array.add(imagineDTO.getBase64());
|
||||
}
|
||||
List<DataUrl> dataUrls;
|
||||
try {
|
||||
dataUrls = ConvertUtils.convertBase64Array(base64Array);
|
||||
} catch (MalformedURLException e) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
|
||||
}
|
||||
task.setPromptEn(promptEn);
|
||||
task.setDescription("/imagine " + prompt);
|
||||
return this.taskService.submitImagine(task, dataUrls);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "绘图变化-simple")
|
||||
@PostMapping("/simple-change")
|
||||
public SubmitResultVO simpleChange(@RequestBody SubmitSimpleChangeDTO simpleChangeDTO) {
|
||||
TaskChangeParams changeParams = ConvertUtils.convertChangeParams(simpleChangeDTO.getContent());
|
||||
if (changeParams == null) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "content参数错误");
|
||||
}
|
||||
SubmitChangeDTO changeDTO = new SubmitChangeDTO();
|
||||
changeDTO.setAction(changeParams.getAction());
|
||||
changeDTO.setTaskId(changeParams.getId());
|
||||
changeDTO.setIndex(changeParams.getIndex());
|
||||
changeDTO.setState(simpleChangeDTO.getState());
|
||||
changeDTO.setNotifyHook(simpleChangeDTO.getNotifyHook());
|
||||
return change(changeDTO);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "绘图变化")
|
||||
@PostMapping("/change")
|
||||
public SubmitResultVO change(@RequestBody SubmitChangeDTO changeDTO) {
|
||||
if (CharSequenceUtil.isBlank(changeDTO.getTaskId())) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "taskId不能为空");
|
||||
}
|
||||
if (!Set.of(TaskAction.UPSCALE, TaskAction.VARIATION, TaskAction.REROLL).contains(changeDTO.getAction())) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "action参数错误");
|
||||
}
|
||||
String description = "/up " + changeDTO.getTaskId();
|
||||
if (TaskAction.REROLL.equals(changeDTO.getAction())) {
|
||||
description += " R";
|
||||
} else {
|
||||
description += " " + changeDTO.getAction().name().charAt(0) + changeDTO.getIndex();
|
||||
}
|
||||
if (TaskAction.UPSCALE.equals(changeDTO.getAction())) {
|
||||
TaskCondition condition = new TaskCondition().setDescription(description);
|
||||
Task existTask = this.taskStoreService.findOne(condition);
|
||||
if (existTask != null) {
|
||||
return SubmitResultVO.of(ReturnCode.EXISTED, "任务已存在", existTask.getId())
|
||||
.setProperty("status", existTask.getStatus())
|
||||
.setProperty("imageUrl", existTask.getImageUrl());
|
||||
}
|
||||
}
|
||||
Task targetTask = this.taskStoreService.get(changeDTO.getTaskId());
|
||||
if (targetTask == null) {
|
||||
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "关联任务不存在或已失效");
|
||||
}
|
||||
if (!TaskStatus.SUCCESS.equals(targetTask.getStatus())) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务状态错误");
|
||||
}
|
||||
if (!Set.of(TaskAction.IMAGINE, TaskAction.VARIATION, TaskAction.REROLL, TaskAction.BLEND).contains(targetTask.getAction())) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务不允许执行变化");
|
||||
}
|
||||
Task task = newTask(changeDTO);
|
||||
task.setAction(changeDTO.getAction());
|
||||
task.setPrompt(targetTask.getPrompt());
|
||||
task.setPromptEn(targetTask.getPromptEn());
|
||||
task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, targetTask.getProperty(Constants.TASK_PROPERTY_FINAL_PROMPT));
|
||||
task.setProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_MESSAGE_ID));
|
||||
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID));
|
||||
task.setDescription(description);
|
||||
int messageFlags = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_FLAGS);
|
||||
String messageId = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_ID);
|
||||
String messageHash = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_HASH);
|
||||
if (TaskAction.UPSCALE.equals(changeDTO.getAction())) {
|
||||
return this.taskService.submitUpscale(task, messageId, messageHash, changeDTO.getIndex(), messageFlags);
|
||||
} else if (TaskAction.VARIATION.equals(changeDTO.getAction())) {
|
||||
return this.taskService.submitVariation(task, messageId, messageHash, changeDTO.getIndex(), messageFlags);
|
||||
} else {
|
||||
return this.taskService.submitReroll(task, messageId, messageHash, messageFlags);
|
||||
}
|
||||
}
|
||||
|
||||
@ApiOperation(value = "提交Describe任务")
|
||||
@PostMapping("/describe")
|
||||
public SubmitResultVO describe(@RequestBody SubmitDescribeDTO describeDTO) {
|
||||
if (CharSequenceUtil.isBlank(describeDTO.getBase64())) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64不能为空");
|
||||
}
|
||||
IDataUrlSerializer serializer = new DataUrlSerializer();
|
||||
DataUrl dataUrl;
|
||||
try {
|
||||
dataUrl = serializer.unserialize(describeDTO.getBase64());
|
||||
} catch (MalformedURLException e) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
|
||||
}
|
||||
Task task = newTask(describeDTO);
|
||||
task.setAction(TaskAction.DESCRIBE);
|
||||
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
|
||||
task.setDescription("/describe " + taskFileName);
|
||||
return this.taskService.submitDescribe(task, dataUrl);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "提交Blend任务")
|
||||
@PostMapping("/blend")
|
||||
public SubmitResultVO blend(@RequestBody SubmitBlendDTO blendDTO) {
|
||||
List<String> base64Array = blendDTO.getBase64Array();
|
||||
if (base64Array == null || base64Array.size() < 2 || base64Array.size() > 5) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64List参数错误");
|
||||
}
|
||||
if (blendDTO.getDimensions() == null) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "dimensions参数错误");
|
||||
}
|
||||
IDataUrlSerializer serializer = new DataUrlSerializer();
|
||||
List<DataUrl> dataUrlList = new ArrayList<>();
|
||||
try {
|
||||
for (String base64 : base64Array) {
|
||||
DataUrl dataUrl = serializer.unserialize(base64);
|
||||
dataUrlList.add(dataUrl);
|
||||
}
|
||||
} catch (MalformedURLException e) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
|
||||
}
|
||||
Task task = newTask(blendDTO);
|
||||
task.setAction(TaskAction.BLEND);
|
||||
task.setDescription("/blend " + task.getId() + " " + dataUrlList.size());
|
||||
return this.taskService.submitBlend(task, dataUrlList, blendDTO.getDimensions());
|
||||
}
|
||||
|
||||
private Task newTask(BaseSubmitDTO base) {
|
||||
Task task = new Task();
|
||||
task.setId(System.currentTimeMillis() + RandomUtil.randomNumbers(3));
|
||||
task.setSubmitTime(System.currentTimeMillis());
|
||||
task.setState(base.getState());
|
||||
String notifyHook = CharSequenceUtil.isBlank(base.getNotifyHook()) ? this.properties.getNotifyHook() : base.getNotifyHook();
|
||||
task.setProperty(Constants.TASK_PROPERTY_NOTIFY_HOOK, notifyHook);
|
||||
task.setProperty(Constants.TASK_PROPERTY_NONCE, SnowFlake.INSTANCE.nextId());
|
||||
return task;
|
||||
}
|
||||
|
||||
private String translatePrompt(String prompt) {
|
||||
if (TranslateWay.NULL.equals(this.properties.getTranslateWay()) || CharSequenceUtil.isBlank(prompt)) {
|
||||
return prompt;
|
||||
}
|
||||
List<String> imageUrls = new ArrayList<>();
|
||||
Matcher imageMatcher = Pattern.compile("https?://[a-z0-9-_:@&?=+,.!/~*'%$]+\\x20+", Pattern.CASE_INSENSITIVE).matcher(prompt);
|
||||
while (imageMatcher.find()) {
|
||||
imageUrls.add(imageMatcher.group(0));
|
||||
}
|
||||
String paramStr = "";
|
||||
Matcher paramMatcher = Pattern.compile("\\x20+-{1,2}[a-z]+.*$", Pattern.CASE_INSENSITIVE).matcher(prompt);
|
||||
if (paramMatcher.find()) {
|
||||
paramStr = paramMatcher.group(0);
|
||||
}
|
||||
String imageStr = CharSequenceUtil.join("", imageUrls);
|
||||
String text = prompt.substring(imageStr.length(), prompt.length() - paramStr.length());
|
||||
if (CharSequenceUtil.isNotBlank(text)) {
|
||||
text = this.translateService.translateToEnglish(text).trim();
|
||||
}
|
||||
return imageStr + text + paramStr;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package com.xmzs.midjourney.controller;
|
||||
|
||||
import cn.hutool.core.comparator.CompareUtil;
|
||||
import com.xmzs.midjourney.dto.TaskConditionDTO;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
|
||||
import com.xmzs.midjourney.service.TaskStoreService;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import io.swagger.annotations.Api;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import io.swagger.annotations.ApiParam;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
@Api(tags = "任务查询")
|
||||
@RestController
|
||||
@RequestMapping("/mj/task")
|
||||
@RequiredArgsConstructor
|
||||
public class TaskController {
|
||||
private final TaskStoreService taskStoreService;
|
||||
private final DiscordLoadBalancer discordLoadBalancer;
|
||||
|
||||
@ApiOperation(value = "指定ID获取任务")
|
||||
@GetMapping("/{id}/fetch")
|
||||
public Task fetch(@ApiParam(value = "任务ID") @PathVariable String id) {
|
||||
return this.taskStoreService.get(id);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "查询任务队列")
|
||||
@GetMapping("/queue")
|
||||
public List<Task> queue() {
|
||||
return this.discordLoadBalancer.getQueueTaskIds().stream()
|
||||
.map(this.taskStoreService::get).filter(Objects::nonNull)
|
||||
.sorted(Comparator.comparing(Task::getSubmitTime))
|
||||
.toList();
|
||||
}
|
||||
|
||||
@ApiOperation(value = "查询所有任务")
|
||||
@GetMapping("/list")
|
||||
public List<Task> list() {
|
||||
return this.taskStoreService.list().stream()
|
||||
.sorted((t1, t2) -> CompareUtil.compare(t2.getSubmitTime(), t1.getSubmitTime()))
|
||||
.toList();
|
||||
}
|
||||
|
||||
@ApiOperation(value = "根据ID列表查询任务")
|
||||
@PostMapping("/list-by-condition")
|
||||
public List<Task> listByIds(@RequestBody TaskConditionDTO conditionDTO) {
|
||||
if (conditionDTO.getIds() == null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return conditionDTO.getIds().stream().map(this.taskStoreService::get).filter(Objects::nonNull).toList();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package com.xmzs.midjourney.domain;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.xmzs.midjourney.Constants;
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@ApiModel("Discord账号")
|
||||
public class DiscordAccount extends DomainObject {
|
||||
|
||||
@ApiModelProperty("服务器ID")
|
||||
private String guildId;
|
||||
@ApiModelProperty("频道ID")
|
||||
private String channelId;
|
||||
@ApiModelProperty("用户Token")
|
||||
private String userToken;
|
||||
@ApiModelProperty("用户UserAgent")
|
||||
private String userAgent = Constants.DEFAULT_DISCORD_USER_AGENT;
|
||||
|
||||
@ApiModelProperty("是否可用")
|
||||
private boolean enable = true;
|
||||
|
||||
@ApiModelProperty("并发数")
|
||||
private int coreSize = 3;
|
||||
@ApiModelProperty("等待队列长度")
|
||||
private int queueSize = 10;
|
||||
@ApiModelProperty("任务超时时间(分钟)")
|
||||
private int timeoutMinutes = 5;
|
||||
|
||||
@JsonIgnore
|
||||
public String getDisplay() {
|
||||
return this.channelId;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package com.xmzs.midjourney.domain;
|
||||
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
public class DomainObject implements Serializable {
|
||||
@Getter
|
||||
@Setter
|
||||
@ApiModelProperty("ID")
|
||||
protected String id;
|
||||
|
||||
@Setter
|
||||
protected Map<String, Object> properties; // 扩展属性,仅支持基本类型
|
||||
|
||||
@JsonIgnore
|
||||
private final transient Object lock = new Object();
|
||||
|
||||
public void sleep() throws InterruptedException {
|
||||
synchronized (this.lock) {
|
||||
this.lock.wait();
|
||||
}
|
||||
}
|
||||
|
||||
public void awake() {
|
||||
synchronized (this.lock) {
|
||||
this.lock.notifyAll();
|
||||
}
|
||||
}
|
||||
|
||||
public DomainObject setProperty(String name, Object value) {
|
||||
getProperties().put(name, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public DomainObject removeProperty(String name) {
|
||||
getProperties().remove(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Object getProperty(String name) {
|
||||
return getProperties().get(name);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> T getPropertyGeneric(String name) {
|
||||
return (T) getProperty(name);
|
||||
}
|
||||
|
||||
public <T> T getProperty(String name, Class<T> clz) {
|
||||
return getProperty(name, clz, null);
|
||||
}
|
||||
|
||||
public <T> T getProperty(String name, Class<T> clz, T defaultValue) {
|
||||
Object value = getProperty(name);
|
||||
return value == null ? defaultValue : clz.cast(value);
|
||||
}
|
||||
|
||||
public Map<String, Object> getProperties() {
|
||||
if (this.properties == null) {
|
||||
this.properties = new HashMap<>();
|
||||
}
|
||||
return this.properties;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public abstract class BaseSubmitDTO {
|
||||
|
||||
@ApiModelProperty("自定义参数")
|
||||
protected String state;
|
||||
|
||||
@ApiModelProperty("回调地址, 为空时使用全局notifyHook")
|
||||
protected String notifyHook;
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import com.xmzs.midjourney.enums.BlendDimensions;
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@ApiModel("Blend提交参数")
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class SubmitBlendDTO extends BaseSubmitDTO {
|
||||
|
||||
@ApiModelProperty(value = "图片base64数组", required = true, example = "[\"data:image/png;base64,xxx1\", \"data:image/png;base64,xxx2\"]")
|
||||
private List<String> base64Array;
|
||||
|
||||
@ApiModelProperty(value = "比例: PORTRAIT(2:3); SQUARE(1:1); LANDSCAPE(3:2)", example = "SQUARE")
|
||||
private BlendDimensions dimensions = BlendDimensions.SQUARE;
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import com.xmzs.midjourney.enums.TaskAction;
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
|
||||
@Data
|
||||
@ApiModel("变化任务提交参数")
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class SubmitChangeDTO extends BaseSubmitDTO {
|
||||
|
||||
@ApiModelProperty(value = "任务ID", required = true, example = "\"1320098173412546\"")
|
||||
private String taskId;
|
||||
|
||||
@ApiModelProperty(value = "UPSCALE(放大); VARIATION(变换); REROLL(重新生成)", required = true,
|
||||
allowableValues = "UPSCALE, VARIATION, REROLL", example = "UPSCALE")
|
||||
private TaskAction action;
|
||||
|
||||
@ApiModelProperty(value = "序号(1~4), action为UPSCALE,VARIATION时必传", allowableValues = "range[1, 4]", example = "1")
|
||||
private Integer index;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
@Data
|
||||
@ApiModel("Describe提交参数")
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class SubmitDescribeDTO extends BaseSubmitDTO {
|
||||
|
||||
@ApiModelProperty(value = "图片base64", required = true, example = "data:image/png;base64,xxx")
|
||||
private String base64;
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Data
|
||||
@ApiModel("Imagine提交参数")
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class SubmitImagineDTO extends BaseSubmitDTO {
|
||||
|
||||
@ApiModelProperty(value = "提示词", required = true, example = "Cat")
|
||||
private String prompt;
|
||||
|
||||
@ApiModelProperty(value = "垫图base64数组")
|
||||
private List<String> base64Array;
|
||||
|
||||
@ApiModelProperty(hidden = true)
|
||||
@Deprecated(since = "3.0", forRemoval = true)
|
||||
private String base64;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
|
||||
@Data
|
||||
@ApiModel("变化任务提交参数-simple")
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class SubmitSimpleChangeDTO extends BaseSubmitDTO {
|
||||
|
||||
@ApiModelProperty(value = "变化描述: ID $action$index", required = true, example = "1320098173412546 U2")
|
||||
private String content;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@ApiModel("任务查询参数")
|
||||
public class TaskConditionDTO {
|
||||
|
||||
private List<String> ids;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.xmzs.midjourney.enums;
|
||||
|
||||
|
||||
public enum BlendDimensions {
|
||||
|
||||
PORTRAIT("2:3"),
|
||||
|
||||
SQUARE("1:1"),
|
||||
|
||||
LANDSCAPE("3:2");
|
||||
|
||||
private final String value;
|
||||
|
||||
BlendDimensions(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return this.value;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.xmzs.midjourney.enums;
|
||||
|
||||
|
||||
public enum MessageType {
|
||||
/**
|
||||
* 创建.
|
||||
*/
|
||||
CREATE,
|
||||
/**
|
||||
* 修改.
|
||||
*/
|
||||
UPDATE,
|
||||
/**
|
||||
* 删除.
|
||||
*/
|
||||
DELETE;
|
||||
|
||||
public static MessageType of(String type) {
|
||||
return switch (type) {
|
||||
case "MESSAGE_CREATE" -> CREATE;
|
||||
case "MESSAGE_UPDATE" -> UPDATE;
|
||||
case "MESSAGE_DELETE" -> DELETE;
|
||||
default -> null;
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.xmzs.midjourney.enums;
|
||||
|
||||
|
||||
public enum TaskAction {
|
||||
/**
|
||||
* 生成图片.
|
||||
*/
|
||||
IMAGINE,
|
||||
/**
|
||||
* 选中放大.
|
||||
*/
|
||||
UPSCALE,
|
||||
/**
|
||||
* 选中其中的一张图,生成四张相似的.
|
||||
*/
|
||||
VARIATION,
|
||||
/**
|
||||
* 重新执行.
|
||||
*/
|
||||
REROLL,
|
||||
/**
|
||||
* 图转prompt.
|
||||
*/
|
||||
DESCRIBE,
|
||||
/**
|
||||
* 多图混合.
|
||||
*/
|
||||
BLEND
|
||||
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.xmzs.midjourney.enums;
|
||||
|
||||
|
||||
public enum TaskStatus {
|
||||
/**
|
||||
* 未启动.
|
||||
*/
|
||||
NOT_START,
|
||||
/**
|
||||
* 已提交.
|
||||
*/
|
||||
SUBMITTED,
|
||||
/**
|
||||
* 执行中.
|
||||
*/
|
||||
IN_PROGRESS,
|
||||
/**
|
||||
* 失败.
|
||||
*/
|
||||
FAILURE,
|
||||
/**
|
||||
* 成功.
|
||||
*/
|
||||
SUCCESS
|
||||
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.xmzs.midjourney.enums;
|
||||
|
||||
|
||||
public enum TranslateWay {
|
||||
/**
|
||||
* 百度翻译.
|
||||
*/
|
||||
BAIDU,
|
||||
/**
|
||||
* GPT翻译.
|
||||
*/
|
||||
GPT,
|
||||
/**
|
||||
* 不翻译.
|
||||
*/
|
||||
NULL
|
||||
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.xmzs.midjourney.exception;
|
||||
|
||||
public class BannedPromptException extends Exception {
|
||||
|
||||
public BannedPromptException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.xmzs.midjourney.exception;
|
||||
|
||||
public class SnowFlakeException extends RuntimeException {
|
||||
|
||||
public SnowFlakeException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public SnowFlakeException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
public SnowFlakeException(Throwable cause) {
|
||||
super(cause);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package com.xmzs.midjourney.loadbalancer;
|
||||
|
||||
|
||||
import com.xmzs.midjourney.domain.DiscordAccount;
|
||||
import com.xmzs.midjourney.result.Message;
|
||||
import com.xmzs.midjourney.result.SubmitResultVO;
|
||||
import com.xmzs.midjourney.service.DiscordService;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
public interface DiscordInstance extends DiscordService {
|
||||
|
||||
String getInstanceId();
|
||||
|
||||
DiscordAccount account();
|
||||
|
||||
boolean isAlive();
|
||||
|
||||
void startWss() throws Exception;
|
||||
|
||||
List<Task> getRunningTasks();
|
||||
|
||||
void exitTask(Task task);
|
||||
|
||||
Map<String, Future<?>> getRunningFutures();
|
||||
|
||||
SubmitResultVO submitTask(Task task, Callable<Message<Void>> discordSubmit);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package com.xmzs.midjourney.loadbalancer;
|
||||
|
||||
|
||||
import com.xmzs.midjourney.Constants;
|
||||
import com.xmzs.midjourney.ReturnCode;
|
||||
import com.xmzs.midjourney.domain.DiscordAccount;
|
||||
import com.xmzs.midjourney.enums.BlendDimensions;
|
||||
import com.xmzs.midjourney.enums.TaskStatus;
|
||||
import com.xmzs.midjourney.result.Message;
|
||||
import com.xmzs.midjourney.result.SubmitResultVO;
|
||||
import com.xmzs.midjourney.service.DiscordService;
|
||||
import com.xmzs.midjourney.service.DiscordServiceImpl;
|
||||
import com.xmzs.midjourney.service.NotifyService;
|
||||
import com.xmzs.midjourney.service.TaskStoreService;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import com.xmzs.midjourney.wss.WebSocketStarter;
|
||||
import com.xmzs.midjourney.wss.user.UserWebSocketStarter;
|
||||
import eu.maxschuster.dataurl.DataUrl;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.RejectedExecutionException;
|
||||
|
||||
@Slf4j
|
||||
public class DiscordInstanceImpl implements DiscordInstance {
|
||||
private final DiscordAccount account;
|
||||
private final WebSocketStarter socketStarter;
|
||||
private final DiscordService service;
|
||||
private final TaskStoreService taskStoreService;
|
||||
private final NotifyService notifyService;
|
||||
|
||||
private final ThreadPoolTaskExecutor taskExecutor;
|
||||
private final List<Task> runningTasks;
|
||||
private final Map<String, Future<?>> taskFutureMap = Collections.synchronizedMap(new HashMap<>());
|
||||
|
||||
public DiscordInstanceImpl(DiscordAccount account, UserWebSocketStarter socketStarter, RestTemplate restTemplate,
|
||||
TaskStoreService taskStoreService, NotifyService notifyService, Map<String, String> paramsMap) {
|
||||
this.account = account;
|
||||
this.socketStarter = socketStarter;
|
||||
this.taskStoreService = taskStoreService;
|
||||
this.notifyService = notifyService;
|
||||
this.service = new DiscordServiceImpl(account, restTemplate, paramsMap);
|
||||
this.runningTasks = new CopyOnWriteArrayList<>();
|
||||
this.taskExecutor = new ThreadPoolTaskExecutor();
|
||||
this.taskExecutor.setCorePoolSize(account.getCoreSize());
|
||||
this.taskExecutor.setMaxPoolSize(account.getCoreSize());
|
||||
this.taskExecutor.setQueueCapacity(account.getQueueSize());
|
||||
this.taskExecutor.setThreadNamePrefix("TaskQueue-" + account.getDisplay() + "-");
|
||||
this.taskExecutor.initialize();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getInstanceId() {
|
||||
return this.account.getChannelId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DiscordAccount account() {
|
||||
return this.account;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isAlive() {
|
||||
return this.account.isEnable();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void startWss() throws Exception {
|
||||
this.socketStarter.setTrying(true);
|
||||
this.socketStarter.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Task> getRunningTasks() {
|
||||
return this.runningTasks;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exitTask(Task task) {
|
||||
try {
|
||||
Future<?> future = this.taskFutureMap.get(task.getId());
|
||||
if (future != null) {
|
||||
future.cancel(true);
|
||||
}
|
||||
saveAndNotify(task);
|
||||
} finally {
|
||||
this.runningTasks.remove(task);
|
||||
this.taskFutureMap.remove(task.getId());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Future<?>> getRunningFutures() {
|
||||
return this.taskFutureMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized SubmitResultVO submitTask(Task task, Callable<Message<Void>> discordSubmit) {
|
||||
this.taskStoreService.save(task);
|
||||
int currentWaitNumbers;
|
||||
try {
|
||||
currentWaitNumbers = this.taskExecutor.getThreadPoolExecutor().getQueue().size();
|
||||
Future<?> future = this.taskExecutor.submit(() -> executeTask(task, discordSubmit));
|
||||
this.taskFutureMap.put(task.getId(), future);
|
||||
} catch (RejectedExecutionException e) {
|
||||
this.taskStoreService.delete(task.getId());
|
||||
return SubmitResultVO.fail(ReturnCode.QUEUE_REJECTED, "队列已满,请稍后尝试")
|
||||
.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId());
|
||||
} catch (Exception e) {
|
||||
log.error("submit task error", e);
|
||||
return SubmitResultVO.fail(ReturnCode.FAILURE, "提交失败,系统异常")
|
||||
.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId());
|
||||
}
|
||||
if (currentWaitNumbers == 0) {
|
||||
return SubmitResultVO.of(ReturnCode.SUCCESS, "提交成功", task.getId())
|
||||
.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId());
|
||||
} else {
|
||||
return SubmitResultVO.of(ReturnCode.IN_QUEUE, "排队中,前面还有" + currentWaitNumbers + "个任务", task.getId())
|
||||
.setProperty("numberOfQueues", currentWaitNumbers)
|
||||
.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId());
|
||||
}
|
||||
}
|
||||
|
||||
private void executeTask(Task task, Callable<Message<Void>> discordSubmit) {
|
||||
this.runningTasks.add(task);
|
||||
try {
|
||||
task.start();
|
||||
Message<Void> result = discordSubmit.call();
|
||||
if (result.getCode() != ReturnCode.SUCCESS) {
|
||||
task.fail(result.getDescription());
|
||||
saveAndNotify(task);
|
||||
return;
|
||||
}
|
||||
saveAndNotify(task);
|
||||
do {
|
||||
task.sleep();
|
||||
saveAndNotify(task);
|
||||
} while (task.getStatus() == TaskStatus.IN_PROGRESS);
|
||||
log.debug("task finished, id: {}, status: {}", task.getId(), task.getStatus());
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
} catch (Exception e) {
|
||||
log.error("task execute error", e);
|
||||
task.fail("执行错误,系统异常");
|
||||
saveAndNotify(task);
|
||||
} finally {
|
||||
this.runningTasks.remove(task);
|
||||
this.taskFutureMap.remove(task.getId());
|
||||
}
|
||||
}
|
||||
|
||||
private void saveAndNotify(Task task) {
|
||||
this.taskStoreService.save(task);
|
||||
this.notifyService.notifyTaskChange(task);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> imagine(String prompt, String nonce) {
|
||||
return this.service.imagine(prompt, nonce);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> upscale(String messageId, int index, String messageHash, int messageFlags, String nonce) {
|
||||
return this.service.upscale(messageId, index, messageHash, messageFlags, nonce);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> variation(String messageId, int index, String messageHash, int messageFlags, String nonce) {
|
||||
return this.service.variation(messageId, index, messageHash, messageFlags, nonce);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> reroll(String messageId, String messageHash, int messageFlags, String nonce) {
|
||||
return this.service.reroll(messageId, messageHash, messageFlags, nonce);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> describe(String finalFileName, String nonce) {
|
||||
return this.service.describe(finalFileName, nonce);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> blend(List<String> finalFileNames, BlendDimensions dimensions, String nonce) {
|
||||
return this.service.blend(finalFileNames, dimensions, nonce);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<String> upload(String fileName, DataUrl dataUrl) {
|
||||
return this.service.upload(fileName, dataUrl);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<String> sendImageMessage(String content, String finalFileName) {
|
||||
return this.service.sendImageMessage(content, finalFileName);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package com.xmzs.midjourney.loadbalancer;
|
||||
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import com.xmzs.midjourney.loadbalancer.rule.IRule;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import com.xmzs.midjourney.support.TaskCondition;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class DiscordLoadBalancer {
|
||||
private final IRule rule;
|
||||
|
||||
private final List<DiscordInstance> instances = Collections.synchronizedList(new ArrayList<>());
|
||||
|
||||
public List<DiscordInstance> getAllInstances() {
|
||||
return this.instances;
|
||||
}
|
||||
|
||||
public List<DiscordInstance> getAliveInstances() {
|
||||
return this.instances.stream().filter(DiscordInstance::isAlive).toList();
|
||||
}
|
||||
|
||||
public DiscordInstance chooseInstance() {
|
||||
return this.rule.choose(getAliveInstances());
|
||||
}
|
||||
|
||||
public DiscordInstance getDiscordInstance(String instanceId) {
|
||||
if (CharSequenceUtil.isBlank(instanceId)) {
|
||||
return null;
|
||||
}
|
||||
return this.instances.stream()
|
||||
.filter(instance -> CharSequenceUtil.equals(instanceId, instance.getInstanceId()))
|
||||
.findFirst().orElse(null);
|
||||
}
|
||||
|
||||
public Set<String> getQueueTaskIds() {
|
||||
Set<String> taskIds = Collections.synchronizedSet(new HashSet<>());
|
||||
for (DiscordInstance instance : getAliveInstances()) {
|
||||
taskIds.addAll(instance.getRunningFutures().keySet());
|
||||
}
|
||||
return taskIds;
|
||||
}
|
||||
|
||||
public Stream<Task> findRunningTask(TaskCondition condition) {
|
||||
return getAliveInstances().stream().flatMap(instance -> instance.getRunningTasks().stream().filter(condition));
|
||||
}
|
||||
|
||||
public Task getRunningTask(String id) {
|
||||
for (DiscordInstance instance : getAliveInstances()) {
|
||||
Optional<Task> optional = instance.getRunningTasks().stream().filter(t -> id.equals(t.getId())).findFirst();
|
||||
if (optional.isPresent()) {
|
||||
return optional.get();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public Task getRunningTaskByNonce(String nonce) {
|
||||
if (CharSequenceUtil.isBlank(nonce)) {
|
||||
return null;
|
||||
}
|
||||
TaskCondition condition = new TaskCondition().setNonce(nonce);
|
||||
for (DiscordInstance instance : getAliveInstances()) {
|
||||
Optional<Task> optional = instance.getRunningTasks().stream().filter(condition).findFirst();
|
||||
if (optional.isPresent()) {
|
||||
return optional.get();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package com.xmzs.midjourney.loadbalancer.rule;
|
||||
|
||||
import cn.hutool.core.util.RandomUtil;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 最少等待空闲.
|
||||
* 选择等待数最少的实例,如果都不需要等待,则随机选择
|
||||
*/
|
||||
public class BestWaitIdleRule implements IRule {
|
||||
|
||||
@Override
|
||||
public DiscordInstance choose(List<DiscordInstance> instances) {
|
||||
if (instances.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
Map<Integer, List<DiscordInstance>> map = instances.stream()
|
||||
.collect(Collectors.groupingBy(i -> {
|
||||
int wait = i.getRunningFutures().size() - i.account().getCoreSize();
|
||||
return wait >= 0 ? wait : -1;
|
||||
}));
|
||||
List<DiscordInstance> instanceList = map.entrySet().stream().min(Comparator.comparingInt(Map.Entry::getKey)).orElseThrow().getValue();
|
||||
return RandomUtil.randomEle(instanceList);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.xmzs.midjourney.loadbalancer.rule;
|
||||
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface IRule {
|
||||
|
||||
DiscordInstance choose(List<DiscordInstance> instances);
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.xmzs.midjourney.loadbalancer.rule;
|
||||
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/**
|
||||
* 轮询.
|
||||
*/
|
||||
public class RoundRobinRule implements IRule {
|
||||
private final AtomicInteger position = new AtomicInteger(0);
|
||||
|
||||
@Override
|
||||
public DiscordInstance choose(List<DiscordInstance> instances) {
|
||||
if (instances.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
int pos = incrementAndGet();
|
||||
return instances.get(pos % instances.size());
|
||||
}
|
||||
|
||||
private int incrementAndGet() {
|
||||
int current;
|
||||
int next;
|
||||
do {
|
||||
current = this.position.get();
|
||||
next = current == Integer.MAX_VALUE ? 0 : current + 1;
|
||||
} while (!this.position.compareAndSet(current, next));
|
||||
return next;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package com.xmzs.midjourney.result;
|
||||
|
||||
import com.xmzs.midjourney.ReturnCode;
|
||||
import lombok.Getter;
|
||||
|
||||
@Getter
|
||||
public class Message<T> {
|
||||
private final int code;
|
||||
private final String description;
|
||||
private final T result;
|
||||
|
||||
public static <Y> Message<Y> success() {
|
||||
return new Message<>(ReturnCode.SUCCESS, "成功");
|
||||
}
|
||||
|
||||
public static <T> Message<T> success(T result) {
|
||||
return new Message<>(ReturnCode.SUCCESS, "成功", result);
|
||||
}
|
||||
|
||||
public static <T> Message<T> success(int code, String description, T result) {
|
||||
return new Message<>(code, description, result);
|
||||
}
|
||||
|
||||
public static <Y> Message<Y> notFound() {
|
||||
return new Message<>(ReturnCode.NOT_FOUND, "数据未找到");
|
||||
}
|
||||
|
||||
public static <Y> Message<Y> validationError() {
|
||||
return new Message<>(ReturnCode.VALIDATION_ERROR, "校验错误");
|
||||
}
|
||||
|
||||
public static <Y> Message<Y> failure() {
|
||||
return new Message<>(ReturnCode.FAILURE, "系统异常");
|
||||
}
|
||||
|
||||
public static <Y> Message<Y> failure(String description) {
|
||||
return new Message<>(ReturnCode.FAILURE, description);
|
||||
}
|
||||
|
||||
public static <Y> Message<Y> of(int code, String description) {
|
||||
return new Message<>(code, description);
|
||||
}
|
||||
|
||||
public static <T> Message<T> of(int code, String description, T result) {
|
||||
return new Message<>(code, description, result);
|
||||
}
|
||||
|
||||
private Message(int code, String description) {
|
||||
this(code, description, null);
|
||||
}
|
||||
|
||||
private Message(int code, String description, T result) {
|
||||
this.code = code;
|
||||
this.description = description;
|
||||
this.result = result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package com.xmzs.midjourney.result;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@ApiModel("提交结果")
|
||||
public class SubmitResultVO {
|
||||
|
||||
@ApiModelProperty(value = "状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)", required = true, example = "1")
|
||||
private int code;
|
||||
|
||||
@ApiModelProperty(value = "描述", required = true, example = "提交成功")
|
||||
private String description;
|
||||
|
||||
@ApiModelProperty(value = "任务ID", example = "1320098173412546")
|
||||
private String result;
|
||||
|
||||
@ApiModelProperty(value = "扩展字段")
|
||||
private Map<String, Object> properties = new HashMap<>();
|
||||
|
||||
public SubmitResultVO setProperty(String name, Object value) {
|
||||
this.properties.put(name, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public SubmitResultVO removeProperty(String name) {
|
||||
this.properties.remove(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Object getProperty(String name) {
|
||||
return this.properties.get(name);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> T getPropertyGeneric(String name) {
|
||||
return (T) getProperty(name);
|
||||
}
|
||||
|
||||
public <T> T getProperty(String name, Class<T> clz) {
|
||||
return clz.cast(getProperty(name));
|
||||
}
|
||||
|
||||
public static SubmitResultVO of(int code, String description, String result) {
|
||||
return new SubmitResultVO(code, description, result);
|
||||
}
|
||||
|
||||
public static SubmitResultVO fail(int code, String description) {
|
||||
return new SubmitResultVO(code, description, null);
|
||||
}
|
||||
|
||||
private SubmitResultVO(int code, String description, String result) {
|
||||
this.code = code;
|
||||
this.description = description;
|
||||
this.result = result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package com.xmzs.midjourney.service;
|
||||
|
||||
|
||||
import com.xmzs.midjourney.enums.BlendDimensions;
|
||||
import com.xmzs.midjourney.result.Message;
|
||||
import eu.maxschuster.dataurl.DataUrl;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface DiscordService {
|
||||
|
||||
Message<Void> imagine(String prompt, String nonce);
|
||||
|
||||
Message<Void> upscale(String messageId, int index, String messageHash, int messageFlags, String nonce);
|
||||
|
||||
Message<Void> variation(String messageId, int index, String messageHash, int messageFlags, String nonce);
|
||||
|
||||
Message<Void> reroll(String messageId, String messageHash, int messageFlags, String nonce);
|
||||
|
||||
Message<Void> describe(String finalFileName, String nonce);
|
||||
|
||||
Message<Void> blend(List<String> finalFileNames, BlendDimensions dimensions, String nonce);
|
||||
|
||||
Message<String> upload(String fileName, DataUrl dataUrl);
|
||||
|
||||
Message<String> sendImageMessage(String content, String finalFileName);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
package com.xmzs.midjourney.service;
|
||||
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import com.xmzs.midjourney.ReturnCode;
|
||||
import com.xmzs.midjourney.domain.DiscordAccount;
|
||||
import com.xmzs.midjourney.enums.BlendDimensions;
|
||||
import com.xmzs.midjourney.result.Message;
|
||||
import com.xmzs.midjourney.support.DiscordHelper;
|
||||
import com.xmzs.midjourney.support.SpringContextHolder;
|
||||
import eu.maxschuster.dataurl.DataUrl;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.json.JSONArray;
|
||||
import org.json.JSONObject;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.HttpStatusCodeException;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class DiscordServiceImpl implements DiscordService {
|
||||
private static final String DEFAULT_SESSION_ID = "f1a313a09ce079ce252459dc70231f30";
|
||||
|
||||
private final DiscordAccount account;
|
||||
private final Map<String, String> paramsMap;
|
||||
private final RestTemplate restTemplate;
|
||||
private final DiscordHelper discordHelper;
|
||||
|
||||
private final String discordInteractionUrl;
|
||||
private final String discordAttachmentUrl;
|
||||
private final String discordMessageUrl;
|
||||
|
||||
public DiscordServiceImpl(DiscordAccount account, RestTemplate restTemplate, Map<String, String> paramsMap) {
|
||||
this.account = account;
|
||||
this.restTemplate = restTemplate;
|
||||
this.discordHelper = SpringContextHolder.getApplicationContext().getBean(DiscordHelper.class);
|
||||
this.paramsMap = paramsMap;
|
||||
String discordServer = this.discordHelper.getServer();
|
||||
this.discordInteractionUrl = discordServer + "/api/v9/interactions";
|
||||
this.discordAttachmentUrl = discordServer + "/api/v9/channels/" + account.getChannelId() + "/attachments";
|
||||
this.discordMessageUrl = discordServer + "/api/v9/channels/" + account.getChannelId() + "/messages";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> imagine(String prompt, String nonce) {
|
||||
String paramsStr = replaceInteractionParams(this.paramsMap.get("imagine"), nonce);
|
||||
JSONObject params = new JSONObject(paramsStr);
|
||||
params.getJSONObject("data").getJSONArray("options").getJSONObject(0)
|
||||
.put("value", prompt);
|
||||
return postJsonAndCheckStatus(params.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> upscale(String messageId, int index, String messageHash, int messageFlags, String nonce) {
|
||||
String paramsStr = replaceInteractionParams(this.paramsMap.get("upscale"), nonce)
|
||||
.replace("$message_id", messageId)
|
||||
.replace("$index", String.valueOf(index))
|
||||
.replace("$message_hash", messageHash);
|
||||
paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString();
|
||||
return postJsonAndCheckStatus(paramsStr);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> variation(String messageId, int index, String messageHash, int messageFlags, String nonce) {
|
||||
String paramsStr = replaceInteractionParams(this.paramsMap.get("variation"), nonce)
|
||||
.replace("$message_id", messageId)
|
||||
.replace("$index", String.valueOf(index))
|
||||
.replace("$message_hash", messageHash);
|
||||
paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString();
|
||||
return postJsonAndCheckStatus(paramsStr);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> reroll(String messageId, String messageHash, int messageFlags, String nonce) {
|
||||
String paramsStr = replaceInteractionParams(this.paramsMap.get("reroll"), nonce)
|
||||
.replace("$message_id", messageId)
|
||||
.replace("$message_hash", messageHash);
|
||||
paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString();
|
||||
return postJsonAndCheckStatus(paramsStr);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> describe(String finalFileName, String nonce) {
|
||||
String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true);
|
||||
String paramsStr = replaceInteractionParams(this.paramsMap.get("describe"), nonce)
|
||||
.replace("$file_name", fileName)
|
||||
.replace("$final_file_name", finalFileName);
|
||||
return postJsonAndCheckStatus(paramsStr);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<Void> blend(List<String> finalFileNames, BlendDimensions dimensions, String nonce) {
|
||||
String paramsStr = replaceInteractionParams(this.paramsMap.get("blend"), nonce);
|
||||
JSONObject params = new JSONObject(paramsStr);
|
||||
JSONArray options = params.getJSONObject("data").getJSONArray("options");
|
||||
JSONArray attachments = params.getJSONObject("data").getJSONArray("attachments");
|
||||
for (int i = 0; i < finalFileNames.size(); i++) {
|
||||
String finalFileName = finalFileNames.get(i);
|
||||
String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true);
|
||||
JSONObject attachment = new JSONObject().put("id", String.valueOf(i))
|
||||
.put("filename", fileName)
|
||||
.put("uploaded_filename", finalFileName);
|
||||
attachments.put(attachment);
|
||||
JSONObject option = new JSONObject().put("type", 11)
|
||||
.put("name", "image" + (i + 1))
|
||||
.put("value", i);
|
||||
options.put(option);
|
||||
}
|
||||
options.put(new JSONObject().put("type", 3)
|
||||
.put("name", "dimensions")
|
||||
.put("value", "--ar " + dimensions.getValue()));
|
||||
return postJsonAndCheckStatus(params.toString());
|
||||
}
|
||||
|
||||
private String replaceInteractionParams(String paramsStr, String nonce) {
|
||||
return paramsStr.replace("$guild_id", this.account.getGuildId())
|
||||
.replace("$channel_id", this.account.getChannelId())
|
||||
.replace("$session_id", DEFAULT_SESSION_ID)
|
||||
.replace("$nonce", nonce);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<String> upload(String fileName, DataUrl dataUrl) {
|
||||
try {
|
||||
JSONObject fileObj = new JSONObject();
|
||||
fileObj.put("filename", fileName);
|
||||
fileObj.put("file_size", dataUrl.getData().length);
|
||||
fileObj.put("id", "0");
|
||||
JSONObject params = new JSONObject()
|
||||
.put("files", new JSONArray().put(fileObj));
|
||||
ResponseEntity<String> responseEntity = postJson(this.discordAttachmentUrl, params.toString());
|
||||
if (responseEntity.getStatusCode() != HttpStatus.OK) {
|
||||
log.error("上传图片到discord失败, status: {}, msg: {}", responseEntity.getStatusCodeValue(), responseEntity.getBody());
|
||||
return Message.of(ReturnCode.VALIDATION_ERROR, "上传图片到discord失败");
|
||||
}
|
||||
JSONArray array = new JSONObject(responseEntity.getBody()).getJSONArray("attachments");
|
||||
if (array.length() == 0) {
|
||||
return Message.of(ReturnCode.VALIDATION_ERROR, "上传图片到discord失败");
|
||||
}
|
||||
String uploadUrl = array.getJSONObject(0).getString("upload_url");
|
||||
String uploadFilename = array.getJSONObject(0).getString("upload_filename");
|
||||
putFile(uploadUrl, dataUrl);
|
||||
return Message.success(uploadFilename);
|
||||
} catch (Exception e) {
|
||||
log.error("上传图片到discord失败", e);
|
||||
return Message.of(ReturnCode.FAILURE, "上传图片到discord失败");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message<String> sendImageMessage(String content, String finalFileName) {
|
||||
String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true);
|
||||
String paramsStr = this.paramsMap.get("message").replace("$content", content)
|
||||
.replace("$channel_id", this.account.getChannelId())
|
||||
.replace("$file_name", fileName)
|
||||
.replace("$final_file_name", finalFileName);
|
||||
ResponseEntity<String> responseEntity = postJson(this.discordMessageUrl, paramsStr);
|
||||
if (responseEntity.getStatusCode() != HttpStatus.OK) {
|
||||
log.error("发送图片消息到discord失败, status: {}, msg: {}", responseEntity.getStatusCodeValue(), responseEntity.getBody());
|
||||
return Message.of(ReturnCode.VALIDATION_ERROR, "发送图片消息到discord失败");
|
||||
}
|
||||
JSONObject result = new JSONObject(responseEntity.getBody());
|
||||
JSONArray attachments = result.optJSONArray("attachments");
|
||||
if (!attachments.isEmpty()) {
|
||||
return Message.success(attachments.getJSONObject(0).optString("url"));
|
||||
}
|
||||
return Message.failure("发送图片消息到discord失败: 图片不存在");
|
||||
}
|
||||
|
||||
private void putFile(String uploadUrl, DataUrl dataUrl) {
|
||||
uploadUrl = this.discordHelper.getDiscordUploadUrl(uploadUrl);
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.add("User-Agent", this.account.getUserAgent());
|
||||
headers.setContentType(MediaType.valueOf(dataUrl.getMimeType()));
|
||||
headers.setContentLength(dataUrl.getData().length);
|
||||
HttpEntity<byte[]> requestEntity = new HttpEntity<>(dataUrl.getData(), headers);
|
||||
this.restTemplate.put(uploadUrl, requestEntity);
|
||||
}
|
||||
|
||||
private ResponseEntity<String> postJson(String paramsStr) {
|
||||
return postJson(this.discordInteractionUrl, paramsStr);
|
||||
}
|
||||
|
||||
private ResponseEntity<String> postJson(String url, String paramsStr) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.set("Authorization", this.account.getUserToken());
|
||||
headers.set("User-Agent", this.account.getUserAgent());
|
||||
HttpEntity<String> httpEntity = new HttpEntity<>(paramsStr, headers);
|
||||
return this.restTemplate.postForEntity(url, httpEntity, String.class);
|
||||
}
|
||||
|
||||
private Message<Void> postJsonAndCheckStatus(String paramsStr) {
|
||||
try {
|
||||
ResponseEntity<String> responseEntity = postJson(paramsStr);
|
||||
if (responseEntity.getStatusCode() == HttpStatus.NO_CONTENT) {
|
||||
return Message.success();
|
||||
}
|
||||
return Message.of(responseEntity.getStatusCodeValue(), CharSequenceUtil.sub(responseEntity.getBody(), 0, 100));
|
||||
} catch (HttpStatusCodeException e) {
|
||||
return convertHttpStatusCodeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private Message<Void> convertHttpStatusCodeException(HttpStatusCodeException e) {
|
||||
try {
|
||||
JSONObject error = new JSONObject(e.getResponseBodyAsString());
|
||||
return Message.of(error.optInt("code", e.getRawStatusCode()), error.optString("message"));
|
||||
} catch (Exception je) {
|
||||
return Message.of(e.getRawStatusCode(), CharSequenceUtil.sub(e.getMessage(), 0, 100));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.xmzs.midjourney.service;
|
||||
|
||||
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
|
||||
public interface NotifyService {
|
||||
|
||||
void notifyTaskChange(Task task);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.xmzs.midjourney.service;
|
||||
|
||||
import cn.hutool.cache.CacheUtil;
|
||||
import cn.hutool.cache.impl.TimedCache;
|
||||
import cn.hutool.core.exceptions.CheckedUtil;
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.xmzs.midjourney.Constants;
|
||||
import com.xmzs.midjourney.ProxyProperties;
|
||||
import com.xmzs.midjourney.enums.TaskStatus;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class NotifyServiceImpl implements NotifyService {
|
||||
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
|
||||
private final ThreadPoolTaskExecutor executor;
|
||||
private final TimedCache<String, Object> taskLocks = CacheUtil.newTimedCache(Duration.ofHours(1).toMillis());
|
||||
|
||||
public NotifyServiceImpl(ProxyProperties properties) {
|
||||
this.executor = new ThreadPoolTaskExecutor();
|
||||
this.executor.setCorePoolSize(properties.getNotifyPoolSize());
|
||||
this.executor.setThreadNamePrefix("TaskNotify-");
|
||||
this.executor.initialize();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyTaskChange(Task task) {
|
||||
String notifyHook = task.getPropertyGeneric(Constants.TASK_PROPERTY_NOTIFY_HOOK);
|
||||
if (CharSequenceUtil.isBlank(notifyHook)) {
|
||||
return;
|
||||
}
|
||||
String taskId = task.getId();
|
||||
TaskStatus taskStatus = task.getStatus();
|
||||
Object taskLock = this.taskLocks.get(taskId, (CheckedUtil.Func0Rt<Object>) Object::new);
|
||||
try {
|
||||
String paramsStr = OBJECT_MAPPER.writeValueAsString(task);
|
||||
this.executor.execute(() -> {
|
||||
synchronized (taskLock) {
|
||||
try {
|
||||
ResponseEntity<String> responseEntity = postJson(notifyHook, paramsStr);
|
||||
if (responseEntity.getStatusCode() == HttpStatus.OK) {
|
||||
log.debug("推送任务变更成功, 任务ID: {}, status: {}, notifyHook: {}", taskId, taskStatus, notifyHook);
|
||||
} else {
|
||||
log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, code: {}, msg: {}", taskId, notifyHook, responseEntity.getStatusCodeValue(), responseEntity.getBody());
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, 描述: {}", taskId, notifyHook, e.getMessage());
|
||||
}
|
||||
}
|
||||
});
|
||||
} catch (JsonProcessingException e) {
|
||||
log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, 描述: {}", taskId, notifyHook, e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private ResponseEntity<String> postJson(String notifyHook, String paramsJson) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> httpEntity = new HttpEntity<>(paramsJson, headers);
|
||||
return new RestTemplate().postForEntity(notifyHook, httpEntity, String.class);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.xmzs.midjourney.service;
|
||||
|
||||
import com.xmzs.midjourney.enums.BlendDimensions;
|
||||
import com.xmzs.midjourney.result.SubmitResultVO;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import eu.maxschuster.dataurl.DataUrl;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface TaskService {
|
||||
|
||||
SubmitResultVO submitImagine(Task task, List<DataUrl> dataUrls);
|
||||
|
||||
SubmitResultVO submitUpscale(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags);
|
||||
|
||||
SubmitResultVO submitVariation(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags);
|
||||
|
||||
SubmitResultVO submitReroll(Task task, String targetMessageId, String targetMessageHash, int messageFlags);
|
||||
|
||||
SubmitResultVO submitDescribe(Task task, DataUrl dataUrl);
|
||||
|
||||
SubmitResultVO submitBlend(Task task, List<DataUrl> dataUrls, BlendDimensions dimensions);
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package com.xmzs.midjourney.service;
|
||||
|
||||
import com.xmzs.midjourney.Constants;
|
||||
import com.xmzs.midjourney.ReturnCode;
|
||||
import com.xmzs.midjourney.enums.BlendDimensions;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
|
||||
import com.xmzs.midjourney.result.Message;
|
||||
import com.xmzs.midjourney.result.SubmitResultVO;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import com.xmzs.midjourney.util.MimeTypeUtils;
|
||||
import eu.maxschuster.dataurl.DataUrl;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class TaskServiceImpl implements TaskService {
|
||||
private final TaskStoreService taskStoreService;
|
||||
private final DiscordLoadBalancer discordLoadBalancer;
|
||||
|
||||
@Override
|
||||
public SubmitResultVO submitImagine(Task task, List<DataUrl> dataUrls) {
|
||||
DiscordInstance instance = this.discordLoadBalancer.chooseInstance();
|
||||
if (instance == null) {
|
||||
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "无可用的账号实例");
|
||||
}
|
||||
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, instance.getInstanceId());
|
||||
return instance.submitTask(task, () -> {
|
||||
List<String> imageUrls = new ArrayList<>();
|
||||
for (DataUrl dataUrl : dataUrls) {
|
||||
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
|
||||
Message<String> uploadResult = instance.upload(taskFileName, dataUrl);
|
||||
if (uploadResult.getCode() != ReturnCode.SUCCESS) {
|
||||
return Message.of(uploadResult.getCode(), uploadResult.getDescription());
|
||||
}
|
||||
String finalFileName = uploadResult.getResult();
|
||||
Message<String> sendImageResult = instance.sendImageMessage("upload image: " + finalFileName, finalFileName);
|
||||
if (sendImageResult.getCode() != ReturnCode.SUCCESS) {
|
||||
return Message.of(sendImageResult.getCode(), sendImageResult.getDescription());
|
||||
}
|
||||
imageUrls.add(sendImageResult.getResult());
|
||||
}
|
||||
if (!imageUrls.isEmpty()) {
|
||||
task.setPrompt(String.join(" ", imageUrls) + " " + task.getPrompt());
|
||||
task.setPromptEn(String.join(" ", imageUrls) + " " + task.getPromptEn());
|
||||
task.setDescription("/imagine " + task.getPrompt());
|
||||
this.taskStoreService.save(task);
|
||||
}
|
||||
return instance.imagine(task.getPromptEn(), task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE));
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public SubmitResultVO submitUpscale(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags) {
|
||||
String instanceId = task.getPropertyGeneric(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID);
|
||||
DiscordInstance discordInstance = this.discordLoadBalancer.getDiscordInstance(instanceId);
|
||||
if (discordInstance == null || !discordInstance.isAlive()) {
|
||||
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId);
|
||||
}
|
||||
return discordInstance.submitTask(task, () -> discordInstance.upscale(targetMessageId, index, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public SubmitResultVO submitVariation(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags) {
|
||||
String instanceId = task.getPropertyGeneric(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID);
|
||||
DiscordInstance discordInstance = this.discordLoadBalancer.getDiscordInstance(instanceId);
|
||||
if (discordInstance == null || !discordInstance.isAlive()) {
|
||||
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId);
|
||||
}
|
||||
return discordInstance.submitTask(task, () -> discordInstance.variation(targetMessageId, index, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public SubmitResultVO submitReroll(Task task, String targetMessageId, String targetMessageHash, int messageFlags) {
|
||||
String instanceId = task.getPropertyGeneric(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID);
|
||||
DiscordInstance discordInstance = this.discordLoadBalancer.getDiscordInstance(instanceId);
|
||||
if (discordInstance == null || !discordInstance.isAlive()) {
|
||||
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId);
|
||||
}
|
||||
return discordInstance.submitTask(task, () -> discordInstance.reroll(targetMessageId, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public SubmitResultVO submitDescribe(Task task, DataUrl dataUrl) {
|
||||
DiscordInstance discordInstance = this.discordLoadBalancer.chooseInstance();
|
||||
if (discordInstance == null) {
|
||||
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "无可用的账号实例");
|
||||
}
|
||||
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, discordInstance.getInstanceId());
|
||||
return discordInstance.submitTask(task, () -> {
|
||||
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
|
||||
Message<String> uploadResult = discordInstance.upload(taskFileName, dataUrl);
|
||||
if (uploadResult.getCode() != ReturnCode.SUCCESS) {
|
||||
return Message.of(uploadResult.getCode(), uploadResult.getDescription());
|
||||
}
|
||||
String finalFileName = uploadResult.getResult();
|
||||
return discordInstance.describe(finalFileName, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE));
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public SubmitResultVO submitBlend(Task task, List<DataUrl> dataUrls, BlendDimensions dimensions) {
|
||||
DiscordInstance discordInstance = this.discordLoadBalancer.chooseInstance();
|
||||
if (discordInstance == null) {
|
||||
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "无可用的账号实例");
|
||||
}
|
||||
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, discordInstance.getInstanceId());
|
||||
return discordInstance.submitTask(task, () -> {
|
||||
List<String> finalFileNames = new ArrayList<>();
|
||||
for (DataUrl dataUrl : dataUrls) {
|
||||
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
|
||||
Message<String> uploadResult = discordInstance.upload(taskFileName, dataUrl);
|
||||
if (uploadResult.getCode() != ReturnCode.SUCCESS) {
|
||||
return Message.of(uploadResult.getCode(), uploadResult.getDescription());
|
||||
}
|
||||
finalFileNames.add(uploadResult.getResult());
|
||||
}
|
||||
return discordInstance.blend(finalFileNames, dimensions, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE));
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.xmzs.midjourney.service;
|
||||
|
||||
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import com.xmzs.midjourney.support.TaskCondition;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface TaskStoreService {
|
||||
|
||||
void save(Task task);
|
||||
|
||||
void delete(String id);
|
||||
|
||||
Task get(String id);
|
||||
|
||||
List<Task> list();
|
||||
|
||||
List<Task> list(TaskCondition condition);
|
||||
|
||||
Task findOne(TaskCondition condition);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.xmzs.midjourney.service;
|
||||
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
public interface TranslateService {
|
||||
|
||||
String translateToEnglish(String prompt);
|
||||
|
||||
default boolean containsChinese(String prompt) {
|
||||
return Pattern.compile("[\u4e00-\u9fa5]").matcher(prompt).find();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package com.xmzs.midjourney.service.store;
|
||||
|
||||
import cn.hutool.cache.CacheUtil;
|
||||
import cn.hutool.cache.impl.TimedCache;
|
||||
import cn.hutool.core.collection.ListUtil;
|
||||
import cn.hutool.core.stream.StreamUtil;
|
||||
import com.xmzs.midjourney.service.TaskStoreService;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import com.xmzs.midjourney.support.TaskCondition;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public class InMemoryTaskStoreServiceImpl implements TaskStoreService {
|
||||
private final TimedCache<String, Task> taskMap;
|
||||
|
||||
public InMemoryTaskStoreServiceImpl(Duration timeout) {
|
||||
this.taskMap = CacheUtil.newTimedCache(timeout.toMillis());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(Task task) {
|
||||
this.taskMap.put(task.getId(), task);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(String key) {
|
||||
this.taskMap.remove(key);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Task get(String key) {
|
||||
return this.taskMap.get(key);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Task> list() {
|
||||
return ListUtil.toList(this.taskMap.iterator());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Task> list(TaskCondition condition) {
|
||||
return StreamUtil.of(this.taskMap.iterator()).filter(condition).toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Task findOne(TaskCondition condition) {
|
||||
return StreamUtil.of(this.taskMap.iterator()).filter(condition).findFirst().orElse(null);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package com.xmzs.midjourney.service.store;
|
||||
|
||||
import com.xmzs.midjourney.service.TaskStoreService;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import com.xmzs.midjourney.support.TaskCondition;
|
||||
import org.springframework.data.redis.core.Cursor;
|
||||
import org.springframework.data.redis.core.RedisCallback;
|
||||
import org.springframework.data.redis.core.RedisTemplate;
|
||||
import org.springframework.data.redis.core.ScanOptions;
|
||||
import org.springframework.data.redis.core.ValueOperations;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class RedisTaskStoreServiceImpl implements TaskStoreService {
|
||||
private static final String KEY_PREFIX = "mj-task-store::";
|
||||
|
||||
private final Duration timeout;
|
||||
private final RedisTemplate<String, Task> redisTemplate;
|
||||
|
||||
public RedisTaskStoreServiceImpl(Duration timeout, RedisTemplate<String, Task> redisTemplate) {
|
||||
this.timeout = timeout;
|
||||
this.redisTemplate = redisTemplate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(Task task) {
|
||||
this.redisTemplate.opsForValue().set(getRedisKey(task.getId()), task, this.timeout);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(String id) {
|
||||
this.redisTemplate.delete(getRedisKey(id));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Task get(String id) {
|
||||
return this.redisTemplate.opsForValue().get(getRedisKey(id));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Task> list() {
|
||||
Set<String> keys = this.redisTemplate.execute((RedisCallback<Set<String>>) connection -> {
|
||||
Cursor<byte[]> cursor = connection.scan(ScanOptions.scanOptions().match(KEY_PREFIX + "*").count(1000).build());
|
||||
return cursor.stream().map(String::new).collect(Collectors.toSet());
|
||||
});
|
||||
if (keys == null || keys.isEmpty()) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
ValueOperations<String, Task> operations = this.redisTemplate.opsForValue();
|
||||
return keys.stream().map(operations::get)
|
||||
.filter(Objects::nonNull)
|
||||
.toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Task> list(TaskCondition condition) {
|
||||
return list().stream().filter(condition).toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Task findOne(TaskCondition condition) {
|
||||
return list().stream().filter(condition).findFirst().orElse(null);
|
||||
}
|
||||
|
||||
private String getRedisKey(String id) {
|
||||
return KEY_PREFIX + id;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package com.xmzs.midjourney.service.translate;
|
||||
|
||||
|
||||
import cn.hutool.core.exceptions.ValidateException;
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import cn.hutool.core.util.RandomUtil;
|
||||
import cn.hutool.crypto.digest.MD5;
|
||||
import com.xmzs.midjourney.ProxyProperties;
|
||||
import com.xmzs.midjourney.service.TranslateService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.json.JSONArray;
|
||||
import org.json.JSONObject;
|
||||
import org.springframework.beans.factory.support.BeanDefinitionValidationException;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class BaiduTranslateServiceImpl implements TranslateService {
|
||||
private static final String TRANSLATE_API = "https://fanyi-api.baidu.com/api/trans/vip/translate";
|
||||
|
||||
private final String appid;
|
||||
private final String appSecret;
|
||||
|
||||
public BaiduTranslateServiceImpl(ProxyProperties.BaiduTranslateConfig translateConfig) {
|
||||
this.appid = translateConfig.getAppid();
|
||||
this.appSecret = translateConfig.getAppSecret();
|
||||
if (!CharSequenceUtil.isAllNotBlank(this.appid, this.appSecret)) {
|
||||
throw new BeanDefinitionValidationException("mj.baidu-translate.appid或mj.baidu-translate.app-secret未配置");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String translateToEnglish(String prompt) {
|
||||
if (!containsChinese(prompt)) {
|
||||
return prompt;
|
||||
}
|
||||
String salt = RandomUtil.randomNumbers(5);
|
||||
String sign = MD5.create().digestHex(this.appid + prompt + salt + this.appSecret);
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
|
||||
MultiValueMap<String, String> body = new LinkedMultiValueMap<>();
|
||||
body.add("from", "zh");
|
||||
body.add("to", "en");
|
||||
body.add("appid", this.appid);
|
||||
body.add("salt", salt);
|
||||
body.add("q", prompt);
|
||||
body.add("sign", sign);
|
||||
HttpEntity<MultiValueMap<String, String>> requestEntity = new HttpEntity<>(body, headers);
|
||||
try {
|
||||
ResponseEntity<String> responseEntity = new RestTemplate().exchange(TRANSLATE_API, HttpMethod.POST, requestEntity, String.class);
|
||||
if (responseEntity.getStatusCode() != HttpStatus.OK || CharSequenceUtil.isBlank(responseEntity.getBody())) {
|
||||
throw new ValidateException(responseEntity.getStatusCodeValue() + " - " + responseEntity.getBody());
|
||||
}
|
||||
JSONObject result = new JSONObject(responseEntity.getBody());
|
||||
if (result.has("error_code")) {
|
||||
throw new ValidateException(result.getString("error_code") + " - " + result.getString("error_msg"));
|
||||
}
|
||||
List<String> strings = new ArrayList<>();
|
||||
JSONArray transResult = result.getJSONArray("trans_result");
|
||||
for (int i = 0; i < transResult.length(); i++) {
|
||||
strings.add(transResult.getJSONObject(i).getString("dst"));
|
||||
}
|
||||
return CharSequenceUtil.join("\n", strings);
|
||||
} catch (Exception e) {
|
||||
log.warn("调用百度翻译失败: {}", e.getMessage());
|
||||
}
|
||||
return prompt;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package com.xmzs.midjourney.service.translate;
|
||||
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import com.xmzs.midjourney.ProxyProperties;
|
||||
import com.xmzs.midjourney.service.TranslateService;
|
||||
import com.unfbx.chatgpt.OpenAiClient;
|
||||
import com.unfbx.chatgpt.entity.chat.ChatChoice;
|
||||
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
|
||||
import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse;
|
||||
import com.unfbx.chatgpt.entity.chat.Message;
|
||||
import com.unfbx.chatgpt.function.KeyRandomStrategy;
|
||||
import com.unfbx.chatgpt.interceptor.OpenAILogger;
|
||||
import com.unfbx.chatgpt.interceptor.OpenAiResponseInterceptor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.logging.HttpLoggingInterceptor;
|
||||
import org.springframework.beans.factory.support.BeanDefinitionValidationException;
|
||||
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.Proxy;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Slf4j
|
||||
public class GPTTranslateServiceImpl implements TranslateService {
|
||||
private final OpenAiClient openAiClient;
|
||||
private final ProxyProperties.OpenaiConfig openaiConfig;
|
||||
|
||||
public GPTTranslateServiceImpl(ProxyProperties properties) {
|
||||
this.openaiConfig = properties.getOpenai();
|
||||
if (CharSequenceUtil.isBlank(this.openaiConfig.getGptApiKey())) {
|
||||
throw new BeanDefinitionValidationException("mj.openai.gpt-api-key未配置");
|
||||
}
|
||||
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
|
||||
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
|
||||
OkHttpClient.Builder okHttpBuilder = new OkHttpClient.Builder()
|
||||
.addInterceptor(httpLoggingInterceptor)
|
||||
.addInterceptor(new OpenAiResponseInterceptor())
|
||||
.connectTimeout(10, TimeUnit.SECONDS)
|
||||
.writeTimeout(30, TimeUnit.SECONDS)
|
||||
.readTimeout(30, TimeUnit.SECONDS);
|
||||
if (CharSequenceUtil.isNotBlank(properties.getProxy().getHost())) {
|
||||
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(properties.getProxy().getHost(), properties.getProxy().getPort()));
|
||||
okHttpBuilder.proxy(proxy);
|
||||
}
|
||||
OpenAiClient.Builder apiBuilder = OpenAiClient.builder()
|
||||
.apiKey(Collections.singletonList(this.openaiConfig.getGptApiKey()))
|
||||
.keyStrategy(new KeyRandomStrategy())
|
||||
.okHttpClient(okHttpBuilder.build());
|
||||
if (CharSequenceUtil.isNotBlank(this.openaiConfig.getGptApiUrl())) {
|
||||
apiBuilder.apiHost(this.openaiConfig.getGptApiUrl());
|
||||
}
|
||||
this.openAiClient = apiBuilder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String translateToEnglish(String prompt) {
|
||||
if (!containsChinese(prompt)) {
|
||||
return prompt;
|
||||
}
|
||||
Message m1 = Message.builder().role(Message.Role.SYSTEM).content("把中文翻译成英文").build();
|
||||
Message m2 = Message.builder().role(Message.Role.USER).content(prompt).build();
|
||||
ChatCompletion chatCompletion = ChatCompletion.builder()
|
||||
.messages(Arrays.asList(m1, m2))
|
||||
.model(this.openaiConfig.getModel())
|
||||
.temperature(this.openaiConfig.getTemperature())
|
||||
.maxTokens(this.openaiConfig.getMaxTokens())
|
||||
.build();
|
||||
ChatCompletionResponse chatCompletionResponse = this.openAiClient.chatCompletion(chatCompletion);
|
||||
try {
|
||||
List<ChatChoice> choices = chatCompletionResponse.getChoices();
|
||||
if (!choices.isEmpty()) {
|
||||
return choices.get(0).getMessage().getContent();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.warn("调用chat-gpt接口翻译中文失败: {}", e.getMessage());
|
||||
}
|
||||
return prompt;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.xmzs.midjourney.service.translate;
|
||||
|
||||
|
||||
import com.xmzs.midjourney.service.TranslateService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class NoTranslateServiceImpl implements TranslateService {
|
||||
|
||||
@Override
|
||||
public String translateToEnglish(String prompt) {
|
||||
return prompt;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.xmzs.midjourney.support;
|
||||
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import com.xmzs.midjourney.Constants;
|
||||
import com.xmzs.midjourney.ProxyProperties;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.servlet.HandlerInterceptor;
|
||||
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class ApiAuthorizeInterceptor implements HandlerInterceptor {
|
||||
private final ProxyProperties properties;
|
||||
|
||||
@Override
|
||||
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
|
||||
if (CharSequenceUtil.isBlank(this.properties.getApiSecret())) {
|
||||
return true;
|
||||
}
|
||||
String apiSecret = request.getHeader(Constants.API_SECRET_HEADER_NAME);
|
||||
boolean authorized = CharSequenceUtil.equals(apiSecret, this.properties.getApiSecret());
|
||||
if (!authorized) {
|
||||
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
|
||||
}
|
||||
return authorized;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.xmzs.midjourney.support;
|
||||
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import com.xmzs.midjourney.Constants;
|
||||
import com.xmzs.midjourney.ProxyProperties;
|
||||
import com.xmzs.midjourney.domain.DiscordAccount;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordInstanceImpl;
|
||||
import com.xmzs.midjourney.service.NotifyService;
|
||||
import com.xmzs.midjourney.service.TaskStoreService;
|
||||
import com.xmzs.midjourney.wss.handle.MessageHandler;
|
||||
import com.xmzs.midjourney.wss.user.UserMessageListener;
|
||||
import com.xmzs.midjourney.wss.user.UserWebSocketStarter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
public class DiscordAccountHelper {
|
||||
private final DiscordHelper discordHelper;
|
||||
private final ProxyProperties properties;
|
||||
private final RestTemplate restTemplate;
|
||||
private final TaskStoreService taskStoreService;
|
||||
private final NotifyService notifyService;
|
||||
private final List<MessageHandler> messageHandlers;
|
||||
private final Map<String, String> paramsMap;
|
||||
|
||||
public DiscordInstance createDiscordInstance(DiscordAccount account) {
|
||||
if (!CharSequenceUtil.isAllNotBlank(account.getGuildId(), account.getChannelId(), account.getUserToken())) {
|
||||
throw new IllegalArgumentException("guildId, channelId, userToken must not be blank");
|
||||
}
|
||||
if (CharSequenceUtil.isBlank(account.getUserAgent())) {
|
||||
account.setUserAgent(Constants.DEFAULT_DISCORD_USER_AGENT);
|
||||
}
|
||||
var messageListener = new UserMessageListener(account, this.messageHandlers);
|
||||
var webSocketStarter = new UserWebSocketStarter(this.discordHelper.getWss(), account, messageListener, this.properties.getProxy());
|
||||
return new DiscordInstanceImpl(account, webSocketStarter, this.restTemplate,
|
||||
this.taskStoreService, this.notifyService, this.paramsMap);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package com.xmzs.midjourney.support;
|
||||
|
||||
|
||||
import cn.hutool.core.bean.BeanUtil;
|
||||
import cn.hutool.core.exceptions.ValidateException;
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import com.xmzs.midjourney.ProxyProperties;
|
||||
import com.xmzs.midjourney.ReturnCode;
|
||||
import com.xmzs.midjourney.domain.DiscordAccount;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
|
||||
import com.xmzs.midjourney.util.AsyncLockUtils;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.boot.ApplicationArguments;
|
||||
import org.springframework.boot.ApplicationRunner;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class DiscordAccountInitializer implements ApplicationRunner {
|
||||
private final DiscordLoadBalancer discordLoadBalancer;
|
||||
private final DiscordAccountHelper discordAccountHelper;
|
||||
private final ProxyProperties properties;
|
||||
|
||||
@Override
|
||||
public void run(ApplicationArguments args) throws Exception {
|
||||
ProxyProperties.ProxyConfig proxy = this.properties.getProxy();
|
||||
if (Strings.isNotBlank(proxy.getHost())) {
|
||||
System.setProperty("http.proxyHost", proxy.getHost());
|
||||
System.setProperty("http.proxyPort", String.valueOf(proxy.getPort()));
|
||||
System.setProperty("https.proxyHost", proxy.getHost());
|
||||
System.setProperty("https.proxyPort", String.valueOf(proxy.getPort()));
|
||||
}
|
||||
|
||||
List<ProxyProperties.DiscordAccountConfig> configAccounts = this.properties.getAccounts();
|
||||
if (CharSequenceUtil.isNotBlank(this.properties.getDiscord().getChannelId())) {
|
||||
configAccounts.add(this.properties.getDiscord());
|
||||
}
|
||||
List<DiscordInstance> instances = this.discordLoadBalancer.getAllInstances();
|
||||
for (ProxyProperties.DiscordAccountConfig configAccount : configAccounts) {
|
||||
DiscordAccount account = new DiscordAccount();
|
||||
BeanUtil.copyProperties(configAccount, account);
|
||||
account.setId(configAccount.getChannelId());
|
||||
try {
|
||||
DiscordInstance instance = this.discordAccountHelper.createDiscordInstance(account);
|
||||
if (!account.isEnable()) {
|
||||
continue;
|
||||
}
|
||||
instance.startWss();
|
||||
AsyncLockUtils.LockObject lock = AsyncLockUtils.waitForLock("wss:" + account.getChannelId(), Duration.ofSeconds(10));
|
||||
if (ReturnCode.SUCCESS != lock.getProperty("code", Integer.class, 0)) {
|
||||
throw new ValidateException(lock.getProperty("description", String.class));
|
||||
}
|
||||
instances.add(instance);
|
||||
} catch (Exception e) {
|
||||
log.error("Account({}) init fail, disabled: {}", account.getDisplay(), e.getMessage());
|
||||
account.setEnable(false);
|
||||
}
|
||||
}
|
||||
Set<String> enableInstanceIds = instances.stream().filter(DiscordInstance::isAlive).map(DiscordInstance::getInstanceId).collect(Collectors.toSet());
|
||||
log.info("当前可用账号数 [{}] - {}", enableInstanceIds.size(), String.join(", ", enableInstanceIds));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package com.xmzs.midjourney.support;
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import com.xmzs.midjourney.ProxyProperties;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class DiscordHelper {
|
||||
private final ProxyProperties properties;
|
||||
/**
|
||||
* DISCORD_SERVER_URL.
|
||||
*/
|
||||
public static final String DISCORD_SERVER_URL = "https://discord.com";
|
||||
/**
|
||||
* DISCORD_CDN_URL.
|
||||
*/
|
||||
public static final String DISCORD_CDN_URL = "https://cdn.discordapp.com";
|
||||
/**
|
||||
* DISCORD_WSS_URL.
|
||||
*/
|
||||
public static final String DISCORD_WSS_URL = "wss://gateway.discord.gg";
|
||||
/**
|
||||
* DISCORD_UPLOAD_URL.
|
||||
*/
|
||||
public static final String DISCORD_UPLOAD_URL = "https://discord-attachments-uploads-prd.storage.googleapis.com";
|
||||
|
||||
public String getServer() {
|
||||
if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getServer())) {
|
||||
return DISCORD_SERVER_URL;
|
||||
}
|
||||
String serverUrl = this.properties.getNgDiscord().getServer();
|
||||
if (serverUrl.endsWith("/")) {
|
||||
serverUrl = serverUrl.substring(0, serverUrl.length() - 1);
|
||||
}
|
||||
return serverUrl;
|
||||
}
|
||||
|
||||
public String getCdn() {
|
||||
if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getCdn())) {
|
||||
return DISCORD_CDN_URL;
|
||||
}
|
||||
String cdnUrl = this.properties.getNgDiscord().getCdn();
|
||||
if (cdnUrl.endsWith("/")) {
|
||||
cdnUrl = cdnUrl.substring(0, cdnUrl.length() - 1);
|
||||
}
|
||||
return cdnUrl;
|
||||
}
|
||||
|
||||
public String getWss() {
|
||||
if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getWss())) {
|
||||
return DISCORD_WSS_URL;
|
||||
}
|
||||
String wssUrl = this.properties.getNgDiscord().getWss();
|
||||
if (wssUrl.endsWith("/")) {
|
||||
wssUrl = wssUrl.substring(0, wssUrl.length() - 1);
|
||||
}
|
||||
return wssUrl;
|
||||
}
|
||||
|
||||
public String getDiscordUploadUrl(String uploadUrl) {
|
||||
if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getUploadServer()) || CharSequenceUtil.isBlank(uploadUrl)) {
|
||||
return uploadUrl;
|
||||
}
|
||||
String uploadServer = this.properties.getNgDiscord().getUploadServer();
|
||||
if (uploadServer.endsWith("/")) {
|
||||
uploadServer = uploadServer.substring(0, uploadServer.length() - 1);
|
||||
}
|
||||
return uploadUrl.replaceFirst(DISCORD_UPLOAD_URL, uploadServer);
|
||||
}
|
||||
|
||||
public String findTaskIdWithCdnUrl(String url) {
|
||||
if (!CharSequenceUtil.startWith(url, DISCORD_CDN_URL)) {
|
||||
return null;
|
||||
}
|
||||
int hashStartIndex = url.lastIndexOf("/");
|
||||
String taskId = CharSequenceUtil.subBefore(url.substring(hashStartIndex + 1), ".", true);
|
||||
if (CharSequenceUtil.length(taskId) == 16) {
|
||||
return taskId;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public String getMessageHash(String imageUrl) {
|
||||
if (CharSequenceUtil.isBlank(imageUrl)) {
|
||||
return null;
|
||||
}
|
||||
if (CharSequenceUtil.endWith(imageUrl, "_grid_0.webp")) {
|
||||
int hashStartIndex = imageUrl.lastIndexOf("/");
|
||||
if (hashStartIndex < 0) {
|
||||
return null;
|
||||
}
|
||||
return CharSequenceUtil.sub(imageUrl, hashStartIndex + 1, imageUrl.length() - "_grid_0.webp".length());
|
||||
}
|
||||
int hashStartIndex = imageUrl.lastIndexOf("_");
|
||||
if (hashStartIndex < 0) {
|
||||
return null;
|
||||
}
|
||||
return CharSequenceUtil.subBefore(imageUrl.substring(hashStartIndex + 1), ".", true);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.xmzs.midjourney.support;
|
||||
|
||||
import org.springframework.beans.BeansException;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.ApplicationContextAware;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class SpringContextHolder implements ApplicationContextAware {
|
||||
private static ApplicationContext APPLICATION_CONTEXT;
|
||||
|
||||
@Override
|
||||
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
|
||||
APPLICATION_CONTEXT = applicationContext;
|
||||
}
|
||||
|
||||
public static ApplicationContext getApplicationContext() {
|
||||
if (APPLICATION_CONTEXT == null) {
|
||||
throw new IllegalStateException("SpringContextHolder is not ready.");
|
||||
}
|
||||
return APPLICATION_CONTEXT;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package com.xmzs.midjourney.support;
|
||||
|
||||
import com.xmzs.midjourney.domain.DomainObject;
|
||||
import com.xmzs.midjourney.enums.TaskAction;
|
||||
import com.xmzs.midjourney.enums.TaskStatus;
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
import java.io.Serial;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@ApiModel("任务")
|
||||
public class Task extends DomainObject {
|
||||
@Serial
|
||||
private static final long serialVersionUID = -674915748204390789L;
|
||||
|
||||
@ApiModelProperty("任务类型")
|
||||
private TaskAction action;
|
||||
@ApiModelProperty("任务状态")
|
||||
private TaskStatus status = TaskStatus.NOT_START;
|
||||
|
||||
@ApiModelProperty("提示词")
|
||||
private String prompt;
|
||||
@ApiModelProperty("提示词-英文")
|
||||
private String promptEn;
|
||||
|
||||
@ApiModelProperty("任务描述")
|
||||
private String description;
|
||||
@ApiModelProperty("自定义参数")
|
||||
private String state;
|
||||
|
||||
@ApiModelProperty("提交时间")
|
||||
private Long submitTime;
|
||||
@ApiModelProperty("开始执行时间")
|
||||
private Long startTime;
|
||||
@ApiModelProperty("结束时间")
|
||||
private Long finishTime;
|
||||
|
||||
@ApiModelProperty("图片url")
|
||||
private String imageUrl;
|
||||
|
||||
@ApiModelProperty("任务进度")
|
||||
private String progress;
|
||||
@ApiModelProperty("失败原因")
|
||||
private String failReason;
|
||||
|
||||
public void start() {
|
||||
this.startTime = System.currentTimeMillis();
|
||||
this.status = TaskStatus.SUBMITTED;
|
||||
this.progress = "0%";
|
||||
}
|
||||
|
||||
public void success() {
|
||||
this.finishTime = System.currentTimeMillis();
|
||||
this.status = TaskStatus.SUCCESS;
|
||||
this.progress = "100%";
|
||||
}
|
||||
|
||||
public void fail(String reason) {
|
||||
this.finishTime = System.currentTimeMillis();
|
||||
this.status = TaskStatus.FAILURE;
|
||||
this.failReason = reason;
|
||||
this.progress = "";
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user