feat: 功能优化

This commit is contained in:
lihao05
2025-10-20 10:12:44 +08:00
parent 9f6d363d55
commit e7d7de79fe
11 changed files with 136 additions and 112 deletions

View File

@@ -21,7 +21,8 @@ import java.util.concurrent.TimeUnit;
@Service
public class SSEEmitterHelper {
private static final Cache<SseEmitter, Boolean> COMPLETED_SSE = CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.MINUTES).build();
private static final Cache<SseEmitter, Boolean> COMPLETED_SSE = CacheBuilder.newBuilder()
.expireAfterWrite(10, TimeUnit.MINUTES).build();
@Resource
private StringRedisTemplate stringRedisTemplate;
@@ -41,8 +42,6 @@ public class SSEEmitterHelper {
} else {
sendPartial(sseEmitter, name, " " + content);
}
// content = content.replaceAll("[\\r\\n]", "\ndata:");
// sendPartial(sseEmitter, name, " " + content);
}
public static void sendPartial(SseEmitter sseEmitter, String name, String msg) {
@@ -63,13 +62,6 @@ public class SSEEmitterHelper {
public boolean checkOrComplete(User user, SseEmitter sseEmitter) {
//Check: rate limit
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
// if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
// sendErrorAndComplete(user.getId(), sseEmitter, "访问太过频繁");
// return false;
// }
//Check: If still waiting response
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
@@ -125,7 +117,10 @@ public class SSEEmitterHelper {
return;
}
try {
sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.ERROR).data(Objects.toString(errorMsg, "")));
SseEmitter.SseEventBuilder event = SseEmitter.event();
event.name(AdiConstant.SSEEventName.ERROR);
event.data(Objects.toString(errorMsg, ""));
sseEmitter.send(event);
} catch (IOException e) {
log.warn("sendErrorAndComplete userId:{},errorMsg:{}", userId, errorMsg);
throw new RuntimeException(e);

View File

@@ -35,10 +35,6 @@ public class JsonUtil {
objectMapper.registerModules(LocalDateTimeUtil.getSimpleModule(), new JavaTimeModule(), new Jdk8Module());
}
public static final ObjectMapper getObjectMapper() {
return objectMapper;
}
public static String toJson(Object obj) {
String resp = null;
try {
@@ -67,20 +63,6 @@ public class JsonUtil {
return null;
}
/**
* 创建JSON生成器的静态方法, 使用标准输出
*
* @return
*/
private static JsonGenerator getGenerator(StringWriter sw) {
try {
return objectMapper.getFactory().createGenerator(sw);
} catch (IOException e) {
log.error("JsonUtil getGenerator error", e);
}
return null;
}
/**
* JSON对象反序列化
*/
@@ -152,15 +134,6 @@ public class JsonUtil {
return result;
}
public static <T> List<T> toList(String json, Class<T> clazz) {
try {
return objectMapper.readValue(json, objectMapper.getTypeFactory().constructCollectionType(List.class, clazz));
} catch (JsonProcessingException e) {
log.error("反序列化失败", e);
}
return new ArrayList<>();
}
public static Map<String, Object> toMap(Object obj) {
try {
return objectMapper.convertValue(obj, new TypeReference<HashMap<String, Object>>() {
@@ -178,8 +151,4 @@ public class JsonUtil {
return objectMapper.createObjectNode();
}
public static ArrayNode createArrayNode() {
return objectMapper.createArrayNode();
}
}

View File

@@ -1,6 +1,7 @@
package org.ruoyi.workflow.workflow;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.commons.collections4.CollectionUtils;
@@ -97,7 +98,6 @@ public class WfNodeIODataUtil {
if (outputExist) {
return result;
}
if (null != defaultInputName) {
defaultInputName.setName(DEFAULT_OUTPUT_PARAM_NAME);
} else if (null != txtExist) {
@@ -105,7 +105,7 @@ public class WfNodeIODataUtil {
} else if (null != first) {
first.setName(DEFAULT_OUTPUT_PARAM_NAME);
}
result.add(inputs.get(0));
return result;
}

View File

@@ -1,6 +1,7 @@
package org.ruoyi.workflow.workflow;
import cn.hutool.core.collection.CollStreamUtil;
import cn.hutool.core.collection.CollUtil;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@@ -25,6 +26,7 @@ import org.ruoyi.workflow.service.WorkflowRuntimeService;
import org.ruoyi.workflow.util.JsonUtil;
import org.ruoyi.workflow.workflow.data.NodeIOData;
import org.ruoyi.workflow.workflow.def.WfNodeIO;
import org.ruoyi.workflow.workflow.def.WfNodeParamRef;
import org.ruoyi.workflow.workflow.node.AbstractWfNode;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@@ -187,12 +189,31 @@ public class WorkflowEngine {
NodeProcessResult processResult = abstractWfNode.process((is) -> {
workflowRuntimeNodeService.updateInput(runtimeNodeDto.getId(), nodeState);
for (NodeIOData input : nodeState.getInputs()) {
List<NodeIOData> nodeIODataList = nodeState.getInputs();
// if (!wfNode.getWorkflowComponentId().equals(1L)) {
// String inputConfig = wfNode.getInputConfig();
// WfNodeInputConfig nodeInputConfig = NodeInputConfigTypeHandler.fillNodeInputConfig(inputConfig);
// List<WfNodeParamRef> refInputs = nodeInputConfig.getRefInputs();
// Set<String> nameSet = CollStreamUtil.toSet(refInputs, WfNodeParamRef::getNodeParamName);
// if (CollUtil.isNotEmpty(nameSet)) {
// nodeIODataList = nodeIODataList.stream().filter(item -> nameSet.contains(item.getName()))
// .collect(Collectors.toList());
// } else {
// nodeIODataList = nodeIODataList.stream().filter(item -> item.getName().contains("input"))
// .collect(Collectors.toList());
// }
// }
for (NodeIOData input : nodeIODataList) {
String inputConfig = wfNode.getInputConfig();
WfNodeInputConfig nodeInputConfig = NodeInputConfigTypeHandler.fillNodeInputConfig(inputConfig);
List<WfNodeParamRef> refInputs = nodeInputConfig.getRefInputs();
if (CollUtil.isNotEmpty(refInputs) && "input".equals(input.getName())) {
continue;
}
SSEEmitterHelper.parseAndSendPartialMsg(sseEmitter, "[NODE_INPUT_" + wfNode.getUuid() + "]", JsonUtil.toJson(input));
}
}, (is) -> {
workflowRuntimeNodeService.updateOutput(runtimeNodeDto.getId(), nodeState);
//并行节点内部的节点执行结束后,需要主动向客户端发送输出结果
String nodeUuid = wfNode.getUuid();
List<NodeIOData> nodeOutputs = nodeState.getOutputs();
@@ -229,7 +250,7 @@ public class WorkflowEngine {
if (out instanceof StreamingOutput<WfNodeState> streamingOutput) {
String node = streamingOutput.node();
String chunk = streamingOutput.chunk();
log.info("node:{},chunk:{}", node, streamingOutput.chunk());
log.info("node:{},chunk:{}", node, chunk);
SSEEmitterHelper.parseAndSendPartialMsg(sseEmitter, "[NODE_CHUNK_" + node + "]", chunk);
} else {
AbstractWfNode abstractWfNode = wfState.getCompletedNodes().stream()

View File

@@ -1,5 +1,7 @@
package org.ruoyi.workflow.workflow;
import cn.hutool.core.collection.CollStreamUtil;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import dev.langchain4j.data.message.UserMessage;
import jakarta.annotation.Resource;
@@ -10,16 +12,19 @@ import org.ruoyi.chat.factory.ChatServiceFactory;
import org.ruoyi.chat.service.chat.IChatService;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.request.ChatRequest;
import org.ruoyi.workflow.base.NodeInputConfigTypeHandler;
import org.ruoyi.workflow.entity.WorkflowNode;
import org.ruoyi.workflow.enums.WfIODataTypeEnum;
import org.ruoyi.workflow.util.JsonUtil;
import org.ruoyi.workflow.workflow.data.NodeIOData;
import org.ruoyi.workflow.workflow.data.NodeIODataContent;
import org.ruoyi.workflow.workflow.def.WfNodeParamRef;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.ruoyi.workflow.cosntant.AdiConstant.WorkflowConstant.DEFAULT_OUTPUT_PARAM_NAME;
@@ -65,12 +70,12 @@ public class WorkflowUtil {
return String.valueOf(tip);
}
public void streamingInvokeLLM(WfState wfState, WfNodeState state, WorkflowNode node, String modelPlatform,
String modelName, List<UserMessage> msgs) {
log.info("stream invoke, modelPlatform: {}, modelName: {}", modelPlatform, modelName);
public void streamingInvokeLLM(WfState wfState, WfNodeState state, WorkflowNode node, String category,
String modelName, List<UserMessage> systemMessage) {
log.info("stream invoke, category: {}, modelName: {}", category, modelName);
// 根据 modelPlatform 获取对应的 ChatService不使用计费代理工作流场景单独计费
IChatService chatService = chatServiceFactory.getOriginalService(modelPlatform);
// 根据 category 获取对应的 ChatService不使用计费代理工作流场景单独计费
IChatService chatService = chatServiceFactory.getOriginalService(category);
StreamingChatGenerator<AgentState> streamingGenerator = StreamingChatGenerator.builder()
.mapResult(response -> {
@@ -85,19 +90,73 @@ public class WorkflowUtil {
.build();
// 构建 ruoyi-ai 的 ChatRequest
List<Message> messages = new ArrayList<>();
addUserMessage(node, state.getInputs(), messages);
addSystemMessage(systemMessage, messages);
ChatRequest chatRequest = new ChatRequest();
chatRequest.setModel(modelName);
List<Message> messages = new ArrayList<>();
for (UserMessage userMsg : msgs) {
Message message = new Message();
message.setContent(userMsg.singleText());
message.setRole("user");
messages.add(message);
}
chatRequest.setMessages(messages);
// 使用工作流专用方法
chatService.chat(chatRequest, streamingGenerator.handler());
wfState.getNodeToStreamingGenerator().put(node.getUuid(), streamingGenerator);
}
/**
* 添加用户信息
*
* @param node
* @param messages
*/
private void addUserMessage(WorkflowNode node, List<NodeIOData> userMessage, List<Message> messages) {
if (CollUtil.isEmpty(userMessage)) {
return;
}
WfNodeInputConfig nodeInputConfig = NodeInputConfigTypeHandler.fillNodeInputConfig(node.getInputConfig());
List<WfNodeParamRef> refInputs = nodeInputConfig.getRefInputs();
Set<String> nameSet = CollStreamUtil.toSet(refInputs, WfNodeParamRef::getName);
userMessage.stream().filter(item -> nameSet.contains(item.getName()))
.map(item -> getMessage("role", item.getContent().getValue())).forEach(messages::add);
if (CollUtil.isNotEmpty(messages)) {
return;
}
userMessage.stream().filter(item -> "input".equals(item.getName()))
.map(item -> getMessage("role", item.getContent().getValue())).forEach(messages::add);
}
/**
* 组装message对象
*
* @param role
* @param value
* @return
*/
private Message getMessage(String role, Object value) {
Message message = new Message();
message.setContent(String.valueOf(value));
message.setRole(role);
return message;
}
/**
* 添加系统信息
*
* @param systemMessage
* @param messages
*/
private void addSystemMessage(List<UserMessage> systemMessage, List<Message> messages) {
if (CollUtil.isEmpty(systemMessage)) {
return;
}
systemMessage.stream().map(userMsg -> getMessage("system", userMsg.singleText())).forEach(messages::add);
}
}

View File

@@ -18,5 +18,6 @@ public class WfNodeParamRef implements Serializable {
private String nodeUuid;
@JsonProperty("node_param_name")
private String nodeParamName;
private String name;
}

View File

@@ -1,5 +1,6 @@
package org.ruoyi.workflow.workflow.node;
import cn.hutool.core.collection.CollUtil;
import com.fasterxml.jackson.databind.node.ObjectNode;
import jakarta.validation.ConstraintViolation;
import lombok.Data;

View File

@@ -45,8 +45,8 @@ public class LLMAnswerNode extends AbstractWfNode {
WorkflowUtil workflowUtil = SpringUtil.getBean(WorkflowUtil.class);
String modelName = nodeConfigObj.getModelName();
String category = nodeConfigObj.getCategory();
List<UserMessage> userMessageList = List.of(UserMessage.from(prompt));
workflowUtil.streamingInvokeLLM(wfState, state, node, category, modelName, userMessageList);
List<UserMessage> systemMessage = List.of(UserMessage.from(prompt));
workflowUtil.streamingInvokeLLM(wfState, state, node, category, modelName, systemMessage);
return new NodeProcessResult();
}
}

View File

@@ -1,9 +1,16 @@
package org.ruoyi.chat.service.chat;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import org.ruoyi.common.chat.request.ChatRequest;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.ArrayList;
import java.util.List;
/**
* 对话Service接口
*
@@ -23,12 +30,27 @@ public interface IChatService {
* 工作流场景:支持 langchain4j 的 StreamingChatResponseHandler
*
* @param chatRequest ruoyi-ai 的请求对象
* @param handler langchain4j 的流式响应处理器
* @param handler langchain4j 的流式响应处理器
*/
default void chat(ChatRequest chatRequest, StreamingChatResponseHandler handler) {
throw new UnsupportedOperationException("此服务暂不支持工作流场景");
}
default dev.langchain4j.model.chat.request.ChatRequest convertToLangchainRequest(ChatRequest request) {
List<ChatMessage> messages = new ArrayList<>();
for (org.ruoyi.common.chat.entity.chat.Message msg : request.getMessages()) {
// 简单转换,您可以根据实际需求调整
if ("user".equals(msg.getRole())) {
messages.add(UserMessage.from(msg.getContent().toString()));
} else if ("system".equals(msg.getRole())) {
messages.add(SystemMessage.from(msg.getContent().toString()));
} else if ("assistant".equals(msg.getRole())) {
messages.add(AiMessage.from(msg.getContent().toString()));
}
}
return dev.langchain4j.model.chat.request.ChatRequest.builder().messages(messages).build();
}
/**
* 获取此服务支持的模型类别
*/

View File

@@ -92,37 +92,19 @@ public class DeepSeekChatImpl implements IChatService {
.modelName(chatModelVo.getModelName())
.logRequests(true)
.logResponses(true)
.temperature(0.8)
.temperature(0.7)
.build();
try {
// 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式
dev.langchain4j.model.chat.request.ChatRequest langchainRequest = convertToLangchainRequest(request);
chatModel.chat(langchainRequest, handler);
dev.langchain4j.model.chat.request.ChatRequest chatRequest = convertToLangchainRequest(request);
chatModel.chat(chatRequest, handler);
} catch (Exception e) {
log.error("workflow deepseek请求失败{}", e.getMessage(), e);
throw new RuntimeException("DeepSeek workflow chat failed: " + e.getMessage(), e);
}
}
/**
* 转换请求格式
*/
private dev.langchain4j.model.chat.request.ChatRequest convertToLangchainRequest(ChatRequest request) {
List<ChatMessage> messages = new ArrayList<>();
for (org.ruoyi.common.chat.entity.chat.Message msg : request.getMessages()) {
// 简单转换,您可以根据实际需求调整
if ("user".equals(msg.getRole())) {
messages.add(UserMessage.from(msg.getContent().toString()));
} else if ("system".equals(msg.getRole())) {
messages.add(SystemMessage.from(msg.getContent().toString()));
} else if ("assistant".equals(msg.getRole())) {
messages.add(AiMessage.from(msg.getContent().toString()));
}
}
return dev.langchain4j.model.chat.request.ChatRequest.builder().messages(messages).build();
}
@Override
public String getCategory() {
return ChatModeType.DEEPSEEK.getCode();

View File

@@ -1,10 +1,6 @@
package org.ruoyi.chat.service.chat.impl;
import dev.langchain4j.community.model.dashscope.QwenStreamingChatModel;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
@@ -20,9 +16,6 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.ArrayList;
import java.util.List;
/**
* 阿里通义千问
@@ -92,33 +85,14 @@ public class QianWenAiChatServiceImpl implements IChatService {
try {
// 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式
dev.langchain4j.model.chat.request.ChatRequest langchainRequest = convertToLangchainRequest(request);
model.chat(langchainRequest, handler);
dev.langchain4j.model.chat.request.ChatRequest chatRequest = convertToLangchainRequest(request);
model.chat(chatRequest, handler);
} catch (Exception e) {
log.error("workflow 千问请求失败:{}", e.getMessage(), e);
throw new RuntimeException("QianWen workflow chat failed: " + e.getMessage(), e);
}
}
/**
* 转换请求格式
*/
private dev.langchain4j.model.chat.request.ChatRequest convertToLangchainRequest(ChatRequest request) {
List<ChatMessage> messages = new ArrayList<>();
for (org.ruoyi.common.chat.entity.chat.Message msg : request.getMessages()) {
if ("user".equals(msg.getRole())) {
messages.add(UserMessage.from(msg.getContent().toString()));
} else if ("system".equals(msg.getRole())) {
messages.add(SystemMessage.from(msg.getContent().toString()));
} else if ("assistant".equals(msg.getRole())) {
messages.add(AiMessage.from(msg.getContent().toString()));
}
}
return dev.langchain4j.model.chat.request.ChatRequest.builder()
.messages(messages)
.build();
}
@Override
public String getCategory() {
return ChatModeType.QIANWEN.getCode();