init v1.0.0

This commit is contained in:
ageer
2024-02-27 20:52:19 +08:00
parent 1f7f97e86a
commit a079ef44e5
602 changed files with 163057 additions and 95 deletions

View File

@@ -0,0 +1,20 @@
package com.xmzs.midjourney;
import lombok.experimental.UtilityClass;
@UtilityClass
public final class Constants {
// 任务扩展属性 start
public static final String TASK_PROPERTY_NOTIFY_HOOK = "notifyHook";
public static final String TASK_PROPERTY_FINAL_PROMPT = "finalPrompt";
public static final String TASK_PROPERTY_MESSAGE_ID = "messageId";
public static final String TASK_PROPERTY_MESSAGE_HASH = "messageHash";
public static final String TASK_PROPERTY_PROGRESS_MESSAGE_ID = "progressMessageId";
public static final String TASK_PROPERTY_FLAGS = "flags";
public static final String TASK_PROPERTY_NONCE = "nonce";
public static final String TASK_PROPERTY_DISCORD_INSTANCE_ID = "discordInstanceId";
// 任务扩展属性 end
public static final String API_SECRET_HEADER_NAME = "mj-api-secret";
public static final String DEFAULT_DISCORD_USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36";
}

View File

@@ -0,0 +1,19 @@
package com.xmzs.midjourney;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Import;
import org.springframework.scheduling.annotation.EnableScheduling;
import spring.config.BeanConfig;
import spring.config.WebMvcConfig;
@EnableScheduling
@SpringBootApplication
@Import({BeanConfig.class, WebMvcConfig.class})
public class ProxyApplication {
public static void main(String[] args) {
SpringApplication.run(ProxyApplication.class, args);
}
}

View File

@@ -0,0 +1,211 @@
package com.xmzs.midjourney;
import com.xmzs.midjourney.enums.TranslateWay;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
@Data
@Component
@ConfigurationProperties(prefix = "mj")
public class ProxyProperties {
/**
* task存储配置.
*/
private final TaskStore taskStore = new TaskStore();
/**
* discord账号选择规则.
*/
private String accountChooseRule = "BestWaitIdleRule";
/**
* discord单账号配置.
*/
private final DiscordAccountConfig discord = new DiscordAccountConfig();
/**
* discord账号池配置.
*/
private final List<DiscordAccountConfig> accounts = new ArrayList<>();
/**
* 代理配置.
*/
private final ProxyConfig proxy = new ProxyConfig();
/**
* 反代配置.
*/
private final NgDiscordConfig ngDiscord = new NgDiscordConfig();
/**
* 百度翻译配置.
*/
private final BaiduTranslateConfig baiduTranslate = new BaiduTranslateConfig();
/**
* openai配置.
*/
private final OpenaiConfig openai = new OpenaiConfig();
/**
* 中文prompt翻译方式.
*/
private TranslateWay translateWay = TranslateWay.NULL;
/**
* 接口密钥,为空不启用鉴权;调用接口时需要加请求头 mj-api-secret.
*/
private String apiSecret;
/**
* 任务状态变更回调地址.
*/
private String notifyHook;
/**
* 通知回调线程池大小.
*/
private int notifyPoolSize = 10;
@Data
public static class DiscordAccountConfig {
/**
* 服务器ID.
*/
private String guildId;
/**
* 频道ID.
*/
private String channelId;
/**
* 用户Token.
*/
private String userToken;
/**
* 用户UserAgent.
*/
private String userAgent = Constants.DEFAULT_DISCORD_USER_AGENT;
/**
* 是否可用.
*/
private boolean enable = true;
/**
* 并发数.
*/
private int coreSize = 3;
/**
* 等待队列长度.
*/
private int queueSize = 10;
/**
* 任务超时时间(分钟).
*/
private int timeoutMinutes = 5;
}
@Data
public static class BaiduTranslateConfig {
/**
* 百度翻译的APP_ID.
*/
private String appid;
/**
* 百度翻译的密钥.
*/
private String appSecret;
}
@Data
public static class OpenaiConfig {
/**
* 自定义gpt的api-url.
*/
private String gptApiUrl;
/**
* gpt的api-key.
*/
private String gptApiKey;
/**
* 超时时间.
*/
private Duration timeout = Duration.ofSeconds(30);
/**
* 使用的模型.
*/
private String model = "gpt-3.5-turbo";
/**
* 返回结果的最大分词数.
*/
private int maxTokens = 2048;
/**
* 相似度,取值 0-2.
*/
private double temperature = 0;
}
@Data
public static class TaskStore {
/**
* 任务过期时间默认30天.
*/
private Duration timeout = Duration.ofDays(30);
/**
* 任务存储方式: redis(默认)、in_memory.
*/
private Type type = Type.IN_MEMORY;
public enum Type {
/**
* redis.
*/
REDIS,
/**
* in_memory.
*/
IN_MEMORY
}
}
@Data
public static class ProxyConfig {
/**
* 代理host.
*/
private String host;
/**
* 代理端口.
*/
private Integer port;
}
@Data
public static class NgDiscordConfig {
/**
* https://discord.com 反代.
*/
private String server;
/**
* https://cdn.discordapp.com 反代.
*/
private String cdn;
/**
* wss://gateway.discord.gg 反代.
*/
private String wss;
/**
* https://discord-attachments-uploads-prd.storage.googleapis.com 反代.
*/
private String uploadServer;
}
@Data
public static class TaskQueueConfig {
/**
* 并发数.
*/
private int coreSize = 3;
/**
* 等待队列长度.
*/
private int queueSize = 10;
/**
* 任务超时时间(分钟).
*/
private int timeoutMinutes = 5;
}
}

View File

@@ -0,0 +1,42 @@
package com.xmzs.midjourney;
import lombok.experimental.UtilityClass;
@UtilityClass
public final class ReturnCode {
/**
* 成功.
*/
public static final int SUCCESS = 1;
/**
* 数据未找到.
*/
public static final int NOT_FOUND = 3;
/**
* 校验错误.
*/
public static final int VALIDATION_ERROR = 4;
/**
* 系统异常.
*/
public static final int FAILURE = 9;
/**
* 已存在.
*/
public static final int EXISTED = 21;
/**
* 排队中.
*/
public static final int IN_QUEUE = 22;
/**
* 队列已满.
*/
public static final int QUEUE_REJECTED = 23;
/**
* prompt包含敏感词.
*/
public static final int BANNED_PROMPT = 24;
}

View File

@@ -0,0 +1,36 @@
package com.xmzs.midjourney.controller;
import com.xmzs.midjourney.domain.DiscordAccount;
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import io.swagger.annotations.ApiParam;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@Api(tags = "账号查询")
@RestController
@RequestMapping("/mj/account")
@RequiredArgsConstructor
public class AccountController {
private final DiscordLoadBalancer loadBalancer;
@ApiOperation(value = "指定ID获取账号")
@GetMapping("/{id}/fetch")
public DiscordAccount fetch(@ApiParam(value = "账号ID") @PathVariable String id) {
DiscordInstance instance = this.loadBalancer.getDiscordInstance(id);
return instance == null ? null : instance.account();
}
@ApiOperation(value = "查询所有账号")
@GetMapping("/list")
public List<DiscordAccount> list() {
return this.loadBalancer.getAllInstances().stream().map(DiscordInstance::account).toList();
}
}

View File

@@ -0,0 +1,240 @@
package com.xmzs.midjourney.controller;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.RandomUtil;
import com.xmzs.midjourney.Constants;
import com.xmzs.midjourney.ProxyProperties;
import com.xmzs.midjourney.ReturnCode;
import com.xmzs.midjourney.dto.BaseSubmitDTO;
import com.xmzs.midjourney.dto.SubmitBlendDTO;
import com.xmzs.midjourney.dto.SubmitChangeDTO;
import com.xmzs.midjourney.dto.SubmitDescribeDTO;
import com.xmzs.midjourney.dto.SubmitImagineDTO;
import com.xmzs.midjourney.dto.SubmitSimpleChangeDTO;
import com.xmzs.midjourney.enums.TaskAction;
import com.xmzs.midjourney.enums.TaskStatus;
import com.xmzs.midjourney.enums.TranslateWay;
import com.xmzs.midjourney.exception.BannedPromptException;
import com.xmzs.midjourney.result.SubmitResultVO;
import com.xmzs.midjourney.service.TaskService;
import com.xmzs.midjourney.service.TaskStoreService;
import com.xmzs.midjourney.service.TranslateService;
import com.xmzs.midjourney.support.Task;
import com.xmzs.midjourney.support.TaskCondition;
import com.xmzs.midjourney.util.BannedPromptUtils;
import com.xmzs.midjourney.util.ConvertUtils;
import com.xmzs.midjourney.util.MimeTypeUtils;
import com.xmzs.midjourney.util.SnowFlake;
import com.xmzs.midjourney.util.TaskChangeParams;
import eu.maxschuster.dataurl.DataUrl;
import eu.maxschuster.dataurl.DataUrlSerializer;
import eu.maxschuster.dataurl.IDataUrlSerializer;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Api(tags = "任务提交")
@RestController
@RequestMapping("/mj/submit")
@RequiredArgsConstructor
public class SubmitController {
private final TranslateService translateService;
private final TaskStoreService taskStoreService;
private final ProxyProperties properties;
private final TaskService taskService;
@ApiOperation(value = "提交Imagine任务")
@PostMapping("/imagine")
public SubmitResultVO imagine(@RequestBody SubmitImagineDTO imagineDTO) {
String prompt = imagineDTO.getPrompt();
if (CharSequenceUtil.isBlank(prompt)) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "prompt不能为空");
}
prompt = prompt.trim();
Task task = newTask(imagineDTO);
task.setAction(TaskAction.IMAGINE);
task.setPrompt(prompt);
String promptEn = translatePrompt(prompt);
try {
BannedPromptUtils.checkBanned(promptEn);
} catch (BannedPromptException e) {
return SubmitResultVO.fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词")
.setProperty("promptEn", promptEn).setProperty("bannedWord", e.getMessage());
}
List<String> base64Array = Optional.ofNullable(imagineDTO.getBase64Array()).orElse(new ArrayList<>());
if (CharSequenceUtil.isNotBlank(imagineDTO.getBase64())) {
base64Array.add(imagineDTO.getBase64());
}
List<DataUrl> dataUrls;
try {
dataUrls = ConvertUtils.convertBase64Array(base64Array);
} catch (MalformedURLException e) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
}
task.setPromptEn(promptEn);
task.setDescription("/imagine " + prompt);
return this.taskService.submitImagine(task, dataUrls);
}
@ApiOperation(value = "绘图变化-simple")
@PostMapping("/simple-change")
public SubmitResultVO simpleChange(@RequestBody SubmitSimpleChangeDTO simpleChangeDTO) {
TaskChangeParams changeParams = ConvertUtils.convertChangeParams(simpleChangeDTO.getContent());
if (changeParams == null) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "content参数错误");
}
SubmitChangeDTO changeDTO = new SubmitChangeDTO();
changeDTO.setAction(changeParams.getAction());
changeDTO.setTaskId(changeParams.getId());
changeDTO.setIndex(changeParams.getIndex());
changeDTO.setState(simpleChangeDTO.getState());
changeDTO.setNotifyHook(simpleChangeDTO.getNotifyHook());
return change(changeDTO);
}
@ApiOperation(value = "绘图变化")
@PostMapping("/change")
public SubmitResultVO change(@RequestBody SubmitChangeDTO changeDTO) {
if (CharSequenceUtil.isBlank(changeDTO.getTaskId())) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "taskId不能为空");
}
if (!Set.of(TaskAction.UPSCALE, TaskAction.VARIATION, TaskAction.REROLL).contains(changeDTO.getAction())) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "action参数错误");
}
String description = "/up " + changeDTO.getTaskId();
if (TaskAction.REROLL.equals(changeDTO.getAction())) {
description += " R";
} else {
description += " " + changeDTO.getAction().name().charAt(0) + changeDTO.getIndex();
}
if (TaskAction.UPSCALE.equals(changeDTO.getAction())) {
TaskCondition condition = new TaskCondition().setDescription(description);
Task existTask = this.taskStoreService.findOne(condition);
if (existTask != null) {
return SubmitResultVO.of(ReturnCode.EXISTED, "任务已存在", existTask.getId())
.setProperty("status", existTask.getStatus())
.setProperty("imageUrl", existTask.getImageUrl());
}
}
Task targetTask = this.taskStoreService.get(changeDTO.getTaskId());
if (targetTask == null) {
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "关联任务不存在或已失效");
}
if (!TaskStatus.SUCCESS.equals(targetTask.getStatus())) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务状态错误");
}
if (!Set.of(TaskAction.IMAGINE, TaskAction.VARIATION, TaskAction.REROLL, TaskAction.BLEND).contains(targetTask.getAction())) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务不允许执行变化");
}
Task task = newTask(changeDTO);
task.setAction(changeDTO.getAction());
task.setPrompt(targetTask.getPrompt());
task.setPromptEn(targetTask.getPromptEn());
task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, targetTask.getProperty(Constants.TASK_PROPERTY_FINAL_PROMPT));
task.setProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_MESSAGE_ID));
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID));
task.setDescription(description);
int messageFlags = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_FLAGS);
String messageId = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_ID);
String messageHash = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_HASH);
if (TaskAction.UPSCALE.equals(changeDTO.getAction())) {
return this.taskService.submitUpscale(task, messageId, messageHash, changeDTO.getIndex(), messageFlags);
} else if (TaskAction.VARIATION.equals(changeDTO.getAction())) {
return this.taskService.submitVariation(task, messageId, messageHash, changeDTO.getIndex(), messageFlags);
} else {
return this.taskService.submitReroll(task, messageId, messageHash, messageFlags);
}
}
@ApiOperation(value = "提交Describe任务")
@PostMapping("/describe")
public SubmitResultVO describe(@RequestBody SubmitDescribeDTO describeDTO) {
if (CharSequenceUtil.isBlank(describeDTO.getBase64())) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64不能为空");
}
IDataUrlSerializer serializer = new DataUrlSerializer();
DataUrl dataUrl;
try {
dataUrl = serializer.unserialize(describeDTO.getBase64());
} catch (MalformedURLException e) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
}
Task task = newTask(describeDTO);
task.setAction(TaskAction.DESCRIBE);
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
task.setDescription("/describe " + taskFileName);
return this.taskService.submitDescribe(task, dataUrl);
}
@ApiOperation(value = "提交Blend任务")
@PostMapping("/blend")
public SubmitResultVO blend(@RequestBody SubmitBlendDTO blendDTO) {
List<String> base64Array = blendDTO.getBase64Array();
if (base64Array == null || base64Array.size() < 2 || base64Array.size() > 5) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64List参数错误");
}
if (blendDTO.getDimensions() == null) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "dimensions参数错误");
}
IDataUrlSerializer serializer = new DataUrlSerializer();
List<DataUrl> dataUrlList = new ArrayList<>();
try {
for (String base64 : base64Array) {
DataUrl dataUrl = serializer.unserialize(base64);
dataUrlList.add(dataUrl);
}
} catch (MalformedURLException e) {
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
}
Task task = newTask(blendDTO);
task.setAction(TaskAction.BLEND);
task.setDescription("/blend " + task.getId() + " " + dataUrlList.size());
return this.taskService.submitBlend(task, dataUrlList, blendDTO.getDimensions());
}
private Task newTask(BaseSubmitDTO base) {
Task task = new Task();
task.setId(System.currentTimeMillis() + RandomUtil.randomNumbers(3));
task.setSubmitTime(System.currentTimeMillis());
task.setState(base.getState());
String notifyHook = CharSequenceUtil.isBlank(base.getNotifyHook()) ? this.properties.getNotifyHook() : base.getNotifyHook();
task.setProperty(Constants.TASK_PROPERTY_NOTIFY_HOOK, notifyHook);
task.setProperty(Constants.TASK_PROPERTY_NONCE, SnowFlake.INSTANCE.nextId());
return task;
}
private String translatePrompt(String prompt) {
if (TranslateWay.NULL.equals(this.properties.getTranslateWay()) || CharSequenceUtil.isBlank(prompt)) {
return prompt;
}
List<String> imageUrls = new ArrayList<>();
Matcher imageMatcher = Pattern.compile("https?://[a-z0-9-_:@&?=+,.!/~*'%$]+\\x20+", Pattern.CASE_INSENSITIVE).matcher(prompt);
while (imageMatcher.find()) {
imageUrls.add(imageMatcher.group(0));
}
String paramStr = "";
Matcher paramMatcher = Pattern.compile("\\x20+-{1,2}[a-z]+.*$", Pattern.CASE_INSENSITIVE).matcher(prompt);
if (paramMatcher.find()) {
paramStr = paramMatcher.group(0);
}
String imageStr = CharSequenceUtil.join("", imageUrls);
String text = prompt.substring(imageStr.length(), prompt.length() - paramStr.length());
if (CharSequenceUtil.isNotBlank(text)) {
text = this.translateService.translateToEnglish(text).trim();
}
return imageStr + text + paramStr;
}
}

View File

@@ -0,0 +1,64 @@
package com.xmzs.midjourney.controller;
import cn.hutool.core.comparator.CompareUtil;
import com.xmzs.midjourney.dto.TaskConditionDTO;
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
import com.xmzs.midjourney.service.TaskStoreService;
import com.xmzs.midjourney.support.Task;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import io.swagger.annotations.ApiParam;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
@Api(tags = "任务查询")
@RestController
@RequestMapping("/mj/task")
@RequiredArgsConstructor
public class TaskController {
private final TaskStoreService taskStoreService;
private final DiscordLoadBalancer discordLoadBalancer;
@ApiOperation(value = "指定ID获取任务")
@GetMapping("/{id}/fetch")
public Task fetch(@ApiParam(value = "任务ID") @PathVariable String id) {
return this.taskStoreService.get(id);
}
@ApiOperation(value = "查询任务队列")
@GetMapping("/queue")
public List<Task> queue() {
return this.discordLoadBalancer.getQueueTaskIds().stream()
.map(this.taskStoreService::get).filter(Objects::nonNull)
.sorted(Comparator.comparing(Task::getSubmitTime))
.toList();
}
@ApiOperation(value = "查询所有任务")
@GetMapping("/list")
public List<Task> list() {
return this.taskStoreService.list().stream()
.sorted((t1, t2) -> CompareUtil.compare(t2.getSubmitTime(), t1.getSubmitTime()))
.toList();
}
@ApiOperation(value = "根据ID列表查询任务")
@PostMapping("/list-by-condition")
public List<Task> listByIds(@RequestBody TaskConditionDTO conditionDTO) {
if (conditionDTO.getIds() == null) {
return Collections.emptyList();
}
return conditionDTO.getIds().stream().map(this.taskStoreService::get).filter(Objects::nonNull).toList();
}
}

View File

@@ -0,0 +1,38 @@
package com.xmzs.midjourney.domain;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.xmzs.midjourney.Constants;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = true)
@ApiModel("Discord账号")
public class DiscordAccount extends DomainObject {
@ApiModelProperty("服务器ID")
private String guildId;
@ApiModelProperty("频道ID")
private String channelId;
@ApiModelProperty("用户Token")
private String userToken;
@ApiModelProperty("用户UserAgent")
private String userAgent = Constants.DEFAULT_DISCORD_USER_AGENT;
@ApiModelProperty("是否可用")
private boolean enable = true;
@ApiModelProperty("并发数")
private int coreSize = 3;
@ApiModelProperty("等待队列长度")
private int queueSize = 10;
@ApiModelProperty("任务超时时间(分钟)")
private int timeoutMinutes = 5;
@JsonIgnore
public String getDisplay() {
return this.channelId;
}
}

View File

@@ -0,0 +1,72 @@
package com.xmzs.midjourney.domain;
import com.fasterxml.jackson.annotation.JsonIgnore;
import io.swagger.annotations.ApiModelProperty;
import lombok.Getter;
import lombok.Setter;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
public class DomainObject implements Serializable {
@Getter
@Setter
@ApiModelProperty("ID")
protected String id;
@Setter
protected Map<String, Object> properties; // 扩展属性,仅支持基本类型
@JsonIgnore
private final transient Object lock = new Object();
public void sleep() throws InterruptedException {
synchronized (this.lock) {
this.lock.wait();
}
}
public void awake() {
synchronized (this.lock) {
this.lock.notifyAll();
}
}
public DomainObject setProperty(String name, Object value) {
getProperties().put(name, value);
return this;
}
public DomainObject removeProperty(String name) {
getProperties().remove(name);
return this;
}
public Object getProperty(String name) {
return getProperties().get(name);
}
@SuppressWarnings("unchecked")
public <T> T getPropertyGeneric(String name) {
return (T) getProperty(name);
}
public <T> T getProperty(String name, Class<T> clz) {
return getProperty(name, clz, null);
}
public <T> T getProperty(String name, Class<T> clz, T defaultValue) {
Object value = getProperty(name);
return value == null ? defaultValue : clz.cast(value);
}
public Map<String, Object> getProperties() {
if (this.properties == null) {
this.properties = new HashMap<>();
}
return this.properties;
}
}

View File

@@ -0,0 +1,16 @@
package com.xmzs.midjourney.dto;
import io.swagger.annotations.ApiModelProperty;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
public abstract class BaseSubmitDTO {
@ApiModelProperty("自定义参数")
protected String state;
@ApiModelProperty("回调地址, 为空时使用全局notifyHook")
protected String notifyHook;
}

View File

@@ -0,0 +1,21 @@
package com.xmzs.midjourney.dto;
import com.xmzs.midjourney.enums.BlendDimensions;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.util.List;
@Data
@ApiModel("Blend提交参数")
@EqualsAndHashCode(callSuper = true)
public class SubmitBlendDTO extends BaseSubmitDTO {
@ApiModelProperty(value = "图片base64数组", required = true, example = "[\"data:image/png;base64,xxx1\", \"data:image/png;base64,xxx2\"]")
private List<String> base64Array;
@ApiModelProperty(value = "比例: PORTRAIT(2:3); SQUARE(1:1); LANDSCAPE(3:2)", example = "SQUARE")
private BlendDimensions dimensions = BlendDimensions.SQUARE;
}

View File

@@ -0,0 +1,25 @@
package com.xmzs.midjourney.dto;
import com.xmzs.midjourney.enums.TaskAction;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@ApiModel("变化任务提交参数")
@EqualsAndHashCode(callSuper = true)
public class SubmitChangeDTO extends BaseSubmitDTO {
@ApiModelProperty(value = "任务ID", required = true, example = "\"1320098173412546\"")
private String taskId;
@ApiModelProperty(value = "UPSCALE(放大); VARIATION(变换); REROLL(重新生成)", required = true,
allowableValues = "UPSCALE, VARIATION, REROLL", example = "UPSCALE")
private TaskAction action;
@ApiModelProperty(value = "序号(1~4), action为UPSCALE,VARIATION时必传", allowableValues = "range[1, 4]", example = "1")
private Integer index;
}

View File

@@ -0,0 +1,15 @@
package com.xmzs.midjourney.dto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@ApiModel("Describe提交参数")
@EqualsAndHashCode(callSuper = true)
public class SubmitDescribeDTO extends BaseSubmitDTO {
@ApiModelProperty(value = "图片base64", required = true, example = "data:image/png;base64,xxx")
private String base64;
}

View File

@@ -0,0 +1,26 @@
package com.xmzs.midjourney.dto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.util.List;
@Data
@ApiModel("Imagine提交参数")
@EqualsAndHashCode(callSuper = true)
public class SubmitImagineDTO extends BaseSubmitDTO {
@ApiModelProperty(value = "提示词", required = true, example = "Cat")
private String prompt;
@ApiModelProperty(value = "垫图base64数组")
private List<String> base64Array;
@ApiModelProperty(hidden = true)
@Deprecated(since = "3.0", forRemoval = true)
private String base64;
}

View File

@@ -0,0 +1,17 @@
package com.xmzs.midjourney.dto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@ApiModel("变化任务提交参数-simple")
@EqualsAndHashCode(callSuper = true)
public class SubmitSimpleChangeDTO extends BaseSubmitDTO {
@ApiModelProperty(value = "变化描述: ID $action$index", required = true, example = "1320098173412546 U2")
private String content;
}

View File

@@ -0,0 +1,14 @@
package com.xmzs.midjourney.dto;
import io.swagger.annotations.ApiModel;
import lombok.Data;
import java.util.List;
@Data
@ApiModel("任务查询参数")
public class TaskConditionDTO {
private List<String> ids;
}

View File

@@ -0,0 +1,21 @@
package com.xmzs.midjourney.enums;
public enum BlendDimensions {
PORTRAIT("2:3"),
SQUARE("1:1"),
LANDSCAPE("3:2");
private final String value;
BlendDimensions(String value) {
this.value = value;
}
public String getValue() {
return this.value;
}
}

View File

@@ -0,0 +1,26 @@
package com.xmzs.midjourney.enums;
public enum MessageType {
/**
* 创建.
*/
CREATE,
/**
* 修改.
*/
UPDATE,
/**
* 删除.
*/
DELETE;
public static MessageType of(String type) {
return switch (type) {
case "MESSAGE_CREATE" -> CREATE;
case "MESSAGE_UPDATE" -> UPDATE;
case "MESSAGE_DELETE" -> DELETE;
default -> null;
};
}
}

View File

@@ -0,0 +1,30 @@
package com.xmzs.midjourney.enums;
public enum TaskAction {
/**
* 生成图片.
*/
IMAGINE,
/**
* 选中放大.
*/
UPSCALE,
/**
* 选中其中的一张图,生成四张相似的.
*/
VARIATION,
/**
* 重新执行.
*/
REROLL,
/**
* 图转prompt.
*/
DESCRIBE,
/**
* 多图混合.
*/
BLEND
}

View File

@@ -0,0 +1,26 @@
package com.xmzs.midjourney.enums;
public enum TaskStatus {
/**
* 未启动.
*/
NOT_START,
/**
* 已提交.
*/
SUBMITTED,
/**
* 执行中.
*/
IN_PROGRESS,
/**
* 失败.
*/
FAILURE,
/**
* 成功.
*/
SUCCESS
}

View File

@@ -0,0 +1,18 @@
package com.xmzs.midjourney.enums;
public enum TranslateWay {
/**
* 百度翻译.
*/
BAIDU,
/**
* GPT翻译.
*/
GPT,
/**
* 不翻译.
*/
NULL
}

View File

@@ -0,0 +1,8 @@
package com.xmzs.midjourney.exception;
public class BannedPromptException extends Exception {
public BannedPromptException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,16 @@
package com.xmzs.midjourney.exception;
public class SnowFlakeException extends RuntimeException {
public SnowFlakeException(String message) {
super(message);
}
public SnowFlakeException(String message, Throwable cause) {
super(message, cause);
}
public SnowFlakeException(Throwable cause) {
super(cause);
}
}

View File

@@ -0,0 +1,33 @@
package com.xmzs.midjourney.loadbalancer;
import com.xmzs.midjourney.domain.DiscordAccount;
import com.xmzs.midjourney.result.Message;
import com.xmzs.midjourney.result.SubmitResultVO;
import com.xmzs.midjourney.service.DiscordService;
import com.xmzs.midjourney.support.Task;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
public interface DiscordInstance extends DiscordService {
String getInstanceId();
DiscordAccount account();
boolean isAlive();
void startWss() throws Exception;
List<Task> getRunningTasks();
void exitTask(Task task);
Map<String, Future<?>> getRunningFutures();
SubmitResultVO submitTask(Task task, Callable<Message<Void>> discordSubmit);
}

View File

@@ -0,0 +1,205 @@
package com.xmzs.midjourney.loadbalancer;
import com.xmzs.midjourney.Constants;
import com.xmzs.midjourney.ReturnCode;
import com.xmzs.midjourney.domain.DiscordAccount;
import com.xmzs.midjourney.enums.BlendDimensions;
import com.xmzs.midjourney.enums.TaskStatus;
import com.xmzs.midjourney.result.Message;
import com.xmzs.midjourney.result.SubmitResultVO;
import com.xmzs.midjourney.service.DiscordService;
import com.xmzs.midjourney.service.DiscordServiceImpl;
import com.xmzs.midjourney.service.NotifyService;
import com.xmzs.midjourney.service.TaskStoreService;
import com.xmzs.midjourney.support.Task;
import com.xmzs.midjourney.wss.WebSocketStarter;
import com.xmzs.midjourney.wss.user.UserWebSocketStarter;
import eu.maxschuster.dataurl.DataUrl;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.web.client.RestTemplate;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
@Slf4j
public class DiscordInstanceImpl implements DiscordInstance {
private final DiscordAccount account;
private final WebSocketStarter socketStarter;
private final DiscordService service;
private final TaskStoreService taskStoreService;
private final NotifyService notifyService;
private final ThreadPoolTaskExecutor taskExecutor;
private final List<Task> runningTasks;
private final Map<String, Future<?>> taskFutureMap = Collections.synchronizedMap(new HashMap<>());
public DiscordInstanceImpl(DiscordAccount account, UserWebSocketStarter socketStarter, RestTemplate restTemplate,
TaskStoreService taskStoreService, NotifyService notifyService, Map<String, String> paramsMap) {
this.account = account;
this.socketStarter = socketStarter;
this.taskStoreService = taskStoreService;
this.notifyService = notifyService;
this.service = new DiscordServiceImpl(account, restTemplate, paramsMap);
this.runningTasks = new CopyOnWriteArrayList<>();
this.taskExecutor = new ThreadPoolTaskExecutor();
this.taskExecutor.setCorePoolSize(account.getCoreSize());
this.taskExecutor.setMaxPoolSize(account.getCoreSize());
this.taskExecutor.setQueueCapacity(account.getQueueSize());
this.taskExecutor.setThreadNamePrefix("TaskQueue-" + account.getDisplay() + "-");
this.taskExecutor.initialize();
}
@Override
public String getInstanceId() {
return this.account.getChannelId();
}
@Override
public DiscordAccount account() {
return this.account;
}
@Override
public boolean isAlive() {
return this.account.isEnable();
}
@Override
public void startWss() throws Exception {
this.socketStarter.setTrying(true);
this.socketStarter.start();
}
@Override
public List<Task> getRunningTasks() {
return this.runningTasks;
}
@Override
public void exitTask(Task task) {
try {
Future<?> future = this.taskFutureMap.get(task.getId());
if (future != null) {
future.cancel(true);
}
saveAndNotify(task);
} finally {
this.runningTasks.remove(task);
this.taskFutureMap.remove(task.getId());
}
}
@Override
public Map<String, Future<?>> getRunningFutures() {
return this.taskFutureMap;
}
@Override
public synchronized SubmitResultVO submitTask(Task task, Callable<Message<Void>> discordSubmit) {
this.taskStoreService.save(task);
int currentWaitNumbers;
try {
currentWaitNumbers = this.taskExecutor.getThreadPoolExecutor().getQueue().size();
Future<?> future = this.taskExecutor.submit(() -> executeTask(task, discordSubmit));
this.taskFutureMap.put(task.getId(), future);
} catch (RejectedExecutionException e) {
this.taskStoreService.delete(task.getId());
return SubmitResultVO.fail(ReturnCode.QUEUE_REJECTED, "队列已满,请稍后尝试")
.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId());
} catch (Exception e) {
log.error("submit task error", e);
return SubmitResultVO.fail(ReturnCode.FAILURE, "提交失败,系统异常")
.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId());
}
if (currentWaitNumbers == 0) {
return SubmitResultVO.of(ReturnCode.SUCCESS, "提交成功", task.getId())
.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId());
} else {
return SubmitResultVO.of(ReturnCode.IN_QUEUE, "排队中,前面还有" + currentWaitNumbers + "个任务", task.getId())
.setProperty("numberOfQueues", currentWaitNumbers)
.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, this.getInstanceId());
}
}
private void executeTask(Task task, Callable<Message<Void>> discordSubmit) {
this.runningTasks.add(task);
try {
task.start();
Message<Void> result = discordSubmit.call();
if (result.getCode() != ReturnCode.SUCCESS) {
task.fail(result.getDescription());
saveAndNotify(task);
return;
}
saveAndNotify(task);
do {
task.sleep();
saveAndNotify(task);
} while (task.getStatus() == TaskStatus.IN_PROGRESS);
log.debug("task finished, id: {}, status: {}", task.getId(), task.getStatus());
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (Exception e) {
log.error("task execute error", e);
task.fail("执行错误,系统异常");
saveAndNotify(task);
} finally {
this.runningTasks.remove(task);
this.taskFutureMap.remove(task.getId());
}
}
private void saveAndNotify(Task task) {
this.taskStoreService.save(task);
this.notifyService.notifyTaskChange(task);
}
@Override
public Message<Void> imagine(String prompt, String nonce) {
return this.service.imagine(prompt, nonce);
}
@Override
public Message<Void> upscale(String messageId, int index, String messageHash, int messageFlags, String nonce) {
return this.service.upscale(messageId, index, messageHash, messageFlags, nonce);
}
@Override
public Message<Void> variation(String messageId, int index, String messageHash, int messageFlags, String nonce) {
return this.service.variation(messageId, index, messageHash, messageFlags, nonce);
}
@Override
public Message<Void> reroll(String messageId, String messageHash, int messageFlags, String nonce) {
return this.service.reroll(messageId, messageHash, messageFlags, nonce);
}
@Override
public Message<Void> describe(String finalFileName, String nonce) {
return this.service.describe(finalFileName, nonce);
}
@Override
public Message<Void> blend(List<String> finalFileNames, BlendDimensions dimensions, String nonce) {
return this.service.blend(finalFileNames, dimensions, nonce);
}
@Override
public Message<String> upload(String fileName, DataUrl dataUrl) {
return this.service.upload(fileName, dataUrl);
}
@Override
public Message<String> sendImageMessage(String content, String finalFileName) {
return this.service.sendImageMessage(content, finalFileName);
}
}

View File

@@ -0,0 +1,83 @@
package com.xmzs.midjourney.loadbalancer;
import cn.hutool.core.text.CharSequenceUtil;
import com.xmzs.midjourney.loadbalancer.rule.IRule;
import com.xmzs.midjourney.support.Task;
import com.xmzs.midjourney.support.TaskCondition;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
@Component
@RequiredArgsConstructor
public class DiscordLoadBalancer {
private final IRule rule;
private final List<DiscordInstance> instances = Collections.synchronizedList(new ArrayList<>());
public List<DiscordInstance> getAllInstances() {
return this.instances;
}
public List<DiscordInstance> getAliveInstances() {
return this.instances.stream().filter(DiscordInstance::isAlive).toList();
}
public DiscordInstance chooseInstance() {
return this.rule.choose(getAliveInstances());
}
public DiscordInstance getDiscordInstance(String instanceId) {
if (CharSequenceUtil.isBlank(instanceId)) {
return null;
}
return this.instances.stream()
.filter(instance -> CharSequenceUtil.equals(instanceId, instance.getInstanceId()))
.findFirst().orElse(null);
}
public Set<String> getQueueTaskIds() {
Set<String> taskIds = Collections.synchronizedSet(new HashSet<>());
for (DiscordInstance instance : getAliveInstances()) {
taskIds.addAll(instance.getRunningFutures().keySet());
}
return taskIds;
}
public Stream<Task> findRunningTask(TaskCondition condition) {
return getAliveInstances().stream().flatMap(instance -> instance.getRunningTasks().stream().filter(condition));
}
public Task getRunningTask(String id) {
for (DiscordInstance instance : getAliveInstances()) {
Optional<Task> optional = instance.getRunningTasks().stream().filter(t -> id.equals(t.getId())).findFirst();
if (optional.isPresent()) {
return optional.get();
}
}
return null;
}
public Task getRunningTaskByNonce(String nonce) {
if (CharSequenceUtil.isBlank(nonce)) {
return null;
}
TaskCondition condition = new TaskCondition().setNonce(nonce);
for (DiscordInstance instance : getAliveInstances()) {
Optional<Task> optional = instance.getRunningTasks().stream().filter(condition).findFirst();
if (optional.isPresent()) {
return optional.get();
}
}
return null;
}
}

View File

@@ -0,0 +1,31 @@
package com.xmzs.midjourney.loadbalancer.rule;
import cn.hutool.core.util.RandomUtil;
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* 最少等待空闲.
* 选择等待数最少的实例,如果都不需要等待,则随机选择
*/
public class BestWaitIdleRule implements IRule {
@Override
public DiscordInstance choose(List<DiscordInstance> instances) {
if (instances.isEmpty()) {
return null;
}
Map<Integer, List<DiscordInstance>> map = instances.stream()
.collect(Collectors.groupingBy(i -> {
int wait = i.getRunningFutures().size() - i.account().getCoreSize();
return wait >= 0 ? wait : -1;
}));
List<DiscordInstance> instanceList = map.entrySet().stream().min(Comparator.comparingInt(Map.Entry::getKey)).orElseThrow().getValue();
return RandomUtil.randomEle(instanceList);
}
}

View File

@@ -0,0 +1,10 @@
package com.xmzs.midjourney.loadbalancer.rule;
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
import java.util.List;
public interface IRule {
DiscordInstance choose(List<DiscordInstance> instances);
}

View File

@@ -0,0 +1,32 @@
package com.xmzs.midjourney.loadbalancer.rule;
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
/**
* 轮询.
*/
public class RoundRobinRule implements IRule {
private final AtomicInteger position = new AtomicInteger(0);
@Override
public DiscordInstance choose(List<DiscordInstance> instances) {
if (instances.isEmpty()) {
return null;
}
int pos = incrementAndGet();
return instances.get(pos % instances.size());
}
private int incrementAndGet() {
int current;
int next;
do {
current = this.position.get();
next = current == Integer.MAX_VALUE ? 0 : current + 1;
} while (!this.position.compareAndSet(current, next));
return next;
}
}

View File

@@ -0,0 +1,57 @@
package com.xmzs.midjourney.result;
import com.xmzs.midjourney.ReturnCode;
import lombok.Getter;
@Getter
public class Message<T> {
private final int code;
private final String description;
private final T result;
public static <Y> Message<Y> success() {
return new Message<>(ReturnCode.SUCCESS, "成功");
}
public static <T> Message<T> success(T result) {
return new Message<>(ReturnCode.SUCCESS, "成功", result);
}
public static <T> Message<T> success(int code, String description, T result) {
return new Message<>(code, description, result);
}
public static <Y> Message<Y> notFound() {
return new Message<>(ReturnCode.NOT_FOUND, "数据未找到");
}
public static <Y> Message<Y> validationError() {
return new Message<>(ReturnCode.VALIDATION_ERROR, "校验错误");
}
public static <Y> Message<Y> failure() {
return new Message<>(ReturnCode.FAILURE, "系统异常");
}
public static <Y> Message<Y> failure(String description) {
return new Message<>(ReturnCode.FAILURE, description);
}
public static <Y> Message<Y> of(int code, String description) {
return new Message<>(code, description);
}
public static <T> Message<T> of(int code, String description, T result) {
return new Message<>(code, description, result);
}
private Message(int code, String description) {
this(code, description, null);
}
private Message(int code, String description, T result) {
this.code = code;
this.description = description;
this.result = result;
}
}

View File

@@ -0,0 +1,62 @@
package com.xmzs.midjourney.result;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import java.util.HashMap;
import java.util.Map;
@Data
@ApiModel("提交结果")
public class SubmitResultVO {
@ApiModelProperty(value = "状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)", required = true, example = "1")
private int code;
@ApiModelProperty(value = "描述", required = true, example = "提交成功")
private String description;
@ApiModelProperty(value = "任务ID", example = "1320098173412546")
private String result;
@ApiModelProperty(value = "扩展字段")
private Map<String, Object> properties = new HashMap<>();
public SubmitResultVO setProperty(String name, Object value) {
this.properties.put(name, value);
return this;
}
public SubmitResultVO removeProperty(String name) {
this.properties.remove(name);
return this;
}
public Object getProperty(String name) {
return this.properties.get(name);
}
@SuppressWarnings("unchecked")
public <T> T getPropertyGeneric(String name) {
return (T) getProperty(name);
}
public <T> T getProperty(String name, Class<T> clz) {
return clz.cast(getProperty(name));
}
public static SubmitResultVO of(int code, String description, String result) {
return new SubmitResultVO(code, description, result);
}
public static SubmitResultVO fail(int code, String description) {
return new SubmitResultVO(code, description, null);
}
private SubmitResultVO(int code, String description, String result) {
this.code = code;
this.description = description;
this.result = result;
}
}

View File

@@ -0,0 +1,28 @@
package com.xmzs.midjourney.service;
import com.xmzs.midjourney.enums.BlendDimensions;
import com.xmzs.midjourney.result.Message;
import eu.maxschuster.dataurl.DataUrl;
import java.util.List;
public interface DiscordService {
Message<Void> imagine(String prompt, String nonce);
Message<Void> upscale(String messageId, int index, String messageHash, int messageFlags, String nonce);
Message<Void> variation(String messageId, int index, String messageHash, int messageFlags, String nonce);
Message<Void> reroll(String messageId, String messageHash, int messageFlags, String nonce);
Message<Void> describe(String finalFileName, String nonce);
Message<Void> blend(List<String> finalFileNames, BlendDimensions dimensions, String nonce);
Message<String> upload(String fileName, DataUrl dataUrl);
Message<String> sendImageMessage(String content, String finalFileName);
}

View File

@@ -0,0 +1,219 @@
package com.xmzs.midjourney.service;
import cn.hutool.core.text.CharSequenceUtil;
import com.xmzs.midjourney.ReturnCode;
import com.xmzs.midjourney.domain.DiscordAccount;
import com.xmzs.midjourney.enums.BlendDimensions;
import com.xmzs.midjourney.result.Message;
import com.xmzs.midjourney.support.DiscordHelper;
import com.xmzs.midjourney.support.SpringContextHolder;
import eu.maxschuster.dataurl.DataUrl;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONArray;
import org.json.JSONObject;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.HttpStatusCodeException;
import org.springframework.web.client.RestTemplate;
import java.util.List;
import java.util.Map;
@Slf4j
public class DiscordServiceImpl implements DiscordService {
private static final String DEFAULT_SESSION_ID = "f1a313a09ce079ce252459dc70231f30";
private final DiscordAccount account;
private final Map<String, String> paramsMap;
private final RestTemplate restTemplate;
private final DiscordHelper discordHelper;
private final String discordInteractionUrl;
private final String discordAttachmentUrl;
private final String discordMessageUrl;
public DiscordServiceImpl(DiscordAccount account, RestTemplate restTemplate, Map<String, String> paramsMap) {
this.account = account;
this.restTemplate = restTemplate;
this.discordHelper = SpringContextHolder.getApplicationContext().getBean(DiscordHelper.class);
this.paramsMap = paramsMap;
String discordServer = this.discordHelper.getServer();
this.discordInteractionUrl = discordServer + "/api/v9/interactions";
this.discordAttachmentUrl = discordServer + "/api/v9/channels/" + account.getChannelId() + "/attachments";
this.discordMessageUrl = discordServer + "/api/v9/channels/" + account.getChannelId() + "/messages";
}
@Override
public Message<Void> imagine(String prompt, String nonce) {
String paramsStr = replaceInteractionParams(this.paramsMap.get("imagine"), nonce);
JSONObject params = new JSONObject(paramsStr);
params.getJSONObject("data").getJSONArray("options").getJSONObject(0)
.put("value", prompt);
return postJsonAndCheckStatus(params.toString());
}
@Override
public Message<Void> upscale(String messageId, int index, String messageHash, int messageFlags, String nonce) {
String paramsStr = replaceInteractionParams(this.paramsMap.get("upscale"), nonce)
.replace("$message_id", messageId)
.replace("$index", String.valueOf(index))
.replace("$message_hash", messageHash);
paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString();
return postJsonAndCheckStatus(paramsStr);
}
@Override
public Message<Void> variation(String messageId, int index, String messageHash, int messageFlags, String nonce) {
String paramsStr = replaceInteractionParams(this.paramsMap.get("variation"), nonce)
.replace("$message_id", messageId)
.replace("$index", String.valueOf(index))
.replace("$message_hash", messageHash);
paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString();
return postJsonAndCheckStatus(paramsStr);
}
@Override
public Message<Void> reroll(String messageId, String messageHash, int messageFlags, String nonce) {
String paramsStr = replaceInteractionParams(this.paramsMap.get("reroll"), nonce)
.replace("$message_id", messageId)
.replace("$message_hash", messageHash);
paramsStr = new JSONObject(paramsStr).put("message_flags", messageFlags).toString();
return postJsonAndCheckStatus(paramsStr);
}
@Override
public Message<Void> describe(String finalFileName, String nonce) {
String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true);
String paramsStr = replaceInteractionParams(this.paramsMap.get("describe"), nonce)
.replace("$file_name", fileName)
.replace("$final_file_name", finalFileName);
return postJsonAndCheckStatus(paramsStr);
}
@Override
public Message<Void> blend(List<String> finalFileNames, BlendDimensions dimensions, String nonce) {
String paramsStr = replaceInteractionParams(this.paramsMap.get("blend"), nonce);
JSONObject params = new JSONObject(paramsStr);
JSONArray options = params.getJSONObject("data").getJSONArray("options");
JSONArray attachments = params.getJSONObject("data").getJSONArray("attachments");
for (int i = 0; i < finalFileNames.size(); i++) {
String finalFileName = finalFileNames.get(i);
String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true);
JSONObject attachment = new JSONObject().put("id", String.valueOf(i))
.put("filename", fileName)
.put("uploaded_filename", finalFileName);
attachments.put(attachment);
JSONObject option = new JSONObject().put("type", 11)
.put("name", "image" + (i + 1))
.put("value", i);
options.put(option);
}
options.put(new JSONObject().put("type", 3)
.put("name", "dimensions")
.put("value", "--ar " + dimensions.getValue()));
return postJsonAndCheckStatus(params.toString());
}
private String replaceInteractionParams(String paramsStr, String nonce) {
return paramsStr.replace("$guild_id", this.account.getGuildId())
.replace("$channel_id", this.account.getChannelId())
.replace("$session_id", DEFAULT_SESSION_ID)
.replace("$nonce", nonce);
}
@Override
public Message<String> upload(String fileName, DataUrl dataUrl) {
try {
JSONObject fileObj = new JSONObject();
fileObj.put("filename", fileName);
fileObj.put("file_size", dataUrl.getData().length);
fileObj.put("id", "0");
JSONObject params = new JSONObject()
.put("files", new JSONArray().put(fileObj));
ResponseEntity<String> responseEntity = postJson(this.discordAttachmentUrl, params.toString());
if (responseEntity.getStatusCode() != HttpStatus.OK) {
log.error("上传图片到discord失败, status: {}, msg: {}", responseEntity.getStatusCodeValue(), responseEntity.getBody());
return Message.of(ReturnCode.VALIDATION_ERROR, "上传图片到discord失败");
}
JSONArray array = new JSONObject(responseEntity.getBody()).getJSONArray("attachments");
if (array.length() == 0) {
return Message.of(ReturnCode.VALIDATION_ERROR, "上传图片到discord失败");
}
String uploadUrl = array.getJSONObject(0).getString("upload_url");
String uploadFilename = array.getJSONObject(0).getString("upload_filename");
putFile(uploadUrl, dataUrl);
return Message.success(uploadFilename);
} catch (Exception e) {
log.error("上传图片到discord失败", e);
return Message.of(ReturnCode.FAILURE, "上传图片到discord失败");
}
}
@Override
public Message<String> sendImageMessage(String content, String finalFileName) {
String fileName = CharSequenceUtil.subAfter(finalFileName, "/", true);
String paramsStr = this.paramsMap.get("message").replace("$content", content)
.replace("$channel_id", this.account.getChannelId())
.replace("$file_name", fileName)
.replace("$final_file_name", finalFileName);
ResponseEntity<String> responseEntity = postJson(this.discordMessageUrl, paramsStr);
if (responseEntity.getStatusCode() != HttpStatus.OK) {
log.error("发送图片消息到discord失败, status: {}, msg: {}", responseEntity.getStatusCodeValue(), responseEntity.getBody());
return Message.of(ReturnCode.VALIDATION_ERROR, "发送图片消息到discord失败");
}
JSONObject result = new JSONObject(responseEntity.getBody());
JSONArray attachments = result.optJSONArray("attachments");
if (!attachments.isEmpty()) {
return Message.success(attachments.getJSONObject(0).optString("url"));
}
return Message.failure("发送图片消息到discord失败: 图片不存在");
}
private void putFile(String uploadUrl, DataUrl dataUrl) {
uploadUrl = this.discordHelper.getDiscordUploadUrl(uploadUrl);
HttpHeaders headers = new HttpHeaders();
headers.add("User-Agent", this.account.getUserAgent());
headers.setContentType(MediaType.valueOf(dataUrl.getMimeType()));
headers.setContentLength(dataUrl.getData().length);
HttpEntity<byte[]> requestEntity = new HttpEntity<>(dataUrl.getData(), headers);
this.restTemplate.put(uploadUrl, requestEntity);
}
private ResponseEntity<String> postJson(String paramsStr) {
return postJson(this.discordInteractionUrl, paramsStr);
}
private ResponseEntity<String> postJson(String url, String paramsStr) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.set("Authorization", this.account.getUserToken());
headers.set("User-Agent", this.account.getUserAgent());
HttpEntity<String> httpEntity = new HttpEntity<>(paramsStr, headers);
return this.restTemplate.postForEntity(url, httpEntity, String.class);
}
private Message<Void> postJsonAndCheckStatus(String paramsStr) {
try {
ResponseEntity<String> responseEntity = postJson(paramsStr);
if (responseEntity.getStatusCode() == HttpStatus.NO_CONTENT) {
return Message.success();
}
return Message.of(responseEntity.getStatusCodeValue(), CharSequenceUtil.sub(responseEntity.getBody(), 0, 100));
} catch (HttpStatusCodeException e) {
return convertHttpStatusCodeException(e);
}
}
private Message<Void> convertHttpStatusCodeException(HttpStatusCodeException e) {
try {
JSONObject error = new JSONObject(e.getResponseBodyAsString());
return Message.of(error.optInt("code", e.getRawStatusCode()), error.optString("message"));
} catch (Exception je) {
return Message.of(e.getRawStatusCode(), CharSequenceUtil.sub(e.getMessage(), 0, 100));
}
}
}

View File

@@ -0,0 +1,10 @@
package com.xmzs.midjourney.service;
import com.xmzs.midjourney.support.Task;
public interface NotifyService {
void notifyTaskChange(Task task);
}

View File

@@ -0,0 +1,76 @@
package com.xmzs.midjourney.service;
import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.TimedCache;
import cn.hutool.core.exceptions.CheckedUtil;
import cn.hutool.core.text.CharSequenceUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.xmzs.midjourney.Constants;
import com.xmzs.midjourney.ProxyProperties;
import com.xmzs.midjourney.enums.TaskStatus;
import com.xmzs.midjourney.support.Task;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import java.time.Duration;
@Slf4j
@Service
public class NotifyServiceImpl implements NotifyService {
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private final ThreadPoolTaskExecutor executor;
private final TimedCache<String, Object> taskLocks = CacheUtil.newTimedCache(Duration.ofHours(1).toMillis());
public NotifyServiceImpl(ProxyProperties properties) {
this.executor = new ThreadPoolTaskExecutor();
this.executor.setCorePoolSize(properties.getNotifyPoolSize());
this.executor.setThreadNamePrefix("TaskNotify-");
this.executor.initialize();
}
@Override
public void notifyTaskChange(Task task) {
String notifyHook = task.getPropertyGeneric(Constants.TASK_PROPERTY_NOTIFY_HOOK);
if (CharSequenceUtil.isBlank(notifyHook)) {
return;
}
String taskId = task.getId();
TaskStatus taskStatus = task.getStatus();
Object taskLock = this.taskLocks.get(taskId, (CheckedUtil.Func0Rt<Object>) Object::new);
try {
String paramsStr = OBJECT_MAPPER.writeValueAsString(task);
this.executor.execute(() -> {
synchronized (taskLock) {
try {
ResponseEntity<String> responseEntity = postJson(notifyHook, paramsStr);
if (responseEntity.getStatusCode() == HttpStatus.OK) {
log.debug("推送任务变更成功, 任务ID: {}, status: {}, notifyHook: {}", taskId, taskStatus, notifyHook);
} else {
log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, code: {}, msg: {}", taskId, notifyHook, responseEntity.getStatusCodeValue(), responseEntity.getBody());
}
} catch (Exception e) {
log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, 描述: {}", taskId, notifyHook, e.getMessage());
}
}
});
} catch (JsonProcessingException e) {
log.warn("推送任务变更失败, 任务ID: {}, notifyHook: {}, 描述: {}", taskId, notifyHook, e.getMessage());
}
}
private ResponseEntity<String> postJson(String notifyHook, String paramsJson) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> httpEntity = new HttpEntity<>(paramsJson, headers);
return new RestTemplate().postForEntity(notifyHook, httpEntity, String.class);
}
}

View File

@@ -0,0 +1,23 @@
package com.xmzs.midjourney.service;
import com.xmzs.midjourney.enums.BlendDimensions;
import com.xmzs.midjourney.result.SubmitResultVO;
import com.xmzs.midjourney.support.Task;
import eu.maxschuster.dataurl.DataUrl;
import java.util.List;
public interface TaskService {
SubmitResultVO submitImagine(Task task, List<DataUrl> dataUrls);
SubmitResultVO submitUpscale(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags);
SubmitResultVO submitVariation(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags);
SubmitResultVO submitReroll(Task task, String targetMessageId, String targetMessageHash, int messageFlags);
SubmitResultVO submitDescribe(Task task, DataUrl dataUrl);
SubmitResultVO submitBlend(Task task, List<DataUrl> dataUrls, BlendDimensions dimensions);
}

View File

@@ -0,0 +1,128 @@
package com.xmzs.midjourney.service;
import com.xmzs.midjourney.Constants;
import com.xmzs.midjourney.ReturnCode;
import com.xmzs.midjourney.enums.BlendDimensions;
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
import com.xmzs.midjourney.result.Message;
import com.xmzs.midjourney.result.SubmitResultVO;
import com.xmzs.midjourney.support.Task;
import com.xmzs.midjourney.util.MimeTypeUtils;
import eu.maxschuster.dataurl.DataUrl;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
public class TaskServiceImpl implements TaskService {
private final TaskStoreService taskStoreService;
private final DiscordLoadBalancer discordLoadBalancer;
@Override
public SubmitResultVO submitImagine(Task task, List<DataUrl> dataUrls) {
DiscordInstance instance = this.discordLoadBalancer.chooseInstance();
if (instance == null) {
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "无可用的账号实例");
}
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, instance.getInstanceId());
return instance.submitTask(task, () -> {
List<String> imageUrls = new ArrayList<>();
for (DataUrl dataUrl : dataUrls) {
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
Message<String> uploadResult = instance.upload(taskFileName, dataUrl);
if (uploadResult.getCode() != ReturnCode.SUCCESS) {
return Message.of(uploadResult.getCode(), uploadResult.getDescription());
}
String finalFileName = uploadResult.getResult();
Message<String> sendImageResult = instance.sendImageMessage("upload image: " + finalFileName, finalFileName);
if (sendImageResult.getCode() != ReturnCode.SUCCESS) {
return Message.of(sendImageResult.getCode(), sendImageResult.getDescription());
}
imageUrls.add(sendImageResult.getResult());
}
if (!imageUrls.isEmpty()) {
task.setPrompt(String.join(" ", imageUrls) + " " + task.getPrompt());
task.setPromptEn(String.join(" ", imageUrls) + " " + task.getPromptEn());
task.setDescription("/imagine " + task.getPrompt());
this.taskStoreService.save(task);
}
return instance.imagine(task.getPromptEn(), task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE));
});
}
@Override
public SubmitResultVO submitUpscale(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags) {
String instanceId = task.getPropertyGeneric(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID);
DiscordInstance discordInstance = this.discordLoadBalancer.getDiscordInstance(instanceId);
if (discordInstance == null || !discordInstance.isAlive()) {
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId);
}
return discordInstance.submitTask(task, () -> discordInstance.upscale(targetMessageId, index, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)));
}
@Override
public SubmitResultVO submitVariation(Task task, String targetMessageId, String targetMessageHash, int index, int messageFlags) {
String instanceId = task.getPropertyGeneric(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID);
DiscordInstance discordInstance = this.discordLoadBalancer.getDiscordInstance(instanceId);
if (discordInstance == null || !discordInstance.isAlive()) {
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId);
}
return discordInstance.submitTask(task, () -> discordInstance.variation(targetMessageId, index, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)));
}
@Override
public SubmitResultVO submitReroll(Task task, String targetMessageId, String targetMessageHash, int messageFlags) {
String instanceId = task.getPropertyGeneric(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID);
DiscordInstance discordInstance = this.discordLoadBalancer.getDiscordInstance(instanceId);
if (discordInstance == null || !discordInstance.isAlive()) {
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "账号不可用: " + instanceId);
}
return discordInstance.submitTask(task, () -> discordInstance.reroll(targetMessageId, targetMessageHash, messageFlags, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE)));
}
@Override
public SubmitResultVO submitDescribe(Task task, DataUrl dataUrl) {
DiscordInstance discordInstance = this.discordLoadBalancer.chooseInstance();
if (discordInstance == null) {
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "无可用的账号实例");
}
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, discordInstance.getInstanceId());
return discordInstance.submitTask(task, () -> {
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
Message<String> uploadResult = discordInstance.upload(taskFileName, dataUrl);
if (uploadResult.getCode() != ReturnCode.SUCCESS) {
return Message.of(uploadResult.getCode(), uploadResult.getDescription());
}
String finalFileName = uploadResult.getResult();
return discordInstance.describe(finalFileName, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE));
});
}
@Override
public SubmitResultVO submitBlend(Task task, List<DataUrl> dataUrls, BlendDimensions dimensions) {
DiscordInstance discordInstance = this.discordLoadBalancer.chooseInstance();
if (discordInstance == null) {
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "无可用的账号实例");
}
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, discordInstance.getInstanceId());
return discordInstance.submitTask(task, () -> {
List<String> finalFileNames = new ArrayList<>();
for (DataUrl dataUrl : dataUrls) {
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
Message<String> uploadResult = discordInstance.upload(taskFileName, dataUrl);
if (uploadResult.getCode() != ReturnCode.SUCCESS) {
return Message.of(uploadResult.getCode(), uploadResult.getDescription());
}
finalFileNames.add(uploadResult.getResult());
}
return discordInstance.blend(finalFileNames, dimensions, task.getPropertyGeneric(Constants.TASK_PROPERTY_NONCE));
});
}
}

View File

@@ -0,0 +1,23 @@
package com.xmzs.midjourney.service;
import com.xmzs.midjourney.support.Task;
import com.xmzs.midjourney.support.TaskCondition;
import java.util.List;
public interface TaskStoreService {
void save(Task task);
void delete(String id);
Task get(String id);
List<Task> list();
List<Task> list(TaskCondition condition);
Task findOne(TaskCondition condition);
}

View File

@@ -0,0 +1,13 @@
package com.xmzs.midjourney.service;
import java.util.regex.Pattern;
public interface TranslateService {
String translateToEnglish(String prompt);
default boolean containsChinese(String prompt) {
return Pattern.compile("[\u4e00-\u9fa5]").matcher(prompt).find();
}
}

View File

@@ -0,0 +1,52 @@
package com.xmzs.midjourney.service.store;
import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.TimedCache;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.stream.StreamUtil;
import com.xmzs.midjourney.service.TaskStoreService;
import com.xmzs.midjourney.support.Task;
import com.xmzs.midjourney.support.TaskCondition;
import java.time.Duration;
import java.util.List;
public class InMemoryTaskStoreServiceImpl implements TaskStoreService {
private final TimedCache<String, Task> taskMap;
public InMemoryTaskStoreServiceImpl(Duration timeout) {
this.taskMap = CacheUtil.newTimedCache(timeout.toMillis());
}
@Override
public void save(Task task) {
this.taskMap.put(task.getId(), task);
}
@Override
public void delete(String key) {
this.taskMap.remove(key);
}
@Override
public Task get(String key) {
return this.taskMap.get(key);
}
@Override
public List<Task> list() {
return ListUtil.toList(this.taskMap.iterator());
}
@Override
public List<Task> list(TaskCondition condition) {
return StreamUtil.of(this.taskMap.iterator()).filter(condition).toList();
}
@Override
public Task findOne(TaskCondition condition) {
return StreamUtil.of(this.taskMap.iterator()).filter(condition).findFirst().orElse(null);
}
}

View File

@@ -0,0 +1,74 @@
package com.xmzs.midjourney.service.store;
import com.xmzs.midjourney.service.TaskStoreService;
import com.xmzs.midjourney.support.Task;
import com.xmzs.midjourney.support.TaskCondition;
import org.springframework.data.redis.core.Cursor;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ScanOptions;
import org.springframework.data.redis.core.ValueOperations;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
public class RedisTaskStoreServiceImpl implements TaskStoreService {
private static final String KEY_PREFIX = "mj-task-store::";
private final Duration timeout;
private final RedisTemplate<String, Task> redisTemplate;
public RedisTaskStoreServiceImpl(Duration timeout, RedisTemplate<String, Task> redisTemplate) {
this.timeout = timeout;
this.redisTemplate = redisTemplate;
}
@Override
public void save(Task task) {
this.redisTemplate.opsForValue().set(getRedisKey(task.getId()), task, this.timeout);
}
@Override
public void delete(String id) {
this.redisTemplate.delete(getRedisKey(id));
}
@Override
public Task get(String id) {
return this.redisTemplate.opsForValue().get(getRedisKey(id));
}
@Override
public List<Task> list() {
Set<String> keys = this.redisTemplate.execute((RedisCallback<Set<String>>) connection -> {
Cursor<byte[]> cursor = connection.scan(ScanOptions.scanOptions().match(KEY_PREFIX + "*").count(1000).build());
return cursor.stream().map(String::new).collect(Collectors.toSet());
});
if (keys == null || keys.isEmpty()) {
return Collections.emptyList();
}
ValueOperations<String, Task> operations = this.redisTemplate.opsForValue();
return keys.stream().map(operations::get)
.filter(Objects::nonNull)
.toList();
}
@Override
public List<Task> list(TaskCondition condition) {
return list().stream().filter(condition).toList();
}
@Override
public Task findOne(TaskCondition condition) {
return list().stream().filter(condition).findFirst().orElse(null);
}
private String getRedisKey(String id) {
return KEY_PREFIX + id;
}
}

View File

@@ -0,0 +1,80 @@
package com.xmzs.midjourney.service.translate;
import cn.hutool.core.exceptions.ValidateException;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.crypto.digest.MD5;
import com.xmzs.midjourney.ProxyProperties;
import com.xmzs.midjourney.service.TranslateService;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONArray;
import org.json.JSONObject;
import org.springframework.beans.factory.support.BeanDefinitionValidationException;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
import java.util.ArrayList;
import java.util.List;
@Slf4j
public class BaiduTranslateServiceImpl implements TranslateService {
private static final String TRANSLATE_API = "https://fanyi-api.baidu.com/api/trans/vip/translate";
private final String appid;
private final String appSecret;
public BaiduTranslateServiceImpl(ProxyProperties.BaiduTranslateConfig translateConfig) {
this.appid = translateConfig.getAppid();
this.appSecret = translateConfig.getAppSecret();
if (!CharSequenceUtil.isAllNotBlank(this.appid, this.appSecret)) {
throw new BeanDefinitionValidationException("mj.baidu-translate.appid或mj.baidu-translate.app-secret未配置");
}
}
@Override
public String translateToEnglish(String prompt) {
if (!containsChinese(prompt)) {
return prompt;
}
String salt = RandomUtil.randomNumbers(5);
String sign = MD5.create().digestHex(this.appid + prompt + salt + this.appSecret);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
MultiValueMap<String, String> body = new LinkedMultiValueMap<>();
body.add("from", "zh");
body.add("to", "en");
body.add("appid", this.appid);
body.add("salt", salt);
body.add("q", prompt);
body.add("sign", sign);
HttpEntity<MultiValueMap<String, String>> requestEntity = new HttpEntity<>(body, headers);
try {
ResponseEntity<String> responseEntity = new RestTemplate().exchange(TRANSLATE_API, HttpMethod.POST, requestEntity, String.class);
if (responseEntity.getStatusCode() != HttpStatus.OK || CharSequenceUtil.isBlank(responseEntity.getBody())) {
throw new ValidateException(responseEntity.getStatusCodeValue() + " - " + responseEntity.getBody());
}
JSONObject result = new JSONObject(responseEntity.getBody());
if (result.has("error_code")) {
throw new ValidateException(result.getString("error_code") + " - " + result.getString("error_msg"));
}
List<String> strings = new ArrayList<>();
JSONArray transResult = result.getJSONArray("trans_result");
for (int i = 0; i < transResult.length(); i++) {
strings.add(transResult.getJSONObject(i).getString("dst"));
}
return CharSequenceUtil.join("\n", strings);
} catch (Exception e) {
log.warn("调用百度翻译失败: {}", e.getMessage());
}
return prompt;
}
}

View File

@@ -0,0 +1,83 @@
package com.xmzs.midjourney.service.translate;
import cn.hutool.core.text.CharSequenceUtil;
import com.xmzs.midjourney.ProxyProperties;
import com.xmzs.midjourney.service.TranslateService;
import com.unfbx.chatgpt.OpenAiClient;
import com.unfbx.chatgpt.entity.chat.ChatChoice;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.function.KeyRandomStrategy;
import com.unfbx.chatgpt.interceptor.OpenAILogger;
import com.unfbx.chatgpt.interceptor.OpenAiResponseInterceptor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import org.springframework.beans.factory.support.BeanDefinitionValidationException;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
@Slf4j
public class GPTTranslateServiceImpl implements TranslateService {
private final OpenAiClient openAiClient;
private final ProxyProperties.OpenaiConfig openaiConfig;
public GPTTranslateServiceImpl(ProxyProperties properties) {
this.openaiConfig = properties.getOpenai();
if (CharSequenceUtil.isBlank(this.openaiConfig.getGptApiKey())) {
throw new BeanDefinitionValidationException("mj.openai.gpt-api-key未配置");
}
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
OkHttpClient.Builder okHttpBuilder = new OkHttpClient.Builder()
.addInterceptor(httpLoggingInterceptor)
.addInterceptor(new OpenAiResponseInterceptor())
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(30, TimeUnit.SECONDS)
.readTimeout(30, TimeUnit.SECONDS);
if (CharSequenceUtil.isNotBlank(properties.getProxy().getHost())) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(properties.getProxy().getHost(), properties.getProxy().getPort()));
okHttpBuilder.proxy(proxy);
}
OpenAiClient.Builder apiBuilder = OpenAiClient.builder()
.apiKey(Collections.singletonList(this.openaiConfig.getGptApiKey()))
.keyStrategy(new KeyRandomStrategy())
.okHttpClient(okHttpBuilder.build());
if (CharSequenceUtil.isNotBlank(this.openaiConfig.getGptApiUrl())) {
apiBuilder.apiHost(this.openaiConfig.getGptApiUrl());
}
this.openAiClient = apiBuilder.build();
}
@Override
public String translateToEnglish(String prompt) {
if (!containsChinese(prompt)) {
return prompt;
}
Message m1 = Message.builder().role(Message.Role.SYSTEM).content("把中文翻译成英文").build();
Message m2 = Message.builder().role(Message.Role.USER).content(prompt).build();
ChatCompletion chatCompletion = ChatCompletion.builder()
.messages(Arrays.asList(m1, m2))
.model(this.openaiConfig.getModel())
.temperature(this.openaiConfig.getTemperature())
.maxTokens(this.openaiConfig.getMaxTokens())
.build();
ChatCompletionResponse chatCompletionResponse = this.openAiClient.chatCompletion(chatCompletion);
try {
List<ChatChoice> choices = chatCompletionResponse.getChoices();
if (!choices.isEmpty()) {
return choices.get(0).getMessage().getContent();
}
} catch (Exception e) {
log.warn("调用chat-gpt接口翻译中文失败: {}", e.getMessage());
}
return prompt;
}
}

View File

@@ -0,0 +1,14 @@
package com.xmzs.midjourney.service.translate;
import com.xmzs.midjourney.service.TranslateService;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class NoTranslateServiceImpl implements TranslateService {
@Override
public String translateToEnglish(String prompt) {
return prompt;
}
}

View File

@@ -0,0 +1,32 @@
package com.xmzs.midjourney.support;
import cn.hutool.core.text.CharSequenceUtil;
import com.xmzs.midjourney.Constants;
import com.xmzs.midjourney.ProxyProperties;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;
@Component
@RequiredArgsConstructor
public class ApiAuthorizeInterceptor implements HandlerInterceptor {
private final ProxyProperties properties;
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
if (CharSequenceUtil.isBlank(this.properties.getApiSecret())) {
return true;
}
String apiSecret = request.getHeader(Constants.API_SECRET_HEADER_NAME);
boolean authorized = CharSequenceUtil.equals(apiSecret, this.properties.getApiSecret());
if (!authorized) {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
return authorized;
}
}

View File

@@ -0,0 +1,43 @@
package com.xmzs.midjourney.support;
import cn.hutool.core.text.CharSequenceUtil;
import com.xmzs.midjourney.Constants;
import com.xmzs.midjourney.ProxyProperties;
import com.xmzs.midjourney.domain.DiscordAccount;
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
import com.xmzs.midjourney.loadbalancer.DiscordInstanceImpl;
import com.xmzs.midjourney.service.NotifyService;
import com.xmzs.midjourney.service.TaskStoreService;
import com.xmzs.midjourney.wss.handle.MessageHandler;
import com.xmzs.midjourney.wss.user.UserMessageListener;
import com.xmzs.midjourney.wss.user.UserWebSocketStarter;
import lombok.RequiredArgsConstructor;
import org.springframework.web.client.RestTemplate;
import java.util.List;
import java.util.Map;
@RequiredArgsConstructor
public class DiscordAccountHelper {
private final DiscordHelper discordHelper;
private final ProxyProperties properties;
private final RestTemplate restTemplate;
private final TaskStoreService taskStoreService;
private final NotifyService notifyService;
private final List<MessageHandler> messageHandlers;
private final Map<String, String> paramsMap;
public DiscordInstance createDiscordInstance(DiscordAccount account) {
if (!CharSequenceUtil.isAllNotBlank(account.getGuildId(), account.getChannelId(), account.getUserToken())) {
throw new IllegalArgumentException("guildId, channelId, userToken must not be blank");
}
if (CharSequenceUtil.isBlank(account.getUserAgent())) {
account.setUserAgent(Constants.DEFAULT_DISCORD_USER_AGENT);
}
var messageListener = new UserMessageListener(account, this.messageHandlers);
var webSocketStarter = new UserWebSocketStarter(this.discordHelper.getWss(), account, messageListener, this.properties.getProxy());
return new DiscordInstanceImpl(account, webSocketStarter, this.restTemplate,
this.taskStoreService, this.notifyService, this.paramsMap);
}
}

View File

@@ -0,0 +1,72 @@
package com.xmzs.midjourney.support;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.exceptions.ValidateException;
import cn.hutool.core.text.CharSequenceUtil;
import com.xmzs.midjourney.ProxyProperties;
import com.xmzs.midjourney.ReturnCode;
import com.xmzs.midjourney.domain.DiscordAccount;
import com.xmzs.midjourney.loadbalancer.DiscordInstance;
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
import com.xmzs.midjourney.util.AsyncLockUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;
import java.time.Duration;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@Component
@RequiredArgsConstructor
public class DiscordAccountInitializer implements ApplicationRunner {
private final DiscordLoadBalancer discordLoadBalancer;
private final DiscordAccountHelper discordAccountHelper;
private final ProxyProperties properties;
@Override
public void run(ApplicationArguments args) throws Exception {
ProxyProperties.ProxyConfig proxy = this.properties.getProxy();
if (Strings.isNotBlank(proxy.getHost())) {
System.setProperty("http.proxyHost", proxy.getHost());
System.setProperty("http.proxyPort", String.valueOf(proxy.getPort()));
System.setProperty("https.proxyHost", proxy.getHost());
System.setProperty("https.proxyPort", String.valueOf(proxy.getPort()));
}
List<ProxyProperties.DiscordAccountConfig> configAccounts = this.properties.getAccounts();
if (CharSequenceUtil.isNotBlank(this.properties.getDiscord().getChannelId())) {
configAccounts.add(this.properties.getDiscord());
}
List<DiscordInstance> instances = this.discordLoadBalancer.getAllInstances();
for (ProxyProperties.DiscordAccountConfig configAccount : configAccounts) {
DiscordAccount account = new DiscordAccount();
BeanUtil.copyProperties(configAccount, account);
account.setId(configAccount.getChannelId());
try {
DiscordInstance instance = this.discordAccountHelper.createDiscordInstance(account);
if (!account.isEnable()) {
continue;
}
instance.startWss();
AsyncLockUtils.LockObject lock = AsyncLockUtils.waitForLock("wss:" + account.getChannelId(), Duration.ofSeconds(10));
if (ReturnCode.SUCCESS != lock.getProperty("code", Integer.class, 0)) {
throw new ValidateException(lock.getProperty("description", String.class));
}
instances.add(instance);
} catch (Exception e) {
log.error("Account({}) init fail, disabled: {}", account.getDisplay(), e.getMessage());
account.setEnable(false);
}
}
Set<String> enableInstanceIds = instances.stream().filter(DiscordInstance::isAlive).map(DiscordInstance::getInstanceId).collect(Collectors.toSet());
log.info("当前可用账号数 [{}] - {}", enableInstanceIds.size(), String.join(", ", enableInstanceIds));
}
}

View File

@@ -0,0 +1,103 @@
package com.xmzs.midjourney.support;
import cn.hutool.core.text.CharSequenceUtil;
import com.xmzs.midjourney.ProxyProperties;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
@Component
@RequiredArgsConstructor
public class DiscordHelper {
private final ProxyProperties properties;
/**
* DISCORD_SERVER_URL.
*/
public static final String DISCORD_SERVER_URL = "https://discord.com";
/**
* DISCORD_CDN_URL.
*/
public static final String DISCORD_CDN_URL = "https://cdn.discordapp.com";
/**
* DISCORD_WSS_URL.
*/
public static final String DISCORD_WSS_URL = "wss://gateway.discord.gg";
/**
* DISCORD_UPLOAD_URL.
*/
public static final String DISCORD_UPLOAD_URL = "https://discord-attachments-uploads-prd.storage.googleapis.com";
public String getServer() {
if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getServer())) {
return DISCORD_SERVER_URL;
}
String serverUrl = this.properties.getNgDiscord().getServer();
if (serverUrl.endsWith("/")) {
serverUrl = serverUrl.substring(0, serverUrl.length() - 1);
}
return serverUrl;
}
public String getCdn() {
if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getCdn())) {
return DISCORD_CDN_URL;
}
String cdnUrl = this.properties.getNgDiscord().getCdn();
if (cdnUrl.endsWith("/")) {
cdnUrl = cdnUrl.substring(0, cdnUrl.length() - 1);
}
return cdnUrl;
}
public String getWss() {
if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getWss())) {
return DISCORD_WSS_URL;
}
String wssUrl = this.properties.getNgDiscord().getWss();
if (wssUrl.endsWith("/")) {
wssUrl = wssUrl.substring(0, wssUrl.length() - 1);
}
return wssUrl;
}
public String getDiscordUploadUrl(String uploadUrl) {
if (CharSequenceUtil.isBlank(this.properties.getNgDiscord().getUploadServer()) || CharSequenceUtil.isBlank(uploadUrl)) {
return uploadUrl;
}
String uploadServer = this.properties.getNgDiscord().getUploadServer();
if (uploadServer.endsWith("/")) {
uploadServer = uploadServer.substring(0, uploadServer.length() - 1);
}
return uploadUrl.replaceFirst(DISCORD_UPLOAD_URL, uploadServer);
}
public String findTaskIdWithCdnUrl(String url) {
if (!CharSequenceUtil.startWith(url, DISCORD_CDN_URL)) {
return null;
}
int hashStartIndex = url.lastIndexOf("/");
String taskId = CharSequenceUtil.subBefore(url.substring(hashStartIndex + 1), ".", true);
if (CharSequenceUtil.length(taskId) == 16) {
return taskId;
}
return null;
}
public String getMessageHash(String imageUrl) {
if (CharSequenceUtil.isBlank(imageUrl)) {
return null;
}
if (CharSequenceUtil.endWith(imageUrl, "_grid_0.webp")) {
int hashStartIndex = imageUrl.lastIndexOf("/");
if (hashStartIndex < 0) {
return null;
}
return CharSequenceUtil.sub(imageUrl, hashStartIndex + 1, imageUrl.length() - "_grid_0.webp".length());
}
int hashStartIndex = imageUrl.lastIndexOf("_");
if (hashStartIndex < 0) {
return null;
}
return CharSequenceUtil.subBefore(imageUrl.substring(hashStartIndex + 1), ".", true);
}
}

View File

@@ -0,0 +1,23 @@
package com.xmzs.midjourney.support;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
@Component
public class SpringContextHolder implements ApplicationContextAware {
private static ApplicationContext APPLICATION_CONTEXT;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
APPLICATION_CONTEXT = applicationContext;
}
public static ApplicationContext getApplicationContext() {
if (APPLICATION_CONTEXT == null) {
throw new IllegalStateException("SpringContextHolder is not ready.");
}
return APPLICATION_CONTEXT;
}
}

View File

@@ -0,0 +1,68 @@
package com.xmzs.midjourney.support;
import com.xmzs.midjourney.domain.DomainObject;
import com.xmzs.midjourney.enums.TaskAction;
import com.xmzs.midjourney.enums.TaskStatus;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.io.Serial;
@Data
@EqualsAndHashCode(callSuper = true)
@ApiModel("任务")
public class Task extends DomainObject {
@Serial
private static final long serialVersionUID = -674915748204390789L;
@ApiModelProperty("任务类型")
private TaskAction action;
@ApiModelProperty("任务状态")
private TaskStatus status = TaskStatus.NOT_START;
@ApiModelProperty("提示词")
private String prompt;
@ApiModelProperty("提示词-英文")
private String promptEn;
@ApiModelProperty("任务描述")
private String description;
@ApiModelProperty("自定义参数")
private String state;
@ApiModelProperty("提交时间")
private Long submitTime;
@ApiModelProperty("开始执行时间")
private Long startTime;
@ApiModelProperty("结束时间")
private Long finishTime;
@ApiModelProperty("图片url")
private String imageUrl;
@ApiModelProperty("任务进度")
private String progress;
@ApiModelProperty("失败原因")
private String failReason;
public void start() {
this.startTime = System.currentTimeMillis();
this.status = TaskStatus.SUBMITTED;
this.progress = "0%";
}
public void success() {
this.finishTime = System.currentTimeMillis();
this.status = TaskStatus.SUCCESS;
this.progress = "100%";
}
public void fail(String reason) {
this.finishTime = System.currentTimeMillis();
this.status = TaskStatus.FAILURE;
this.failReason = reason;
this.progress = "";
}
}

Some files were not shown because too many files have changed in this diff Show More