mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-16 21:33:40 +00:00
fix: ollama兼容联网查询 知识库检索
This commit is contained in:
@@ -466,8 +466,8 @@ public class OpenAiStreamClient {
|
|||||||
* @since 1.1.3
|
* @since 1.1.3
|
||||||
*/
|
*/
|
||||||
public ResponseBody textToSpeech(TextToSpeech textToSpeech){
|
public ResponseBody textToSpeech(TextToSpeech textToSpeech){
|
||||||
Call<ResponseBody> responseBody = this.openAiApi.textToSpeech(textToSpeech);
|
|
||||||
try {
|
try {
|
||||||
|
Call<ResponseBody> responseBody = this.openAiApi.textToSpeech(textToSpeech);
|
||||||
return responseBody.execute().body();
|
return responseBody.execute().body();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new BaseException("文本转语音(同步)失败: "+e.getMessage());
|
throw new BaseException("文本转语音(同步)失败: "+e.getMessage());
|
||||||
|
|||||||
@@ -26,6 +26,11 @@ public class ChatRequest {
|
|||||||
*/
|
*/
|
||||||
private String prompt;
|
private String prompt;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 系统提示词
|
||||||
|
*/
|
||||||
|
private String sysPrompt;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 是否开启流式对话
|
* 是否开启流式对话
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -84,77 +84,43 @@ public class SseServiceImpl implements ISseService {
|
|||||||
|
|
||||||
private final IChatCostService chatCostService;
|
private final IChatCostService chatCostService;
|
||||||
|
|
||||||
private static final String requestIdTemplate = "mycompany-%d";
|
private static final String requestIdTemplate = "company-%d";
|
||||||
|
|
||||||
private static final ObjectMapper mapper = new ObjectMapper();
|
private static final ObjectMapper mapper = new ObjectMapper();
|
||||||
|
|
||||||
private OpenAiStreamClient openAiModelStreamClient;
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
||||||
SseEmitter sseEmitter = new SseEmitter(0L);
|
SseEmitter sseEmitter = new SseEmitter(0L);
|
||||||
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
|
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
|
||||||
// 获取对话消息列表
|
// 获取对话消息列表
|
||||||
List<Message> messages = chatRequest.getMessages();
|
List<Message> messages = chatRequest.getMessages();
|
||||||
// 用户对话内容
|
|
||||||
String chatString = null;
|
|
||||||
try {
|
try {
|
||||||
if (StpUtil.isLogin()) {
|
// 查询模型信息
|
||||||
// 通过模型名称查询模型信息
|
|
||||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||||
|
|
||||||
|
OpenAiStreamClient openAiModelStreamClient;
|
||||||
if(chatModelVo!=null){
|
if(chatModelVo!=null){
|
||||||
// 通过模型信息构建请求客户端
|
// 建请求客户端
|
||||||
openAiModelStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
|
openAiModelStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
|
||||||
|
// 设置默认提示词
|
||||||
|
chatRequest.setSysPrompt(chatModelVo.getSystemPrompt());
|
||||||
}else {
|
}else {
|
||||||
// 使用默认客户端
|
// 使用默认客户端
|
||||||
openAiModelStreamClient = openAiStreamClient;
|
openAiModelStreamClient = openAiStreamClient;
|
||||||
}
|
}
|
||||||
// 设置默认提示词
|
// 构建消息列表增加联网、知识库等内容
|
||||||
Message sysMessage = Message.builder().content(chatModelVo.getSystemPrompt()).role(Message.Role.SYSTEM).build();
|
buildChatMessageList(chatRequest);
|
||||||
messages.add(0,sysMessage);
|
|
||||||
|
|
||||||
// 查询向量库相关信息加入到上下文
|
// 根据模型名称前缀调用不同的处理逻辑
|
||||||
if(chatRequest.getKid()!=null){
|
switchModelAndHandle(chatRequest);
|
||||||
List<Message> knMessages = new ArrayList<>();
|
|
||||||
String content = messages.get(messages.size() - 1).getContent().toString();
|
|
||||||
List<String> nearestList;
|
|
||||||
List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
|
|
||||||
nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
|
|
||||||
for (String prompt : nearestList) {
|
|
||||||
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
|
|
||||||
knMessages.add(userMessage);
|
|
||||||
}
|
|
||||||
Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
|
|
||||||
knMessages.add(userMessage);
|
|
||||||
messages.addAll(knMessages);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取用户对话信息
|
|
||||||
Object content = messages.get(messages.size() - 1).getContent();
|
|
||||||
if (content instanceof List<?> listContent) {
|
|
||||||
if (CollectionUtil.isNotEmpty(listContent)) {
|
|
||||||
chatString = listContent.get(0).toString();
|
|
||||||
}
|
|
||||||
} else if (content instanceof String) {
|
|
||||||
chatString = (String) content;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 加载联网信息
|
|
||||||
if(chatRequest.getSearch()){
|
|
||||||
Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
|
|
||||||
messages.add(message);
|
|
||||||
}
|
|
||||||
}else {
|
|
||||||
// 未登录用户限制对话次数
|
// 未登录用户限制对话次数
|
||||||
|
if (!StpUtil.isLogin()) {
|
||||||
String clientIp = IpUtil.getClientIp(request);
|
String clientIp = IpUtil.getClientIp(request);
|
||||||
|
|
||||||
// 访客每天默认只能对话5次
|
// 访客每天默认只能对话5次
|
||||||
int timeWindowInSeconds = 5;
|
int timeWindowInSeconds = 5;
|
||||||
|
|
||||||
String redisKey = "clientIp:" + clientIp;
|
String redisKey = "clientIp:" + clientIp;
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
|
||||||
if (RedisUtils.getCacheObject(redisKey) == null) {
|
if (RedisUtils.getCacheObject(redisKey) == null) {
|
||||||
// 缓存有效时间1天
|
// 缓存有效时间1天
|
||||||
RedisUtils.setCacheObject(redisKey, count, Duration.ofSeconds(86400));
|
RedisUtils.setCacheObject(redisKey, count, Duration.ofSeconds(86400));
|
||||||
@@ -175,6 +141,7 @@ public class SseServiceImpl implements ISseService {
|
|||||||
.stream(chatRequest.getStream())
|
.stream(chatRequest.getStream())
|
||||||
.build();
|
.build();
|
||||||
openAiModelStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
|
openAiModelStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
|
||||||
|
|
||||||
// 保存消息记录 并扣除费用
|
// 保存消息记录 并扣除费用
|
||||||
chatCostService.deductToken(chatRequest);
|
chatCostService.deductToken(chatRequest);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -185,6 +152,69 @@ public class SseServiceImpl implements ISseService {
|
|||||||
return sseEmitter;
|
return sseEmitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据模型名称前缀调用不同的处理逻辑
|
||||||
|
*/
|
||||||
|
private void switchModelAndHandle(ChatRequest chatRequest) {
|
||||||
|
String model = chatRequest.getModel();
|
||||||
|
// 如果模型名称以ollama开头,则调用ollama中部署的本地模型
|
||||||
|
if (model.startsWith("ollama-")) {
|
||||||
|
String[] parts = chatRequest.getModel().split("ollama-", 2); // 限制分割次数为2
|
||||||
|
if (parts.length > 1) {
|
||||||
|
chatRequest.setModel(parts[1]);
|
||||||
|
ollamaChat(chatRequest);
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("Invalid ollama model name: " + chatRequest.getModel());
|
||||||
|
}
|
||||||
|
} else if (model.startsWith("gpt-4-gizmo")) {
|
||||||
|
chatRequest.setModel("gpt-4-gizmo");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建消息列表
|
||||||
|
*/
|
||||||
|
private void buildChatMessageList(ChatRequest chatRequest){
|
||||||
|
// 获取对话消息列表
|
||||||
|
List<Message> messages = chatRequest.getMessages();
|
||||||
|
// 设置系统默认提示词
|
||||||
|
Message sysMessage = Message.builder().content(chatRequest.getSysPrompt()).role(Message.Role.SYSTEM).build();
|
||||||
|
messages.add(0,sysMessage);
|
||||||
|
|
||||||
|
// 查询向量库相关信息加入到上下文
|
||||||
|
if(chatRequest.getKid()!=null){
|
||||||
|
List<Message> knMessages = new ArrayList<>();
|
||||||
|
String content = messages.get(messages.size() - 1).getContent().toString();
|
||||||
|
List<String> nearestList;
|
||||||
|
List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
|
||||||
|
nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
|
||||||
|
for (String prompt : nearestList) {
|
||||||
|
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
|
||||||
|
knMessages.add(userMessage);
|
||||||
|
}
|
||||||
|
Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
|
||||||
|
knMessages.add(userMessage);
|
||||||
|
messages.addAll(knMessages);
|
||||||
|
}
|
||||||
|
// 用户对话内容
|
||||||
|
String chatString = null;
|
||||||
|
// 获取用户对话信息
|
||||||
|
Object content = messages.get(messages.size() - 1).getContent();
|
||||||
|
if (content instanceof List<?> listContent) {
|
||||||
|
if (CollectionUtil.isNotEmpty(listContent)) {
|
||||||
|
chatString = listContent.get(0).toString();
|
||||||
|
}
|
||||||
|
} else if (content instanceof String) {
|
||||||
|
chatString = (String) content;
|
||||||
|
}
|
||||||
|
// 设置对话信息
|
||||||
|
chatRequest.setPrompt(chatString);
|
||||||
|
// 加载联网信息
|
||||||
|
if(chatRequest.getSearch()){
|
||||||
|
Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
|
||||||
|
messages.add(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 发送SSE错误事件的封装方法
|
* 发送SSE错误事件的封装方法
|
||||||
@@ -295,13 +325,13 @@ public class SseServiceImpl implements ISseService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SseEmitter ollamaChat(ChatRequest chatRequest) {
|
public SseEmitter ollamaChat(ChatRequest chatRequest) {
|
||||||
String[] parts = chatRequest.getModel().split("ollama-");
|
|
||||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||||
final SseEmitter emitter = new SseEmitter();
|
final SseEmitter emitter = new SseEmitter();
|
||||||
String host = chatModelVo.getApiHost();
|
String host = chatModelVo.getApiHost();
|
||||||
List<Message> msgList = chatRequest.getMessages();
|
List<Message> msgList = chatRequest.getMessages();
|
||||||
List<OllamaChatMessage> messages = new ArrayList<>();
|
|
||||||
|
|
||||||
|
List<OllamaChatMessage> messages = new ArrayList<>();
|
||||||
for (Message message : msgList) {
|
for (Message message : msgList) {
|
||||||
OllamaChatMessage ollamaChatMessage = new OllamaChatMessage();
|
OllamaChatMessage ollamaChatMessage = new OllamaChatMessage();
|
||||||
ollamaChatMessage.setRole(OllamaChatMessageRole.USER);
|
ollamaChatMessage.setRole(OllamaChatMessageRole.USER);
|
||||||
@@ -310,7 +340,7 @@ public class SseServiceImpl implements ISseService {
|
|||||||
}
|
}
|
||||||
OllamaAPI api = new OllamaAPI(host);
|
OllamaAPI api = new OllamaAPI(host);
|
||||||
api.setRequestTimeoutSeconds(100);
|
api.setRequestTimeoutSeconds(100);
|
||||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(parts[1]);
|
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatRequest.getModel());
|
||||||
|
|
||||||
OllamaChatRequestModel requestModel = builder
|
OllamaChatRequestModel requestModel = builder
|
||||||
.withMessages(messages)
|
.withMessages(messages)
|
||||||
@@ -356,11 +386,11 @@ public class SseServiceImpl implements ISseService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String webSearch (String prompt) {
|
public String webSearch (String prompt) {
|
||||||
String zhipuValue = configService.getConfigValue("zhipu", "key");
|
String zpValue = configService.getConfigValue("zhipu", "key");
|
||||||
if(StringUtils.isEmpty(zhipuValue)){
|
if(StringUtils.isEmpty(zpValue)){
|
||||||
throw new IllegalStateException("zhipu config value is empty,请在chat_config中配置zhipu key信息");
|
throw new IllegalStateException("请在chat_config中配置智谱key信息");
|
||||||
}else {
|
}else {
|
||||||
ClientV4 client = new ClientV4.Builder(zhipuValue)
|
ClientV4 client = new ClientV4.Builder(zpValue)
|
||||||
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
|
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
|
||||||
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
|
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
|
||||||
.build();
|
.build();
|
||||||
|
|||||||
Reference in New Issue
Block a user