mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-06 16:27:32 +00:00
feat: 全局格式化代码
This commit is contained in:
@@ -23,33 +23,32 @@ import java.util.concurrent.TimeUnit;
|
||||
@RequiredArgsConstructor
|
||||
public class ChatConfig {
|
||||
|
||||
private final ConfigService configService;
|
||||
@Getter
|
||||
private OpenAiStreamClient openAiStreamClient;
|
||||
|
||||
private final ConfigService configService;
|
||||
|
||||
@Bean
|
||||
public OpenAiStreamClient openAiStreamClient() {
|
||||
String apiHost = configService.getConfigValue("chat", "apiHost");
|
||||
String apiKey = configService.getConfigValue("chat", "apiKey");
|
||||
openAiStreamClient = createOpenAiStreamClient(apiHost,apiKey);
|
||||
return openAiStreamClient;
|
||||
}
|
||||
|
||||
public static OpenAiStreamClient createOpenAiStreamClient(String apiHost, String apiKey) {
|
||||
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
|
||||
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
|
||||
OkHttpClient okHttpClient = new OkHttpClient.Builder()
|
||||
.addInterceptor(httpLoggingInterceptor)
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.writeTimeout(600, TimeUnit.SECONDS)
|
||||
.readTimeout(600, TimeUnit.SECONDS)
|
||||
.build();
|
||||
.addInterceptor(httpLoggingInterceptor)
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.writeTimeout(600, TimeUnit.SECONDS)
|
||||
.readTimeout(600, TimeUnit.SECONDS)
|
||||
.build();
|
||||
return OpenAiStreamClient.builder()
|
||||
.apiHost(apiHost)
|
||||
.apiKey(Collections.singletonList(apiKey))
|
||||
.keyStrategy(new KeyRandomStrategy())
|
||||
.okHttpClient(okHttpClient)
|
||||
.build();
|
||||
.apiHost(apiHost)
|
||||
.apiKey(Collections.singletonList(apiKey))
|
||||
.keyStrategy(new KeyRandomStrategy())
|
||||
.okHttpClient(okHttpClient)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiStreamClient openAiStreamClient() {
|
||||
String apiHost = configService.getConfigValue("chat", "apiHost");
|
||||
String apiKey = configService.getConfigValue("chat", "apiKey");
|
||||
openAiStreamClient = createOpenAiStreamClient(apiHost, apiKey);
|
||||
return openAiStreamClient;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ public class OkHttpConfig {
|
||||
|
||||
private void initializeOkHttpUtil(String modelName) {
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(modelName);
|
||||
if(chatModelVo==null){
|
||||
if (chatModelVo == null) {
|
||||
return;
|
||||
}
|
||||
OkHttpUtil okHttpUtil = new OkHttpUtil();
|
||||
|
||||
@@ -68,7 +68,7 @@ public class ChatConfigController extends BaseController {
|
||||
@SaCheckPermission("system:config:query")
|
||||
@GetMapping("/{id}")
|
||||
public R<ChatConfigVo> getInfo(@NotNull(message = "主键不能为空")
|
||||
@PathVariable Long id) {
|
||||
@PathVariable Long id) {
|
||||
return R.ok(chatConfigService.queryById(id));
|
||||
}
|
||||
|
||||
@@ -81,9 +81,9 @@ public class ChatConfigController extends BaseController {
|
||||
@PostMapping("/saveOrUpdate")
|
||||
public R<Void> saveOrUpdate(@RequestBody List<ChatConfigBo> boList) {
|
||||
for (ChatConfigBo chatConfigBo : boList) {
|
||||
if(chatConfigBo.getId() == null){
|
||||
if (chatConfigBo.getId() == null) {
|
||||
chatConfigService.insertByBo(chatConfigBo);
|
||||
}else {
|
||||
} else {
|
||||
chatConfigService.updateByBo(chatConfigBo);
|
||||
}
|
||||
}
|
||||
@@ -121,12 +121,11 @@ public class ChatConfigController extends BaseController {
|
||||
*/
|
||||
@GetMapping(value = "/configKey/{configKey}")
|
||||
public R<String> getConfigKey(@PathVariable String configKey) {
|
||||
return R.ok(configService.getConfigValue("sys",configKey));
|
||||
return R.ok(configService.getConfigValue("sys", configKey));
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询系统参数
|
||||
*
|
||||
*/
|
||||
@GetMapping(value = "/sysConfigKey")
|
||||
public R<List<ChatConfigVo>> getSysConfigKey() {
|
||||
|
||||
@@ -19,7 +19,7 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
|
||||
/**
|
||||
* 聊天管理
|
||||
* 聊天管理
|
||||
*
|
||||
* @author ageerle@163.com
|
||||
* @date 2023-03-01
|
||||
@@ -38,7 +38,7 @@ public class ChatController {
|
||||
@PostMapping("/send")
|
||||
@ResponseBody
|
||||
public SseEmitter sseChat(@RequestBody @Valid ChatRequest chatRequest, HttpServletRequest request) {
|
||||
return sseService.sseChat(chatRequest,request);
|
||||
return sseService.sseChat(chatRequest, request);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -74,12 +74,11 @@ public class ChatMessageController extends BaseController {
|
||||
*/
|
||||
@GetMapping("/{id}")
|
||||
public R<ChatMessageVo> getInfo(@NotNull(message = "主键不能为空")
|
||||
@PathVariable Long id) {
|
||||
@PathVariable Long id) {
|
||||
return R.ok(chatMessageService.queryById(id));
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 查询聊天消息列表 uniapp
|
||||
*/
|
||||
|
||||
@@ -82,7 +82,7 @@ public class ChatModelController extends BaseController {
|
||||
*/
|
||||
@GetMapping("/{id}")
|
||||
public R<ChatModelVo> getInfo(@NotNull(message = "主键不能为空")
|
||||
@PathVariable Long id) {
|
||||
@PathVariable Long id) {
|
||||
return R.ok(chatModelService.queryById(id));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package org.ruoyi.chat.controller.chat;
|
||||
|
||||
import cn.dev33.satoken.annotation.SaCheckPermission;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
@@ -43,9 +42,9 @@ public class ChatSessionController extends BaseController {
|
||||
*/
|
||||
@GetMapping("/list")
|
||||
public TableDataInfo<ChatSessionVo> list(ChatSessionBo bo, PageQuery pageQuery) {
|
||||
if(!LoginHelper.isLogin()){
|
||||
// 如果用户没有登录,返回空会话列表
|
||||
return TableDataInfo.build();
|
||||
if (!LoginHelper.isLogin()) {
|
||||
// 如果用户没有登录,返回空会话列表
|
||||
return TableDataInfo.build();
|
||||
}
|
||||
// 默认查询当前用户会话
|
||||
bo.setUserId(LoginHelper.getUserId());
|
||||
@@ -69,7 +68,7 @@ public class ChatSessionController extends BaseController {
|
||||
*/
|
||||
@GetMapping("/{id}")
|
||||
public R<ChatSessionVo> getInfo(@NotNull(message = "主键不能为空")
|
||||
@PathVariable Long id) {
|
||||
@PathVariable Long id) {
|
||||
return R.ok(chatSessionService.queryById(id));
|
||||
}
|
||||
|
||||
|
||||
@@ -19,14 +19,7 @@ import org.ruoyi.domain.bo.PromptTemplateBo;
|
||||
import org.ruoyi.domain.vo.PromptTemplateVo;
|
||||
import org.ruoyi.service.IPromptTemplateService;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.PutMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package org.ruoyi.chat.controller.knowledge;
|
||||
|
||||
import cn.dev33.satoken.stp.StpUtil;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
@@ -30,7 +29,6 @@ import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ public class FaceController {
|
||||
@PostMapping("/insight-face/swap")
|
||||
public String insightFace(@RequestBody InsightFace insightFace) {
|
||||
// 扣除接口费用并且保存消息记录
|
||||
chatCostService.taskDeduct("mj","Face Changing", 0.0);
|
||||
chatCostService.taskDeduct("mj", "Face Changing", 0.0);
|
||||
// 创建请求体(这里使用JSON作为媒体类型)
|
||||
String insightFaceJson = JSONUtil.toJsonStr(insightFace);
|
||||
String url = "mj/insight-face/swap";
|
||||
|
||||
@@ -47,20 +47,20 @@ public class SubmitController {
|
||||
public String action(@RequestBody SubmitActionDTO changeDTO) {
|
||||
ActionType actionType = ActionType.fromCustomId(getAction(changeDTO.getCustomId()));
|
||||
Optional.ofNullable(actionType).ifPresentOrElse(
|
||||
type -> {
|
||||
switch (type) {
|
||||
case UP_SAMPLE:
|
||||
chatCostService.taskDeduct("mj","enlarge", 0.0);
|
||||
break;
|
||||
case IN_PAINT:
|
||||
// 局部重绘已经扣费,不执行任何操作
|
||||
break;
|
||||
default:
|
||||
chatCostService.taskDeduct("mj","change", 0.0);
|
||||
break;
|
||||
}
|
||||
},
|
||||
() -> chatCostService.taskDeduct("mj","change", 0.0)
|
||||
type -> {
|
||||
switch (type) {
|
||||
case UP_SAMPLE:
|
||||
chatCostService.taskDeduct("mj", "enlarge", 0.0);
|
||||
break;
|
||||
case IN_PAINT:
|
||||
// 局部重绘已经扣费,不执行任何操作
|
||||
break;
|
||||
default:
|
||||
chatCostService.taskDeduct("mj", "change", 0.0);
|
||||
break;
|
||||
}
|
||||
},
|
||||
() -> chatCostService.taskDeduct("mj", "change", 0.0)
|
||||
);
|
||||
|
||||
String jsonStr = JSONUtil.toJsonStr(changeDTO);
|
||||
@@ -81,7 +81,7 @@ public class SubmitController {
|
||||
@Operation(summary = "提交图生图、混图任务")
|
||||
@PostMapping("/blend")
|
||||
public String blend(@RequestBody SubmitBlendDTO blendDTO) {
|
||||
chatCostService.taskDeduct("mj","blend", 0.0);
|
||||
chatCostService.taskDeduct("mj", "blend", 0.0);
|
||||
String jsonStr = JSONUtil.toJsonStr(blendDTO);
|
||||
String url = "mj/submit/blend";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
@@ -91,7 +91,7 @@ public class SubmitController {
|
||||
@Operation(summary = "提交图生文任务")
|
||||
@PostMapping("/describe")
|
||||
public String describe(@RequestBody SubmitDescribeDTO describeDTO) {
|
||||
chatCostService.taskDeduct("mj","describe",0.0);
|
||||
chatCostService.taskDeduct("mj", "describe", 0.0);
|
||||
String jsonStr = JSONUtil.toJsonStr(describeDTO);
|
||||
String url = "mj/submit/describe";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
@@ -101,7 +101,7 @@ public class SubmitController {
|
||||
@Operation(summary = "提交文生图任务")
|
||||
@PostMapping("/imagine")
|
||||
public String imagine(@RequestBody SubmitImagineDTO imagineDTO) {
|
||||
chatCostService.taskDeduct("mj",imagineDTO.getPrompt(), 0.0);
|
||||
chatCostService.taskDeduct("mj", imagineDTO.getPrompt(), 0.0);
|
||||
String jsonStr = JSONUtil.toJsonStr(imagineDTO);
|
||||
String url = "mj/submit/imagine";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
@@ -111,7 +111,7 @@ public class SubmitController {
|
||||
@Operation(summary = "提交局部重绘任务")
|
||||
@PostMapping("/modal")
|
||||
public String modal(@RequestBody SubmitModalDTO submitModalDTO) {
|
||||
chatCostService.taskDeduct("mj","repaint ", 0.0);
|
||||
chatCostService.taskDeduct("mj", "repaint ", 0.0);
|
||||
String jsonStr = JSONUtil.toJsonStr(submitModalDTO);
|
||||
String url = "mj/submit/modal";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
@@ -121,7 +121,7 @@ public class SubmitController {
|
||||
@Operation(summary = "提交提示词分析任务")
|
||||
@PostMapping("/shorten")
|
||||
public String shorten(@RequestBody SubmitShortenDTO submitShortenDTO) {
|
||||
chatCostService.taskDeduct("mj","shorten", 0.0);
|
||||
chatCostService.taskDeduct("mj", "shorten", 0.0);
|
||||
String jsonStr = JSONUtil.toJsonStr(submitShortenDTO);
|
||||
String url = "mj/submit/shorten";
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, jsonStr);
|
||||
|
||||
@@ -34,7 +34,7 @@ public class SunoController {
|
||||
public String generate(@RequestBody GenerateSuno generateSuno) {
|
||||
OkHttpUtil okHttpUtil = okHttpConfig.getOkHttpUtil("suno");
|
||||
// 扣除接口费用并且保存消息记录
|
||||
chatCostService.taskDeduct("suno","文生歌曲", NumberUtils.toDouble(okHttpConfig.getGenerate(), 0.3));
|
||||
chatCostService.taskDeduct("suno", "文生歌曲", NumberUtils.toDouble(okHttpConfig.getGenerate(), 0.3));
|
||||
// 创建请求体(这里使用JSON作为媒体类型)
|
||||
String generateJson = JSONUtil.toJsonStr(generateSuno);
|
||||
String url = "suno/generate";
|
||||
@@ -57,7 +57,7 @@ public class SunoController {
|
||||
@GetMapping("/lyrics/{taskId}")
|
||||
public String lyrics(@PathVariable String taskId) {
|
||||
OkHttpUtil okHttpUtil = okHttpConfig.getOkHttpUtil("suno");
|
||||
String url = "task/suno/v1/fetch/"+taskId;
|
||||
String url = "task/suno/v1/fetch/" + taskId;
|
||||
Request request = okHttpUtil.createGetRequest(url);
|
||||
return okHttpUtil.executeRequest(request);
|
||||
}
|
||||
@@ -67,7 +67,7 @@ public class SunoController {
|
||||
@GetMapping("/feed/{taskId}")
|
||||
public String feed(@PathVariable String taskId) {
|
||||
OkHttpUtil okHttpUtil = okHttpConfig.getOkHttpUtil("suno");
|
||||
String url = "suno/feed/"+taskId;
|
||||
String url = "suno/feed/" + taskId;
|
||||
Request request = okHttpUtil.createGetRequest(url);
|
||||
return okHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
@@ -27,22 +27,22 @@ public class TaskController {
|
||||
|
||||
private final MjOkHttpUtil mjOkHttpUtil;
|
||||
|
||||
@Operation(summary = "指定ID获取任务")
|
||||
@GetMapping("/{id}/fetch")
|
||||
@Operation(summary = "指定ID获取任务")
|
||||
@GetMapping("/{id}/fetch")
|
||||
public String fetch(@Parameter(description = "任务ID") @PathVariable String id) {
|
||||
String url = "mj/task/" + id + "/fetch";
|
||||
Request request = mjOkHttpUtil.createGetRequest(url);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
|
||||
@Operation(summary = "根据ID列表查询任务")
|
||||
@PostMapping("/list-by-condition")
|
||||
public String listByIds(@RequestBody TaskConditionDTO conditionDTO) {
|
||||
@Operation(summary = "根据ID列表查询任务")
|
||||
@PostMapping("/list-by-condition")
|
||||
public String listByIds(@RequestBody TaskConditionDTO conditionDTO) {
|
||||
String url = "mj/task/list-by-condition";
|
||||
String conditionJson = JSONUtil.toJsonStr(conditionDTO);
|
||||
Request request = mjOkHttpUtil.createPostRequest(url,conditionJson);
|
||||
Request request = mjOkHttpUtil.createPostRequest(url, conditionJson);
|
||||
return mjOkHttpUtil.executeRequest(request);
|
||||
}
|
||||
}
|
||||
|
||||
@Operation(summary = "获取任务图片的seed")
|
||||
@GetMapping("/{id}/image-seed")
|
||||
|
||||
@@ -12,61 +12,59 @@ import java.util.Map;
|
||||
|
||||
|
||||
public class DomainObject implements Serializable {
|
||||
@Getter
|
||||
@Setter
|
||||
@Schema(description = "ID")
|
||||
protected String id;
|
||||
@JsonIgnore
|
||||
private final transient Object lock = new Object();
|
||||
@Getter
|
||||
@Setter
|
||||
@Schema(description = "ID")
|
||||
protected String id;
|
||||
@Setter
|
||||
protected Map<String, Object> properties; // 扩展属性,仅支持基本类型
|
||||
|
||||
@Setter
|
||||
protected Map<String, Object> properties; // 扩展属性,仅支持基本类型
|
||||
public void sleep() throws InterruptedException {
|
||||
synchronized (this.lock) {
|
||||
this.lock.wait();
|
||||
}
|
||||
}
|
||||
|
||||
@JsonIgnore
|
||||
private final transient Object lock = new Object();
|
||||
public void awake() {
|
||||
synchronized (this.lock) {
|
||||
this.lock.notifyAll();
|
||||
}
|
||||
}
|
||||
|
||||
public void sleep() throws InterruptedException {
|
||||
synchronized (this.lock) {
|
||||
this.lock.wait();
|
||||
}
|
||||
}
|
||||
public DomainObject setProperty(String name, Object value) {
|
||||
getProperties().put(name, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public void awake() {
|
||||
synchronized (this.lock) {
|
||||
this.lock.notifyAll();
|
||||
}
|
||||
}
|
||||
public DomainObject removeProperty(String name) {
|
||||
getProperties().remove(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
public DomainObject setProperty(String name, Object value) {
|
||||
getProperties().put(name, value);
|
||||
return this;
|
||||
}
|
||||
public Object getProperty(String name) {
|
||||
return getProperties().get(name);
|
||||
}
|
||||
|
||||
public DomainObject removeProperty(String name) {
|
||||
getProperties().remove(name);
|
||||
return this;
|
||||
}
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> T getPropertyGeneric(String name) {
|
||||
return (T) getProperty(name);
|
||||
}
|
||||
|
||||
public Object getProperty(String name) {
|
||||
return getProperties().get(name);
|
||||
}
|
||||
public <T> T getProperty(String name, Class<T> clz) {
|
||||
return getProperty(name, clz, null);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> T getPropertyGeneric(String name) {
|
||||
return (T) getProperty(name);
|
||||
}
|
||||
public <T> T getProperty(String name, Class<T> clz, T defaultValue) {
|
||||
Object value = getProperty(name);
|
||||
return value == null ? defaultValue : clz.cast(value);
|
||||
}
|
||||
|
||||
public <T> T getProperty(String name, Class<T> clz) {
|
||||
return getProperty(name, clz, null);
|
||||
}
|
||||
|
||||
public <T> T getProperty(String name, Class<T> clz, T defaultValue) {
|
||||
Object value = getProperty(name);
|
||||
return value == null ? defaultValue : clz.cast(value);
|
||||
}
|
||||
|
||||
public Map<String, Object> getProperties() {
|
||||
if (this.properties == null) {
|
||||
this.properties = new HashMap<>();
|
||||
}
|
||||
return this.properties;
|
||||
}
|
||||
public Map<String, Object> getProperties() {
|
||||
if (this.properties == null) {
|
||||
this.properties = new HashMap<>();
|
||||
}
|
||||
return this.properties;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,11 +11,15 @@ import java.io.Serializable;
|
||||
@Data
|
||||
@Schema(name = "Discord账号")
|
||||
public class InsightFace implements Serializable {
|
||||
/**本人头像json*/
|
||||
/**
|
||||
* 本人头像json
|
||||
*/
|
||||
@Schema(description = "本人头像json")
|
||||
private String sourceBase64;
|
||||
|
||||
/**明星头像json*/
|
||||
/**
|
||||
* 明星头像json
|
||||
*/
|
||||
@Schema(description = "明星头像json")
|
||||
private String targetBase64;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package org.ruoyi.chat.domain.bo;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 文生视频请求对象
|
||||
* 文生视频请求对象
|
||||
*
|
||||
* @author ageerle@163.com
|
||||
* date 2024/6/27
|
||||
|
||||
@@ -3,7 +3,7 @@ package org.ruoyi.chat.domain.bo;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 生成歌词
|
||||
* 生成歌词
|
||||
*
|
||||
* @author ageerle@163.com
|
||||
* date 2024/6/27
|
||||
|
||||
@@ -8,9 +8,9 @@ import lombok.Setter;
|
||||
@Setter
|
||||
public abstract class BaseSubmitDTO {
|
||||
|
||||
@Schema(description = "自定义参数")
|
||||
protected String state;
|
||||
@Schema(description = "自定义参数")
|
||||
protected String state;
|
||||
|
||||
@Schema(description = "回调地址, 为空时使用全局notifyHook")
|
||||
protected String notifyHook;
|
||||
@Schema(description = "回调地址, 为空时使用全局notifyHook")
|
||||
protected String notifyHook;
|
||||
}
|
||||
|
||||
@@ -8,9 +8,9 @@ import lombok.Data;
|
||||
@Schema(name = "变化任务提交参数")
|
||||
public class SubmitActionDTO {
|
||||
|
||||
private String customId;
|
||||
private String customId;
|
||||
|
||||
private String taskId;
|
||||
private String taskId;
|
||||
|
||||
private String state;
|
||||
|
||||
|
||||
@@ -13,9 +13,9 @@ import java.util.List;
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class SubmitBlendDTO extends BaseSubmitDTO {
|
||||
|
||||
@ArraySchema(arraySchema = @Schema(description = "图片base64数组", requiredMode = Schema.RequiredMode.REQUIRED), schema = @Schema(example = "data:image/png;base64,xxx1"))
|
||||
private List<String> base64Array;
|
||||
@ArraySchema(arraySchema = @Schema(description = "图片base64数组", requiredMode = Schema.RequiredMode.REQUIRED), schema = @Schema(example = "data:image/png;base64,xxx1"))
|
||||
private List<String> base64Array;
|
||||
|
||||
@Schema(description = "比例: PORTRAIT(2:3); SQUARE(1:1); LANDSCAPE(3:2)", example = "SQUARE")
|
||||
private BlendDimensions dimensions = BlendDimensions.SQUARE;
|
||||
@Schema(description = "比例: PORTRAIT(2:3); SQUARE(1:1); LANDSCAPE(3:2)", example = "SQUARE")
|
||||
private BlendDimensions dimensions = BlendDimensions.SQUARE;
|
||||
}
|
||||
|
||||
@@ -11,13 +11,13 @@ import org.ruoyi.chat.enums.TaskAction;
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class SubmitChangeDTO extends BaseSubmitDTO {
|
||||
|
||||
@Schema(description = "任务ID", requiredMode = Schema.RequiredMode.REQUIRED, example = "\"1320098173412546\"")
|
||||
private String taskId;
|
||||
@Schema(description = "任务ID", requiredMode = Schema.RequiredMode.REQUIRED, example = "\"1320098173412546\"")
|
||||
private String taskId;
|
||||
|
||||
@Schema(description = "UPSCALE(放大); VARIATION(变换); REROLL(重新生成)", requiredMode = Schema.RequiredMode.REQUIRED, allowableValues = {"UPSCALE", "VARIATION", "REROLL"}, example = "UPSCALE")
|
||||
private TaskAction action;
|
||||
@Schema(description = "UPSCALE(放大); VARIATION(变换); REROLL(重新生成)", requiredMode = Schema.RequiredMode.REQUIRED, allowableValues = {"UPSCALE", "VARIATION", "REROLL"}, example = "UPSCALE")
|
||||
private TaskAction action;
|
||||
|
||||
@Schema(description = "序号(1~4), action为UPSCALE,VARIATION时必传", minimum = "1", maximum = "4", example = "1")
|
||||
private Integer index;
|
||||
@Schema(description = "序号(1~4), action为UPSCALE,VARIATION时必传", minimum = "1", maximum = "4", example = "1")
|
||||
private Integer index;
|
||||
|
||||
}
|
||||
|
||||
@@ -9,6 +9,6 @@ import lombok.EqualsAndHashCode;
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class SubmitDescribeDTO extends BaseSubmitDTO {
|
||||
|
||||
@Schema(description = "图片base64", requiredMode = Schema.RequiredMode.REQUIRED, example = "data:image/png;base64,xxx")
|
||||
private String base64;
|
||||
@Schema(description = "图片base64", requiredMode = Schema.RequiredMode.REQUIRED, example = "data:image/png;base64,xxx")
|
||||
private String base64;
|
||||
}
|
||||
|
||||
@@ -12,14 +12,14 @@ import java.util.List;
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class SubmitImagineDTO extends BaseSubmitDTO {
|
||||
|
||||
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "Cat")
|
||||
private String prompt;
|
||||
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "Cat")
|
||||
private String prompt;
|
||||
|
||||
@Schema(description = "垫图base64数组")
|
||||
private List<String> base64Array;
|
||||
@Schema(description = "垫图base64数组")
|
||||
private List<String> base64Array;
|
||||
|
||||
@Schema(hidden = true)
|
||||
@Deprecated(since = "3.0", forRemoval = true)
|
||||
private String base64;
|
||||
@Schema(hidden = true)
|
||||
@Deprecated(since = "3.0", forRemoval = true)
|
||||
private String base64;
|
||||
|
||||
}
|
||||
|
||||
@@ -8,11 +8,11 @@ import lombok.EqualsAndHashCode;
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Schema(name = "局部重绘提交参数")
|
||||
public class SubmitModalDTO extends BaseSubmitDTO{
|
||||
public class SubmitModalDTO extends BaseSubmitDTO {
|
||||
|
||||
private String maskBase64;
|
||||
private String maskBase64;
|
||||
|
||||
private String taskId;
|
||||
private String taskId;
|
||||
|
||||
private String prompt;
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@ import lombok.EqualsAndHashCode;
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@Schema(name = "prompt分析提交参数")
|
||||
public class SubmitShortenDTO extends BaseSubmitDTO{
|
||||
public class SubmitShortenDTO extends BaseSubmitDTO {
|
||||
|
||||
private String botType;
|
||||
private String botType;
|
||||
|
||||
private String prompt;
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import lombok.EqualsAndHashCode;
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class SubmitSimpleChangeDTO extends BaseSubmitDTO {
|
||||
|
||||
@Schema(description = "变化描述: ID $action$index", requiredMode = Schema.RequiredMode.REQUIRED, example = "1320098173412546 U2")
|
||||
private String content;
|
||||
@Schema(description = "变化描述: ID $action$index", requiredMode = Schema.RequiredMode.REQUIRED, example = "1320098173412546 U2")
|
||||
private String content;
|
||||
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import java.util.List;
|
||||
@Schema(name = "任务查询参数")
|
||||
public class TaskConditionDTO {
|
||||
|
||||
@ArraySchema(arraySchema = @Schema(description = "任务ID列表"), schema = @Schema(example = "1320098173412546"))
|
||||
private List<String> ids;
|
||||
@ArraySchema(arraySchema = @Schema(description = "任务ID列表"), schema = @Schema(example = "1320098173412546"))
|
||||
private List<String> ids;
|
||||
|
||||
}
|
||||
|
||||
@@ -6,16 +6,16 @@ import lombok.Getter;
|
||||
@Getter
|
||||
public enum BlendDimensions {
|
||||
|
||||
PORTRAIT("2:3"),
|
||||
PORTRAIT("2:3"),
|
||||
|
||||
SQUARE("1:1"),
|
||||
SQUARE("1:1"),
|
||||
|
||||
LANDSCAPE("3:2");
|
||||
LANDSCAPE("3:2");
|
||||
|
||||
private final String value;
|
||||
private final String value;
|
||||
|
||||
BlendDimensions(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
BlendDimensions(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package org.ruoyi.chat.enums;
|
||||
import lombok.Getter;
|
||||
|
||||
/**
|
||||
* 是否显示
|
||||
* 是否显示
|
||||
*
|
||||
* @author ageerle@163.com
|
||||
* date 2025/4/10
|
||||
|
||||
@@ -2,29 +2,29 @@ package org.ruoyi.chat.enums;
|
||||
|
||||
|
||||
public enum TaskAction {
|
||||
/**
|
||||
* 生成图片.
|
||||
*/
|
||||
IMAGINE,
|
||||
/**
|
||||
* 选中放大.
|
||||
*/
|
||||
UPSCALE,
|
||||
/**
|
||||
* 选中其中的一张图,生成四张相似的.
|
||||
*/
|
||||
VARIATION,
|
||||
/**
|
||||
* 重新执行.
|
||||
*/
|
||||
REROLL,
|
||||
/**
|
||||
* 图转prompt.
|
||||
*/
|
||||
DESCRIBE,
|
||||
/**
|
||||
* 多图混合.
|
||||
*/
|
||||
BLEND
|
||||
/**
|
||||
* 生成图片.
|
||||
*/
|
||||
IMAGINE,
|
||||
/**
|
||||
* 选中放大.
|
||||
*/
|
||||
UPSCALE,
|
||||
/**
|
||||
* 选中其中的一张图,生成四张相似的.
|
||||
*/
|
||||
VARIATION,
|
||||
/**
|
||||
* 重新执行.
|
||||
*/
|
||||
REROLL,
|
||||
/**
|
||||
* 图转prompt.
|
||||
*/
|
||||
DESCRIBE,
|
||||
/**
|
||||
* 多图混合.
|
||||
*/
|
||||
BLEND
|
||||
|
||||
}
|
||||
|
||||
@@ -24,11 +24,28 @@ public class ChatMessageCreatedEvent extends ApplicationEvent {
|
||||
this.messageId = messageId;
|
||||
}
|
||||
|
||||
public Long getUserId() { return userId; }
|
||||
public Long getSessionId() { return sessionId; }
|
||||
public String getModelName() { return modelName; }
|
||||
public String getRole() { return role; }
|
||||
public String getContent() { return content; }
|
||||
public Long getMessageId() { return messageId; }
|
||||
public Long getUserId() {
|
||||
return userId;
|
||||
}
|
||||
|
||||
public Long getSessionId() {
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
public String getModelName() {
|
||||
return modelName;
|
||||
}
|
||||
|
||||
public String getRole() {
|
||||
return role;
|
||||
}
|
||||
|
||||
public String getContent() {
|
||||
return content;
|
||||
}
|
||||
|
||||
public Long getMessageId() {
|
||||
return messageId;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
* date 2025/5/10
|
||||
*/
|
||||
@Component
|
||||
public class ChatServiceFactory implements ApplicationContextAware {
|
||||
public class ChatServiceFactory implements ApplicationContextAware {
|
||||
private final Map<String, IChatService> chatServiceMap = new ConcurrentHashMap<>();
|
||||
private IChatCostService chatCostService;
|
||||
|
||||
@@ -26,7 +26,7 @@ public class ChatServiceFactory implements ApplicationContextAware {
|
||||
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
|
||||
// 获取计费服务
|
||||
this.chatCostService = applicationContext.getBean(IChatCostService.class);
|
||||
|
||||
|
||||
// 初始化时收集所有IChatService的实现
|
||||
Map<String, IChatService> serviceMap = applicationContext.getBeansOfType(IChatService.class);
|
||||
for (IChatService service : serviceMap.values()) {
|
||||
|
||||
@@ -5,11 +5,9 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.event.ChatMessageCreatedEvent;
|
||||
import org.ruoyi.chat.service.chat.IChatCostService;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.transaction.event.TransactionPhase;
|
||||
import org.springframework.transaction.event.TransactionalEventListener;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@@ -21,8 +19,8 @@ public class BillingEventListener {
|
||||
@Async
|
||||
@EventListener
|
||||
public void onChatMessageCreated(ChatMessageCreatedEvent event) {
|
||||
log.debug("BillingEventListener->接收到计费事件,用户ID: {},会话ID: {},模型: {}",
|
||||
event.getUserId(), event.getSessionId(), event.getModelName());
|
||||
log.debug("BillingEventListener->接收到计费事件,用户ID: {},会话ID: {},模型: {}",
|
||||
event.getUserId(), event.getSessionId(), event.getModelName());
|
||||
try {
|
||||
ChatRequest chatRequest = new ChatRequest();
|
||||
chatRequest.setUserId(event.getUserId());
|
||||
@@ -38,9 +36,9 @@ public class BillingEventListener {
|
||||
} catch (Exception ex) {
|
||||
// 由于已有预检查,这里的异常主要是系统异常(数据库连接等)
|
||||
// 记录错误但不中断异步线程
|
||||
log.error("BillingEventListener->异步计费异常,用户ID: {},模型: {},错误: {}",
|
||||
event.getUserId(), event.getModelName(), ex.getMessage(), ex);
|
||||
|
||||
log.error("BillingEventListener->异步计费异常,用户ID: {},模型: {},错误: {}",
|
||||
event.getUserId(), event.getModelName(), ex.getMessage(), ex);
|
||||
|
||||
// TODO: 可以考虑加入重试机制或者错误通知机制
|
||||
// 例如:发送到死信队列,或者通知运维人员
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package org.ruoyi.chat.listener;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -9,13 +8,13 @@ import okhttp3.ResponseBody;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.ruoyi.chat.support.RetryNotifier;
|
||||
import org.ruoyi.chat.util.SSEUtil;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.util.Objects;
|
||||
import org.ruoyi.chat.support.RetryNotifier;
|
||||
import org.ruoyi.chat.util.SSEUtil;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@@ -45,7 +44,7 @@ public class FastGPTSSEEventSourceListener extends EventSourceListener {
|
||||
try {
|
||||
log.debug("事件类型为: {}", type);
|
||||
log.debug("事件数据为: {}", data);
|
||||
if ("flowResponses".equals(type)){
|
||||
if ("flowResponses".equals(type)) {
|
||||
emitter.send(data);
|
||||
emitter.complete();
|
||||
RetryNotifier.clear(emitter);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package org.ruoyi.chat.listener;
|
||||
|
||||
|
||||
import cn.dev33.satoken.stp.StpUtil;
|
||||
import cn.hutool.core.collection.CollectionUtil;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -13,22 +12,19 @@ import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.ruoyi.chat.service.chat.IChatCostService;
|
||||
import org.ruoyi.chat.support.RetryNotifier;
|
||||
import org.ruoyi.chat.util.SSEUtil;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.common.core.service.BaseContext;
|
||||
import org.ruoyi.common.core.utils.SpringUtils;
|
||||
import org.ruoyi.common.core.utils.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.ruoyi.chat.util.SSEUtil;
|
||||
import org.ruoyi.chat.support.RetryNotifier;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* OpenAIEventSourceListener
|
||||
* OpenAIEventSourceListener
|
||||
*
|
||||
* @author https:www.unfbx.com
|
||||
* @date 2023-02-22
|
||||
@@ -38,18 +34,18 @@ import java.util.Objects;
|
||||
@Component
|
||||
public class SSEEventSourceListener extends EventSourceListener {
|
||||
|
||||
private static final IChatCostService chatCostService = SpringUtils.getBean(IChatCostService.class);
|
||||
private SseEmitter emitter;
|
||||
|
||||
private Long userId;
|
||||
|
||||
private Long sessionId;
|
||||
|
||||
private String token;
|
||||
|
||||
private boolean retryEnabled;
|
||||
private StringBuilder stringBuffer = new StringBuilder();
|
||||
|
||||
private String modelName;
|
||||
|
||||
@Autowired(required = false)
|
||||
public SSEEventSourceListener(SseEmitter emitter,Long userId,Long sessionId, String token, boolean retryEnabled) {
|
||||
public SSEEventSourceListener(SseEmitter emitter, Long userId, Long sessionId, String token, boolean retryEnabled) {
|
||||
this.emitter = emitter;
|
||||
this.userId = userId;
|
||||
this.sessionId = sessionId;
|
||||
@@ -57,13 +53,6 @@ public class SSEEventSourceListener extends EventSourceListener {
|
||||
this.retryEnabled = retryEnabled;
|
||||
}
|
||||
|
||||
|
||||
private StringBuilder stringBuffer = new StringBuilder();
|
||||
|
||||
private String modelName;
|
||||
|
||||
private static final IChatCostService chatCostService = SpringUtils.getBean(IChatCostService.class);
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@@ -105,13 +94,13 @@ public class SSEEventSourceListener extends EventSourceListener {
|
||||
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class);
|
||||
if(completionResponse == null || CollectionUtil.isEmpty(completionResponse.getChoices())){
|
||||
if (completionResponse == null || CollectionUtil.isEmpty(completionResponse.getChoices())) {
|
||||
return;
|
||||
}
|
||||
Object content = completionResponse.getChoices().get(0).getDelta().getContent();
|
||||
|
||||
if(content != null ){
|
||||
if(StringUtils.isEmpty(modelName)){
|
||||
if (content != null) {
|
||||
if (StringUtils.isEmpty(modelName)) {
|
||||
modelName = completionResponse.getModel();
|
||||
}
|
||||
stringBuffer.append(content);
|
||||
|
||||
@@ -27,7 +27,6 @@ public interface IChatCostService {
|
||||
void saveMessage(ChatRequest chatRequest);
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 仅发布异步计费事件(不做入库)
|
||||
*
|
||||
@@ -37,7 +36,6 @@ public interface IChatCostService {
|
||||
|
||||
/**
|
||||
* 直接扣除用户的余额
|
||||
*
|
||||
*/
|
||||
void deductUserBalance(Long userId, Double numberCost);
|
||||
|
||||
@@ -45,11 +43,11 @@ public interface IChatCostService {
|
||||
/**
|
||||
* 扣除任务费用并且保存记录
|
||||
*
|
||||
* @param type 任务类型
|
||||
* @param type 任务类型
|
||||
* @param prompt 任务描述
|
||||
* @param cost 扣除费用
|
||||
* @param cost 扣除费用
|
||||
*/
|
||||
void taskDeduct(String type,String prompt, double cost);
|
||||
void taskDeduct(String type, String prompt, double cost);
|
||||
|
||||
|
||||
/**
|
||||
@@ -64,7 +62,7 @@ public interface IChatCostService {
|
||||
|
||||
/**
|
||||
* 检查用户余额是否足够支付预估费用
|
||||
*
|
||||
*
|
||||
* @param chatRequest 对话信息
|
||||
* @return true=余额充足,false=余额不足
|
||||
*/
|
||||
|
||||
@@ -21,15 +21,17 @@ public interface ISseService {
|
||||
|
||||
/**
|
||||
* 客户端发送消息到服务端
|
||||
*
|
||||
* @param chatRequest 请求对象
|
||||
*/
|
||||
SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request);
|
||||
|
||||
/**
|
||||
* 语音转文字
|
||||
*
|
||||
* @param file 语音文件
|
||||
*/
|
||||
WhisperResponse speechToTextTranscriptionsV2(MultipartFile file);
|
||||
WhisperResponse speechToTextTranscriptionsV2(MultipartFile file);
|
||||
|
||||
/**
|
||||
* 文字转语音
|
||||
|
||||
@@ -3,11 +3,9 @@ package org.ruoyi.chat.service.chat.impl;
|
||||
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import java.math.BigDecimal;
|
||||
import java.math.RoundingMode;
|
||||
import org.ruoyi.chat.enums.BillingType;
|
||||
import org.ruoyi.chat.event.ChatMessageCreatedEvent;
|
||||
import org.ruoyi.chat.enums.UserGradeType;
|
||||
import org.ruoyi.chat.event.ChatMessageCreatedEvent;
|
||||
import org.ruoyi.chat.service.chat.IChatCostService;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.common.chat.utils.TikTokensUtil;
|
||||
@@ -26,6 +24,9 @@ import org.ruoyi.system.mapper.SysUserMapper;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.RoundingMode;
|
||||
|
||||
|
||||
/**
|
||||
* 计费管理Service实现
|
||||
@@ -108,10 +109,10 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
// 计算批次数:每1000个Token为一批,每批扣费单价
|
||||
int batches = billable / threshold;
|
||||
BigDecimal numberCost = unitPrice
|
||||
.multiply(BigDecimal.valueOf(batches))
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
.multiply(BigDecimal.valueOf(batches))
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
log.debug("deductToken->按token扣费,结算token数量: {},批次数: {},单价: {},费用: {}",
|
||||
billable, batches, unitPrice, numberCost);
|
||||
billable, batches, unitPrice, numberCost);
|
||||
|
||||
try {
|
||||
// 先尝试扣费
|
||||
@@ -172,7 +173,7 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
// 保存成功后,将生成的消息ID设置到ChatRequest中
|
||||
chatRequest.setMessageId(chatMessageBo.getId());
|
||||
log.debug("saveMessage->成功保存消息,消息ID: {}, 用户ID: {}, 会话ID: {}",
|
||||
chatMessageBo.getId(), chatRequest.getUserId(), chatRequest.getSessionId());
|
||||
chatMessageBo.getId(), chatRequest.getUserId(), chatRequest.getSessionId());
|
||||
} catch (Exception e) {
|
||||
log.error("saveMessage->保存消息失败", e);
|
||||
throw new ServiceException("保存消息失败");
|
||||
@@ -180,28 +181,27 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public void publishBillingEvent(ChatRequest chatRequest) {
|
||||
log.debug("publishBillingEvent->发布计费事件,用户ID: {},会话ID: {},模型: {}",
|
||||
chatRequest.getUserId(), chatRequest.getSessionId(), chatRequest.getModel());
|
||||
chatRequest.getUserId(), chatRequest.getSessionId(), chatRequest.getModel());
|
||||
|
||||
// 预检查:评估可能的扣费金额,如果余额不足则直接抛异常
|
||||
try {
|
||||
preCheckBalance(chatRequest);
|
||||
} catch (ServiceException e) {
|
||||
log.warn("publishBillingEvent->预检查余额不足,用户ID: {},模型: {}",
|
||||
chatRequest.getUserId(), chatRequest.getModel());
|
||||
chatRequest.getUserId(), chatRequest.getModel());
|
||||
throw e; // 直接抛出,阻止消息保存和对话继续
|
||||
}
|
||||
|
||||
eventPublisher.publishEvent(new ChatMessageCreatedEvent(
|
||||
chatRequest.getUserId(),
|
||||
chatRequest.getSessionId(),
|
||||
chatRequest.getModel(),
|
||||
chatRequest.getRole(),
|
||||
chatRequest.getPrompt(),
|
||||
chatRequest.getMessageId()
|
||||
chatRequest.getUserId(),
|
||||
chatRequest.getSessionId(),
|
||||
chatRequest.getModel(),
|
||||
chatRequest.getRole(),
|
||||
chatRequest.getPrompt(),
|
||||
chatRequest.getMessageId()
|
||||
));
|
||||
log.debug("publishBillingEvent->计费事件发布完成");
|
||||
}
|
||||
@@ -237,8 +237,8 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
// 计算批次数:每1000个Token为一批,每批扣费单价
|
||||
int batches = billable / threshold;
|
||||
BigDecimal numberCost = unitPrice
|
||||
.multiply(BigDecimal.valueOf(batches))
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
.multiply(BigDecimal.valueOf(batches))
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
checkUserBalanceWithoutDeduct(chatRequest.getUserId(), numberCost.doubleValue());
|
||||
}
|
||||
}
|
||||
@@ -253,9 +253,9 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
}
|
||||
|
||||
BigDecimal userBalance = BigDecimal.valueOf(sysUser.getUserBalance() == null ? 0D : sysUser.getUserBalance())
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
BigDecimal cost = BigDecimal.valueOf(numberCost == null ? 0D : numberCost)
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
|
||||
if (userBalance.compareTo(cost) < 0 || userBalance.compareTo(BigDecimal.ZERO) == 0) {
|
||||
throw new ServiceException("余额不足, 请充值。当前余额: " + userBalance + ",需要: " + cost);
|
||||
@@ -283,7 +283,7 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
// 更新消息
|
||||
chatMessageService.updateByBo(updateMessage);
|
||||
log.debug("updateMessageWithoutBilling->更新消息基本信息成功,消息ID: {}, 实际tokens: {}, 计费类型: {}",
|
||||
chatRequest.getMessageId(), actualTokens, billingTypeCode);
|
||||
chatRequest.getMessageId(), actualTokens, billingTypeCode);
|
||||
} catch (Exception e) {
|
||||
log.error("updateMessageWithoutBilling->更新消息基本信息失败,消息ID: {}", chatRequest.getMessageId(), e);
|
||||
// 更新失败不影响主流程,只记录错误日志
|
||||
@@ -318,7 +318,7 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
// 更新消息
|
||||
chatMessageService.updateByBo(updateMessage);
|
||||
log.debug("updateMessageBilling->更新消息计费信息成功,消息ID: {}, 实际tokens: {}, 计费tokens: {}, 费用: {}",
|
||||
chatRequest.getMessageId(), actualTokens, billedTokens, cost);
|
||||
chatRequest.getMessageId(), actualTokens, billedTokens, cost);
|
||||
} catch (Exception e) {
|
||||
log.error("updateMessageBilling->更新消息计费信息失败,消息ID: {}", chatRequest.getMessageId(), e);
|
||||
// 更新失败不影响主流程,只记录错误日志
|
||||
@@ -333,8 +333,10 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
BillingType billingType = BillingType.fromCode(billingTypeCode);
|
||||
if (billingType != null) {
|
||||
return switch (billingType) {
|
||||
case TIMES -> String.format("%s:消耗 %d tokens,扣费 %.2f 元", billingType.getDescription(), billedTokens, cost);
|
||||
case TOKEN -> String.format("%s:结算 %d tokens,扣费 %.2f 元", billingType.getDescription(), billedTokens, cost);
|
||||
case TIMES ->
|
||||
String.format("%s:消耗 %d tokens,扣费 %.2f 元", billingType.getDescription(), billedTokens, cost);
|
||||
case TOKEN ->
|
||||
String.format("%s:结算 %d tokens,扣费 %.2f 元", billingType.getDescription(), billedTokens, cost);
|
||||
};
|
||||
} else {
|
||||
return String.format("系统计费:处理 %d tokens,扣费 %.2f 元", billedTokens, cost);
|
||||
@@ -342,7 +344,6 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 从用户余额中扣除费用
|
||||
*
|
||||
@@ -358,9 +359,9 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
}
|
||||
|
||||
BigDecimal userBalance = BigDecimal.valueOf(sysUser.getUserBalance() == null ? 0D : sysUser.getUserBalance())
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
BigDecimal cost = BigDecimal.valueOf(numberCost == null ? 0D : numberCost)
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
.setScale(2, RoundingMode.HALF_UP);
|
||||
|
||||
log.debug("deductUserBalance->准备扣除: {},当前余额: {}", cost, userBalance);
|
||||
|
||||
@@ -375,16 +376,16 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
newBalance = newBalance.setScale(2, RoundingMode.HALF_UP);
|
||||
|
||||
sysUserMapper.update(null,
|
||||
new LambdaUpdateWrapper<SysUser>()
|
||||
.set(SysUser::getUserBalance, newBalance.doubleValue())
|
||||
.eq(SysUser::getUserId, userId));
|
||||
new LambdaUpdateWrapper<SysUser>()
|
||||
.set(SysUser::getUserBalance, newBalance.doubleValue())
|
||||
.eq(SysUser::getUserId, userId));
|
||||
}
|
||||
|
||||
/**
|
||||
* 扣除任务费用
|
||||
*/
|
||||
@Override
|
||||
public void taskDeduct(String type,String prompt, double cost) {
|
||||
public void taskDeduct(String type, String prompt, double cost) {
|
||||
// 判断用户是否付费
|
||||
checkUserGrade();
|
||||
// 扣除费用
|
||||
@@ -406,7 +407,7 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
@Override
|
||||
public void checkUserGrade() {
|
||||
SysUser sysUser = sysUserMapper.selectById(getUserId());
|
||||
if(UserGradeType.UNPAID.getCode().equals(sysUser.getUserGrade())){
|
||||
if (UserGradeType.UNPAID.getCode().equals(sysUser.getUserGrade())) {
|
||||
throw new BaseException("该模型仅限付费用户使用。请升级套餐,开启高效体验之旅!");
|
||||
}
|
||||
}
|
||||
@@ -439,11 +440,11 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
return true; // 预检查通过,余额充足
|
||||
} catch (ServiceException e) {
|
||||
log.debug("checkBalanceSufficient->余额不足,用户ID: {}, 模型: {}, 错误: {}",
|
||||
chatRequest.getUserId(), chatRequest.getModel(), e.getMessage());
|
||||
chatRequest.getUserId(), chatRequest.getModel(), e.getMessage());
|
||||
return false; // 预检查失败,余额不足
|
||||
} catch (Exception e) {
|
||||
log.error("checkBalanceSufficient->检查余额时发生异常,用户ID: {}, 模型: {}",
|
||||
chatRequest.getUserId(), chatRequest.getModel(), e);
|
||||
chatRequest.getUserId(), chatRequest.getModel(), e);
|
||||
return false; // 异常情况视为余额不足,保守处理
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import io.reactivex.Flowable;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.enums.ChatModeType;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
import org.ruoyi.chat.support.RetryNotifier;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
@@ -20,8 +22,6 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
import java.util.Collections;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import org.ruoyi.chat.support.RetryNotifier;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
|
||||
/**
|
||||
* 扣子聊天管理
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package org.ruoyi.chat.service.chat.impl;
|
||||
|
||||
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.data.message.SystemMessage;
|
||||
import dev.langchain4j.data.message.UserMessage;
|
||||
@@ -15,7 +12,6 @@ import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import okhttp3.Response;
|
||||
import org.ruoyi.chat.enums.ChatModeType;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
@@ -27,6 +23,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
@@ -34,31 +31,22 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.InputStreamReader;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
/**
|
||||
* deepseek
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class DeepSeekChatImpl implements IChatService {
|
||||
|
||||
@Autowired
|
||||
private IChatModelService chatModelService;
|
||||
public class DeepSeekChatImpl implements IChatService {
|
||||
|
||||
private static final MediaType JSON = MediaType.get("application/json; charset=utf-8");
|
||||
|
||||
// 创建一个用于直接API调用的OkHttpClient
|
||||
private final OkHttpClient client = new OkHttpClient.Builder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.readTimeout(60, TimeUnit.SECONDS)
|
||||
.writeTimeout(30, TimeUnit.SECONDS)
|
||||
.build();
|
||||
@Autowired
|
||||
private IChatModelService chatModelService;
|
||||
|
||||
@Override
|
||||
public SseEmitter chat(ChatRequest chatRequest, SseEmitter emitter) {
|
||||
|
||||
@@ -15,7 +15,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.enums.ChatModeType;
|
||||
import org.ruoyi.chat.service.chat.IChatCostService;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
import org.ruoyi.chat.support.RetryNotifier;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.domain.bo.ChatSessionBo;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
@@ -25,10 +26,8 @@ import org.ruoyi.service.IChatSessionService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
|
||||
import java.util.Objects;
|
||||
import org.ruoyi.chat.support.RetryNotifier;
|
||||
|
||||
/**
|
||||
* dify 聊天管理
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package org.ruoyi.chat.service.chat.impl;
|
||||
|
||||
import cn.dev33.satoken.stp.StpUtil;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.config.ChatConfig;
|
||||
import org.ruoyi.chat.enums.ChatModeType;
|
||||
import org.ruoyi.chat.listener.SSEEventSourceListener;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||
@@ -17,8 +17,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.util.*;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 图片识别模型
|
||||
|
||||
@@ -9,7 +9,8 @@ import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.enums.ChatModeType;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.chat.util.SSEUtil;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
import org.ruoyi.chat.support.RetryNotifier;
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
@@ -22,8 +23,6 @@ import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.ruoyi.chat.support.RetryNotifier;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
|
||||
|
||||
/**
|
||||
|
||||
@@ -31,14 +31,12 @@ import java.util.List;
|
||||
@Slf4j
|
||||
public class OpenAIServiceImpl implements IChatService {
|
||||
|
||||
private final ChatClient chatClient;
|
||||
@Autowired
|
||||
private IChatModelService chatModelService;
|
||||
|
||||
@Value("${spring.ai.mcp.client.enabled}")
|
||||
private Boolean enabled;
|
||||
|
||||
private final ChatClient chatClient;
|
||||
|
||||
public OpenAIServiceImpl(ChatClient.Builder chatClientBuilder, List<McpSyncClient> mcpSyncClients) {
|
||||
this.chatClient = chatClientBuilder
|
||||
.defaultOptions(
|
||||
@@ -48,13 +46,13 @@ public class OpenAIServiceImpl implements IChatService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public SseEmitter chat(ChatRequest chatRequest,SseEmitter emitter) {
|
||||
public SseEmitter chat(ChatRequest chatRequest, SseEmitter emitter) {
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||
OpenAiStreamClient openAiStreamClient = ChatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
|
||||
List<Message> messages = chatRequest.getMessages();
|
||||
if (enabled) {
|
||||
String toolString = mcpChat(chatRequest.getPrompt());
|
||||
Message userMessage = Message.builder().content("工具返回信息:"+toolString).role(Message.Role.USER).build();
|
||||
Message userMessage = Message.builder().content("工具返回信息:" + toolString).role(Message.Role.USER).build();
|
||||
messages.add(userMessage);
|
||||
}
|
||||
SSEEventSourceListener listener = ChatServiceHelper.createOpenAiListener(emitter, chatRequest);
|
||||
@@ -73,7 +71,7 @@ public class OpenAIServiceImpl implements IChatService {
|
||||
return emitter;
|
||||
}
|
||||
|
||||
public String mcpChat(String prompt){
|
||||
public String mcpChat(String prompt) {
|
||||
return this.chatClient.prompt(prompt).call().content();
|
||||
}
|
||||
|
||||
|
||||
@@ -30,11 +30,7 @@ import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.domain.vo.PromptTemplateVo;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.ruoyi.service.IChatSessionService;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.IPromptTemplateService;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.ruoyi.service.*;
|
||||
import org.springframework.core.io.InputStreamResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.http.MediaType;
|
||||
@@ -73,12 +69,26 @@ public class SseServiceImpl implements ISseService {
|
||||
private final IChatSessionService chatSessionService;
|
||||
|
||||
private final IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
private ChatModelVo chatModelVo;
|
||||
|
||||
// 提示词模板服务
|
||||
private final IPromptTemplateService promptTemplateService;
|
||||
private ChatModelVo chatModelVo;
|
||||
|
||||
/**
|
||||
* 获取对话标题
|
||||
*
|
||||
* @param str 原字符
|
||||
* @return 截取后的字符
|
||||
*/
|
||||
public static String getFirst10Characters(String str) {
|
||||
// 判断字符串长度
|
||||
if (str.length() > 10) {
|
||||
// 如果长度大于10,截取前10个字符
|
||||
return str.substring(0, 10);
|
||||
} else {
|
||||
// 如果长度不足10,返回整个字符串
|
||||
return str;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
||||
@@ -109,7 +119,7 @@ public class SseServiceImpl implements ISseService {
|
||||
chatSessionService.insertByBo(chatSessionBo);
|
||||
chatRequest.setSessionId(chatSessionBo.getId());
|
||||
}
|
||||
|
||||
|
||||
// 保存用户消息
|
||||
chatCostService.saveMessage(chatRequest);
|
||||
}
|
||||
@@ -195,23 +205,6 @@ public class SseServiceImpl implements ISseService {
|
||||
return model;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取对话标题
|
||||
*
|
||||
* @param str 原字符
|
||||
* @return 截取后的字符
|
||||
*/
|
||||
public static String getFirst10Characters(String str) {
|
||||
// 判断字符串长度
|
||||
if (str.length() > 10) {
|
||||
// 如果长度大于10,截取前10个字符
|
||||
return str.substring(0, 10);
|
||||
} else {
|
||||
// 如果长度不足10,返回整个字符串
|
||||
return str;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建消息列表
|
||||
*/
|
||||
|
||||
@@ -9,14 +9,13 @@ import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.enums.ChatModeType;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
@@ -24,20 +23,17 @@ import org.ruoyi.chat.support.ChatServiceHelper;
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class ZhipuAiChatServiceImpl implements IChatService {
|
||||
|
||||
@Autowired
|
||||
private IChatModelService chatModelService;
|
||||
|
||||
public class ZhipuAiChatServiceImpl implements IChatService {
|
||||
|
||||
ToolSpecification currentTime = ToolSpecification.builder()
|
||||
.name("currentTime")
|
||||
.description("currentTime")
|
||||
.build();
|
||||
|
||||
@Autowired
|
||||
private IChatModelService chatModelService;
|
||||
|
||||
@Override
|
||||
public SseEmitter chat(ChatRequest chatRequest, SseEmitter emitter){
|
||||
public SseEmitter chat(ChatRequest chatRequest, SseEmitter emitter) {
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||
// 发送流式消息
|
||||
try {
|
||||
|
||||
@@ -6,9 +6,7 @@ import org.ruoyi.chat.service.chat.IChatCostService;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.common.chat.utils.TikTokensUtil;
|
||||
import org.ruoyi.common.core.service.BaseContext;
|
||||
import org.ruoyi.domain.bo.ChatMessageBo;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
@@ -18,7 +16,6 @@ import java.util.function.Consumer;
|
||||
/**
|
||||
* 统一计费代理类
|
||||
* 自动处理所有ChatService的AI回复保存和计费逻辑
|
||||
*
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
@@ -33,7 +30,7 @@ public class BillingChatServiceProxy implements IChatService {
|
||||
if (!chatCostService.checkBalanceSufficient(chatRequest)) {
|
||||
String errorMsg = "余额不足,无法使用AI服务,请充值后再试";
|
||||
log.warn("余额不足阻止AI回复,用户ID: {}, 模型: {}",
|
||||
chatRequest.getUserId(), chatRequest.getModel());
|
||||
chatRequest.getUserId(), chatRequest.getModel());
|
||||
try {
|
||||
emitter.send(errorMsg);
|
||||
emitter.complete();
|
||||
@@ -47,7 +44,7 @@ public class BillingChatServiceProxy implements IChatService {
|
||||
}
|
||||
|
||||
log.debug("余额检查通过,开始AI回复,用户ID: {}, 模型: {}",
|
||||
chatRequest.getUserId(), chatRequest.getModel());
|
||||
chatRequest.getUserId(), chatRequest.getModel());
|
||||
|
||||
// 创建增强的SseEmitter,自动收集AI回复
|
||||
BillingSseEmitter billingEmitter = new BillingSseEmitter(emitter, chatRequest, chatCostService);
|
||||
@@ -150,11 +147,11 @@ public class BillingChatServiceProxy implements IChatService {
|
||||
chatCostService.publishBillingEvent(aiChatRequest);
|
||||
|
||||
log.debug("AI回复保存和计费完成,用户ID: {}, 会话ID: {}, 回复长度: {}",
|
||||
chatRequest.getUserId(), chatRequest.getSessionId(), aiResponse.length());
|
||||
chatRequest.getUserId(), chatRequest.getSessionId(), aiResponse.length());
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("保存AI回复和计费失败,用户ID: {}, 会话ID: {}",
|
||||
chatRequest.getUserId(), chatRequest.getSessionId(), e);
|
||||
chatRequest.getUserId(), chatRequest.getSessionId(), e);
|
||||
// 不抛出异常,避免影响用户体验
|
||||
}
|
||||
}
|
||||
@@ -212,10 +209,10 @@ public class BillingChatServiceProxy implements IChatService {
|
||||
|
||||
String trimmed = data.trim();
|
||||
return "[DONE]".equals(trimmed)
|
||||
|| "null".equals(trimmed)
|
||||
|| trimmed.startsWith("event:")
|
||||
|| trimmed.startsWith("id:")
|
||||
|| trimmed.startsWith("retry:");
|
||||
|| "null".equals(trimmed)
|
||||
|| trimmed.startsWith("event:")
|
||||
|| trimmed.startsWith("id:")
|
||||
|| trimmed.startsWith("retry:");
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -260,11 +257,11 @@ public class BillingChatServiceProxy implements IChatService {
|
||||
*/
|
||||
private boolean isPureTextContent(String data) {
|
||||
return data != null
|
||||
&& !data.trim().isEmpty()
|
||||
&& !data.contains("{")
|
||||
&& !data.contains("[")
|
||||
&& !data.contains("data:")
|
||||
&& data.length() < 500; // 合理的文本长度
|
||||
&& !data.trim().isEmpty()
|
||||
&& !data.contains("{")
|
||||
&& !data.contains("[")
|
||||
&& !data.contains("data:")
|
||||
&& data.length() < 500; // 合理的文本长度
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -92,23 +92,23 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
public TableDataInfo<KnowledgeInfoVo> queryPageListByRole(KnowledgeInfoBo bo, PageQuery pageQuery) {
|
||||
// 查询用户关联角色
|
||||
LoginUser loginUser = LoginHelper.getLoginUser();
|
||||
|
||||
|
||||
// 构建查询条件
|
||||
LambdaQueryWrapper<KnowledgeInfo> lqw = buildQueryWrapper(bo);
|
||||
|
||||
|
||||
// 管理员用户直接查询所有数据
|
||||
if (Objects.equals(loginUser.getUserId(), 1L)) {
|
||||
Page<KnowledgeInfoVo> result = baseMapper.selectVoPage(pageQuery.build(), lqw);
|
||||
return TableDataInfo.build(result);
|
||||
}
|
||||
|
||||
|
||||
// 检查用户是否配置了角色信息
|
||||
if (StringUtils.isNotEmpty(loginUser.getKroleGroupIds()) && StringUtils.isNotEmpty(loginUser.getKroleGroupType())) {
|
||||
// 角色/角色组id列表
|
||||
List<String> groupIdList = Arrays.stream(loginUser.getKroleGroupIds().split(","))
|
||||
.filter(StringUtils::isNotEmpty)
|
||||
.toList();
|
||||
|
||||
|
||||
// 查询用户关联的角色
|
||||
List<KnowledgeRole> knowledgeRoles = new ArrayList<>();
|
||||
LambdaQueryWrapper<KnowledgeRole> roleLqw = Wrappers.lambdaQuery();
|
||||
@@ -123,8 +123,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
if (!CollectionUtils.isEmpty(knowledgeRoles)) {
|
||||
// 查询这些角色关联的知识库
|
||||
LambdaQueryWrapper<KnowledgeRoleRelation> relationLqw = Wrappers.lambdaQuery();
|
||||
relationLqw.in(KnowledgeRoleRelation::getKnowledgeRoleId,
|
||||
knowledgeRoles.stream().map(KnowledgeRole::getId).filter(Objects::nonNull).collect(Collectors.toList()));
|
||||
relationLqw.in(KnowledgeRoleRelation::getKnowledgeRoleId,
|
||||
knowledgeRoles.stream().map(KnowledgeRole::getId).filter(Objects::nonNull).collect(Collectors.toList()));
|
||||
List<KnowledgeRoleRelation> knowledgeRoleRelations = knowledgeRoleRelationMapper.selectList(relationLqw);
|
||||
|
||||
// 如果角色关联了知识库
|
||||
@@ -151,7 +151,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
// 用户没有配置角色信息,只显示自己的
|
||||
lqw.eq(KnowledgeInfo::getUid, loginUser.getUserId());
|
||||
}
|
||||
|
||||
|
||||
Page<KnowledgeInfoVo> result = baseMapper.selectVoPage(pageQuery.build(), lqw);
|
||||
return TableDataInfo.build(result);
|
||||
}
|
||||
|
||||
@@ -9,24 +9,19 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.function.BiConsumer;
|
||||
|
||||
/**
|
||||
* 统一的聊天重试与降级调度器。
|
||||
*
|
||||
* <p>
|
||||
* 策略:
|
||||
* - 当前模型最多重试 3 次;仍失败则降级到同分类内、优先级小于当前的最高优先级模型。
|
||||
* - 降级模型同样最多重试 3 次;仍失败则向前端返回失败信息并停止。
|
||||
*
|
||||
* <p>
|
||||
* 注意:实现依赖调用方在底层异步失败时执行 onFailure.run() 通知本调度器。
|
||||
*/
|
||||
@Slf4j
|
||||
public class ChatRetryHelper {
|
||||
|
||||
public interface AttemptStarter {
|
||||
void start(ChatModelVo model, Runnable onFailure) throws Exception;
|
||||
}
|
||||
|
||||
public static void executeWithRetry(
|
||||
ChatModelVo primaryModel,
|
||||
String category,
|
||||
@@ -110,6 +105,10 @@ public class ChatRetryHelper {
|
||||
|
||||
new Scheduler().startAttempt();
|
||||
}
|
||||
|
||||
public interface AttemptStarter {
|
||||
void start(ChatModelVo model, Runnable onFailure) throws Exception;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package org.ruoyi.chat.support;
|
||||
|
||||
import org.ruoyi.chat.listener.SSEEventSourceListener;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.chat.util.SSEUtil;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
/**
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user