mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-05 15:57:32 +00:00
feat: mcp 测试版本
This commit is contained in:
@@ -18,7 +18,7 @@ import org.ruoyi.common.core.utils.SpringUtils;
|
||||
import org.ruoyi.common.core.utils.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
@@ -34,13 +34,13 @@ import java.util.Objects;
|
||||
public class SSEEventSourceListener extends EventSourceListener {
|
||||
|
||||
@Autowired(required = false)
|
||||
public SSEEventSourceListener(ResponseBodyEmitter emitter) {
|
||||
public SSEEventSourceListener(SseEmitter emitter) {
|
||||
this.emitter = emitter;
|
||||
}
|
||||
|
||||
private ResponseBodyEmitter emitter;
|
||||
private SseEmitter emitter;
|
||||
|
||||
private StringBuilder stringBuffer;
|
||||
private StringBuilder stringBuffer = new StringBuilder();
|
||||
|
||||
private String modelName;
|
||||
|
||||
@@ -61,7 +61,6 @@ public class SSEEventSourceListener extends EventSourceListener {
|
||||
@Override
|
||||
public void onEvent(@NotNull EventSource eventSource, String id, String type, String data) {
|
||||
try {
|
||||
|
||||
if ("[DONE]".equals(data)) {
|
||||
//成功响应
|
||||
emitter.complete();
|
||||
@@ -72,25 +71,23 @@ public class SSEEventSourceListener extends EventSourceListener {
|
||||
chatCostService.deductToken(chatRequest);
|
||||
return;
|
||||
}
|
||||
// 解析返回内容
|
||||
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class);
|
||||
if(completionResponse == null || CollectionUtil.isEmpty(completionResponse.getChoices())){
|
||||
return;
|
||||
}
|
||||
Object content = completionResponse.getChoices().get(0).getDelta().getContent();
|
||||
if(content == null){
|
||||
content = completionResponse.getChoices().get(0).getDelta().getReasoningContent();
|
||||
if(content == null) return;
|
||||
|
||||
if(content != null ){
|
||||
if(StringUtils.isEmpty(modelName)){
|
||||
modelName = completionResponse.getModel();
|
||||
}
|
||||
stringBuffer.append(content);
|
||||
emitter.send(content);
|
||||
}
|
||||
if(StringUtils.isEmpty(modelName)){
|
||||
modelName = completionResponse.getModel();
|
||||
}
|
||||
stringBuffer.append(content);
|
||||
emitter.send(data);
|
||||
} catch (Exception e) {
|
||||
log.error("sse信息推送失败{}内容:{}",e.getMessage(),data);
|
||||
eventSource.cancel();
|
||||
emitter.completeWithError(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,4 +40,9 @@ public interface IChatCostService {
|
||||
* 判断用户是否付费
|
||||
*/
|
||||
void checkUserGrade();
|
||||
|
||||
/**
|
||||
* 获取登录用户id
|
||||
*/
|
||||
Long getUserId();
|
||||
}
|
||||
|
||||
@@ -22,5 +22,5 @@ public interface IChatService {
|
||||
* 客户端发送消息到服务端
|
||||
* @param chatRequest 请求对象
|
||||
*/
|
||||
SseEmitter mcpChat(ChatRequest chatRequest,SseEmitter emitter);
|
||||
void mcpChat(ChatRequest chatRequest,SseEmitter emitter);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.enums.BillingType;
|
||||
import org.ruoyi.chat.enums.UserGradeType;
|
||||
import org.ruoyi.chat.service.chat.IChatCostService;
|
||||
import org.ruoyi.common.chat.config.LocalCache;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.common.chat.utils.TikTokensUtil;
|
||||
import org.ruoyi.common.core.domain.model.LoginUser;
|
||||
@@ -96,6 +97,12 @@ public class ChatCostServiceImpl implements IChatCostService {
|
||||
chatToken.setUserId(chatMessageBo.getUserId());
|
||||
chatTokenService.editToken(chatToken);
|
||||
}
|
||||
Object userId = LocalCache.CACHE.get("userId");
|
||||
if(userId!=null){
|
||||
chatMessageBo.setUserId((Long) userId);
|
||||
}else {
|
||||
chatMessageBo.setUserId(getUserId());
|
||||
}
|
||||
// 保存消息记录
|
||||
chatMessageService.insertByBo(chatMessageBo);
|
||||
}
|
||||
|
||||
@@ -7,24 +7,13 @@ import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequestModel;
|
||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.chat.util.SSEUtil;
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.memory.InMemoryChatMemory;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.ollama.api.OllamaModel;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
@@ -35,27 +24,11 @@ import java.util.concurrent.CompletableFuture;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OllamaServiceImpl implements IChatService {
|
||||
public class OllamaServiceImpl {
|
||||
|
||||
@Autowired
|
||||
private IChatModelService chatModelService;
|
||||
@Autowired
|
||||
private IChatModelService chatModelService;
|
||||
|
||||
private final ChatClient chatClient;
|
||||
|
||||
private final ChatMemory chatMemory = new InMemoryChatMemory();
|
||||
|
||||
public OllamaServiceImpl(ChatClient.Builder chatClientBuilder,ToolCallbackProvider tools) {
|
||||
this.chatClient = chatClientBuilder
|
||||
.defaultTools(tools)
|
||||
.defaultOptions(
|
||||
OllamaOptions.builder()
|
||||
.model(OllamaModel.QWEN_2_5_7B)
|
||||
.temperature(0.4)
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SseEmitter chat(ChatRequest chatRequest,SseEmitter emitter) {
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||
String host = chatModelVo.getApiHost();
|
||||
@@ -100,44 +73,4 @@ public class OllamaServiceImpl implements IChatService {
|
||||
return emitter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SseEmitter mcpChat(ChatRequest chatRequest, SseEmitter emitter) {
|
||||
List<Message> msgList = chatRequest.getMessages();
|
||||
// 添加记忆
|
||||
for (int i = 0; i < msgList.size(); i++) {
|
||||
org.springframework.ai.chat.messages.Message springAiMessage = new UserMessage(msgList.get(i).getContent().toString());
|
||||
chatMemory.add(String.valueOf(i),springAiMessage);
|
||||
}
|
||||
var messageChatMemoryAdvisor = new MessageChatMemoryAdvisor(chatMemory, chatRequest.getUserId(), 10);
|
||||
|
||||
this.chatClient.prompt(chatRequest.getPrompt())
|
||||
.advisors(messageChatMemoryAdvisor)
|
||||
.stream()
|
||||
.chatResponse()
|
||||
.subscribe(
|
||||
chatResponse -> {
|
||||
try {
|
||||
emitter.send(chatResponse, MediaType.APPLICATION_JSON);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
},
|
||||
error -> {
|
||||
try {
|
||||
emitter.completeWithError(error);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
},
|
||||
() -> {
|
||||
try {
|
||||
emitter.complete();
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
return emitter;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
package org.ruoyi.chat.service.chat.impl;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.memory.InMemoryChatMemory;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OpenAIServiceImpl implements IChatService {
|
||||
|
||||
private final ChatClient chatClient;
|
||||
|
||||
private final ChatMemory chatMemory = new InMemoryChatMemory();
|
||||
|
||||
|
||||
public OpenAIServiceImpl(ChatClient.Builder chatClientBuilder, ToolCallbackProvider tools) {
|
||||
this.chatClient = chatClientBuilder
|
||||
.defaultTools(tools)
|
||||
.defaultOptions(
|
||||
OpenAiChatOptions.builder()
|
||||
.model("gpt-4o-mini")
|
||||
.temperature(0.4)
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SseEmitter chat(ChatRequest chatRequest,SseEmitter emitter) {
|
||||
return emitter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mcpChat(ChatRequest chatRequest, SseEmitter emitter) {
|
||||
List<Message> msgList = chatRequest.getMessages();
|
||||
// 添加记忆
|
||||
for (int i = 0; i < msgList.size(); i++) {
|
||||
org.springframework.ai.chat.messages.Message springAiMessage = new UserMessage(msgList.get(i).getContent().toString());
|
||||
chatMemory.add(String.valueOf(i), springAiMessage);
|
||||
}
|
||||
var messageChatMemoryAdvisor = new MessageChatMemoryAdvisor(chatMemory, chatRequest.getUserId().toString(), 10);
|
||||
|
||||
Flux<String> content = chatClient
|
||||
.prompt(chatRequest.getPrompt())
|
||||
.advisors(messageChatMemoryAdvisor)
|
||||
.stream().content();
|
||||
|
||||
content.publishOn(Schedulers.boundedElastic())
|
||||
.doOnNext(text -> {
|
||||
try {
|
||||
emitter.send(text);
|
||||
} catch (IOException e) {
|
||||
emitter.completeWithError(e);
|
||||
}
|
||||
})
|
||||
.doOnError(error -> {
|
||||
log.error("Error in SSE stream: ", error);
|
||||
emitter.completeWithError(error);
|
||||
})
|
||||
.doOnComplete(emitter::complete)
|
||||
.subscribe();
|
||||
}
|
||||
}
|
||||
@@ -11,11 +11,13 @@ import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
|
||||
import org.ruoyi.chat.config.ChatConfig;
|
||||
import org.ruoyi.chat.listener.SSEEventSourceListener;
|
||||
import org.ruoyi.chat.service.chat.IChatCostService;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.chat.service.chat.ISseService;
|
||||
import org.ruoyi.chat.util.IpUtil;
|
||||
import org.ruoyi.chat.util.SSEUtil;
|
||||
import org.ruoyi.common.chat.config.LocalCache;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
||||
@@ -33,7 +35,9 @@ import org.ruoyi.common.core.utils.file.MimeTypeUtils;
|
||||
|
||||
import org.ruoyi.common.redis.utils.RedisUtils;
|
||||
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.service.EmbeddingService;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.core.io.InputStreamResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
@@ -74,27 +78,35 @@ public class SseServiceImpl implements ISseService {
|
||||
|
||||
private final IChatService chatService;
|
||||
|
||||
private final IChatModelService chatModelService;
|
||||
|
||||
private static final String requestIdTemplate = "company-%d";
|
||||
|
||||
private static final ObjectMapper mapper = new ObjectMapper();
|
||||
|
||||
private final ChatConfig chatConfig;
|
||||
|
||||
@Override
|
||||
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
||||
SseEmitter sseEmitter = new SseEmitter(0L);
|
||||
SseEmitter sseEmitter = new SseEmitter();
|
||||
try {
|
||||
// 构建消息列表增加联网、知识库等内容
|
||||
buildChatMessageList(chatRequest);
|
||||
if (!StpUtil.isLogin()) {
|
||||
// 未登录用户限制对话次数
|
||||
checkUnauthenticatedUserChatLimit(request);
|
||||
}else {
|
||||
LocalCache.CACHE.put("userId", chatCostService.getUserId());
|
||||
|
||||
chatRequest.setUserId(chatCostService.getUserId());
|
||||
// 保存消息记录 并扣除费用
|
||||
// chatCostService.deductToken(chatRequest);
|
||||
}
|
||||
// 根据模型名称前缀调用不同的处理逻辑
|
||||
switchModelAndHandle(chatRequest,sseEmitter);
|
||||
// 未登录用户限制对话次数
|
||||
checkUnauthenticatedUserChatLimit(request);
|
||||
// 保存消息记录 并扣除费用
|
||||
chatCostService.deductToken(chatRequest);
|
||||
} catch (Exception e) {
|
||||
String message = e.getMessage();
|
||||
SSEUtil.sendErrorEvent(sseEmitter, message);
|
||||
return sseEmitter;
|
||||
log.error(e.getMessage(),e);
|
||||
sseEmitter.completeWithError(e);
|
||||
}
|
||||
return sseEmitter;
|
||||
}
|
||||
@@ -106,8 +118,7 @@ public class SseServiceImpl implements ISseService {
|
||||
* @throws ServiceException 如果当日免费次数已用完
|
||||
*/
|
||||
public void checkUnauthenticatedUserChatLimit(HttpServletRequest request) throws ServiceException {
|
||||
// 未登录用户限制对话次数
|
||||
if (!StpUtil.isLogin()) {
|
||||
|
||||
String clientIp = IpUtil.getClientIp(request);
|
||||
// 访客每天默认只能对话5次
|
||||
int timeWindowInSeconds = 5;
|
||||
@@ -125,13 +136,14 @@ public class SseServiceImpl implements ISseService {
|
||||
count++;
|
||||
RedisUtils.setCacheObject(redisKey, count);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型名称前缀调用不同的处理逻辑
|
||||
*/
|
||||
private void switchModelAndHandle(ChatRequest chatRequest,SseEmitter emitter) {
|
||||
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(emitter);
|
||||
String model = chatRequest.getModel();
|
||||
// 如果模型名称以ollama开头,则调用ollama中部署的本地模型
|
||||
if (model.startsWith("ollama-")) {
|
||||
@@ -142,8 +154,24 @@ public class SseServiceImpl implements ISseService {
|
||||
} else {
|
||||
throw new IllegalArgumentException("Invalid ollama model name: " + chatRequest.getModel());
|
||||
}
|
||||
} else if (model.startsWith("gpt-4-gizmo")) {
|
||||
chatRequest.setModel("gpt-4-gizmo");
|
||||
} else {
|
||||
|
||||
if (model.startsWith("gpt-4-gizmo")) {
|
||||
chatRequest.setModel("gpt-4-gizmo");
|
||||
}
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||
//openAiStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
|
||||
|
||||
ChatCompletion completion = ChatCompletion
|
||||
.builder()
|
||||
.messages(chatRequest.getMessages())
|
||||
.model(chatRequest.getModel())
|
||||
.temperature(0.2)
|
||||
.topP(1.0)
|
||||
.stream(true)
|
||||
.build();
|
||||
openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,9 +179,10 @@ public class SseServiceImpl implements ISseService {
|
||||
* 构建消息列表
|
||||
*/
|
||||
private void buildChatMessageList(ChatRequest chatRequest){
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||
// 获取对话消息列表
|
||||
List<Message> messages = chatRequest.getMessages();
|
||||
String sysPrompt = chatRequest.getSysPrompt();
|
||||
String sysPrompt = chatModelVo.getSystemPrompt();
|
||||
if(StringUtils.isEmpty(sysPrompt)){
|
||||
sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手,名字叫熊猫助手。你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。" +
|
||||
"当前时间:"+ DateUtils.getDate();
|
||||
@@ -162,8 +191,9 @@ public class SseServiceImpl implements ISseService {
|
||||
Message sysMessage = Message.builder().content(sysPrompt).role(Message.Role.SYSTEM).build();
|
||||
messages.add(0,sysMessage);
|
||||
|
||||
chatRequest.setSysPrompt(sysPrompt);
|
||||
// 查询向量库相关信息加入到上下文
|
||||
if(chatRequest.getKid()!=null){
|
||||
if(StringUtils.isNotEmpty(chatRequest.getKid())){
|
||||
List<Message> knMessages = new ArrayList<>();
|
||||
String content = messages.get(messages.size() - 1).getContent().toString();
|
||||
List<String> nearestList;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.ruoyi.chat.util;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
@@ -19,7 +20,7 @@ public class SSEUtil {
|
||||
* @param sseEmitter sse事件对象
|
||||
* @param errorMessage 错误信息
|
||||
*/
|
||||
public static void sendErrorEvent(SseEmitter sseEmitter, String errorMessage) {
|
||||
public static void sendErrorEvent(ResponseBodyEmitter sseEmitter, String errorMessage) {
|
||||
SseEmitter.SseEventBuilder event = SseEmitter.event()
|
||||
.name("error")
|
||||
.data(errorMessage);
|
||||
|
||||
Reference in New Issue
Block a user