mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-05 15:57:32 +00:00
feat: 调整知识库问答接入提示词模板
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
package org.ruoyi.chat.enums;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
/**
|
||||
* 提示词模板分类
|
||||
*
|
||||
* @author evo
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum promptTemplateEnum {
|
||||
CHAT(1, "chat"),
|
||||
VECTOR(2, "vector"),
|
||||
;
|
||||
|
||||
private final Integer code;
|
||||
private final String desc;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package org.ruoyi.chat.service.chat.impl;
|
||||
|
||||
import cn.dev33.satoken.stp.StpUtil;
|
||||
import cn.hutool.core.collection.CollectionUtil;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.ResponseBody;
|
||||
import org.ruoyi.chat.enums.promptTemplateEnum;
|
||||
import org.ruoyi.chat.factory.ChatServiceFactory;
|
||||
import org.ruoyi.chat.service.chat.IChatCostService;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
@@ -27,9 +29,11 @@ import org.ruoyi.domain.bo.ChatSessionBo;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.domain.vo.PromptTemplateVo;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.ruoyi.service.IChatSessionService;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.IPromptTemplateService;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.core.io.InputStreamResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
@@ -45,9 +49,8 @@ import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import cn.dev33.satoken.stp.StpUtil;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* @author ageer
|
||||
@@ -73,6 +76,9 @@ public class SseServiceImpl implements ISseService {
|
||||
|
||||
private ChatModelVo chatModelVo;
|
||||
|
||||
// 提示词模板服务
|
||||
private final IPromptTemplateService promptTemplateService;
|
||||
|
||||
|
||||
@Override
|
||||
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
||||
@@ -89,9 +95,9 @@ public class SseServiceImpl implements ISseService {
|
||||
// 设置对话角色
|
||||
chatRequest.setRole(Message.Role.USER.getName());
|
||||
|
||||
if(LoginHelper.isLogin()){
|
||||
if (LoginHelper.isLogin()) {
|
||||
|
||||
// 设置用户id
|
||||
// 设置用户id
|
||||
chatRequest.setUserId(LoginHelper.getUserId());
|
||||
|
||||
|
||||
@@ -100,10 +106,10 @@ public class SseServiceImpl implements ISseService {
|
||||
//待优化的地方 (这里请前端提交send的时候传递uuid进来或者sessionId)
|
||||
{
|
||||
// 设置会话id
|
||||
if (chatRequest.getUuid() == null){
|
||||
if (chatRequest.getUuid() == null) {
|
||||
//暂时随机生成会话id
|
||||
chatRequest.setSessionId(System.currentTimeMillis());
|
||||
}else{
|
||||
} else {
|
||||
//这里或许需要修改一下,这里应该用uuid 或者 前端传递 sessionId
|
||||
chatRequest.setSessionId(chatRequest.getUuid());
|
||||
}
|
||||
@@ -113,7 +119,7 @@ public class SseServiceImpl implements ISseService {
|
||||
// 保存消息记录 并扣除费用
|
||||
chatCostService.deductToken(chatRequest);
|
||||
chatRequest.setUserId(chatCostService.getUserId());
|
||||
if(chatRequest.getSessionId()==null){
|
||||
if (chatRequest.getSessionId() == null) {
|
||||
ChatSessionBo chatSessionBo = new ChatSessionBo();
|
||||
chatSessionBo.setUserId(chatCostService.getUserId());
|
||||
chatSessionBo.setSessionTitle(getFirst10Characters(chatRequest.getPrompt()));
|
||||
@@ -130,29 +136,30 @@ public class SseServiceImpl implements ISseService {
|
||||
ChatModelVo currentModel = this.chatModelVo;
|
||||
String currentCategory = currentModel.getCategory();
|
||||
ChatRetryHelper.executeWithRetry(
|
||||
currentModel,
|
||||
currentCategory,
|
||||
chatModelService,
|
||||
sseEmitter,
|
||||
(modelForTry, onFailure) -> {
|
||||
// 替换请求中的模型名称
|
||||
chatRequest.setModel(modelForTry.getModelName());
|
||||
// 以 emitter 实例为唯一键注册失败回调
|
||||
RetryNotifier.setFailureCallback(sseEmitter, onFailure);
|
||||
try {
|
||||
autoSelectServiceByCategoryAndInvoke(chatRequest, sseEmitter, modelForTry.getCategory());
|
||||
} finally {
|
||||
// 不在此处清理,待下游结束/失败时清理
|
||||
currentModel,
|
||||
currentCategory,
|
||||
chatModelService,
|
||||
sseEmitter,
|
||||
(modelForTry, onFailure) -> {
|
||||
// 替换请求中的模型名称
|
||||
chatRequest.setModel(modelForTry.getModelName());
|
||||
// 以 emitter 实例为唯一键注册失败回调
|
||||
RetryNotifier.setFailureCallback(sseEmitter, onFailure);
|
||||
try {
|
||||
autoSelectServiceByCategoryAndInvoke(chatRequest, sseEmitter,
|
||||
modelForTry.getCategory());
|
||||
} finally {
|
||||
// 不在此处清理,待下游结束/失败时清理
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
} else {
|
||||
// 不重试不降级,直接调用
|
||||
chatService.chat(chatRequest, sseEmitter);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error(e.getMessage(),e);
|
||||
SSEUtil.sendErrorEvent(sseEmitter,e.getMessage());
|
||||
log.error(e.getMessage(), e);
|
||||
SSEUtil.sendErrorEvent(sseEmitter, e.getMessage());
|
||||
}
|
||||
return sseEmitter;
|
||||
}
|
||||
@@ -169,7 +176,7 @@ public class SseServiceImpl implements ISseService {
|
||||
} else {
|
||||
chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||
}
|
||||
|
||||
|
||||
if (chatModelVo == null) {
|
||||
throw new IllegalStateException("未找到模型名称:" + chatRequest.getModel());
|
||||
}
|
||||
@@ -190,7 +197,7 @@ public class SseServiceImpl implements ISseService {
|
||||
IChatService service = chatServiceFactory.getChatService(category);
|
||||
service.chat(chatRequest, sseEmitter);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 根据分类选择优先级最高的模型
|
||||
*/
|
||||
@@ -220,23 +227,23 @@ public class SseServiceImpl implements ISseService {
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建消息列表
|
||||
* 构建消息列表
|
||||
*/
|
||||
private void buildChatMessageList(ChatRequest chatRequest){
|
||||
private void buildChatMessageList(ChatRequest chatRequest) {
|
||||
List<Message> messages = chatRequest.getMessages();
|
||||
|
||||
|
||||
// 处理知识库相关逻辑
|
||||
String sysPrompt = processKnowledgeBase(chatRequest, messages);
|
||||
|
||||
|
||||
// 设置系统提示词
|
||||
Message sysMessage = Message.builder()
|
||||
.content(sysPrompt)
|
||||
.role(Message.Role.SYSTEM)
|
||||
.build();
|
||||
messages.add(0, sysMessage);
|
||||
|
||||
|
||||
chatRequest.setSysPrompt(sysPrompt);
|
||||
|
||||
|
||||
// 用户对话内容
|
||||
String chatString = null;
|
||||
// 获取用户对话信息
|
||||
@@ -250,54 +257,55 @@ public class SseServiceImpl implements ISseService {
|
||||
}
|
||||
chatRequest.setPrompt(chatString);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 处理知识库相关逻辑
|
||||
*/
|
||||
private String processKnowledgeBase(ChatRequest chatRequest, List<Message> messages) {
|
||||
if (StringUtils.isEmpty(chatRequest.getKid())) {
|
||||
return getDefaultSystemPrompt();
|
||||
return getPromptTemplatePrompt(promptTemplateEnum.VECTOR.getDesc());
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
// 查询知识库信息
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKid()));
|
||||
if (knowledgeInfoVo == null) {
|
||||
log.warn("知识库信息不存在,kid: {}", chatRequest.getKid());
|
||||
return getDefaultSystemPrompt();
|
||||
return getPromptTemplatePrompt(promptTemplateEnum.VECTOR.getDesc());
|
||||
}
|
||||
|
||||
|
||||
// 查询向量模型配置信息
|
||||
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
|
||||
if (chatModel == null) {
|
||||
log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModelName());
|
||||
return getDefaultSystemPrompt();
|
||||
return getPromptTemplatePrompt(promptTemplateEnum.VECTOR.getDesc());
|
||||
}
|
||||
|
||||
|
||||
// 构建向量查询参数
|
||||
QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel);
|
||||
|
||||
|
||||
// 获取向量查询结果
|
||||
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
|
||||
|
||||
|
||||
// 添加知识库消息到上下文
|
||||
addKnowledgeMessages(messages, nearestList);
|
||||
|
||||
|
||||
// 返回知识库系统提示词
|
||||
return getKnowledgeSystemPrompt(knowledgeInfoVo);
|
||||
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("处理知识库信息失败: {}", e.getMessage(), e);
|
||||
return getDefaultSystemPrompt();
|
||||
return getPromptTemplatePrompt(promptTemplateEnum.VECTOR.getDesc());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 构建向量查询参数
|
||||
*/
|
||||
private QueryVectorBo buildQueryVectorBo(ChatRequest chatRequest, KnowledgeInfoVo knowledgeInfoVo, ChatModelVo chatModel) {
|
||||
private QueryVectorBo buildQueryVectorBo(ChatRequest chatRequest, KnowledgeInfoVo knowledgeInfoVo,
|
||||
ChatModelVo chatModel) {
|
||||
String content = chatRequest.getMessages().get(chatRequest.getMessages().size() - 1).getContent().toString();
|
||||
|
||||
|
||||
QueryVectorBo queryVectorBo = new QueryVectorBo();
|
||||
queryVectorBo.setQuery(content);
|
||||
queryVectorBo.setKid(chatRequest.getKid());
|
||||
@@ -306,10 +314,10 @@ public class SseServiceImpl implements ISseService {
|
||||
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModelName());
|
||||
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
|
||||
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
|
||||
|
||||
|
||||
return queryVectorBo;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 添加知识库消息到上下文
|
||||
*/
|
||||
@@ -322,7 +330,7 @@ public class SseServiceImpl implements ISseService {
|
||||
messages.add(userMessage);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 获取知识库系统提示词
|
||||
*/
|
||||
@@ -339,16 +347,29 @@ public class SseServiceImpl implements ISseService {
|
||||
}
|
||||
return sysPrompt;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 获取提示词模板提示词
|
||||
*/
|
||||
private String getPromptTemplatePrompt(String category) {
|
||||
PromptTemplateVo promptTemplateVo = promptTemplateService.queryByCategory(category);
|
||||
if (Objects.isNull(promptTemplateVo) || StringUtils.isEmpty(promptTemplateVo.getTemplateContent())) {
|
||||
return getDefaultSystemPrompt();
|
||||
}
|
||||
return promptTemplateVo.getTemplateContent();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取默认系统提示词
|
||||
*/
|
||||
private String getDefaultSystemPrompt() {
|
||||
String sysPrompt = chatModelVo != null ? chatModelVo.getSystemPrompt() : null;
|
||||
if (StringUtils.isEmpty(sysPrompt)) {
|
||||
sysPrompt = "你是一个由RuoYI-AI开发的人工智能助手,名字叫熊猫助手。你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。" +
|
||||
"当前时间:" + DateUtils.getDate() +
|
||||
"#注意:回复之前注意结合上下文和工具返回内容进行回复。";
|
||||
sysPrompt = "你是一个由RuoYI-AI开发的人工智能助手,名字叫RuoYI人工智能助手。"
|
||||
+ "你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。"
|
||||
+ "当前时间:" + DateUtils.getDate()
|
||||
+ "#注意:回复之前注意结合上下文和工具返回内容进行回复。";
|
||||
}
|
||||
return sysPrompt;
|
||||
}
|
||||
@@ -365,8 +386,8 @@ public class SseServiceImpl implements ISseService {
|
||||
InputStreamResource resource = new InputStreamResource(body.byteStream());
|
||||
// 创建并返回ResponseEntity
|
||||
return ResponseEntity.ok()
|
||||
.contentType(MediaType.parseMediaType("audio/mpeg"))
|
||||
.body(resource);
|
||||
.contentType(MediaType.parseMediaType("audio/mpeg"))
|
||||
.body(resource);
|
||||
} else {
|
||||
// 如果ResponseBody为空,返回404状态码
|
||||
return ResponseEntity.notFound().build();
|
||||
|
||||
Reference in New Issue
Block a user