feat: 功能优化

This commit is contained in:
lihao05
2025-10-20 10:12:44 +08:00
parent 9f6d363d55
commit e7d7de79fe
11 changed files with 136 additions and 112 deletions

View File

@@ -21,7 +21,8 @@ import java.util.concurrent.TimeUnit;
@Service @Service
public class SSEEmitterHelper { public class SSEEmitterHelper {
private static final Cache<SseEmitter, Boolean> COMPLETED_SSE = CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.MINUTES).build(); private static final Cache<SseEmitter, Boolean> COMPLETED_SSE = CacheBuilder.newBuilder()
.expireAfterWrite(10, TimeUnit.MINUTES).build();
@Resource @Resource
private StringRedisTemplate stringRedisTemplate; private StringRedisTemplate stringRedisTemplate;
@@ -41,8 +42,6 @@ public class SSEEmitterHelper {
} else { } else {
sendPartial(sseEmitter, name, " " + content); sendPartial(sseEmitter, name, " " + content);
} }
// content = content.replaceAll("[\\r\\n]", "\ndata:");
// sendPartial(sseEmitter, name, " " + content);
} }
public static void sendPartial(SseEmitter sseEmitter, String name, String msg) { public static void sendPartial(SseEmitter sseEmitter, String name, String msg) {
@@ -63,13 +62,6 @@ public class SSEEmitterHelper {
public boolean checkOrComplete(User user, SseEmitter sseEmitter) { public boolean checkOrComplete(User user, SseEmitter sseEmitter) {
//Check: rate limit
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
// if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
// sendErrorAndComplete(user.getId(), sseEmitter, "访问太过频繁");
// return false;
// }
//Check: If still waiting response //Check: If still waiting response
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId()); String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
String askingVal = stringRedisTemplate.opsForValue().get(askingKey); String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
@@ -125,7 +117,10 @@ public class SSEEmitterHelper {
return; return;
} }
try { try {
sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.ERROR).data(Objects.toString(errorMsg, ""))); SseEmitter.SseEventBuilder event = SseEmitter.event();
event.name(AdiConstant.SSEEventName.ERROR);
event.data(Objects.toString(errorMsg, ""));
sseEmitter.send(event);
} catch (IOException e) { } catch (IOException e) {
log.warn("sendErrorAndComplete userId:{},errorMsg:{}", userId, errorMsg); log.warn("sendErrorAndComplete userId:{},errorMsg:{}", userId, errorMsg);
throw new RuntimeException(e); throw new RuntimeException(e);

View File

@@ -35,10 +35,6 @@ public class JsonUtil {
objectMapper.registerModules(LocalDateTimeUtil.getSimpleModule(), new JavaTimeModule(), new Jdk8Module()); objectMapper.registerModules(LocalDateTimeUtil.getSimpleModule(), new JavaTimeModule(), new Jdk8Module());
} }
public static final ObjectMapper getObjectMapper() {
return objectMapper;
}
public static String toJson(Object obj) { public static String toJson(Object obj) {
String resp = null; String resp = null;
try { try {
@@ -67,20 +63,6 @@ public class JsonUtil {
return null; return null;
} }
/**
* 创建JSON生成器的静态方法, 使用标准输出
*
* @return
*/
private static JsonGenerator getGenerator(StringWriter sw) {
try {
return objectMapper.getFactory().createGenerator(sw);
} catch (IOException e) {
log.error("JsonUtil getGenerator error", e);
}
return null;
}
/** /**
* JSON对象反序列化 * JSON对象反序列化
*/ */
@@ -152,15 +134,6 @@ public class JsonUtil {
return result; return result;
} }
public static <T> List<T> toList(String json, Class<T> clazz) {
try {
return objectMapper.readValue(json, objectMapper.getTypeFactory().constructCollectionType(List.class, clazz));
} catch (JsonProcessingException e) {
log.error("反序列化失败", e);
}
return new ArrayList<>();
}
public static Map<String, Object> toMap(Object obj) { public static Map<String, Object> toMap(Object obj) {
try { try {
return objectMapper.convertValue(obj, new TypeReference<HashMap<String, Object>>() { return objectMapper.convertValue(obj, new TypeReference<HashMap<String, Object>>() {
@@ -178,8 +151,4 @@ public class JsonUtil {
return objectMapper.createObjectNode(); return objectMapper.createObjectNode();
} }
public static ArrayNode createArrayNode() {
return objectMapper.createArrayNode();
}
} }

View File

@@ -1,6 +1,7 @@
package org.ruoyi.workflow.workflow; package org.ruoyi.workflow.workflow;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.CollectionUtils;
@@ -97,7 +98,6 @@ public class WfNodeIODataUtil {
if (outputExist) { if (outputExist) {
return result; return result;
} }
if (null != defaultInputName) { if (null != defaultInputName) {
defaultInputName.setName(DEFAULT_OUTPUT_PARAM_NAME); defaultInputName.setName(DEFAULT_OUTPUT_PARAM_NAME);
} else if (null != txtExist) { } else if (null != txtExist) {
@@ -105,7 +105,7 @@ public class WfNodeIODataUtil {
} else if (null != first) { } else if (null != first) {
first.setName(DEFAULT_OUTPUT_PARAM_NAME); first.setName(DEFAULT_OUTPUT_PARAM_NAME);
} }
result.add(inputs.get(0));
return result; return result;
} }

View File

@@ -1,6 +1,7 @@
package org.ruoyi.workflow.workflow; package org.ruoyi.workflow.workflow;
import cn.hutool.core.collection.CollStreamUtil; import cn.hutool.core.collection.CollStreamUtil;
import cn.hutool.core.collection.CollUtil;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@@ -25,6 +26,7 @@ import org.ruoyi.workflow.service.WorkflowRuntimeService;
import org.ruoyi.workflow.util.JsonUtil; import org.ruoyi.workflow.util.JsonUtil;
import org.ruoyi.workflow.workflow.data.NodeIOData; import org.ruoyi.workflow.workflow.data.NodeIOData;
import org.ruoyi.workflow.workflow.def.WfNodeIO; import org.ruoyi.workflow.workflow.def.WfNodeIO;
import org.ruoyi.workflow.workflow.def.WfNodeParamRef;
import org.ruoyi.workflow.workflow.node.AbstractWfNode; import org.ruoyi.workflow.workflow.node.AbstractWfNode;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@@ -187,12 +189,31 @@ public class WorkflowEngine {
NodeProcessResult processResult = abstractWfNode.process((is) -> { NodeProcessResult processResult = abstractWfNode.process((is) -> {
workflowRuntimeNodeService.updateInput(runtimeNodeDto.getId(), nodeState); workflowRuntimeNodeService.updateInput(runtimeNodeDto.getId(), nodeState);
for (NodeIOData input : nodeState.getInputs()) { List<NodeIOData> nodeIODataList = nodeState.getInputs();
// if (!wfNode.getWorkflowComponentId().equals(1L)) {
// String inputConfig = wfNode.getInputConfig();
// WfNodeInputConfig nodeInputConfig = NodeInputConfigTypeHandler.fillNodeInputConfig(inputConfig);
// List<WfNodeParamRef> refInputs = nodeInputConfig.getRefInputs();
// Set<String> nameSet = CollStreamUtil.toSet(refInputs, WfNodeParamRef::getNodeParamName);
// if (CollUtil.isNotEmpty(nameSet)) {
// nodeIODataList = nodeIODataList.stream().filter(item -> nameSet.contains(item.getName()))
// .collect(Collectors.toList());
// } else {
// nodeIODataList = nodeIODataList.stream().filter(item -> item.getName().contains("input"))
// .collect(Collectors.toList());
// }
// }
for (NodeIOData input : nodeIODataList) {
String inputConfig = wfNode.getInputConfig();
WfNodeInputConfig nodeInputConfig = NodeInputConfigTypeHandler.fillNodeInputConfig(inputConfig);
List<WfNodeParamRef> refInputs = nodeInputConfig.getRefInputs();
if (CollUtil.isNotEmpty(refInputs) && "input".equals(input.getName())) {
continue;
}
SSEEmitterHelper.parseAndSendPartialMsg(sseEmitter, "[NODE_INPUT_" + wfNode.getUuid() + "]", JsonUtil.toJson(input)); SSEEmitterHelper.parseAndSendPartialMsg(sseEmitter, "[NODE_INPUT_" + wfNode.getUuid() + "]", JsonUtil.toJson(input));
} }
}, (is) -> { }, (is) -> {
workflowRuntimeNodeService.updateOutput(runtimeNodeDto.getId(), nodeState); workflowRuntimeNodeService.updateOutput(runtimeNodeDto.getId(), nodeState);
//并行节点内部的节点执行结束后,需要主动向客户端发送输出结果 //并行节点内部的节点执行结束后,需要主动向客户端发送输出结果
String nodeUuid = wfNode.getUuid(); String nodeUuid = wfNode.getUuid();
List<NodeIOData> nodeOutputs = nodeState.getOutputs(); List<NodeIOData> nodeOutputs = nodeState.getOutputs();
@@ -229,7 +250,7 @@ public class WorkflowEngine {
if (out instanceof StreamingOutput<WfNodeState> streamingOutput) { if (out instanceof StreamingOutput<WfNodeState> streamingOutput) {
String node = streamingOutput.node(); String node = streamingOutput.node();
String chunk = streamingOutput.chunk(); String chunk = streamingOutput.chunk();
log.info("node:{},chunk:{}", node, streamingOutput.chunk()); log.info("node:{},chunk:{}", node, chunk);
SSEEmitterHelper.parseAndSendPartialMsg(sseEmitter, "[NODE_CHUNK_" + node + "]", chunk); SSEEmitterHelper.parseAndSendPartialMsg(sseEmitter, "[NODE_CHUNK_" + node + "]", chunk);
} else { } else {
AbstractWfNode abstractWfNode = wfState.getCompletedNodes().stream() AbstractWfNode abstractWfNode = wfState.getCompletedNodes().stream()

View File

@@ -1,5 +1,7 @@
package org.ruoyi.workflow.workflow; package org.ruoyi.workflow.workflow;
import cn.hutool.core.collection.CollStreamUtil;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.data.message.UserMessage;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
@@ -10,16 +12,19 @@ import org.ruoyi.chat.factory.ChatServiceFactory;
import org.ruoyi.chat.service.chat.IChatService; import org.ruoyi.chat.service.chat.IChatService;
import org.ruoyi.common.chat.entity.chat.Message; import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.request.ChatRequest; import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.workflow.base.NodeInputConfigTypeHandler;
import org.ruoyi.workflow.entity.WorkflowNode; import org.ruoyi.workflow.entity.WorkflowNode;
import org.ruoyi.workflow.enums.WfIODataTypeEnum; import org.ruoyi.workflow.enums.WfIODataTypeEnum;
import org.ruoyi.workflow.util.JsonUtil; import org.ruoyi.workflow.util.JsonUtil;
import org.ruoyi.workflow.workflow.data.NodeIOData; import org.ruoyi.workflow.workflow.data.NodeIOData;
import org.ruoyi.workflow.workflow.data.NodeIODataContent; import org.ruoyi.workflow.workflow.data.NodeIODataContent;
import org.ruoyi.workflow.workflow.def.WfNodeParamRef;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import static org.ruoyi.workflow.cosntant.AdiConstant.WorkflowConstant.DEFAULT_OUTPUT_PARAM_NAME; import static org.ruoyi.workflow.cosntant.AdiConstant.WorkflowConstant.DEFAULT_OUTPUT_PARAM_NAME;
@@ -65,12 +70,12 @@ public class WorkflowUtil {
return String.valueOf(tip); return String.valueOf(tip);
} }
public void streamingInvokeLLM(WfState wfState, WfNodeState state, WorkflowNode node, String modelPlatform, public void streamingInvokeLLM(WfState wfState, WfNodeState state, WorkflowNode node, String category,
String modelName, List<UserMessage> msgs) { String modelName, List<UserMessage> systemMessage) {
log.info("stream invoke, modelPlatform: {}, modelName: {}", modelPlatform, modelName); log.info("stream invoke, category: {}, modelName: {}", category, modelName);
// 根据 modelPlatform 获取对应的 ChatService不使用计费代理工作流场景单独计费 // 根据 category 获取对应的 ChatService不使用计费代理工作流场景单独计费
IChatService chatService = chatServiceFactory.getOriginalService(modelPlatform); IChatService chatService = chatServiceFactory.getOriginalService(category);
StreamingChatGenerator<AgentState> streamingGenerator = StreamingChatGenerator.builder() StreamingChatGenerator<AgentState> streamingGenerator = StreamingChatGenerator.builder()
.mapResult(response -> { .mapResult(response -> {
@@ -85,19 +90,73 @@ public class WorkflowUtil {
.build(); .build();
// 构建 ruoyi-ai 的 ChatRequest // 构建 ruoyi-ai 的 ChatRequest
List<Message> messages = new ArrayList<>();
addUserMessage(node, state.getInputs(), messages);
addSystemMessage(systemMessage, messages);
ChatRequest chatRequest = new ChatRequest(); ChatRequest chatRequest = new ChatRequest();
chatRequest.setModel(modelName); chatRequest.setModel(modelName);
List<Message> messages = new ArrayList<>();
for (UserMessage userMsg : msgs) {
Message message = new Message();
message.setContent(userMsg.singleText());
message.setRole("user");
messages.add(message);
}
chatRequest.setMessages(messages); chatRequest.setMessages(messages);
// 使用工作流专用方法 // 使用工作流专用方法
chatService.chat(chatRequest, streamingGenerator.handler()); chatService.chat(chatRequest, streamingGenerator.handler());
wfState.getNodeToStreamingGenerator().put(node.getUuid(), streamingGenerator); wfState.getNodeToStreamingGenerator().put(node.getUuid(), streamingGenerator);
} }
/**
* 添加用户信息
*
* @param node
* @param messages
*/
private void addUserMessage(WorkflowNode node, List<NodeIOData> userMessage, List<Message> messages) {
if (CollUtil.isEmpty(userMessage)) {
return;
}
WfNodeInputConfig nodeInputConfig = NodeInputConfigTypeHandler.fillNodeInputConfig(node.getInputConfig());
List<WfNodeParamRef> refInputs = nodeInputConfig.getRefInputs();
Set<String> nameSet = CollStreamUtil.toSet(refInputs, WfNodeParamRef::getName);
userMessage.stream().filter(item -> nameSet.contains(item.getName()))
.map(item -> getMessage("role", item.getContent().getValue())).forEach(messages::add);
if (CollUtil.isNotEmpty(messages)) {
return;
}
userMessage.stream().filter(item -> "input".equals(item.getName()))
.map(item -> getMessage("role", item.getContent().getValue())).forEach(messages::add);
}
/**
* 组装message对象
*
* @param role
* @param value
* @return
*/
private Message getMessage(String role, Object value) {
Message message = new Message();
message.setContent(String.valueOf(value));
message.setRole(role);
return message;
}
/**
* 添加系统信息
*
* @param systemMessage
* @param messages
*/
private void addSystemMessage(List<UserMessage> systemMessage, List<Message> messages) {
if (CollUtil.isEmpty(systemMessage)) {
return;
}
systemMessage.stream().map(userMsg -> getMessage("system", userMsg.singleText())).forEach(messages::add);
}
} }

View File

@@ -18,5 +18,6 @@ public class WfNodeParamRef implements Serializable {
private String nodeUuid; private String nodeUuid;
@JsonProperty("node_param_name") @JsonProperty("node_param_name")
private String nodeParamName; private String nodeParamName;
private String name; private String name;
} }

View File

@@ -1,5 +1,6 @@
package org.ruoyi.workflow.workflow.node; package org.ruoyi.workflow.workflow.node;
import cn.hutool.core.collection.CollUtil;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import jakarta.validation.ConstraintViolation; import jakarta.validation.ConstraintViolation;
import lombok.Data; import lombok.Data;

View File

@@ -45,8 +45,8 @@ public class LLMAnswerNode extends AbstractWfNode {
WorkflowUtil workflowUtil = SpringUtil.getBean(WorkflowUtil.class); WorkflowUtil workflowUtil = SpringUtil.getBean(WorkflowUtil.class);
String modelName = nodeConfigObj.getModelName(); String modelName = nodeConfigObj.getModelName();
String category = nodeConfigObj.getCategory(); String category = nodeConfigObj.getCategory();
List<UserMessage> userMessageList = List.of(UserMessage.from(prompt)); List<UserMessage> systemMessage = List.of(UserMessage.from(prompt));
workflowUtil.streamingInvokeLLM(wfState, state, node, category, modelName, userMessageList); workflowUtil.streamingInvokeLLM(wfState, state, node, category, modelName, systemMessage);
return new NodeProcessResult(); return new NodeProcessResult();
} }
} }

View File

@@ -1,9 +1,16 @@
package org.ruoyi.chat.service.chat; package org.ruoyi.chat.service.chat;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import org.ruoyi.common.chat.request.ChatRequest; import org.ruoyi.common.chat.request.ChatRequest;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.ArrayList;
import java.util.List;
/** /**
* 对话Service接口 * 对话Service接口
* *
@@ -23,12 +30,27 @@ public interface IChatService {
* 工作流场景:支持 langchain4j 的 StreamingChatResponseHandler * 工作流场景:支持 langchain4j 的 StreamingChatResponseHandler
* *
* @param chatRequest ruoyi-ai 的请求对象 * @param chatRequest ruoyi-ai 的请求对象
* @param handler langchain4j 的流式响应处理器 * @param handler langchain4j 的流式响应处理器
*/ */
default void chat(ChatRequest chatRequest, StreamingChatResponseHandler handler) { default void chat(ChatRequest chatRequest, StreamingChatResponseHandler handler) {
throw new UnsupportedOperationException("此服务暂不支持工作流场景"); throw new UnsupportedOperationException("此服务暂不支持工作流场景");
} }
default dev.langchain4j.model.chat.request.ChatRequest convertToLangchainRequest(ChatRequest request) {
List<ChatMessage> messages = new ArrayList<>();
for (org.ruoyi.common.chat.entity.chat.Message msg : request.getMessages()) {
// 简单转换,您可以根据实际需求调整
if ("user".equals(msg.getRole())) {
messages.add(UserMessage.from(msg.getContent().toString()));
} else if ("system".equals(msg.getRole())) {
messages.add(SystemMessage.from(msg.getContent().toString()));
} else if ("assistant".equals(msg.getRole())) {
messages.add(AiMessage.from(msg.getContent().toString()));
}
}
return dev.langchain4j.model.chat.request.ChatRequest.builder().messages(messages).build();
}
/** /**
* 获取此服务支持的模型类别 * 获取此服务支持的模型类别
*/ */

View File

@@ -92,37 +92,19 @@ public class DeepSeekChatImpl implements IChatService {
.modelName(chatModelVo.getModelName()) .modelName(chatModelVo.getModelName())
.logRequests(true) .logRequests(true)
.logResponses(true) .logResponses(true)
.temperature(0.8) .temperature(0.7)
.build(); .build();
try { try {
// 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式 // 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式
dev.langchain4j.model.chat.request.ChatRequest langchainRequest = convertToLangchainRequest(request); dev.langchain4j.model.chat.request.ChatRequest chatRequest = convertToLangchainRequest(request);
chatModel.chat(langchainRequest, handler); chatModel.chat(chatRequest, handler);
} catch (Exception e) { } catch (Exception e) {
log.error("workflow deepseek请求失败{}", e.getMessage(), e); log.error("workflow deepseek请求失败{}", e.getMessage(), e);
throw new RuntimeException("DeepSeek workflow chat failed: " + e.getMessage(), e); throw new RuntimeException("DeepSeek workflow chat failed: " + e.getMessage(), e);
} }
} }
/**
* 转换请求格式
*/
private dev.langchain4j.model.chat.request.ChatRequest convertToLangchainRequest(ChatRequest request) {
List<ChatMessage> messages = new ArrayList<>();
for (org.ruoyi.common.chat.entity.chat.Message msg : request.getMessages()) {
// 简单转换,您可以根据实际需求调整
if ("user".equals(msg.getRole())) {
messages.add(UserMessage.from(msg.getContent().toString()));
} else if ("system".equals(msg.getRole())) {
messages.add(SystemMessage.from(msg.getContent().toString()));
} else if ("assistant".equals(msg.getRole())) {
messages.add(AiMessage.from(msg.getContent().toString()));
}
}
return dev.langchain4j.model.chat.request.ChatRequest.builder().messages(messages).build();
}
@Override @Override
public String getCategory() { public String getCategory() {
return ChatModeType.DEEPSEEK.getCode(); return ChatModeType.DEEPSEEK.getCode();

View File

@@ -1,10 +1,6 @@
package org.ruoyi.chat.service.chat.impl; package org.ruoyi.chat.service.chat.impl;
import dev.langchain4j.community.model.dashscope.QwenStreamingChatModel; import dev.langchain4j.community.model.dashscope.QwenStreamingChatModel;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
@@ -20,9 +16,6 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.ArrayList;
import java.util.List;
/** /**
* 阿里通义千问 * 阿里通义千问
@@ -92,33 +85,14 @@ public class QianWenAiChatServiceImpl implements IChatService {
try { try {
// 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式 // 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式
dev.langchain4j.model.chat.request.ChatRequest langchainRequest = convertToLangchainRequest(request); dev.langchain4j.model.chat.request.ChatRequest chatRequest = convertToLangchainRequest(request);
model.chat(langchainRequest, handler); model.chat(chatRequest, handler);
} catch (Exception e) { } catch (Exception e) {
log.error("workflow 千问请求失败:{}", e.getMessage(), e); log.error("workflow 千问请求失败:{}", e.getMessage(), e);
throw new RuntimeException("QianWen workflow chat failed: " + e.getMessage(), e); throw new RuntimeException("QianWen workflow chat failed: " + e.getMessage(), e);
} }
} }
/**
* 转换请求格式
*/
private dev.langchain4j.model.chat.request.ChatRequest convertToLangchainRequest(ChatRequest request) {
List<ChatMessage> messages = new ArrayList<>();
for (org.ruoyi.common.chat.entity.chat.Message msg : request.getMessages()) {
if ("user".equals(msg.getRole())) {
messages.add(UserMessage.from(msg.getContent().toString()));
} else if ("system".equals(msg.getRole())) {
messages.add(SystemMessage.from(msg.getContent().toString()));
} else if ("assistant".equals(msg.getRole())) {
messages.add(AiMessage.from(msg.getContent().toString()));
}
}
return dev.langchain4j.model.chat.request.ChatRequest.builder()
.messages(messages)
.build();
}
@Override @Override
public String getCategory() { public String getCategory() {
return ChatModeType.QIANWEN.getCode(); return ChatModeType.QIANWEN.getCode();