mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-16 05:13:41 +00:00
增加后台管理,支持docker部署
This commit is contained in:
@@ -1,36 +1,22 @@
|
||||
package com.xmzs.midjourney.controller;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.xmzs.common.chat.constant.OpenAIConst;
|
||||
import com.xmzs.common.core.domain.model.LoginUser;
|
||||
import com.xmzs.common.core.exception.base.BaseException;
|
||||
import com.xmzs.common.satoken.utils.LoginHelper;
|
||||
import com.xmzs.midjourney.domain.InsightFace;
|
||||
import com.xmzs.system.domain.bo.ChatMessageBo;
|
||||
import com.xmzs.system.service.IChatMessageService;
|
||||
import com.xmzs.system.service.IChatService;
|
||||
import com.xmzs.system.service.ISseService;
|
||||
import com.xmzs.midjourney.domain.MjPriceConfig;
|
||||
import com.xmzs.midjourney.util.MjOkHttpUtil;
|
||||
|
||||
import com.xmzs.system.service.IChatCostService;
|
||||
import io.swagger.annotations.Api;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.MediaType;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.Response;
|
||||
import okio.Buffer;
|
||||
import okio.BufferedSink;
|
||||
import okio.GzipSink;
|
||||
import okio.Okio;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.apache.commons.lang3.math.NumberUtils;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Api(tags = "任务查询")
|
||||
@RestController
|
||||
@@ -39,62 +25,22 @@ import java.util.concurrent.TimeUnit;
|
||||
@Slf4j
|
||||
public class FaceController {
|
||||
|
||||
@Value("${chat.apiKey}")
|
||||
private String apiKey;
|
||||
@Value("${chat.apiHost}")
|
||||
private String apiHost;
|
||||
private final IChatCostService chatCostService;
|
||||
|
||||
@Autowired
|
||||
private IChatService chatService;
|
||||
private final MjOkHttpUtil mjOkHttpUtil;
|
||||
|
||||
@Autowired
|
||||
private ISseService sseService;
|
||||
private final MjPriceConfig priceConfig;
|
||||
|
||||
@ApiOperation(value = "换脸")
|
||||
@PostMapping("/insight-face/swap")
|
||||
public String insightFace(@RequestBody InsightFace insightFace) {
|
||||
// 查询是否是付费用户
|
||||
sseService.checkUserGrade();
|
||||
// 扣除接口费用
|
||||
chatService.mjTaskDeduct("换脸", OpenAIConst.MJ_COST_TYPE2);
|
||||
OkHttpClient client = new OkHttpClient.Builder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS) // 连接超时时间
|
||||
.writeTimeout(30, TimeUnit.SECONDS) // 写入超时时间
|
||||
.readTimeout(30, TimeUnit.SECONDS) // 读取超时时间
|
||||
.build();
|
||||
// 创建一个Request对象来配置你的请求
|
||||
// 扣除接口费用并且保存消息记录
|
||||
chatCostService.taskDeduct("mj","换脸", NumberUtils.toDouble(priceConfig.getFaceSwapping(), 0.3));
|
||||
// 创建请求体(这里使用JSON作为媒体类型)
|
||||
String jsonStr = JSONUtil.toJsonStr(insightFace);
|
||||
|
||||
MediaType JSON = MediaType.get("application/json; charset=utf-8");
|
||||
okhttp3.RequestBody body = okhttp3.RequestBody.create(jsonStr, JSON);
|
||||
Buffer buffer = new Buffer();
|
||||
GzipSink gzipSink = new GzipSink(buffer);
|
||||
BufferedSink gzipBufferedSink = Okio.buffer(gzipSink);
|
||||
try {
|
||||
body.writeTo(gzipBufferedSink);
|
||||
gzipBufferedSink.close();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
// 创建POST请求
|
||||
Request request = new Request.Builder()
|
||||
.header("mj-api-secret", apiKey)
|
||||
.header("Content-Encoding", "gzip")
|
||||
.url(apiHost + "mj/insight-face/swap") // 替换为你的URL
|
||||
.post(body)
|
||||
.build();
|
||||
|
||||
try (Response response = client.newCall(request).execute()) {
|
||||
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);
|
||||
if (response.body() != null) {
|
||||
return response.body().string();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
log.error("换脸失败! {}", e.getMessage());
|
||||
}
|
||||
return null;
|
||||
String insightFaceJson = JSONUtil.toJsonStr(insightFace);
|
||||
String url = "mj/insight-face/swap";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, insightFaceJson);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,61 +1,24 @@
|
||||
package com.xmzs.midjourney.controller;
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import cn.hutool.core.util.RandomUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.xmzs.common.chat.constant.OpenAIConst;
|
||||
import com.xmzs.common.core.domain.model.LoginUser;
|
||||
import com.xmzs.common.core.exception.base.BaseException;
|
||||
import com.xmzs.common.satoken.utils.LoginHelper;
|
||||
import com.xmzs.midjourney.Constants;
|
||||
import com.xmzs.midjourney.ProxyProperties;
|
||||
import com.xmzs.midjourney.ReturnCode;
|
||||
import com.xmzs.midjourney.domain.MjPriceConfig;
|
||||
import com.xmzs.midjourney.dto.*;
|
||||
import com.xmzs.midjourney.enums.TaskAction;
|
||||
import com.xmzs.midjourney.enums.TaskStatus;
|
||||
import com.xmzs.midjourney.enums.TranslateWay;
|
||||
import com.xmzs.midjourney.exception.BannedPromptException;
|
||||
import com.xmzs.midjourney.result.SubmitResultVO;
|
||||
import com.xmzs.midjourney.service.TaskService;
|
||||
import com.xmzs.midjourney.service.TaskStoreService;
|
||||
import com.xmzs.midjourney.service.TranslateService;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import com.xmzs.midjourney.support.TaskCondition;
|
||||
import com.xmzs.midjourney.util.BannedPromptUtils;
|
||||
import com.xmzs.midjourney.util.ConvertUtils;
|
||||
import com.xmzs.midjourney.util.MimeTypeUtils;
|
||||
import com.xmzs.midjourney.util.SnowFlake;
|
||||
import com.xmzs.midjourney.util.TaskChangeParams;
|
||||
import com.xmzs.system.domain.bo.ChatMessageBo;
|
||||
import com.xmzs.system.service.IChatMessageService;
|
||||
import com.xmzs.system.service.IChatService;
|
||||
import com.xmzs.system.service.ISseService;
|
||||
import eu.maxschuster.dataurl.DataUrl;
|
||||
import eu.maxschuster.dataurl.DataUrlSerializer;
|
||||
import eu.maxschuster.dataurl.IDataUrlSerializer;
|
||||
import com.xmzs.midjourney.enums.ActionType;
|
||||
import com.xmzs.midjourney.util.*;
|
||||
import com.xmzs.system.service.IChatCostService;
|
||||
import io.swagger.annotations.Api;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.OkHttpClient;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import okhttp3.Request;
|
||||
import org.apache.commons.lang3.math.NumberUtils;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.MalformedURLException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import okhttp3.*;
|
||||
import java.util.Optional;
|
||||
|
||||
@Api(tags = "任务提交")
|
||||
@RestController
|
||||
@@ -63,285 +26,114 @@ import okhttp3.*;
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class SubmitController {
|
||||
private final TranslateService translateService;
|
||||
private final ProxyProperties properties;
|
||||
private final TaskService taskService;
|
||||
private final TaskStoreService taskStoreService;
|
||||
@Value("${chat.apiKey}")
|
||||
private String apiKey;
|
||||
@Value("${chat.apiHost}")
|
||||
private String apiHost;
|
||||
@Autowired
|
||||
private IChatService chatService;
|
||||
@Autowired
|
||||
private IChatMessageService chatMessageService;
|
||||
@Autowired
|
||||
private ISseService sseService;
|
||||
|
||||
@ApiOperation(value = "提交Imagine任务")
|
||||
@PostMapping("/imagine")
|
||||
public SubmitResultVO imagine(@RequestBody SubmitImagineDTO imagineDTO) {
|
||||
String prompt = imagineDTO.getPrompt();
|
||||
if (CharSequenceUtil.isBlank(prompt)) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "prompt不能为空");
|
||||
}
|
||||
prompt = prompt.trim();
|
||||
Task task = newTask(imagineDTO);
|
||||
task.setAction(TaskAction.IMAGINE);
|
||||
task.setPrompt(prompt);
|
||||
String promptEn = translatePrompt(prompt);
|
||||
try {
|
||||
BannedPromptUtils.checkBanned(promptEn);
|
||||
} catch (BannedPromptException e) {
|
||||
return SubmitResultVO.fail(ReturnCode.BANNED_PROMPT, "可能包含敏感词")
|
||||
.setProperty("promptEn", promptEn).setProperty("bannedWord", e.getMessage());
|
||||
}
|
||||
List<String> base64Array = Optional.ofNullable(imagineDTO.getBase64Array()).orElse(new ArrayList<>());
|
||||
if (CharSequenceUtil.isNotBlank(imagineDTO.getBase64())) {
|
||||
base64Array.add(imagineDTO.getBase64());
|
||||
}
|
||||
List<DataUrl> dataUrls;
|
||||
try {
|
||||
dataUrls = ConvertUtils.convertBase64Array(base64Array);
|
||||
} catch (MalformedURLException e) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
|
||||
}
|
||||
task.setPromptEn(promptEn);
|
||||
task.setDescription("/imagine " + prompt);
|
||||
return this.taskService.submitImagine(task, dataUrls);
|
||||
private final MjPriceConfig priceConfig;
|
||||
|
||||
private final IChatCostService chatCostService;
|
||||
|
||||
private final MjOkHttpUtil mjOkHttpUtil;
|
||||
|
||||
@ApiOperation(value = "绘图变化")
|
||||
@PostMapping("/change")
|
||||
public String change(@RequestBody SubmitChangeDTO changeDTO) {
|
||||
String jsonStr = JSONUtil.toJsonStr(changeDTO);
|
||||
String url = "mj/submit/change";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "执行动作")
|
||||
@PostMapping("/action")
|
||||
public String action(@RequestBody SubmitActionDTO changeDTO) {
|
||||
ActionType actionType = ActionType.fromCustomId(getAction(changeDTO.getCustomId()));
|
||||
Optional.ofNullable(actionType).ifPresentOrElse(
|
||||
type -> {
|
||||
switch (type) {
|
||||
case UP_SAMPLE:
|
||||
chatCostService.taskDeduct("mj","放大", NumberUtils.toDouble(priceConfig.getUpsample(), 0.1));
|
||||
break;
|
||||
case IN_PAINT:
|
||||
// 局部重绘已经扣费,不执行任何操作
|
||||
break;
|
||||
default:
|
||||
chatCostService.taskDeduct("mj","变化", NumberUtils.toDouble(priceConfig.getChange(), 0.3));
|
||||
break;
|
||||
}
|
||||
},
|
||||
() -> chatCostService.taskDeduct("mj","变化", NumberUtils.toDouble(priceConfig.getChange(), 0.3))
|
||||
);
|
||||
|
||||
String jsonStr = JSONUtil.toJsonStr(changeDTO);
|
||||
String url = "mj/submit/action";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "绘图变化-simple")
|
||||
@PostMapping("/simple-change")
|
||||
public SubmitResultVO simpleChange(@RequestBody SubmitSimpleChangeDTO simpleChangeDTO) {
|
||||
TaskChangeParams changeParams = ConvertUtils.convertChangeParams(simpleChangeDTO.getContent());
|
||||
if (changeParams == null) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "content参数错误");
|
||||
}
|
||||
SubmitChangeDTO changeDTO = new SubmitChangeDTO();
|
||||
changeDTO.setAction(changeParams.getAction());
|
||||
changeDTO.setTaskId(changeParams.getId());
|
||||
changeDTO.setIndex(changeParams.getIndex());
|
||||
changeDTO.setState(simpleChangeDTO.getState());
|
||||
changeDTO.setNotifyHook(simpleChangeDTO.getNotifyHook());
|
||||
return change(changeDTO);
|
||||
public String simpleChange(@RequestBody SubmitSimpleChangeDTO simpleChangeDTO) {
|
||||
String jsonStr = JSONUtil.toJsonStr(simpleChangeDTO);
|
||||
String url = "mj/submit/simple-change";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "绘图变化")
|
||||
@PostMapping("/change")
|
||||
public SubmitResultVO change(@RequestBody SubmitChangeDTO changeDTO) {
|
||||
if (CharSequenceUtil.isBlank(changeDTO.getTaskId())) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "taskId不能为空");
|
||||
}
|
||||
if (!Set.of(TaskAction.UPSCALE, TaskAction.VARIATION, TaskAction.REROLL).contains(changeDTO.getAction())) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "action参数错误");
|
||||
}
|
||||
String description = "/up " + changeDTO.getTaskId();
|
||||
if (TaskAction.REROLL.equals(changeDTO.getAction())) {
|
||||
description += " R";
|
||||
} else {
|
||||
description += " " + changeDTO.getAction().name().charAt(0) + changeDTO.getIndex();
|
||||
}
|
||||
if (TaskAction.UPSCALE.equals(changeDTO.getAction())) {
|
||||
TaskCondition condition = new TaskCondition().setDescription(description);
|
||||
Task existTask = this.taskStoreService.findOne(condition);
|
||||
if (existTask != null) {
|
||||
return SubmitResultVO.of(ReturnCode.EXISTED, "任务已存在", existTask.getId())
|
||||
.setProperty("status", existTask.getStatus())
|
||||
.setProperty("imageUrl", existTask.getImageUrl());
|
||||
}
|
||||
}
|
||||
Task targetTask = this.taskStoreService.get(changeDTO.getTaskId());
|
||||
if (targetTask == null) {
|
||||
return SubmitResultVO.fail(ReturnCode.NOT_FOUND, "关联任务不存在或已失效");
|
||||
}
|
||||
if (!TaskStatus.SUCCESS.equals(targetTask.getStatus())) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务状态错误");
|
||||
}
|
||||
if (!Set.of(TaskAction.IMAGINE, TaskAction.VARIATION, TaskAction.REROLL, TaskAction.BLEND).contains(targetTask.getAction())) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "关联任务不允许执行变化");
|
||||
}
|
||||
Task task = newTask(changeDTO);
|
||||
task.setAction(changeDTO.getAction());
|
||||
task.setPrompt(targetTask.getPrompt());
|
||||
task.setPromptEn(targetTask.getPromptEn());
|
||||
task.setProperty(Constants.TASK_PROPERTY_FINAL_PROMPT, targetTask.getProperty(Constants.TASK_PROPERTY_FINAL_PROMPT));
|
||||
task.setProperty(Constants.TASK_PROPERTY_PROGRESS_MESSAGE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_MESSAGE_ID));
|
||||
task.setProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID, targetTask.getProperty(Constants.TASK_PROPERTY_DISCORD_INSTANCE_ID));
|
||||
task.setDescription(description);
|
||||
int messageFlags = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_FLAGS);
|
||||
String messageId = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_ID);
|
||||
String messageHash = targetTask.getPropertyGeneric(Constants.TASK_PROPERTY_MESSAGE_HASH);
|
||||
if (TaskAction.UPSCALE.equals(changeDTO.getAction())) {
|
||||
return this.taskService.submitUpscale(task, messageId, messageHash, changeDTO.getIndex(), messageFlags);
|
||||
} else if (TaskAction.VARIATION.equals(changeDTO.getAction())) {
|
||||
return this.taskService.submitVariation(task, messageId, messageHash, changeDTO.getIndex(), messageFlags);
|
||||
} else {
|
||||
return this.taskService.submitReroll(task, messageId, messageHash, messageFlags);
|
||||
}
|
||||
@ApiOperation(value = "提交图生图、混图任务")
|
||||
@PostMapping("/blend")
|
||||
public String blend(@RequestBody SubmitBlendDTO blendDTO) {
|
||||
chatCostService.taskDeduct("mj","图生图", NumberUtils.toDouble(priceConfig.getBlend(), 0.3));
|
||||
String jsonStr = JSONUtil.toJsonStr(blendDTO);
|
||||
String url = "mj/submit/blend";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "绘图变化")
|
||||
@PostMapping("/action")
|
||||
public String action(@RequestBody SubmitActionDTO changeDTO) {
|
||||
// 查询是否是付费用户
|
||||
sseService.checkUserGrade();
|
||||
// 扣除接口费用
|
||||
if ("upsample".equals(getAction(changeDTO.getCustomId()))) {
|
||||
mjTaskDeduct("放大", OpenAIConst.MJ_COST_TYPE2);
|
||||
} else {
|
||||
// Inpaint: 局部重绘
|
||||
// reroll 重绘
|
||||
// upsample 放大
|
||||
// zoom 变焦
|
||||
// upscale 高清放大
|
||||
// variation 变化
|
||||
if (!"Inpaint".equals(getAction(changeDTO.getCustomId()))) {
|
||||
mjTaskDeduct("变化", OpenAIConst.MJ_COST_TYPE1);
|
||||
}
|
||||
}
|
||||
OkHttpClient client = new OkHttpClient.Builder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS) // 连接超时时间
|
||||
.writeTimeout(30, TimeUnit.SECONDS) // 写入超时时间
|
||||
.readTimeout(30, TimeUnit.SECONDS) // 读取超时时间
|
||||
.build();
|
||||
@ApiOperation(value = "提交图生文任务")
|
||||
@PostMapping("/describe")
|
||||
public String describe(@RequestBody SubmitDescribeDTO describeDTO) {
|
||||
chatCostService.taskDeduct("mj","图生文", NumberUtils.toDouble(priceConfig.getDescribe(), 0.1));
|
||||
String jsonStr = JSONUtil.toJsonStr(describeDTO);
|
||||
String url = "mj/submit/describe";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
String jsonStr = JSONUtil.toJsonStr(changeDTO);
|
||||
@ApiOperation(value = "提交文生图任务")
|
||||
@PostMapping("/imagine")
|
||||
public String imagine(@RequestBody SubmitImagineDTO imagineDTO) {
|
||||
chatCostService.taskDeduct("mj",imagineDTO.getPrompt(), NumberUtils.toDouble(priceConfig.getImagine(), 0.3));
|
||||
String jsonStr = JSONUtil.toJsonStr(imagineDTO);
|
||||
String url = "mj/submit/imagine";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
MediaType mediaType = MediaType.parse("application/json");
|
||||
okhttp3.RequestBody body = okhttp3.RequestBody.create(jsonStr, mediaType);
|
||||
Request request = new Request.Builder()
|
||||
.url(apiHost + "mj/submit/action")
|
||||
.method("POST", body)
|
||||
.header("mj-api-secret", apiKey) // 设置Authorization header
|
||||
.build();
|
||||
try {
|
||||
Response response = client.newCall(request).execute();
|
||||
return response.body().string();
|
||||
} catch (IOException e) {
|
||||
log.error("绘图变化失败:{}", e.getMessage());
|
||||
}
|
||||
return null;
|
||||
@ApiOperation(value = "提交局部重绘任务")
|
||||
@PostMapping("/modal")
|
||||
public String modal(@RequestBody SubmitModalDTO submitModalDTO) {
|
||||
chatCostService.taskDeduct("mj","局部重绘", NumberUtils.toDouble(priceConfig.getInpaint(), 0.1));
|
||||
String jsonStr = JSONUtil.toJsonStr(submitModalDTO);
|
||||
String url = "mj/submit/modal";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "提交提示词分析任务")
|
||||
@PostMapping("/shorten")
|
||||
public String shorten(@RequestBody SubmitShortenDTO submitShortenDTO) {
|
||||
chatCostService.taskDeduct("mj","提示词分析", NumberUtils.toDouble(priceConfig.getShorten(), 0.1));
|
||||
String jsonStr = JSONUtil.toJsonStr(submitShortenDTO);
|
||||
String url = "mj/submit/shorten";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
public String getAction(String customId) {
|
||||
// 检查 customId 是否为空
|
||||
if(customId == null || customId.isEmpty()) {
|
||||
if (customId == null || customId.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
// 使用 "::" 分割字符串
|
||||
String[] parts = customId.split("::");
|
||||
// "MJ", "Inpaint", "1", "4fca7c14-181c-4...", "SOLO"
|
||||
if(customId.endsWith("SOLO")) {
|
||||
return parts[1];
|
||||
}
|
||||
// 返回 "upsample" 值,假设它总是在第三个位置
|
||||
return parts[2];
|
||||
}
|
||||
|
||||
public void mjTaskDeduct(String prompt, double cost) {
|
||||
//扣除费用
|
||||
chatService.deductUserBalance(getUserId(), cost);
|
||||
// 保存消息记录
|
||||
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
||||
chatMessageBo.setUserId(getUserId());
|
||||
chatMessageBo.setModelName("mj");
|
||||
chatMessageBo.setContent(prompt);
|
||||
chatMessageBo.setDeductCost(cost);
|
||||
chatMessageBo.setTotalTokens(0);
|
||||
chatMessageService.insertByBo(chatMessageBo);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取用户Id
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public Long getUserId() {
|
||||
LoginUser loginUser = LoginHelper.getLoginUser();
|
||||
if (loginUser == null) {
|
||||
throw new BaseException("用户未登录!");
|
||||
}
|
||||
return loginUser.getUserId();
|
||||
}
|
||||
|
||||
@ApiOperation(value = "提交Describe任务")
|
||||
@PostMapping("/describe")
|
||||
public SubmitResultVO describe(@RequestBody SubmitDescribeDTO describeDTO) {
|
||||
if (CharSequenceUtil.isBlank(describeDTO.getBase64())) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64不能为空");
|
||||
}
|
||||
IDataUrlSerializer serializer = new DataUrlSerializer();
|
||||
DataUrl dataUrl;
|
||||
try {
|
||||
dataUrl = serializer.unserialize(describeDTO.getBase64());
|
||||
} catch (MalformedURLException e) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
|
||||
}
|
||||
Task task = newTask(describeDTO);
|
||||
task.setAction(TaskAction.DESCRIBE);
|
||||
String taskFileName = task.getId() + "." + MimeTypeUtils.guessFileSuffix(dataUrl.getMimeType());
|
||||
task.setDescription("/describe " + taskFileName);
|
||||
return this.taskService.submitDescribe(task, dataUrl);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "提交Blend任务")
|
||||
@PostMapping("/blend")
|
||||
public SubmitResultVO blend(@RequestBody SubmitBlendDTO blendDTO) {
|
||||
List<String> base64Array = blendDTO.getBase64Array();
|
||||
if (base64Array == null || base64Array.size() < 2 || base64Array.size() > 5) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64List参数错误");
|
||||
}
|
||||
if (blendDTO.getDimensions() == null) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "dimensions参数错误");
|
||||
}
|
||||
IDataUrlSerializer serializer = new DataUrlSerializer();
|
||||
List<DataUrl> dataUrlList = new ArrayList<>();
|
||||
try {
|
||||
for (String base64 : base64Array) {
|
||||
DataUrl dataUrl = serializer.unserialize(base64);
|
||||
dataUrlList.add(dataUrl);
|
||||
}
|
||||
} catch (MalformedURLException e) {
|
||||
return SubmitResultVO.fail(ReturnCode.VALIDATION_ERROR, "base64格式错误");
|
||||
}
|
||||
Task task = newTask(blendDTO);
|
||||
task.setAction(TaskAction.BLEND);
|
||||
task.setDescription("/blend " + task.getId() + " " + dataUrlList.size());
|
||||
return this.taskService.submitBlend(task, dataUrlList, blendDTO.getDimensions());
|
||||
}
|
||||
|
||||
private Task newTask(BaseSubmitDTO base) {
|
||||
Task task = new Task();
|
||||
task.setId(System.currentTimeMillis() + RandomUtil.randomNumbers(3));
|
||||
task.setSubmitTime(System.currentTimeMillis());
|
||||
task.setState(base.getState());
|
||||
String notifyHook = CharSequenceUtil.isBlank(base.getNotifyHook()) ? this.properties.getNotifyHook() : base.getNotifyHook();
|
||||
task.setProperty(Constants.TASK_PROPERTY_NOTIFY_HOOK, notifyHook);
|
||||
task.setProperty(Constants.TASK_PROPERTY_NONCE, SnowFlake.INSTANCE.nextId());
|
||||
return task;
|
||||
}
|
||||
|
||||
private String translatePrompt(String prompt) {
|
||||
if (TranslateWay.NULL.equals(this.properties.getTranslateWay()) || CharSequenceUtil.isBlank(prompt)) {
|
||||
return prompt;
|
||||
}
|
||||
List<String> imageUrls = new ArrayList<>();
|
||||
Matcher imageMatcher = Pattern.compile("https?://[a-z0-9-_:@&?=+,.!/~*'%$]+\\x20+", Pattern.CASE_INSENSITIVE).matcher(prompt);
|
||||
while (imageMatcher.find()) {
|
||||
imageUrls.add(imageMatcher.group(0));
|
||||
}
|
||||
String paramStr = "";
|
||||
Matcher paramMatcher = Pattern.compile("\\x20+-{1,2}[a-z]+.*$", Pattern.CASE_INSENSITIVE).matcher(prompt);
|
||||
if (paramMatcher.find()) {
|
||||
paramStr = paramMatcher.group(0);
|
||||
}
|
||||
String imageStr = CharSequenceUtil.join("", imageUrls);
|
||||
String text = prompt.substring(imageStr.length(), prompt.length() - paramStr.length());
|
||||
if (CharSequenceUtil.isNotBlank(text)) {
|
||||
text = this.translateService.translateToEnglish(text).trim();
|
||||
}
|
||||
return imageStr + text + paramStr;
|
||||
return customId.endsWith("SOLO") ? parts[1] : parts[2];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,32 +1,15 @@
|
||||
package com.xmzs.midjourney.controller;
|
||||
|
||||
import cn.hutool.core.comparator.CompareUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.xmzs.midjourney.dto.SubmitImagineDTO;
|
||||
import com.xmzs.midjourney.util.MjOkHttpUtil;
|
||||
import com.xmzs.midjourney.dto.TaskConditionDTO;
|
||||
import com.xmzs.midjourney.loadbalancer.DiscordLoadBalancer;
|
||||
import com.xmzs.midjourney.result.SubmitResultVO;
|
||||
import com.xmzs.midjourney.service.TaskStoreService;
|
||||
import com.xmzs.midjourney.support.Task;
|
||||
import io.swagger.annotations.Api;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import io.swagger.annotations.ApiParam;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import okhttp3.Request;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
@Api(tags = "任务查询")
|
||||
@RestController
|
||||
@@ -34,58 +17,32 @@ import java.util.Objects;
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class TaskController {
|
||||
private final TaskStoreService taskStoreService;
|
||||
private final DiscordLoadBalancer discordLoadBalancer;
|
||||
|
||||
@Value("${chat.apiKey}")
|
||||
private String apiKey;
|
||||
@Value("${chat.apiHost}")
|
||||
private String apiHost;
|
||||
private final MjOkHttpUtil mjOkHttpUtil;
|
||||
|
||||
@ApiOperation(value = "指定ID获取任务")
|
||||
@GetMapping("/{id}/fetch")
|
||||
public String fetch(@ApiParam(value = "任务ID") @PathVariable String id) {
|
||||
OkHttpClient client = new OkHttpClient();
|
||||
// 创建一个Request对象来配置你的请求
|
||||
Request request = new Request.Builder()
|
||||
.header("mj-api-secret", apiKey) // 设置Authorization header
|
||||
.url(apiHost+"mj/task/" + id + "/fetch")
|
||||
.build();
|
||||
try (Response response = client.newCall(request).execute()) {
|
||||
if (!response.isSuccessful()) throw new IOException("Unexpected code " + response);
|
||||
if (response.body() != null) {
|
||||
return response.body().string();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
log.error("任务:{}查询失败:{}",id,e.getMessage());
|
||||
}
|
||||
return null;
|
||||
String url = "mj/task/" + id + "/fetch";
|
||||
Request request = mjOkHttpUtil.createGetRequest(url);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "查询任务队列")
|
||||
@GetMapping("/queue")
|
||||
public List<Task> queue() {
|
||||
return this.discordLoadBalancer.getQueueTaskIds().stream()
|
||||
.map(this.taskStoreService::get).filter(Objects::nonNull)
|
||||
.sorted(Comparator.comparing(Task::getSubmitTime))
|
||||
.toList();
|
||||
}
|
||||
|
||||
@ApiOperation(value = "查询所有任务")
|
||||
@GetMapping("/list")
|
||||
public List<Task> list() {
|
||||
return this.taskStoreService.list().stream()
|
||||
.sorted((t1, t2) -> CompareUtil.compare(t2.getSubmitTime(), t1.getSubmitTime()))
|
||||
.toList();
|
||||
}
|
||||
|
||||
@ApiOperation(value = "根据ID列表查询任务")
|
||||
@PostMapping("/list-by-condition")
|
||||
public List<Task> listByIds(@RequestBody TaskConditionDTO conditionDTO) {
|
||||
if (conditionDTO.getIds() == null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return conditionDTO.getIds().stream().map(this.taskStoreService::get).filter(Objects::nonNull).toList();
|
||||
public String listByIds(@RequestBody TaskConditionDTO conditionDTO) {
|
||||
String url = "mj/task/list-by-condition";
|
||||
String conditionJson = JSONUtil.toJsonStr(conditionDTO);
|
||||
Request request = mjOkHttpUtil.createPostRequest(url,conditionJson);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
@ApiOperation(value = "获取任务图片的seed")
|
||||
@GetMapping("/{id}/image-seed")
|
||||
public String getSeed(@ApiParam(value = "任务ID") @PathVariable String id) {
|
||||
String url = "mj/task/" + id + "/image-seed";
|
||||
Request request = mjOkHttpUtil.createGetRequest(url);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package com.xmzs.midjourney.domain;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* 绘画费用信息
|
||||
*
|
||||
* @author Admin
|
||||
*/
|
||||
@Data
|
||||
@Component
|
||||
@ConfigurationProperties(prefix = "mj")
|
||||
public class MjPriceConfig {
|
||||
/**
|
||||
* 放大图像
|
||||
*/
|
||||
private String upsample;
|
||||
|
||||
/**
|
||||
* 变化
|
||||
*/
|
||||
private String change;
|
||||
|
||||
/**
|
||||
* 图生图
|
||||
*/
|
||||
private String blend;
|
||||
|
||||
/**
|
||||
* 图生文
|
||||
*/
|
||||
private String describe;
|
||||
|
||||
/**
|
||||
* 文生图
|
||||
*/
|
||||
private String imagine;
|
||||
|
||||
/**
|
||||
* 局部重绘
|
||||
*/
|
||||
private String inpaint;
|
||||
|
||||
/**
|
||||
* 提示词分析
|
||||
*/
|
||||
private String shorten;
|
||||
|
||||
/**
|
||||
* 换脸
|
||||
*/
|
||||
private String faceSwapping;
|
||||
|
||||
}
|
||||
@@ -1,10 +1,7 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import com.xmzs.midjourney.enums.TaskAction;
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
|
||||
@Data
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@ApiModel("局部重绘提交参数")
|
||||
public class SubmitModalDTO extends BaseSubmitDTO{
|
||||
|
||||
private String maskBase64;
|
||||
|
||||
private String taskId;
|
||||
|
||||
private String prompt;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.xmzs.midjourney.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@ApiModel("prompt分析提交参数")
|
||||
public class SubmitShortenDTO extends BaseSubmitDTO{
|
||||
|
||||
private String botType;
|
||||
|
||||
private String prompt;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.xmzs.midjourney.enums;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
/**
|
||||
* @author WangLe
|
||||
*/
|
||||
@Getter
|
||||
public enum ActionType {
|
||||
IN_PAINT("Inpaint"), // 局部重绘操作
|
||||
RE_ROLL("reroll"), // 重绘操作
|
||||
UP_SAMPLE("upsample"), // 放大操作
|
||||
ZOOM("zoom"), // 变焦操作
|
||||
UPSCALE("upscale"), // 高清放大操作
|
||||
VARIATION("variation"); // 变化操作
|
||||
|
||||
private final String action;
|
||||
|
||||
ActionType(String action) {
|
||||
this.action = action;
|
||||
}
|
||||
|
||||
public static ActionType fromCustomId(String customId) {
|
||||
for (ActionType type : values()) {
|
||||
if (type.getAction().equals(customId)) {
|
||||
return type;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package com.xmzs.midjourney.enums;
|
||||
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
@Getter
|
||||
public enum BlendDimensions {
|
||||
|
||||
PORTRAIT("2:3"),
|
||||
@@ -15,7 +18,4 @@ public enum BlendDimensions {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return this.value;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
package com.xmzs.midjourney.util;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* @author WangLe
|
||||
*/
|
||||
@RequiredArgsConstructor
|
||||
@Component
|
||||
@Slf4j
|
||||
public class MjOkHttpUtil {
|
||||
|
||||
@Value("${chat.apiKey}")
|
||||
private List<String> apiKey;
|
||||
@Value("${chat.apiHost}")
|
||||
private String apiHost;
|
||||
|
||||
private static final String API_SECRET_HEADER = "mj-api-secret";
|
||||
|
||||
private final OkHttpClient client = new OkHttpClient.Builder()
|
||||
.connectTimeout(300, TimeUnit.SECONDS)
|
||||
.writeTimeout(300, TimeUnit.SECONDS)
|
||||
.readTimeout(300, TimeUnit.SECONDS)
|
||||
.build();
|
||||
|
||||
public String executeRequest(Request request) {
|
||||
try (Response response = client.newCall(request).execute()) {
|
||||
if (!response.isSuccessful()) {
|
||||
throw new IOException("Unexpected code " + response);
|
||||
}
|
||||
return response.body() != null ? response.body().string() : null;
|
||||
} catch (IOException e) {
|
||||
// 这里应根据实际情况使用适当的日志记录方式
|
||||
log.error("请求失败: {}",e.getMessage());
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public Request createPostRequest(String url, String json) {
|
||||
MediaType JSON = MediaType.get("application/json; charset=utf-8");
|
||||
RequestBody body = RequestBody.create(json, JSON);
|
||||
return new Request.Builder()
|
||||
.url(apiHost + url)
|
||||
.post(body)
|
||||
.header(API_SECRET_HEADER, apiKey.get(0))
|
||||
.build();
|
||||
}
|
||||
|
||||
public Request createGetRequest(String url) {
|
||||
return new Request.Builder()
|
||||
.url(apiHost + url)
|
||||
.header(API_SECRET_HEADER, apiKey.get(0))
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package com.xmzs.system.controller.system;
|
||||
|
||||
import cn.dev33.satoken.annotation.SaCheckPermission;
|
||||
import com.xmzs.common.core.domain.R;
|
||||
import com.xmzs.common.core.validate.AddGroup;
|
||||
import com.xmzs.common.core.validate.EditGroup;
|
||||
import com.xmzs.common.excel.utils.ExcelUtil;
|
||||
import com.xmzs.common.idempotent.annotation.RepeatSubmit;
|
||||
import com.xmzs.common.log.annotation.Log;
|
||||
import com.xmzs.common.log.enums.BusinessType;
|
||||
import com.xmzs.common.mybatis.core.page.PageQuery;
|
||||
import com.xmzs.common.mybatis.core.page.TableDataInfo;
|
||||
import com.xmzs.common.web.core.BaseController;
|
||||
import com.xmzs.system.domain.bo.SysModelBo;
|
||||
import com.xmzs.system.domain.vo.SysModelVo;
|
||||
import com.xmzs.system.service.ISysModelService;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 系统模型
|
||||
*
|
||||
* @author Lion Li
|
||||
* @date 2024-04-04
|
||||
*/
|
||||
@Validated
|
||||
@RequiredArgsConstructor
|
||||
@RestController
|
||||
@RequestMapping("/system/model")
|
||||
public class SysModelController extends BaseController {
|
||||
|
||||
private final ISysModelService sysModelService;
|
||||
|
||||
/**
|
||||
* 查询系统模型列表
|
||||
*/
|
||||
@GetMapping("/list")
|
||||
public TableDataInfo<SysModelVo> list(SysModelBo bo, PageQuery pageQuery) {
|
||||
return sysModelService.queryPageList(bo, pageQuery);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询系统模型列表
|
||||
*/
|
||||
@GetMapping("/modelList")
|
||||
public R<List<SysModelVo>> modelList(SysModelBo bo) {
|
||||
bo.setModelShow("0");
|
||||
return R.ok(sysModelService.queryList(bo));
|
||||
}
|
||||
|
||||
/**
|
||||
* 导出系统模型列表
|
||||
*/
|
||||
@SaCheckPermission("system:model:export")
|
||||
@Log(title = "系统模型", businessType = BusinessType.EXPORT)
|
||||
@PostMapping("/export")
|
||||
public void export(SysModelBo bo, HttpServletResponse response) {
|
||||
List<SysModelVo> list = sysModelService.queryList(bo);
|
||||
ExcelUtil.exportExcel(list, "系统模型", SysModelVo.class, response);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取系统模型详细信息
|
||||
*
|
||||
* @param id 主键
|
||||
*/
|
||||
@SaCheckPermission("system:model:query")
|
||||
@GetMapping("/{id}")
|
||||
public R<SysModelVo> getInfo(@NotNull(message = "主键不能为空")
|
||||
@PathVariable Long id) {
|
||||
return R.ok(sysModelService.queryById(id));
|
||||
}
|
||||
|
||||
/**
|
||||
* 新增系统模型
|
||||
*/
|
||||
@SaCheckPermission("system:model:add")
|
||||
@Log(title = "系统模型", businessType = BusinessType.INSERT)
|
||||
@RepeatSubmit()
|
||||
@PostMapping()
|
||||
public R<Void> add(@Validated(AddGroup.class) @RequestBody SysModelBo bo) {
|
||||
return toAjax(sysModelService.insertByBo(bo));
|
||||
}
|
||||
|
||||
/**
|
||||
* 修改系统模型
|
||||
*/
|
||||
@SaCheckPermission("system:model:edit")
|
||||
@Log(title = "系统模型", businessType = BusinessType.UPDATE)
|
||||
@RepeatSubmit()
|
||||
@PutMapping()
|
||||
public R<Void> edit(@Validated(EditGroup.class) @RequestBody SysModelBo bo) {
|
||||
return toAjax(sysModelService.updateByBo(bo));
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除系统模型
|
||||
*
|
||||
* @param ids 主键串
|
||||
*/
|
||||
@SaCheckPermission("system:model:remove")
|
||||
@Log(title = "系统模型", businessType = BusinessType.DELETE)
|
||||
@DeleteMapping("/{ids}")
|
||||
public R<Void> remove(@NotEmpty(message = "主键不能为空")
|
||||
@PathVariable Long[] ids) {
|
||||
return toAjax(sysModelService.deleteWithValidByIds(List.of(ids), true));
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,7 @@ import java.math.BigDecimal;
|
||||
*/
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@TableName("payment_orders")
|
||||
@TableName("sys_pay_order")
|
||||
public class PaymentOrders extends BaseEntity {
|
||||
|
||||
@Serial
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
package com.xmzs.system.domain;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import com.xmzs.common.mybatis.core.domain.BaseEntity;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
import java.io.Serial;
|
||||
|
||||
/**
|
||||
* 系统模型对象 sys_model
|
||||
*
|
||||
* @author Lion Li
|
||||
* @date 2024-04-04
|
||||
*/
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@TableName("sys_model")
|
||||
public class SysModel extends BaseEntity {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
/**
|
||||
* 主键
|
||||
*/
|
||||
@TableId(value = "id")
|
||||
private Long id;
|
||||
|
||||
/**
|
||||
* 模型名称
|
||||
*/
|
||||
private String modelName;
|
||||
|
||||
/**
|
||||
* 模型描述
|
||||
*/
|
||||
private String modelDescribe;
|
||||
|
||||
/**
|
||||
* 模型价格
|
||||
*/
|
||||
private double modelPrice;
|
||||
|
||||
/**
|
||||
* 计费类型
|
||||
*/
|
||||
private String modelType;
|
||||
|
||||
/**
|
||||
* 是否显示
|
||||
*/
|
||||
private String modelShow;
|
||||
|
||||
|
||||
/**
|
||||
* 系统提示词
|
||||
*/
|
||||
private String systemPrompt;
|
||||
|
||||
/**
|
||||
* 备注
|
||||
*/
|
||||
private String remark;
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package com.xmzs.system.domain.bo;
|
||||
|
||||
import com.xmzs.common.core.validate.AddGroup;
|
||||
import com.xmzs.common.core.validate.EditGroup;
|
||||
import com.xmzs.common.mybatis.core.domain.BaseEntity;
|
||||
import com.xmzs.system.domain.SysModel;
|
||||
import io.github.linpeilie.annotations.AutoMapper;
|
||||
import jakarta.validation.constraints.NotBlank;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
/**
|
||||
* 系统模型业务对象 sys_model
|
||||
*
|
||||
* @author Lion Li
|
||||
* @date 2024-04-04
|
||||
*/
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@AutoMapper(target = SysModel.class, reverseConvertGenerate = false)
|
||||
public class SysModelBo extends BaseEntity {
|
||||
|
||||
/**
|
||||
* 主键
|
||||
*/
|
||||
@NotNull(message = "主键不能为空", groups = { EditGroup.class })
|
||||
private Long id;
|
||||
|
||||
/**
|
||||
* 模型名称
|
||||
*/
|
||||
@NotBlank(message = "模型名称不能为空", groups = { AddGroup.class, EditGroup.class })
|
||||
private String modelName;
|
||||
|
||||
|
||||
/**
|
||||
* 模型描述
|
||||
*/
|
||||
@NotBlank(message = "模型描述不能为空", groups = { AddGroup.class, EditGroup.class })
|
||||
private String modelDescribe;
|
||||
|
||||
/**
|
||||
* 模型价格
|
||||
*/
|
||||
@NotNull(message = "模型价格不能为空", groups = { AddGroup.class, EditGroup.class })
|
||||
private double modelPrice;
|
||||
|
||||
/**
|
||||
* 计费类型 (1 token扣费; 2 次数扣费 )
|
||||
*/
|
||||
@NotBlank(message = "计费类型不能为空", groups = { AddGroup.class, EditGroup.class })
|
||||
private String modelType;
|
||||
|
||||
/**
|
||||
* 模型状态 (0 显示; 1 隐藏 )
|
||||
*/
|
||||
private String modelShow;
|
||||
|
||||
|
||||
/**
|
||||
* 系统提示词
|
||||
*/
|
||||
private String systemPrompt;
|
||||
|
||||
/**
|
||||
* 备注
|
||||
*/
|
||||
@NotBlank(message = "备注不能为空", groups = { AddGroup.class, EditGroup.class })
|
||||
private String remark;
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.xmzs.system.domain.vo;
|
||||
|
||||
import com.alibaba.excel.annotation.ExcelIgnoreUnannotated;
|
||||
import com.alibaba.excel.annotation.ExcelProperty;
|
||||
import com.xmzs.system.domain.SysModel;
|
||||
import io.github.linpeilie.annotations.AutoMapper;
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.io.Serializable;
|
||||
|
||||
|
||||
/**
|
||||
* 系统模型视图对象 sys_model
|
||||
*
|
||||
* @author Lion Li
|
||||
* @date 2024-04-04
|
||||
*/
|
||||
@Data
|
||||
@ExcelIgnoreUnannotated
|
||||
@AutoMapper(target = SysModel.class)
|
||||
public class SysModelVo implements Serializable {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
/**
|
||||
* 主键
|
||||
*/
|
||||
@ExcelProperty(value = "主键")
|
||||
private Long id;
|
||||
|
||||
/**
|
||||
* 模型名称
|
||||
*/
|
||||
@ExcelProperty(value = "模型名称")
|
||||
private String modelName;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 模型描述
|
||||
*/
|
||||
@ExcelProperty(value = "模型描述")
|
||||
private String modelDescribe;
|
||||
|
||||
/**
|
||||
* 模型价格
|
||||
*/
|
||||
@ExcelProperty(value = "模型价格")
|
||||
private double modelPrice;
|
||||
|
||||
/**
|
||||
* 计费类型
|
||||
*/
|
||||
@ExcelProperty(value = "计费类型")
|
||||
private String modelType;
|
||||
|
||||
/**
|
||||
* 是否显示
|
||||
*/
|
||||
private String modelShow;
|
||||
|
||||
|
||||
/**
|
||||
* 系统提示词
|
||||
*/
|
||||
private String systemPrompt;
|
||||
|
||||
/**
|
||||
* 备注
|
||||
*/
|
||||
@ExcelProperty(value = "备注")
|
||||
private String remark;
|
||||
|
||||
}
|
||||
@@ -4,14 +4,16 @@ package com.xmzs.system.listener;
|
||||
import cn.hutool.core.collection.CollectionUtil;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.xmzs.common.chat.config.LocalCache;
|
||||
import com.xmzs.common.chat.entity.chat.ChatCompletion;
|
||||
import com.xmzs.common.chat.entity.chat.ChatCompletionResponse;
|
||||
import com.xmzs.common.chat.utils.TikTokensUtil;
|
||||
import com.xmzs.common.core.utils.SpringUtils;
|
||||
import com.xmzs.common.core.utils.StringUtils;
|
||||
import com.xmzs.system.domain.bo.ChatMessageBo;
|
||||
import com.xmzs.system.domain.bo.SysModelBo;
|
||||
import com.xmzs.system.domain.vo.SysModelVo;
|
||||
import com.xmzs.system.service.IChatMessageService;
|
||||
import com.xmzs.system.service.IChatService;
|
||||
import com.xmzs.system.service.IChatCostService;
|
||||
import com.xmzs.system.service.ISysModelService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -24,6 +26,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
@@ -45,7 +48,7 @@ public class SSEEventSourceListener extends EventSourceListener {
|
||||
public SSEEventSourceListener(ResponseBodyEmitter emitter) {
|
||||
this.emitter = emitter;
|
||||
}
|
||||
|
||||
private static final ISysModelService sysModelService = SpringUtils.getBean(ISysModelService.class);
|
||||
private String modelName;
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
@@ -66,34 +69,34 @@ public class SSEEventSourceListener extends EventSourceListener {
|
||||
//成功响应
|
||||
emitter.complete();
|
||||
if(StringUtils.isNotEmpty(modelName)){
|
||||
IChatService IChatService = SpringUtils.context().getBean(IChatService.class);
|
||||
IChatCostService IChatCostService = SpringUtils.context().getBean(IChatCostService.class);
|
||||
IChatMessageService chatMessageService = SpringUtils.context().getBean(IChatMessageService.class);
|
||||
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
||||
chatMessageBo.setModelName(modelName);
|
||||
chatMessageBo.setContent(stringBuffer.toString());
|
||||
Long userId = (Long)LocalCache.CACHE.get("userId");
|
||||
chatMessageBo.setUserId(userId);
|
||||
if(ChatCompletion.Model.GPT_4_ALL.getName().equals(modelName)
|
||||
|| modelName.startsWith(ChatCompletion.Model.GPT_4_GIZMO.getName())
|
||||
|| modelName.startsWith(ChatCompletion.Model.NET.getName())
|
||||
|| ChatCompletion.Model.GPT_4_VISION_PREVIEW.getName().equals(modelName)
|
||||
|| ChatCompletion.Model.CLAUDE_3_SONNET.getName().equals(modelName)
|
||||
|| ChatCompletion.Model.STABLE_DIFFUSION.getName().equals(modelName)
|
||||
|| ChatCompletion.Model.SUNO_V3.getName().equals(modelName)
|
||||
){
|
||||
chatMessageBo.setDeductCost(0.0);
|
||||
chatMessageBo.setTotalTokens(0);
|
||||
|
||||
//查询按次数扣费的模型
|
||||
SysModelBo sysModelBo = new SysModelBo();
|
||||
sysModelBo.setModelType("2");
|
||||
sysModelBo.setModelName(modelName);
|
||||
List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
|
||||
if (CollectionUtil.isNotEmpty(sysModelList)){
|
||||
chatMessageBo.setDeductCost(0d);
|
||||
chatMessageBo.setRemark("提问时扣费");
|
||||
// 保存消息记录
|
||||
chatMessageService.insertByBo(chatMessageBo);
|
||||
}else {
|
||||
// 扣除余额
|
||||
}else{
|
||||
int tokens = TikTokensUtil.tokens(modelName,stringBuffer.toString());
|
||||
chatMessageBo.setTotalTokens(tokens);
|
||||
IChatService.deductToken(chatMessageBo);
|
||||
// 按token扣费并且保存消息记录
|
||||
IChatCostService.deductToken(chatMessageBo);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
// 解析返回内容
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class);
|
||||
if(completionResponse == null || CollectionUtil.isEmpty(completionResponse.getChoices())){
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.xmzs.system.mapper;
|
||||
|
||||
import com.xmzs.common.mybatis.core.mapper.BaseMapperPlus;
|
||||
import com.xmzs.system.domain.SysModel;
|
||||
import com.xmzs.system.domain.vo.SysModelVo;
|
||||
|
||||
/**
|
||||
* 系统模型Mapper接口
|
||||
*
|
||||
* @author Lion Li
|
||||
* @date 2024-04-04
|
||||
*/
|
||||
public interface SysModelMapper extends BaseMapperPlus<SysModel, SysModelVo> {
|
||||
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package com.xmzs.system.service;
|
||||
|
||||
import com.xmzs.system.domain.bo.ChatMessageBo;
|
||||
|
||||
public interface IChatCostService {
|
||||
|
||||
/**
|
||||
* 根据消耗的tokens扣除余额
|
||||
*
|
||||
* @param chatMessageBo
|
||||
* @return 结果
|
||||
*/
|
||||
|
||||
void deductToken(ChatMessageBo chatMessageBo);
|
||||
|
||||
/**
|
||||
* 扣除用户的余额
|
||||
*
|
||||
*/
|
||||
void deductUserBalance(Long userId, Double numberCost);
|
||||
|
||||
|
||||
/**
|
||||
* 扣除任务费用并且保存记录
|
||||
*
|
||||
* @param type 任务类型
|
||||
* @param prompt 任务描述
|
||||
* @param cost 扣除费用
|
||||
*/
|
||||
void taskDeduct(String type,String prompt, double cost);
|
||||
|
||||
|
||||
/**
|
||||
* 判断用户是否付费
|
||||
*/
|
||||
void checkUserGrade();
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.xmzs.system.service;
|
||||
|
||||
import com.xmzs.common.mybatis.core.page.PageQuery;
|
||||
import com.xmzs.common.mybatis.core.page.TableDataInfo;
|
||||
import com.xmzs.system.domain.bo.SysModelBo;
|
||||
import com.xmzs.system.domain.vo.SysModelVo;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 系统模型Service接口
|
||||
*
|
||||
* @author Lion Li
|
||||
* @date 2024-04-04
|
||||
*/
|
||||
public interface ISysModelService {
|
||||
|
||||
/**
|
||||
* 查询系统模型
|
||||
*/
|
||||
SysModelVo queryById(Long id);
|
||||
|
||||
/**
|
||||
* 查询系统模型列表
|
||||
*/
|
||||
TableDataInfo<SysModelVo> queryPageList(SysModelBo bo, PageQuery pageQuery);
|
||||
|
||||
/**
|
||||
* 查询系统模型列表
|
||||
*/
|
||||
List<SysModelVo> queryList(SysModelBo bo);
|
||||
|
||||
/**
|
||||
* 新增系统模型
|
||||
*/
|
||||
Boolean insertByBo(SysModelBo bo);
|
||||
|
||||
/**
|
||||
* 修改系统模型
|
||||
*/
|
||||
Boolean updateByBo(SysModelBo bo);
|
||||
|
||||
/**
|
||||
* 校验并批量删除系统模型信息
|
||||
*/
|
||||
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
package com.xmzs.system.service.impl;
|
||||
|
||||
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
||||
import com.xmzs.common.core.domain.model.LoginUser;
|
||||
import com.xmzs.common.core.exception.ServiceException;
|
||||
import com.xmzs.common.core.exception.base.BaseException;
|
||||
import com.xmzs.common.satoken.utils.LoginHelper;
|
||||
import com.xmzs.system.domain.ChatToken;
|
||||
import com.xmzs.system.domain.SysUser;
|
||||
import com.xmzs.system.domain.bo.ChatMessageBo;
|
||||
import com.xmzs.system.domain.bo.SysModelBo;
|
||||
import com.xmzs.system.domain.vo.SysModelVo;
|
||||
import com.xmzs.system.mapper.SysUserMapper;
|
||||
import com.xmzs.system.service.IChatCostService;
|
||||
import com.xmzs.system.service.IChatMessageService;
|
||||
import com.xmzs.system.service.IChatTokenService;
|
||||
import com.xmzs.system.service.ISysModelService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author hncboy
|
||||
* @date 2023/3/22 19:41
|
||||
* 聊天相关业务实现类
|
||||
*/
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ChatCostServiceImpl implements IChatCostService {
|
||||
|
||||
private final SysUserMapper sysUserMapper;
|
||||
|
||||
private final IChatMessageService chatMessageService;
|
||||
|
||||
private final IChatTokenService chatTokenService;
|
||||
|
||||
private final ISysModelService sysModelService;
|
||||
|
||||
/**
|
||||
* 根据消耗的tokens扣除余额
|
||||
*
|
||||
* @param chatMessageBo
|
||||
*/
|
||||
public void deductToken(ChatMessageBo chatMessageBo) {
|
||||
// 计算总token数
|
||||
ChatToken chatToken = chatTokenService.queryByUserId(chatMessageBo.getUserId(), chatMessageBo.getModelName());
|
||||
if (chatToken == null) {
|
||||
chatToken = new ChatToken();
|
||||
chatToken.setToken(0);
|
||||
}
|
||||
int totalTokens = chatToken.getToken() + chatMessageBo.getTotalTokens();
|
||||
// 如果总token数大于等于1000,进行费用扣除
|
||||
if (totalTokens >= 1000) {
|
||||
// 计算费用
|
||||
int token1 = totalTokens / 1000;
|
||||
int token2 = totalTokens % 1000;
|
||||
if (token2 > 0) {
|
||||
// 保存剩余tokens
|
||||
chatToken.setToken(token2);
|
||||
chatTokenService.editToken(chatToken);
|
||||
} else {
|
||||
chatTokenService.resetToken(chatMessageBo.getUserId(), chatMessageBo.getModelName());
|
||||
}
|
||||
// 扣除用户余额
|
||||
|
||||
SysModelBo sysModelBo = new SysModelBo();
|
||||
sysModelBo.setModelName(chatMessageBo.getModelName());
|
||||
List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
|
||||
double modelPrice = sysModelList.get(0).getModelPrice();
|
||||
Double numberCost = token1 * modelPrice;
|
||||
deductUserBalance(chatMessageBo.getUserId(), numberCost);
|
||||
chatMessageBo.setDeductCost(numberCost);
|
||||
} else {
|
||||
// 扣除用户余额
|
||||
deductUserBalance(chatMessageBo.getUserId(), 0.0);
|
||||
chatMessageBo.setDeductCost(0d);
|
||||
chatMessageBo.setRemark("不满1kToken,计入下一次!");
|
||||
chatToken.setToken(totalTokens);
|
||||
chatToken.setModelName(chatMessageBo.getModelName());
|
||||
chatToken.setUserId(chatMessageBo.getUserId());
|
||||
chatTokenService.editToken(chatToken);
|
||||
}
|
||||
// 保存消息记录
|
||||
chatMessageService.insertByBo(chatMessageBo);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 从用户余额中扣除费用
|
||||
*
|
||||
* @param userId 用户ID
|
||||
* @param numberCost 要扣除的费用
|
||||
*/
|
||||
@Override
|
||||
public void deductUserBalance(Long userId, Double numberCost) {
|
||||
SysUser sysUser = sysUserMapper.selectById(userId);
|
||||
if (sysUser == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
Double userBalance = sysUser.getUserBalance();
|
||||
if (userBalance < numberCost || userBalance == 0) {
|
||||
throw new ServiceException("余额不足,请联系管理员充值!");
|
||||
}
|
||||
sysUserMapper.update(null,
|
||||
new LambdaUpdateWrapper<SysUser>()
|
||||
.set(SysUser::getUserBalance, Math.max(userBalance - numberCost, 0))
|
||||
.eq(SysUser::getUserId, userId));
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 扣除任务费用
|
||||
*
|
||||
*/
|
||||
@Override
|
||||
public void taskDeduct(String type,String prompt, double cost) {
|
||||
// 判断用户是否付费
|
||||
checkUserGrade();
|
||||
// 扣除费用
|
||||
deductUserBalance(getUserId(), cost);
|
||||
// 保存消息记录
|
||||
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
||||
chatMessageBo.setUserId(getUserId());
|
||||
chatMessageBo.setModelName(type);
|
||||
chatMessageBo.setContent(prompt);
|
||||
chatMessageBo.setDeductCost(cost);
|
||||
chatMessageBo.setTotalTokens(0);
|
||||
chatMessageService.insertByBo(chatMessageBo);
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断用户是否付费
|
||||
*/
|
||||
@Override
|
||||
public void checkUserGrade() {
|
||||
SysUser sysUser = sysUserMapper.selectById(getUserId());
|
||||
if("0".equals(sysUser.getUserGrade())){
|
||||
throw new BaseException("免费用户暂时不支持此模型,请切换gpt-3.5-turbo模型或者点击《进入市场选购您的商品》充值后使用!");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取用户Id
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public Long getUserId() {
|
||||
LoginUser loginUser = LoginHelper.getLoginUser();
|
||||
if (loginUser == null) {
|
||||
throw new BaseException("用户未登录!");
|
||||
}
|
||||
return loginUser.getUserId();
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.xmzs.system.service.impl;
|
||||
|
||||
import cn.hutool.core.collection.CollectionUtil;
|
||||
import com.xmzs.common.chat.config.LocalCache;
|
||||
import com.xmzs.common.chat.constant.OpenAIConst;
|
||||
import com.xmzs.common.chat.domain.request.ChatRequest;
|
||||
@@ -15,14 +16,18 @@ import com.xmzs.common.chat.openai.OpenAiStreamClient;
|
||||
import com.xmzs.common.chat.utils.TikTokensUtil;
|
||||
import com.xmzs.common.core.domain.model.LoginUser;
|
||||
import com.xmzs.common.core.exception.base.BaseException;
|
||||
import com.xmzs.common.core.utils.StringUtils;
|
||||
import com.xmzs.common.satoken.utils.LoginHelper;
|
||||
import com.xmzs.system.domain.SysUser;
|
||||
import com.xmzs.system.domain.bo.ChatMessageBo;
|
||||
import com.xmzs.system.domain.bo.SysModelBo;
|
||||
import com.xmzs.system.domain.vo.SysModelVo;
|
||||
import com.xmzs.system.listener.SSEEventSourceListener;
|
||||
import com.xmzs.system.mapper.SysUserMapper;
|
||||
import com.xmzs.system.service.IChatCostService;
|
||||
import com.xmzs.system.service.IChatMessageService;
|
||||
import com.xmzs.system.service.IChatService;
|
||||
import com.xmzs.system.service.ISseService;
|
||||
import com.xmzs.system.service.ISysModelService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.ResponseBody;
|
||||
@@ -58,13 +63,14 @@ public class SseServiceImpl implements ISseService {
|
||||
|
||||
private final OpenAiStreamClient openAiStreamClient;
|
||||
|
||||
|
||||
private final IChatService IChatService;
|
||||
private final IChatCostService chatService;
|
||||
|
||||
private final SysUserMapper sysUserMapper;
|
||||
|
||||
private final IChatMessageService chatMessageService;
|
||||
|
||||
private final ISysModelService sysModelService;
|
||||
|
||||
@Value("${chat.apiKey}")
|
||||
private String apiKey;
|
||||
@Value("${chat.apiHost}")
|
||||
@@ -91,33 +97,35 @@ public class SseServiceImpl implements ISseService {
|
||||
// 判断用户是否付费
|
||||
checkUserGrade();
|
||||
}
|
||||
// 按次数扣费
|
||||
if(ChatCompletion.Model.GPT_4_ALL.getName().equals(chatRequest.getModel())
|
||||
|| chatRequest.getModel().startsWith(ChatCompletion.Model.GPT_4_GIZMO.getName())
|
||||
|| chatRequest.getModel().startsWith(ChatCompletion.Model.NET.getName())
|
||||
|| ChatCompletion.Model.GPT_4_VISION_PREVIEW.getName().equals(chatRequest.getModel())
|
||||
|| ChatCompletion.Model.CLAUDE_3_SONNET.getName().equals(chatRequest.getModel())
|
||||
|| ChatCompletion.Model.STABLE_DIFFUSION.getName().equals(chatRequest.getModel())
|
||||
|| ChatCompletion.Model.SUNO_V3.getName().equals(chatRequest.getModel())
|
||||
){
|
||||
double cost = OpenAIConst.GPT4_COST;
|
||||
if(ChatCompletion.Model.STABLE_DIFFUSION.getName().equals(chatRequest.getModel())){
|
||||
cost = 0.1;
|
||||
//根据模型名称查询模型信息
|
||||
SysModelBo sysModelBo = new SysModelBo();
|
||||
sysModelBo.setModelName(chatRequest.getModel());
|
||||
List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
|
||||
if (CollectionUtil.isEmpty(sysModelList)) {
|
||||
// 如果模型不存在默认使用token扣费方式
|
||||
processByToken(chatRequest.getModel(), msgList, chatMessageBo);
|
||||
} else {
|
||||
// 模型设置默认提示词
|
||||
SysModelVo firstModel = sysModelList.get(0);
|
||||
if (StringUtils.isNotEmpty(firstModel.getSystemPrompt())) {
|
||||
Message sysMessage = Message.builder().content(firstModel.getSystemPrompt()).role(Message.Role.SYSTEM).build();
|
||||
// 假设 msgList 不为空并且至少有一个元素
|
||||
if (msgList.get(0).equals(sysMessage)) {
|
||||
// 如果第一个元素与sysMessage相等,替换第一个元素
|
||||
msgList.set(0, sysMessage);
|
||||
} else {
|
||||
// 如果不相等,将sysMessage插入到列表的第一个位置
|
||||
msgList.add(0, sysMessage);
|
||||
}
|
||||
}
|
||||
if(ChatCompletion.Model.SUNO_V3.getName().equals(chatRequest.getModel())){
|
||||
cost = 0.5;
|
||||
// 计费类型: 1 token扣费 2 次数扣费
|
||||
if ("2".equals(firstModel.getModelType())) {
|
||||
processByModelPrice(firstModel, chatMessageBo);
|
||||
} else {
|
||||
processByToken(chatRequest.getModel(), msgList, chatMessageBo);
|
||||
}
|
||||
IChatService.deductUserBalance(getUserId(), cost);
|
||||
chatMessageBo.setDeductCost(cost);
|
||||
// 保存消息记录
|
||||
chatMessageService.insertByBo(chatMessageBo);
|
||||
}else {
|
||||
int tokens = TikTokensUtil.tokens(chatRequest.getModel(), msgList);
|
||||
chatMessageBo.setTotalTokens(tokens);
|
||||
// 按token扣费并且保存消息记录
|
||||
IChatService.deductToken(chatMessageBo);
|
||||
}
|
||||
}catch (Exception e){
|
||||
} catch (Exception e) {
|
||||
sendErrorEvent(sseEmitter, e.getMessage());
|
||||
return sseEmitter;
|
||||
}
|
||||
@@ -147,6 +155,32 @@ public class SseServiceImpl implements ISseService {
|
||||
return sseEmitter;
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据次数扣除余额
|
||||
*
|
||||
* @param model 模型信息
|
||||
* @param chatMessageBo 对话信息
|
||||
*/
|
||||
private void processByModelPrice(SysModelVo model, ChatMessageBo chatMessageBo) {
|
||||
double cost = model.getModelPrice();
|
||||
chatService.deductUserBalance(getUserId(), cost);
|
||||
chatMessageBo.setDeductCost(cost);
|
||||
chatMessageService.insertByBo(chatMessageBo);
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据token扣除余额
|
||||
*
|
||||
* @param modelName 模型名称
|
||||
* @param msgList 消息列表
|
||||
* @param chatMessageBo 消息记录
|
||||
*/
|
||||
private void processByToken(String modelName, List<Message> msgList, ChatMessageBo chatMessageBo) {
|
||||
int tokens = TikTokensUtil.tokens(modelName, msgList);
|
||||
chatMessageBo.setTotalTokens(tokens);
|
||||
chatService.deductToken(chatMessageBo);
|
||||
}
|
||||
|
||||
/**
|
||||
* 文字转语音
|
||||
*
|
||||
@@ -225,9 +259,9 @@ public class SseServiceImpl implements ISseService {
|
||||
|
||||
// 扣除费用
|
||||
if(Objects.equals(request.getSize(), "1792x1024") || Objects.equals(request.getSize(), "1024x1792")){
|
||||
IChatService.deductUserBalance(getUserId(),OpenAIConst.DALL3_HD_COST);
|
||||
chatService.deductUserBalance(getUserId(),OpenAIConst.DALL3_HD_COST);
|
||||
}else {
|
||||
IChatService.deductUserBalance(getUserId(),OpenAIConst.DALL3_COST);
|
||||
chatService.deductUserBalance(getUserId(),OpenAIConst.DALL3_COST);
|
||||
}
|
||||
// 保存消息记录
|
||||
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
||||
|
||||
@@ -96,7 +96,7 @@ public class SysDictTypeServiceImpl implements ISysDictTypeService, DictService
|
||||
* @param dictType 字典类型
|
||||
* @return 字典数据集合信息
|
||||
*/
|
||||
@Cacheable(cacheNames = CacheNames.SYS_DICT, key = "#dictType")
|
||||
// @Cacheable(cacheNames = CacheNames.SYS_DICT, key = "#dictType")
|
||||
@Override
|
||||
public List<SysDictDataVo> selectDictDataByType(String dictType) {
|
||||
List<SysDictDataVo> dictDatas = dictDataMapper.selectDictDataByType(dictType);
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
package com.xmzs.system.service.impl;
|
||||
|
||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import com.xmzs.common.core.utils.MapstructUtils;
|
||||
import com.xmzs.common.core.utils.StringUtils;
|
||||
import com.xmzs.common.mybatis.core.page.PageQuery;
|
||||
import com.xmzs.common.mybatis.core.page.TableDataInfo;
|
||||
import com.xmzs.system.domain.SysModel;
|
||||
import com.xmzs.system.domain.bo.SysModelBo;
|
||||
import com.xmzs.system.domain.vo.SysModelVo;
|
||||
import com.xmzs.system.mapper.SysModelMapper;
|
||||
import com.xmzs.system.service.ISysModelService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 系统模型Service业务层处理
|
||||
*
|
||||
* @author Lion Li
|
||||
* @date 2024-04-04
|
||||
*/
|
||||
@RequiredArgsConstructor
|
||||
@Service
|
||||
public class SysModelServiceImpl implements ISysModelService {
|
||||
|
||||
private final SysModelMapper baseMapper;
|
||||
|
||||
/**
|
||||
* 查询系统模型
|
||||
*/
|
||||
@Override
|
||||
public SysModelVo queryById(Long id){
|
||||
return baseMapper.selectVoById(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询系统模型列表
|
||||
*/
|
||||
@Override
|
||||
public TableDataInfo<SysModelVo> queryPageList(SysModelBo bo, PageQuery pageQuery) {
|
||||
LambdaQueryWrapper<SysModel> lqw = buildQueryWrapper(bo);
|
||||
Page<SysModelVo> result = baseMapper.selectVoPage(pageQuery.build(), lqw);
|
||||
return TableDataInfo.build(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询系统模型列表
|
||||
*/
|
||||
@Override
|
||||
public List<SysModelVo> queryList(SysModelBo bo) {
|
||||
LambdaQueryWrapper<SysModel> lqw = buildQueryWrapper(bo);
|
||||
return baseMapper.selectVoList(lqw);
|
||||
}
|
||||
|
||||
private LambdaQueryWrapper<SysModel> buildQueryWrapper(SysModelBo bo) {
|
||||
LambdaQueryWrapper<SysModel> lqw = Wrappers.lambdaQuery();
|
||||
lqw.like(StringUtils.isNotBlank(bo.getModelName()), SysModel::getModelName, bo.getModelName());
|
||||
lqw.like(StringUtils.isNotBlank(bo.getModelShow()), SysModel::getModelShow, bo.getModelShow());
|
||||
lqw.eq(StringUtils.isNotBlank(bo.getModelDescribe()), SysModel::getModelDescribe, bo.getModelDescribe());
|
||||
lqw.eq(StringUtils.isNotBlank(bo.getModelType()), SysModel::getModelType, bo.getModelType());
|
||||
return lqw;
|
||||
}
|
||||
|
||||
/**
|
||||
* 新增系统模型
|
||||
*/
|
||||
@Override
|
||||
public Boolean insertByBo(SysModelBo bo) {
|
||||
SysModel add = MapstructUtils.convert(bo, SysModel.class);
|
||||
validEntityBeforeSave(add);
|
||||
boolean flag = baseMapper.insert(add) > 0;
|
||||
if (flag) {
|
||||
bo.setId(add.getId());
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
|
||||
/**
|
||||
* 修改系统模型
|
||||
*/
|
||||
@Override
|
||||
public Boolean updateByBo(SysModelBo bo) {
|
||||
SysModel update = MapstructUtils.convert(bo, SysModel.class);
|
||||
validEntityBeforeSave(update);
|
||||
return baseMapper.updateById(update) > 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存前的数据校验
|
||||
*/
|
||||
private void validEntityBeforeSave(SysModel entity){
|
||||
//TODO 做一些数据校验,如唯一约束
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量删除系统模型
|
||||
*/
|
||||
@Override
|
||||
public Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid) {
|
||||
if(isValid){
|
||||
//TODO 做一些业务上的校验,判断是否需要校验
|
||||
}
|
||||
return baseMapper.deleteBatchIds(ids) > 0;
|
||||
}
|
||||
}
|
||||
@@ -78,6 +78,7 @@ public class SysUserServiceImpl implements ISysUserService, UserService {
|
||||
QueryWrapper<SysUser> wrapper = Wrappers.query();
|
||||
wrapper.eq("u.del_flag", UserConstants.USER_NORMAL)
|
||||
.eq(ObjectUtil.isNotNull(user.getUserId()), "u.user_id", user.getUserId())
|
||||
.eq(ObjectUtil.isNotNull(user.getUserGrade()), "u.user_grade", user.getUserGrade())
|
||||
.like(StringUtils.isNotBlank(user.getUserName()), "u.user_name", user.getUserName())
|
||||
.eq(StringUtils.isNotBlank(user.getStatus()), "u.status", user.getStatus())
|
||||
.like(StringUtils.isNotBlank(user.getPhonenumber()), "u.phonenumber", user.getPhonenumber())
|
||||
@@ -324,9 +325,9 @@ public class SysUserServiceImpl implements ISysUserService, UserService {
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public int updateUser(SysUserBo user) {
|
||||
// 新增用户与角色管理
|
||||
insertUserRole(user, true);
|
||||
//insertUserRole(user, true);
|
||||
// 新增用户与岗位管理
|
||||
insertUserPost(user, true);
|
||||
//insertUserPost(user, true);
|
||||
SysUser sysUser = MapstructUtils.convert(user, SysUser.class);
|
||||
// 防止错误更新后导致的数据误删除
|
||||
int flag = baseMapper.updateById(sysUser);
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<!DOCTYPE mapper
|
||||
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
|
||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
<mapper namespace="com.xmzs.system.mapper.SysModelMapper">
|
||||
|
||||
</mapper>
|
||||
@@ -70,7 +70,7 @@
|
||||
</update>
|
||||
|
||||
<select id="selectPageUserList" resultMap="SysUserResult">
|
||||
select u.user_id, u.dept_id, u.nick_name, u.user_name, u.email, u.avatar, u.phonenumber, u.sex,
|
||||
select u.user_id, u.dept_id, u.nick_name, u.user_name, u.email, u.avatar, u.phonenumber, u.sex,u.user_balance,u.user_grade,
|
||||
u.status, u.del_flag, u.login_ip, u.login_date, u.create_by, u.create_time, u.remark, d.dept_name, d.leader
|
||||
from sys_user u
|
||||
left join sys_dept d on u.dept_id = d.dept_id
|
||||
|
||||
Reference in New Issue
Block a user