fix: ollama兼容联网查询 知识库检索

This commit is contained in:
ageerle
2025-04-11 12:05:13 +08:00
parent af33040117
commit efeb0bd6fb
3 changed files with 97 additions and 62 deletions

View File

@@ -466,8 +466,8 @@ public class OpenAiStreamClient {
* @since 1.1.3
*/
public ResponseBody textToSpeech(TextToSpeech textToSpeech){
Call<ResponseBody> responseBody = this.openAiApi.textToSpeech(textToSpeech);
try {
Call<ResponseBody> responseBody = this.openAiApi.textToSpeech(textToSpeech);
return responseBody.execute().body();
} catch (IOException e) {
throw new BaseException("文本转语音(同步)失败: "+e.getMessage());

View File

@@ -26,6 +26,11 @@ public class ChatRequest {
*/
private String prompt;
/**
* 系统提示词
*/
private String sysPrompt;
/**
* 是否开启流式对话
*/

View File

@@ -84,77 +84,43 @@ public class SseServiceImpl implements ISseService {
private final IChatCostService chatCostService;
private static final String requestIdTemplate = "mycompany-%d";
private static final String requestIdTemplate = "company-%d";
private static final ObjectMapper mapper = new ObjectMapper();
private OpenAiStreamClient openAiModelStreamClient;
@Override
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
SseEmitter sseEmitter = new SseEmitter(0L);
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
// 获取对话消息列表
List<Message> messages = chatRequest.getMessages();
// 用户对话内容
String chatString = null;
try {
if (StpUtil.isLogin()) {
// 通过模型名称查询模型信息
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
if(chatModelVo!=null){
// 通过模型信息构建请求客户端
openAiModelStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
}else {
// 使用默认客户端
openAiModelStreamClient = openAiStreamClient;
}
// 查询模型信息
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
OpenAiStreamClient openAiModelStreamClient;
if(chatModelVo!=null){
// 建请求客户端
openAiModelStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
// 设置默认提示词
Message sysMessage = Message.builder().content(chatModelVo.getSystemPrompt()).role(Message.Role.SYSTEM).build();
messages.add(0,sysMessage);
// 查询向量库相关信息加入到上下文
if(chatRequest.getKid()!=null){
List<Message> knMessages = new ArrayList<>();
String content = messages.get(messages.size() - 1).getContent().toString();
List<String> nearestList;
List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
for (String prompt : nearestList) {
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
knMessages.add(userMessage);
}
Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意回答问题时须严格根据我给你的系统上下文内容原文进行回答请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
knMessages.add(userMessage);
messages.addAll(knMessages);
}
// 获取用户对话信息
Object content = messages.get(messages.size() - 1).getContent();
if (content instanceof List<?> listContent) {
if (CollectionUtil.isNotEmpty(listContent)) {
chatString = listContent.get(0).toString();
}
} else if (content instanceof String) {
chatString = (String) content;
}
// 加载联网信息
if(chatRequest.getSearch()){
Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
messages.add(message);
}
chatRequest.setSysPrompt(chatModelVo.getSystemPrompt());
}else {
// 未登录用户限制对话次数
String clientIp = IpUtil.getClientIp(request);
// 使用默认客户端
openAiModelStreamClient = openAiStreamClient;
}
// 构建消息列表增加联网、知识库等内容
buildChatMessageList(chatRequest);
// 根据模型名称前缀调用不同的处理逻辑
switchModelAndHandle(chatRequest);
// 未登录用户限制对话次数
if (!StpUtil.isLogin()) {
String clientIp = IpUtil.getClientIp(request);
// 访客每天默认只能对话5次
int timeWindowInSeconds = 5;
String redisKey = "clientIp:" + clientIp;
int count = 0;
if (RedisUtils.getCacheObject(redisKey) == null) {
// 缓存有效时间1天
RedisUtils.setCacheObject(redisKey, count, Duration.ofSeconds(86400));
@@ -175,6 +141,7 @@ public class SseServiceImpl implements ISseService {
.stream(chatRequest.getStream())
.build();
openAiModelStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
// 保存消息记录 并扣除费用
chatCostService.deductToken(chatRequest);
} catch (Exception e) {
@@ -185,6 +152,69 @@ public class SseServiceImpl implements ISseService {
return sseEmitter;
}
/**
* 根据模型名称前缀调用不同的处理逻辑
*/
private void switchModelAndHandle(ChatRequest chatRequest) {
String model = chatRequest.getModel();
// 如果模型名称以ollama开头则调用ollama中部署的本地模型
if (model.startsWith("ollama-")) {
String[] parts = chatRequest.getModel().split("ollama-", 2); // 限制分割次数为2
if (parts.length > 1) {
chatRequest.setModel(parts[1]);
ollamaChat(chatRequest);
} else {
throw new IllegalArgumentException("Invalid ollama model name: " + chatRequest.getModel());
}
} else if (model.startsWith("gpt-4-gizmo")) {
chatRequest.setModel("gpt-4-gizmo");
}
}
/**
* 构建消息列表
*/
private void buildChatMessageList(ChatRequest chatRequest){
// 获取对话消息列表
List<Message> messages = chatRequest.getMessages();
// 设置系统默认提示词
Message sysMessage = Message.builder().content(chatRequest.getSysPrompt()).role(Message.Role.SYSTEM).build();
messages.add(0,sysMessage);
// 查询向量库相关信息加入到上下文
if(chatRequest.getKid()!=null){
List<Message> knMessages = new ArrayList<>();
String content = messages.get(messages.size() - 1).getContent().toString();
List<String> nearestList;
List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
for (String prompt : nearestList) {
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
knMessages.add(userMessage);
}
Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意回答问题时须严格根据我给你的系统上下文内容原文进行回答请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
knMessages.add(userMessage);
messages.addAll(knMessages);
}
// 用户对话内容
String chatString = null;
// 获取用户对话信息
Object content = messages.get(messages.size() - 1).getContent();
if (content instanceof List<?> listContent) {
if (CollectionUtil.isNotEmpty(listContent)) {
chatString = listContent.get(0).toString();
}
} else if (content instanceof String) {
chatString = (String) content;
}
// 设置对话信息
chatRequest.setPrompt(chatString);
// 加载联网信息
if(chatRequest.getSearch()){
Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
messages.add(message);
}
}
/**
* 发送SSE错误事件的封装方法
@@ -295,13 +325,13 @@ public class SseServiceImpl implements ISseService {
@Override
public SseEmitter ollamaChat(ChatRequest chatRequest) {
String[] parts = chatRequest.getModel().split("ollama-");
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
final SseEmitter emitter = new SseEmitter();
String host = chatModelVo.getApiHost();
List<Message> msgList = chatRequest.getMessages();
List<OllamaChatMessage> messages = new ArrayList<>();
List<OllamaChatMessage> messages = new ArrayList<>();
for (Message message : msgList) {
OllamaChatMessage ollamaChatMessage = new OllamaChatMessage();
ollamaChatMessage.setRole(OllamaChatMessageRole.USER);
@@ -310,7 +340,7 @@ public class SseServiceImpl implements ISseService {
}
OllamaAPI api = new OllamaAPI(host);
api.setRequestTimeoutSeconds(100);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(parts[1]);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(chatRequest.getModel());
OllamaChatRequestModel requestModel = builder
.withMessages(messages)
@@ -356,11 +386,11 @@ public class SseServiceImpl implements ISseService {
@Override
public String webSearch (String prompt) {
String zhipuValue = configService.getConfigValue("zhipu", "key");
if(StringUtils.isEmpty(zhipuValue)){
throw new IllegalStateException("zhipu config value is empty,请在chat_config中配置zhipu key信息");
String zpValue = configService.getConfigValue("zhipu", "key");
if(StringUtils.isEmpty(zpValue)){
throw new IllegalStateException("请在chat_config中配置智谱key信息");
}else {
ClientV4 client = new ClientV4.Builder(zhipuValue)
ClientV4 client = new ClientV4.Builder(zpValue)
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
.build();