Merge pull request #175 from LM20230311/feat-model-priority

Feat model priority:支持自动选择模型;支持模型的重试;
This commit is contained in:
evo
2025-08-20 14:04:17 +08:00
committed by GitHub
23 changed files with 618 additions and 95 deletions

View File

@@ -67,4 +67,19 @@ public class ChatRequest {
*/
private Long uuid;
/**
* 是否有附件
*/
private Boolean hasAttachment;
/**
* 是否自动切换模型
*/
private Boolean autoSelectModel;
/**
* 会话令牌为避免在非Web线程中获取Request入口处注入
*/
private String token;
}

View File

@@ -76,6 +76,11 @@ public class ChatModel extends BaseEntity {
*/
private String apiKey;
/**
* 优先级
*/
private Integer priority;
/**
* 备注
*/

View File

@@ -74,6 +74,11 @@ public class ChatModelBo extends BaseEntity {
@NotBlank(message = "请求地址不能为空", groups = { AddGroup.class, EditGroup.class })
private String apiHost;
/**
* 优先级
*/
private Integer priority;
/**
* 密钥
*/

View File

@@ -90,6 +90,12 @@ public class ChatModelVo implements Serializable {
@ExcelProperty(value = "密钥")
private String apiKey;
/**
* 优先级(值越大优先级越高)
*/
@ExcelProperty(value = "优先级")
private Integer priority;
/**
* 备注
*/

View File

@@ -57,6 +57,17 @@ public interface IChatModelService {
* 通过模型分类获取模型信息
*/
ChatModelVo selectModelByCategory(String image);
/**
* 通过模型分类获取优先级最高的模型信息
*/
ChatModelVo selectModelByCategoryWithHighestPriority(String category);
/**
* 在同一分类下,查找优先级小于当前优先级的最高优先级模型(用于降级)。
*/
ChatModelVo selectFallbackModelByCategoryAndLessPriority(String category, Integer currentPriority);
/**
* 获取ppt模型信息
*/

View File

@@ -137,6 +137,33 @@ public class ChatModelServiceImpl implements IChatModelService {
return baseMapper.selectVoOne(Wrappers.<ChatModel>lambdaQuery().eq(ChatModel::getCategory, category));
}
/**
* 通过模型分类获取优先级最高的模型信息
*/
@Override
public ChatModelVo selectModelByCategoryWithHighestPriority(String category) {
return baseMapper.selectVoOne(
Wrappers.<ChatModel>lambdaQuery()
.eq(ChatModel::getCategory, category)
.orderByDesc(ChatModel::getPriority)
.last("LIMIT 1")
);
}
/**
* 在同一分类下,查找优先级小于当前优先级的最高优先级模型(用于降级)。
*/
@Override
public ChatModelVo selectFallbackModelByCategoryAndLessPriority(String category, Integer currentPriority) {
return baseMapper.selectVoOne(
Wrappers.<ChatModel>lambdaQuery()
.eq(ChatModel::getCategory, category)
.lt(ChatModel::getPriority, currentPriority)
.orderByDesc(ChatModel::getPriority)
.last("LIMIT 1")
);
}
@Override
public ChatModel getPPT() {
return baseMapper.selectOne(Wrappers.<ChatModel>lambdaQuery().eq(ChatModel::getModelName, "ppt"));

View File

@@ -14,6 +14,8 @@ import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Objects;
import org.ruoyi.chat.support.RetryNotifier;
import org.ruoyi.chat.util.SSEUtil;
@Slf4j
@Component
@@ -21,12 +23,18 @@ import java.util.Objects;
public class FastGPTSSEEventSourceListener extends EventSourceListener {
private SseEmitter emitter;
private Long sessionId;
@Autowired(required = false)
public FastGPTSSEEventSourceListener(SseEmitter emitter) {
this.emitter = emitter;
}
public FastGPTSSEEventSourceListener(SseEmitter emitter, Long sessionId) {
this.emitter = emitter;
this.sessionId = sessionId;
}
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("FastGPT sse连接成功");
@@ -40,6 +48,7 @@ public class FastGPTSSEEventSourceListener extends EventSourceListener {
if ("flowResponses".equals(type)){
emitter.send(data);
emitter.complete();
RetryNotifier.clear(emitter);
} else {
emitter.send(data);
}
@@ -57,13 +66,20 @@ public class FastGPTSSEEventSourceListener extends EventSourceListener {
@SneakyThrows
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if (Objects.isNull(response)) {
SSEUtil.sendErrorEvent(emitter, t != null ? t.getMessage() : "SSE连接失败");
RetryNotifier.notifyFailure(emitter);
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
log.error("FastGPT sse连接异常data{},异常:{}", body.string(), t);
String msg = body.string();
log.error("FastGPT sse连接异常data{},异常:{}", msg, t);
SSEUtil.sendErrorEvent(emitter, msg);
RetryNotifier.notifyFailure(emitter);
} else {
log.error("FastGPT sse连接异常data{},异常:{}", response, t);
SSEUtil.sendErrorEvent(emitter, String.valueOf(response));
RetryNotifier.notifyFailure(emitter);
}
eventSource.cancel();
}

View File

@@ -21,6 +21,8 @@ import org.ruoyi.common.core.utils.SpringUtils;
import org.ruoyi.common.core.utils.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.ruoyi.chat.util.SSEUtil;
import org.ruoyi.chat.support.RetryNotifier;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Objects;
@@ -44,12 +46,15 @@ public class SSEEventSourceListener extends EventSourceListener {
private String token;
private boolean retryEnabled;
@Autowired(required = false)
public SSEEventSourceListener(SseEmitter emitter,Long userId,Long sessionId, String token) {
public SSEEventSourceListener(SseEmitter emitter,Long userId,Long sessionId, String token, boolean retryEnabled) {
this.emitter = emitter;
this.userId = userId;
this.sessionId = sessionId;
this.token = token;
this.retryEnabled = retryEnabled;
}
@@ -77,6 +82,8 @@ public class SSEEventSourceListener extends EventSourceListener {
if ("[DONE]".equals(data)) {
//成功响应
emitter.complete();
// 清理失败回调(以 emitter 为键)
RetryNotifier.clear(emitter);
// 扣除费用
ChatRequest chatRequest = new ChatRequest();
// 设置对话角色
@@ -113,19 +120,38 @@ public class SSEEventSourceListener extends EventSourceListener {
@Override
public void onClosed(EventSource eventSource) {
log.info("OpenAI关闭sse连接...");
// 清理失败回调
RetryNotifier.clear(emitter);
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if (Objects.isNull(response)) {
// 透传错误到前端
SSEUtil.sendErrorEvent(emitter, t != null ? t.getMessage() : "SSE连接失败");
if (retryEnabled) {
// 通知重试(以 emitter 为键)
RetryNotifier.notifyFailure(emitter);
} else {
emitter.complete();
}
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
log.error("OpenAI sse连接异常data{},异常:{}", body.string(), t);
String msg = body.string();
log.error("OpenAI sse连接异常data{},异常:{}", msg, t);
SSEUtil.sendErrorEvent(emitter, msg);
} else {
log.error("OpenAI sse连接异常data{},异常:{}", response, t);
SSEUtil.sendErrorEvent(emitter, String.valueOf(response));
}
if (retryEnabled) {
// 通知重试
RetryNotifier.notifyFailure(emitter);
} else {
emitter.complete();
}
eventSource.cancel();
}

View File

@@ -20,6 +20,8 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Collections;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.ruoyi.chat.support.RetryNotifier;
import org.ruoyi.chat.support.ChatServiceHelper;
/**
* 扣子聊天管理
@@ -53,6 +55,7 @@ public class CozeServiceImpl implements IChatService {
Flowable<ChatEvent> resp = coze.chat().stream(req);
ExecutorService executor = Executors.newFixedThreadPool(10);
executor.submit(() -> {
try {
resp.blockingForEach(
event -> {
if (ChatEventType.CONVERSATION_MESSAGE_DELTA.equals(event.getEvent())) {
@@ -62,10 +65,15 @@ public class CozeServiceImpl implements IChatService {
if (ChatEventType.CONVERSATION_CHAT_COMPLETED.equals(event.getEvent())) {
emitter.complete();
log.info("Token usage: {}", event.getChat().getUsage().getTokenCount());
RetryNotifier.clear(emitter);
}
}
);
} catch (Exception ex) {
ChatServiceHelper.onStreamError(emitter, ex.getMessage());
} finally {
coze.shutdownExecutor();
}
});

View File

@@ -9,13 +9,13 @@ import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.enums.ChatModeType;
import org.ruoyi.chat.service.chat.IChatService;
import org.ruoyi.chat.support.ChatServiceHelper;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatModelService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* deepseek
*/
@@ -57,11 +57,14 @@ public class DeepSeekChatImpl implements IChatService {
@Override
public void onError(Throwable error) {
System.err.println("错误: " + error.getMessage());
ChatServiceHelper.onStreamError(emitter, error.getMessage());
}
});
} catch (Exception e) {
log.error("deepseek请求失败{}", e.getMessage());
// 同步异常直接通知失败
ChatServiceHelper.onStreamError(emitter, e.getMessage());
}
return emitter;

View File

@@ -25,8 +25,10 @@ import org.ruoyi.service.IChatSessionService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.ruoyi.chat.support.ChatServiceHelper;
import java.util.Objects;
import org.ruoyi.chat.support.RetryNotifier;
/**
* dify 聊天管理
@@ -112,20 +114,24 @@ public class DifyServiceImpl implements IChatService {
chatRequestResponse.setSessionId(chatRequest.getSessionId());
chatRequestResponse.setPrompt(respMessage.toString());
chatCostService.deductToken(chatRequestResponse);
RetryNotifier.clear(emitter);
}
@Override
public void onError(ErrorEvent event) {
System.err.println("错误: " + event.getMessage());
ChatServiceHelper.onStreamError(emitter, event.getMessage());
}
@Override
public void onException(Throwable throwable) {
System.err.println("异常: " + throwable.getMessage());
ChatServiceHelper.onStreamError(emitter, throwable.getMessage());
}
});
} catch (Exception e) {
log.error("dify请求失败{}", e.getMessage());
ChatServiceHelper.onStreamError(emitter, e.getMessage());
}
return emitter;

View File

@@ -33,7 +33,7 @@ public class FastGPTServiceImpl implements IChatService {
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
OpenAiStreamClient openAiStreamClient = ChatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
List<Message> messages = chatRequest.getMessages();
FastGPTSSEEventSourceListener listener = new FastGPTSSEEventSourceListener(emitter);
FastGPTSSEEventSourceListener listener = new FastGPTSSEEventSourceListener(emitter, chatRequest.getSessionId());
FastGPTChatCompletion completion = FastGPTChatCompletion
.builder()
.messages(messages)
@@ -41,7 +41,12 @@ public class FastGPTServiceImpl implements IChatService {
.detail(true)
.stream(true)
.build();
try {
openAiStreamClient.streamChatCompletion(completion, listener);
} catch (Exception ex) {
org.ruoyi.chat.support.RetryNotifier.notifyFailure(chatRequest.getSessionId());
throw ex;
}
return emitter;
}

View File

@@ -18,6 +18,7 @@ import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.*;
import org.ruoyi.chat.support.ChatServiceHelper;
/**
* 图片识别模型
@@ -128,10 +129,10 @@ public class ImageServiceImpl implements IChatService {
OpenAiStreamClient openAiStreamClient = ChatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
List<Message> messages = chatRequest.getMessages();
// 获取会话token
String token = StpUtil.getTokenValue();
// 获取会话token从入口透传避免非Web线程取值报错
String token = chatRequest.getToken();
// 创建 SSE 事件源监听器
SSEEventSourceListener listener = new SSEEventSourceListener(emitter, chatRequest.getUserId(), chatRequest.getSessionId(), token);
SSEEventSourceListener listener = ChatServiceHelper.createOpenAiListener(emitter, chatRequest);
// 构建聊天完成请求
ChatCompletion completion = ChatCompletion
@@ -142,7 +143,12 @@ public class ImageServiceImpl implements IChatService {
.build();
// 发起流式聊天完成请求
try {
openAiStreamClient.streamChatCompletion(completion, listener);
} catch (Exception ex) {
ChatServiceHelper.onStreamError(emitter, ex.getMessage());
throw ex;
}
return emitter;
}

View File

@@ -22,6 +22,8 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.ruoyi.chat.support.RetryNotifier;
import org.ruoyi.chat.support.ChatServiceHelper;
/**
@@ -65,13 +67,14 @@ public class OllamaServiceImpl implements IChatService {
try {
emitter.send(substr);
} catch (IOException e) {
SSEUtil.sendErrorEvent(emitter, e.getMessage());
ChatServiceHelper.onStreamError(emitter, e.getMessage());
}
};
api.chat(requestModel, streamHandler);
emitter.complete();
RetryNotifier.clear(emitter);
} catch (Exception e) {
SSEUtil.sendErrorEvent(emitter, e.getMessage());
ChatServiceHelper.onStreamError(emitter, e.getMessage());
}
});

View File

@@ -1,12 +1,12 @@
package org.ruoyi.chat.service.chat.impl;
import cn.dev33.satoken.stp.StpUtil;
import io.modelcontextprotocol.client.McpSyncClient;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.config.ChatConfig;
import org.ruoyi.chat.enums.ChatModeType;
import org.ruoyi.chat.listener.SSEEventSourceListener;
import org.ruoyi.chat.service.chat.IChatService;
import org.ruoyi.chat.support.ChatServiceHelper;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
@@ -57,15 +57,19 @@ public class OpenAIServiceImpl implements IChatService {
Message userMessage = Message.builder().content("工具返回信息:"+toolString).role(Message.Role.USER).build();
messages.add(userMessage);
}
String token = StpUtil.getTokenValue();
SSEEventSourceListener listener = new SSEEventSourceListener(emitter,chatRequest.getUserId(),chatRequest.getSessionId(), token);
SSEEventSourceListener listener = ChatServiceHelper.createOpenAiListener(emitter, chatRequest);
ChatCompletion completion = ChatCompletion
.builder()
.messages(messages)
.model(chatRequest.getModel())
.stream(true)
.build();
try {
openAiStreamClient.streamChatCompletion(completion, listener);
} catch (Exception ex) {
ChatServiceHelper.onStreamError(emitter, ex.getMessage());
throw ex;
}
return emitter;
}

View File

@@ -14,6 +14,7 @@ import org.ruoyi.service.IChatModelService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.ruoyi.chat.support.ChatServiceHelper;
/**
@@ -51,15 +52,18 @@ public class QianWenAiChatServiceImpl implements IChatService {
public void onCompleteResponse(ChatResponse completeResponse) {
emitter.complete();
log.info("消息结束完整消息ID: {}", completeResponse);
org.ruoyi.chat.support.RetryNotifier.clear(emitter);
}
@Override
public void onError(Throwable error) {
error.printStackTrace();
ChatServiceHelper.onStreamError(emitter, error.getMessage());
}
});
} catch (Exception e) {
log.error("千问请求失败:{}", e.getMessage());
ChatServiceHelper.onStreamError(emitter, e.getMessage());
}
return emitter;

View File

@@ -9,6 +9,8 @@ import org.ruoyi.chat.factory.ChatServiceFactory;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.chat.service.chat.IChatService;
import org.ruoyi.chat.service.chat.ISseService;
import org.ruoyi.chat.support.ChatRetryHelper;
import org.ruoyi.chat.support.RetryNotifier;
import org.ruoyi.chat.util.SSEUtil;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.chat.Message;
@@ -45,6 +47,7 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import cn.dev33.satoken.stp.StpUtil;
/**
* @author ageer
@@ -75,6 +78,12 @@ public class SseServiceImpl implements ISseService {
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
SseEmitter sseEmitter = new SseEmitter(0L);
try {
// 记录当前会话令牌,供异步线程使用
try {
chatRequest.setToken(StpUtil.getTokenValue());
} catch (Exception ignore) {
// 保底无token场景下忽略
}
// 构建消息列表
buildChatMessageList(chatRequest);
// 设置对话角色
@@ -113,9 +122,34 @@ public class SseServiceImpl implements ISseService {
chatRequest.setSessionId(chatSessionBo.getId());
}
}
// 根据模型分类调用不同的处理逻辑
IChatService chatService = chatServiceFactory.getChatService(chatModelVo.getCategory());
// 自动选择模型并获取对应的聊天服务
IChatService chatService = autoSelectModelAndGetService(chatRequest);
// 仅当 autoSelectModel = true 时,才启用重试与降级
if (Boolean.TRUE.equals(chatRequest.getAutoSelectModel())) {
ChatModelVo currentModel = this.chatModelVo;
String currentCategory = currentModel.getCategory();
ChatRetryHelper.executeWithRetry(
currentModel,
currentCategory,
chatModelService,
sseEmitter,
(modelForTry, onFailure) -> {
// 替换请求中的模型名称
chatRequest.setModel(modelForTry.getModelName());
// 以 emitter 实例为唯一键注册失败回调
RetryNotifier.setFailureCallback(sseEmitter, onFailure);
try {
autoSelectServiceByCategoryAndInvoke(chatRequest, sseEmitter, modelForTry.getCategory());
} finally {
// 不在此处清理,待下游结束/失败时清理
}
}
);
} else {
// 不重试不降级,直接调用
chatService.chat(chatRequest, sseEmitter);
}
} catch (Exception e) {
log.error(e.getMessage(),e);
SSEUtil.sendErrorEvent(sseEmitter,e.getMessage());
@@ -123,6 +157,51 @@ public class SseServiceImpl implements ISseService {
return sseEmitter;
}
/**
* 自动选择模型并获取对应的聊天服务
*/
private IChatService autoSelectModelAndGetService(ChatRequest chatRequest) {
try {
if (Boolean.TRUE.equals(chatRequest.getHasAttachment())) {
chatModelVo = selectModelByCategory("image");
} else if (Boolean.TRUE.equals(chatRequest.getAutoSelectModel())) {
chatModelVo = selectModelByCategory("chat");
} else {
chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
}
if (chatModelVo == null) {
throw new IllegalStateException("未找到模型名称:" + chatRequest.getModel());
}
// 自动设置请求参数中的模型名称
chatRequest.setModel(chatModelVo.getModelName());
// 直接返回对应的聊天服务
return chatServiceFactory.getChatService(chatModelVo.getCategory());
} catch (Exception e) {
log.error("模型选择和服务获取失败: {}", e.getMessage(), e);
throw new IllegalStateException("模型选择和服务获取失败: " + e.getMessage());
}
}
/**
* 根据给定分类获取服务并发起调用(避免在降级时重复选择模型)
*/
private void autoSelectServiceByCategoryAndInvoke(ChatRequest chatRequest, SseEmitter sseEmitter, String category) {
IChatService service = chatServiceFactory.getChatService(category);
service.chat(chatRequest, sseEmitter);
}
/**
* 根据分类选择优先级最高的模型
*/
private ChatModelVo selectModelByCategory(String category) {
ChatModelVo model = chatModelService.selectModelByCategoryWithHighestPriority(category);
if (model == null) {
throw new IllegalStateException("未找到" + category + "分类的模型配置");
}
return model;
}
/**
* 获取对话标题
*
@@ -144,27 +223,80 @@ public class SseServiceImpl implements ISseService {
* 构建消息列表
*/
private void buildChatMessageList(ChatRequest chatRequest){
String sysPrompt;
// 矫正模型名称 如果是gpt-image 则查询image类型模型 获取模型名称
if(chatRequest.getModel().equals("gpt-image")) {
chatModelVo = chatModelService.selectModelByCategory("image");
if (chatModelVo == null) {
log.error("未找到image类型的模型配置");
throw new IllegalStateException("未找到image类型的模型配置");
List<Message> messages = chatRequest.getMessages();
// 处理知识库相关逻辑
String sysPrompt = processKnowledgeBase(chatRequest, messages);
// 设置系统提示词
Message sysMessage = Message.builder()
.content(sysPrompt)
.role(Message.Role.SYSTEM)
.build();
messages.add(0, sysMessage);
chatRequest.setSysPrompt(sysPrompt);
// 用户对话内容
String chatString = null;
// 获取用户对话信息
Object content = messages.get(messages.size() - 1).getContent();
if (content instanceof List<?> listContent) {
if (CollectionUtil.isNotEmpty(listContent)) {
chatString = listContent.get(0).toString();
}
} else {
chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
chatString = content.toString();
}
// 获取对话消息列表
List<Message> messages = chatRequest.getMessages();
// 查询向量库相关信息加入到上下文
if(StringUtils.isNotEmpty(chatRequest.getKid())){
List<Message> knMessages = new ArrayList<>();
String content = messages.get(messages.size() - 1).getContent().toString();
// 通过kid查询知识库信息
chatRequest.setPrompt(chatString);
}
/**
* 处理知识库相关逻辑
*/
private String processKnowledgeBase(ChatRequest chatRequest, List<Message> messages) {
if (StringUtils.isEmpty(chatRequest.getKid())) {
return getDefaultSystemPrompt();
}
try {
// 查询知识库信息
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKid()));
if (knowledgeInfoVo == null) {
log.warn("知识库信息不存在kid: {}", chatRequest.getKid());
return getDefaultSystemPrompt();
}
// 查询向量模型配置信息
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
if (chatModel == null) {
log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModelName());
return getDefaultSystemPrompt();
}
// 构建向量查询参数
QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel);
// 获取向量查询结果
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
// 添加知识库消息到上下文
addKnowledgeMessages(messages, nearestList);
// 返回知识库系统提示词
return getKnowledgeSystemPrompt(knowledgeInfoVo);
} catch (Exception e) {
log.error("处理知识库信息失败: {}", e.getMessage(), e);
return getDefaultSystemPrompt();
}
}
/**
* 构建向量查询参数
*/
private QueryVectorBo buildQueryVectorBo(ChatRequest chatRequest, KnowledgeInfoVo knowledgeInfoVo, ChatModelVo chatModel) {
String content = chatRequest.getMessages().get(chatRequest.getMessages().size() - 1).getContent().toString();
QueryVectorBo queryVectorBo = new QueryVectorBo();
queryVectorBo.setQuery(content);
@@ -174,49 +306,51 @@ public class SseServiceImpl implements ISseService {
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModelName());
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
for (String prompt : nearestList) {
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
knMessages.add(userMessage);
return queryVectorBo;
}
messages.addAll(knMessages);
// 设置知识库系统提示词
sysPrompt = knowledgeInfoVo.getSystemPrompt();
/**
* 添加知识库消息到上下文
*/
private void addKnowledgeMessages(List<Message> messages, List<String> nearestList) {
for (String prompt : nearestList) {
Message userMessage = Message.builder()
.content(prompt)
.role(Message.Role.USER)
.build();
messages.add(userMessage);
}
}
/**
* 获取知识库系统提示词
*/
private String getKnowledgeSystemPrompt(KnowledgeInfoVo knowledgeInfoVo) {
String sysPrompt = knowledgeInfoVo.getSystemPrompt();
if (StringUtils.isEmpty(sysPrompt)) {
sysPrompt = "###角色设定\n" +
"你是一个智能知识助手,专注于利用上下文中的信息来提供准确和相关的回答。\n" +
"###指令\n" +
"当用户的问题与上下文知识匹配时,利用上下文信息进行回答。如果问题与上下文不匹配,运用自身的推理能力生成合适的回答。\n" +
"###限制\n" +
"确保回答清晰简洁,避免提供不必要的细节。始终保持语气友好" +
"确保回答清晰简洁,避免提供不必要的细节。始终保持语气友好\n" +
"当前时间:" + DateUtils.getDate();
}
}else {
sysPrompt = chatModelVo.getSystemPrompt();
return sysPrompt;
}
/**
* 获取默认系统提示词
*/
private String getDefaultSystemPrompt() {
String sysPrompt = chatModelVo != null ? chatModelVo.getSystemPrompt() : null;
if (StringUtils.isEmpty(sysPrompt)) {
sysPrompt = "你是一个由RuoYI-AI开发的人工智能助手名字叫熊猫助手。你擅长中英文对话能够理解并处理各种问题提供安全、有帮助、准确的回答。" +
"当前时间:" + DateUtils.getDate() +
"#注意:回复之前注意结合上下文和工具返回内容进行回复。";
}
}
// 设置系统默认提示词
Message sysMessage = Message.builder().content(sysPrompt).role(Message.Role.SYSTEM).build();
messages.add(0,sysMessage);
chatRequest.setSysPrompt(sysPrompt);
// 用户对话内容
String chatString = null;
// 获取用户对话信息
Object content = messages.get(messages.size() - 1).getContent();
if (content instanceof List<?> listContent) {
if (CollectionUtil.isNotEmpty(listContent)) {
chatString = listContent.get(0).toString();
}
} else if (content instanceof String) {
chatString = (String) content;
}
// 设置对话信息
chatRequest.setPrompt(chatString);
return sysPrompt;
}

View File

@@ -15,6 +15,7 @@ import org.ruoyi.service.IChatModelService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.ruoyi.chat.support.ChatServiceHelper;
@@ -51,14 +52,14 @@ public class ZhipuAiChatServiceImpl implements IChatService {
@SneakyThrows
@Override
public void onError(Throwable error) {
// System.out.println(error.getMessage());
emitter.send(error.getMessage());
ChatServiceHelper.onStreamError(emitter, error.getMessage());
}
@Override
public void onCompleteResponse(ChatResponse response) {
emitter.complete();
log.info("消息结束完整消息ID: {}", response.aiMessage());
org.ruoyi.chat.support.RetryNotifier.clear(emitter);
}
};
@@ -71,6 +72,7 @@ public class ZhipuAiChatServiceImpl implements IChatService {
model.chat(chatRequest.getPrompt(), handler);
} catch (Exception e) {
log.error("智谱清言请求失败:{}", e.getMessage());
ChatServiceHelper.onStreamError(emitter, e.getMessage());
}
return emitter;

View File

@@ -0,0 +1,115 @@
package org.ruoyi.chat.support;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.util.SSEUtil;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatModelService;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
/**
* 统一的聊天重试与降级调度器。
*
* 策略:
* - 当前模型最多重试 3 次;仍失败则降级到同分类内、优先级小于当前的最高优先级模型。
* - 降级模型同样最多重试 3 次;仍失败则向前端返回失败信息并停止。
*
* 注意:实现依赖调用方在底层异步失败时执行 onFailure.run() 通知本调度器。
*/
@Slf4j
public class ChatRetryHelper {
public interface AttemptStarter {
void start(ChatModelVo model, Runnable onFailure) throws Exception;
}
public static void executeWithRetry(
ChatModelVo primaryModel,
String category,
IChatModelService chatModelService,
SseEmitter emitter,
AttemptStarter attemptStarter
) {
Objects.requireNonNull(primaryModel, "primaryModel must not be null");
Objects.requireNonNull(category, "category must not be null");
Objects.requireNonNull(chatModelService, "chatModelService must not be null");
Objects.requireNonNull(emitter, "emitter must not be null");
Objects.requireNonNull(attemptStarter, "attemptStarter must not be null");
AtomicInteger mainAttempts = new AtomicInteger(0);
AtomicInteger fallbackAttempts = new AtomicInteger(0);
AtomicBoolean inFallback = new AtomicBoolean(false);
AtomicBoolean scheduling = new AtomicBoolean(false);
class Scheduler {
volatile ChatModelVo current = primaryModel;
volatile ChatModelVo fallback = null;
void startAttempt() {
try {
if (!inFallback.get()) {
if (mainAttempts.incrementAndGet() > 3) {
// 进入降级
inFallback.set(true);
if (fallback == null) {
Integer curPriority = primaryModel.getPriority();
if (curPriority == null) {
curPriority = Integer.MAX_VALUE;
}
fallback = chatModelService.selectFallbackModelByCategoryAndLessPriority(category, curPriority);
}
if (fallback == null) {
SSEUtil.sendErrorEvent(emitter, "当前模型重试3次均失败且无可用降级模型");
emitter.complete();
return;
}
current = fallback;
mainAttempts.set(3); // 锁定
fallbackAttempts.set(0);
}
} else {
if (fallbackAttempts.incrementAndGet() > 3) {
SSEUtil.sendErrorEvent(emitter, "降级模型重试3次仍失败");
emitter.complete();
return;
}
}
Runnable onFailure = () -> {
// 去抖:避免同一次失败触发多次重试
if (scheduling.compareAndSet(false, true)) {
try {
SSEUtil.sendErrorEvent(emitter, (inFallback.get() ? "降级模型" : "当前模型") + "调用失败,准备重试...");
// 立即发起下一次尝试
startAttempt();
} finally {
scheduling.set(false);
}
}
};
attemptStarter.start(current, onFailure);
} catch (Exception ex) {
log.error("启动聊天尝试失败: {}", ex.getMessage(), ex);
SSEUtil.sendErrorEvent(emitter, "启动聊天尝试失败: " + ex.getMessage());
// 直接按失败处理,继续重试/降级
if (scheduling.compareAndSet(false, true)) {
try {
startAttempt();
} finally {
scheduling.set(false);
}
}
}
}
}
new Scheduler().startAttempt();
}
}

View File

@@ -0,0 +1,45 @@
package org.ruoyi.chat.support;
import org.ruoyi.chat.listener.SSEEventSourceListener;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.chat.util.SSEUtil;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* 抽取各聊天实现类的通用逻辑:
* - 创建带开关的 SSE 监听器
* - 统一的流错误处理(根据是否在重试场景决定通知或直接结束)
* - 统一的完成处理(清理回调并 complete
*/
public class ChatServiceHelper {
public static SSEEventSourceListener createOpenAiListener(SseEmitter emitter, ChatRequest chatRequest) {
boolean retryEnabled = Boolean.TRUE.equals(chatRequest.getAutoSelectModel());
return new SSEEventSourceListener(
emitter,
chatRequest.getUserId(),
chatRequest.getSessionId(),
chatRequest.getToken(),
retryEnabled
);
}
public static void onStreamError(SseEmitter emitter, String errorMessage) {
SSEUtil.sendErrorEvent(emitter, errorMessage);
if (RetryNotifier.hasCallback(emitter)) {
RetryNotifier.notifyFailure(emitter);
} else {
emitter.complete();
}
}
public static void onStreamComplete(SseEmitter emitter) {
try {
emitter.complete();
} finally {
RetryNotifier.clear(emitter);
}
}
}

View File

@@ -0,0 +1,51 @@
package org.ruoyi.chat.support;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
/**
* 失败回调通知器基于发射器实例SseEmitter 等对象地址)绑定回调,
* 避免与业务标识绑定,且能跨线程正确关联。
*/
public class RetryNotifier {
private static final Map<Integer, Runnable> FAILURE_CALLBACKS = new ConcurrentHashMap<>();
private static int keyOf(Object obj) {
return System.identityHashCode(obj);
}
public static void setFailureCallback(Object emitterLike, Runnable callback) {
if (emitterLike == null || callback == null) {
return;
}
FAILURE_CALLBACKS.put(keyOf(emitterLike), callback);
}
public static void clear(Object emitterLike) {
if (emitterLike == null) {
return;
}
FAILURE_CALLBACKS.remove(keyOf(emitterLike));
}
public static void notifyFailure(Object emitterLike) {
if (emitterLike == null) {
return;
}
Runnable cb = FAILURE_CALLBACKS.get(keyOf(emitterLike));
if (Objects.nonNull(cb)) {
cb.run();
}
}
public static boolean hasCallback(Object emitterLike) {
if (emitterLike == null) {
return false;
}
return FAILURE_CALLBACKS.containsKey(keyOf(emitterLike));
}
}

View File

@@ -25,6 +25,6 @@ public class SSEUtil {
} catch (IOException e) {
log.error("SSE发送失败: {}", e.getMessage());
}
sseEmitter.complete();
// 不立即关闭,由上层策略决定是否继续重试或降级
}
}

View File

@@ -0,0 +1,26 @@
alter table chat_model
add priority int default 1 null comment '模型优先级(值越大优先级越高)';
UPDATE `ruoyi-ai`.chat_model t
SET t.priority = 3
WHERE t.id = 1782792839548735492;
UPDATE `ruoyi-ai`.chat_model t
SET t.priority = 6
WHERE t.id = 1859570229117022212;
UPDATE `ruoyi-ai`.chat_model t
SET t.priority = 5
WHERE t.id = 1859570229117022211;
UPDATE `ruoyi-ai`.chat_model t
SET t.priority = 4
WHERE t.id = 1782792839548735493;
UPDATE `ruoyi-ai`.chat_model t
SET t.priority = 2
WHERE t.id = 1828324413241466881;
UPDATE `ruoyi-ai`.chat_model t
SET t.priority = 2
WHERE t.id = 1782792839548735491;