mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-03-14 05:03:43 +08:00
@@ -53,9 +53,13 @@ public interface IChatModelService {
|
||||
* 通过模型名称获取模型信息
|
||||
*/
|
||||
ChatModelVo selectModelByName(String modelName);
|
||||
|
||||
/**
|
||||
* 通过模型分类获取模型信息
|
||||
*/
|
||||
ChatModelVo selectModelByCategory(String image);
|
||||
/**
|
||||
* 获取ppt模型信息
|
||||
*/
|
||||
ChatModel getPPT();
|
||||
|
||||
}
|
||||
|
||||
@@ -129,6 +129,13 @@ public class ChatModelServiceImpl implements IChatModelService {
|
||||
public ChatModelVo selectModelByName(String modelName) {
|
||||
return baseMapper.selectVoOne(Wrappers.<ChatModel>lambdaQuery().eq(ChatModel::getModelName, modelName));
|
||||
}
|
||||
/**
|
||||
* 通过模型分类获取模型信息
|
||||
*/
|
||||
@Override
|
||||
public ChatModelVo selectModelByCategory(String category) {
|
||||
return baseMapper.selectVoOne(Wrappers.<ChatModel>lambdaQuery().eq(ChatModel::getCategory, category));
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatModel getPPT() {
|
||||
|
||||
@@ -15,7 +15,9 @@ public enum ChatModeType {
|
||||
|
||||
QIANWEN("qianwen", "通义千问"),
|
||||
|
||||
VECTOR("vector", "知识库向量模型");
|
||||
VECTOR("vector", "知识库向量模型"),
|
||||
|
||||
IMAGE("image", "图片识别模型");
|
||||
|
||||
private final String code;
|
||||
private final String description;
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
package org.ruoyi.chat.service.chat.impl;
|
||||
|
||||
import io.modelcontextprotocol.client.McpSyncClient;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.config.ChatConfig;
|
||||
import org.ruoyi.chat.enums.ChatModeType;
|
||||
import org.ruoyi.chat.listener.SSEEventSourceListener;
|
||||
import org.ruoyi.chat.service.chat.IChatService;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||
import org.ruoyi.common.chat.request.ChatRequest;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 图片识别模型
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class ImageOpenAiServiceImpl implements IChatService {
|
||||
|
||||
@Autowired
|
||||
private IChatModelService chatModelService;
|
||||
|
||||
private final ChatClient chatClient;
|
||||
|
||||
public ImageOpenAiServiceImpl(ChatClient.Builder chatClientBuilder) {
|
||||
this.chatClient = chatClientBuilder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SseEmitter chat(ChatRequest chatRequest, SseEmitter emitter) {
|
||||
// 从数据库获取 image 类型的模型配置
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByCategory(ChatModeType.IMAGE.getCode());
|
||||
if (chatModelVo == null) {
|
||||
log.error("未找到 image 类型的模型配置");
|
||||
emitter.completeWithError(new IllegalStateException("未找到 image 类型的模型配置"));
|
||||
return emitter;
|
||||
}
|
||||
|
||||
// 创建 OpenAI 流客户端
|
||||
OpenAiStreamClient openAiStreamClient = ChatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
|
||||
List<Message> messages = chatRequest.getMessages();
|
||||
|
||||
// 创建 SSE 事件源监听器
|
||||
SSEEventSourceListener listener = new SSEEventSourceListener(emitter, chatRequest.getUserId(), chatRequest.getSessionId());
|
||||
|
||||
// 构建聊天完成请求
|
||||
ChatCompletion completion = ChatCompletion
|
||||
.builder()
|
||||
.messages(messages)
|
||||
.model(chatModelVo.getModelName()) // 使用数据库中配置的模型名称
|
||||
.stream(true)
|
||||
.build();
|
||||
|
||||
// 发起流式聊天完成请求
|
||||
openAiStreamClient.streamChatCompletion(completion, listener);
|
||||
|
||||
return emitter;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String getCategory() {
|
||||
return ChatModeType.IMAGE.getCode();
|
||||
}
|
||||
}
|
||||
@@ -125,7 +125,16 @@ public class SseServiceImpl implements ISseService {
|
||||
*/
|
||||
private void buildChatMessageList(ChatRequest chatRequest){
|
||||
String sysPrompt;
|
||||
chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||
// 矫正模型名称 如果是gpt-image 则查询image类型模型 获取模型名称
|
||||
if(chatRequest.getModel().equals("gpt-image")) {
|
||||
chatModelVo = chatModelService.selectModelByCategory("image");
|
||||
if (chatModelVo == null) {
|
||||
log.error("未找到image类型的模型配置");
|
||||
throw new IllegalStateException("未找到image类型的模型配置");
|
||||
}// chatRequest.setModel(chatModelVo.getModelName());
|
||||
}else{
|
||||
chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||
}
|
||||
// 获取对话消息列表
|
||||
List<Message> messages = chatRequest.getMessages();
|
||||
// 查询向量库相关信息加入到上下文
|
||||
|
||||
Reference in New Issue
Block a user