feat: mcp 测试版本

This commit is contained in:
ageerle
2025-04-15 17:34:21 +08:00
parent a4314dbbde
commit f24ff5bbdd
2 changed files with 39 additions and 113 deletions

View File

@@ -0,0 +1,26 @@
package org.ruoyi.chat.enums;
import lombok.Getter;
@Getter
public enum ChatModeType {
OLLAMA("ollama", "本地部署模型"), // token扣费
CHAT("chat", "中转模型"), // 次数扣费
VECTOR("vector", "知识库向量模型"); // 次数扣费
private final String code;
private final String description;
ChatModeType(String code, String description) {
this.code = code;
this.description = description;
}
public String getCode() {
return code;
}
public String getDescription() {
return description;
}
}

View File

@@ -2,39 +2,29 @@ package org.ruoyi.chat.service.chat.impl;
import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.core.collection.CollectionUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.ServiceException;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.service.v4.tools.*;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.ruoyi.chat.config.ChatConfig;
import org.ruoyi.chat.listener.SSEEventSourceListener;
import okhttp3.ResponseBody;
import org.ruoyi.chat.enums.ChatModeType;
import org.ruoyi.chat.service.chat.IChatCostService;
import org.ruoyi.chat.service.chat.IChatService;
import org.ruoyi.chat.service.chat.ISseService;
import org.ruoyi.chat.util.IpUtil;
import org.ruoyi.common.chat.config.LocalCache;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.common.core.utils.DateUtils;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.core.utils.file.FileUtils;
import org.ruoyi.common.core.utils.file.MimeTypeUtils;
import org.ruoyi.common.redis.utils.RedisUtils;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.EmbeddingService;
import org.ruoyi.service.IChatModelService;
@@ -55,11 +45,7 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
@Service
@Slf4j
@@ -72,19 +58,16 @@ public class SseServiceImpl implements ISseService {
private final VectorStoreService vectorStore;
private final ConfigService configService;
private final IChatCostService chatCostService;
private final IChatService chatService;
private final IChatModelService chatModelService;
private static final String requestIdTemplate = "company-%d";
private final OpenAIServiceImpl openAIService;
private static final ObjectMapper mapper = new ObjectMapper();
private final OllamaServiceImpl ollamaService;
private ChatModelVo chatModelVo;
private final ChatConfig chatConfig;
@Override
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
@@ -100,7 +83,7 @@ public class SseServiceImpl implements ISseService {
chatRequest.setUserId(chatCostService.getUserId());
// 保存消息记录 并扣除费用
// chatCostService.deductToken(chatRequest);
chatCostService.deductToken(chatRequest);
}
// 根据模型名称前缀调用不同的处理逻辑
switchModelAndHandle(chatRequest,sseEmitter);
@@ -143,35 +126,11 @@ public class SseServiceImpl implements ISseService {
* 根据模型名称前缀调用不同的处理逻辑
*/
private void switchModelAndHandle(ChatRequest chatRequest,SseEmitter emitter) {
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(emitter);
String model = chatRequest.getModel();
// 如果模型名称以ollama开头则调用ollama中部署的本地模型
if (model.startsWith("ollama-")) {
String[] parts = chatRequest.getModel().split("ollama-", 2);
if (parts.length > 1) {
chatRequest.setModel(parts[1]);
chatService.mcpChat(chatRequest,emitter);
} else {
throw new IllegalArgumentException("Invalid ollama model name: " + chatRequest.getModel());
}
// 调用ollama中部署的本地模型
if (ChatModeType.OLLAMA.getCode().equals(chatModelVo.getCategory())) {
ollamaService.chat(chatRequest,emitter);
} else {
if (model.startsWith("gpt-4-gizmo")) {
chatRequest.setModel("gpt-4-gizmo");
}
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
//openAiStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
ChatCompletion completion = ChatCompletion
.builder()
.messages(chatRequest.getMessages())
.model(chatRequest.getModel())
.temperature(0.2)
.topP(1.0)
.stream(true)
.build();
openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
openAIService.chat(chatRequest,emitter);
}
}
@@ -179,7 +138,7 @@ public class SseServiceImpl implements ISseService {
* 构建消息列表
*/
private void buildChatMessageList(ChatRequest chatRequest){
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
// 获取对话消息列表
List<Message> messages = chatRequest.getMessages();
String sysPrompt = chatModelVo.getSystemPrompt();
@@ -220,11 +179,6 @@ public class SseServiceImpl implements ISseService {
}
// 设置对话信息
chatRequest.setPrompt(chatString);
// 加载联网信息
if(chatRequest.getSearch()){
Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
messages.add(message);
}
}
@@ -333,58 +287,4 @@ public class SseServiceImpl implements ISseService {
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
}
@Override
public String webSearch (String prompt) {
String zpValue = configService.getConfigValue("zhipu", "key");
if(StringUtils.isEmpty(zpValue)){
throw new IllegalStateException("请在chat_config中配置智谱key信息");
}else {
ClientV4 client = new ClientV4.Builder(zpValue)
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
.build();
SearchChatMessage jsonNodes = new SearchChatMessage();
jsonNodes.setRole(Message.Role.USER.getName());
jsonNodes.setContent(prompt);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
.model("web-search-pro")
.stream(Boolean.TRUE)
.messages(Collections.singletonList(jsonNodes))
.requestId(requestId)
.build();
WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest);
List<ChoiceDelta> choices = new ArrayList<>();
if (webSearchApiResponse.isSuccess()) {
AtomicBoolean isFirst = new AtomicBoolean(true);
AtomicReference<WebSearchPro> lastAccumulator = new AtomicReference<>();
webSearchApiResponse.getFlowable().map(result -> result)
.doOnNext(accumulator -> {
{
if (isFirst.getAndSet(false)) {
log.info("Response: ");
}
ChoiceDelta delta = accumulator.getChoices().get(0).getDelta();
if (delta != null && delta.getToolCalls() != null) {
log.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls()));
}
choices.add(delta);
}
})
.doOnComplete(() -> System.out.println("Stream completed."))
.doOnError(throwable -> System.err.println("Error: " + throwable))
.blockingSubscribe();
WebSearchPro chatMessageAccumulator = lastAccumulator.get();
webSearchApiResponse.setFlowable(null);
webSearchApiResponse.setData(chatMessageAccumulator);
}
return choices.get(1).getToolCalls().toString();
}
}
}