mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-03-13 20:53:42 +08:00
feat: 功能优化
This commit is contained in:
@@ -21,7 +21,8 @@ import java.util.concurrent.TimeUnit;
|
|||||||
@Service
|
@Service
|
||||||
public class SSEEmitterHelper {
|
public class SSEEmitterHelper {
|
||||||
|
|
||||||
private static final Cache<SseEmitter, Boolean> COMPLETED_SSE = CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.MINUTES).build();
|
private static final Cache<SseEmitter, Boolean> COMPLETED_SSE = CacheBuilder.newBuilder()
|
||||||
|
.expireAfterWrite(10, TimeUnit.MINUTES).build();
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private StringRedisTemplate stringRedisTemplate;
|
private StringRedisTemplate stringRedisTemplate;
|
||||||
@@ -41,8 +42,6 @@ public class SSEEmitterHelper {
|
|||||||
} else {
|
} else {
|
||||||
sendPartial(sseEmitter, name, " " + content);
|
sendPartial(sseEmitter, name, " " + content);
|
||||||
}
|
}
|
||||||
// content = content.replaceAll("[\\r\\n]", "\ndata:");
|
|
||||||
// sendPartial(sseEmitter, name, " " + content);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void sendPartial(SseEmitter sseEmitter, String name, String msg) {
|
public static void sendPartial(SseEmitter sseEmitter, String name, String msg) {
|
||||||
@@ -63,13 +62,6 @@ public class SSEEmitterHelper {
|
|||||||
|
|
||||||
|
|
||||||
public boolean checkOrComplete(User user, SseEmitter sseEmitter) {
|
public boolean checkOrComplete(User user, SseEmitter sseEmitter) {
|
||||||
//Check: rate limit
|
|
||||||
String requestTimesKey = MessageFormat.format(RedisKeyConstant.USER_REQUEST_TEXT_TIMES, user.getId());
|
|
||||||
// if (!rateLimitHelper.checkRequestTimes(requestTimesKey, LocalCache.TEXT_RATE_LIMIT_CONFIG)) {
|
|
||||||
// sendErrorAndComplete(user.getId(), sseEmitter, "访问太过频繁");
|
|
||||||
// return false;
|
|
||||||
// }
|
|
||||||
|
|
||||||
//Check: If still waiting response
|
//Check: If still waiting response
|
||||||
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
String askingKey = MessageFormat.format(RedisKeyConstant.USER_ASKING, user.getId());
|
||||||
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
|
String askingVal = stringRedisTemplate.opsForValue().get(askingKey);
|
||||||
@@ -125,7 +117,10 @@ public class SSEEmitterHelper {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
sseEmitter.send(SseEmitter.event().name(AdiConstant.SSEEventName.ERROR).data(Objects.toString(errorMsg, "")));
|
SseEmitter.SseEventBuilder event = SseEmitter.event();
|
||||||
|
event.name(AdiConstant.SSEEventName.ERROR);
|
||||||
|
event.data(Objects.toString(errorMsg, ""));
|
||||||
|
sseEmitter.send(event);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
log.warn("sendErrorAndComplete userId:{},errorMsg:{}", userId, errorMsg);
|
log.warn("sendErrorAndComplete userId:{},errorMsg:{}", userId, errorMsg);
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
|||||||
@@ -35,10 +35,6 @@ public class JsonUtil {
|
|||||||
objectMapper.registerModules(LocalDateTimeUtil.getSimpleModule(), new JavaTimeModule(), new Jdk8Module());
|
objectMapper.registerModules(LocalDateTimeUtil.getSimpleModule(), new JavaTimeModule(), new Jdk8Module());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static final ObjectMapper getObjectMapper() {
|
|
||||||
return objectMapper;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String toJson(Object obj) {
|
public static String toJson(Object obj) {
|
||||||
String resp = null;
|
String resp = null;
|
||||||
try {
|
try {
|
||||||
@@ -67,20 +63,6 @@ public class JsonUtil {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 创建JSON生成器的静态方法, 使用标准输出
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
private static JsonGenerator getGenerator(StringWriter sw) {
|
|
||||||
try {
|
|
||||||
return objectMapper.getFactory().createGenerator(sw);
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("JsonUtil getGenerator error", e);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* JSON对象反序列化
|
* JSON对象反序列化
|
||||||
*/
|
*/
|
||||||
@@ -152,15 +134,6 @@ public class JsonUtil {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <T> List<T> toList(String json, Class<T> clazz) {
|
|
||||||
try {
|
|
||||||
return objectMapper.readValue(json, objectMapper.getTypeFactory().constructCollectionType(List.class, clazz));
|
|
||||||
} catch (JsonProcessingException e) {
|
|
||||||
log.error("反序列化失败", e);
|
|
||||||
}
|
|
||||||
return new ArrayList<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Map<String, Object> toMap(Object obj) {
|
public static Map<String, Object> toMap(Object obj) {
|
||||||
try {
|
try {
|
||||||
return objectMapper.convertValue(obj, new TypeReference<HashMap<String, Object>>() {
|
return objectMapper.convertValue(obj, new TypeReference<HashMap<String, Object>>() {
|
||||||
@@ -178,8 +151,4 @@ public class JsonUtil {
|
|||||||
return objectMapper.createObjectNode();
|
return objectMapper.createObjectNode();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ArrayNode createArrayNode() {
|
|
||||||
return objectMapper.createArrayNode();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package org.ruoyi.workflow.workflow;
|
package org.ruoyi.workflow.workflow;
|
||||||
|
|
||||||
import cn.hutool.core.collection.CollUtil;
|
import cn.hutool.core.collection.CollUtil;
|
||||||
|
import cn.hutool.core.util.ObjectUtil;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
@@ -97,7 +98,6 @@ public class WfNodeIODataUtil {
|
|||||||
if (outputExist) {
|
if (outputExist) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (null != defaultInputName) {
|
if (null != defaultInputName) {
|
||||||
defaultInputName.setName(DEFAULT_OUTPUT_PARAM_NAME);
|
defaultInputName.setName(DEFAULT_OUTPUT_PARAM_NAME);
|
||||||
} else if (null != txtExist) {
|
} else if (null != txtExist) {
|
||||||
@@ -105,7 +105,7 @@ public class WfNodeIODataUtil {
|
|||||||
} else if (null != first) {
|
} else if (null != first) {
|
||||||
first.setName(DEFAULT_OUTPUT_PARAM_NAME);
|
first.setName(DEFAULT_OUTPUT_PARAM_NAME);
|
||||||
}
|
}
|
||||||
|
result.add(inputs.get(0));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package org.ruoyi.workflow.workflow;
|
package org.ruoyi.workflow.workflow;
|
||||||
|
|
||||||
import cn.hutool.core.collection.CollStreamUtil;
|
import cn.hutool.core.collection.CollStreamUtil;
|
||||||
|
import cn.hutool.core.collection.CollUtil;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -25,6 +26,7 @@ import org.ruoyi.workflow.service.WorkflowRuntimeService;
|
|||||||
import org.ruoyi.workflow.util.JsonUtil;
|
import org.ruoyi.workflow.util.JsonUtil;
|
||||||
import org.ruoyi.workflow.workflow.data.NodeIOData;
|
import org.ruoyi.workflow.workflow.data.NodeIOData;
|
||||||
import org.ruoyi.workflow.workflow.def.WfNodeIO;
|
import org.ruoyi.workflow.workflow.def.WfNodeIO;
|
||||||
|
import org.ruoyi.workflow.workflow.def.WfNodeParamRef;
|
||||||
import org.ruoyi.workflow.workflow.node.AbstractWfNode;
|
import org.ruoyi.workflow.workflow.node.AbstractWfNode;
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
@@ -187,12 +189,31 @@ public class WorkflowEngine {
|
|||||||
|
|
||||||
NodeProcessResult processResult = abstractWfNode.process((is) -> {
|
NodeProcessResult processResult = abstractWfNode.process((is) -> {
|
||||||
workflowRuntimeNodeService.updateInput(runtimeNodeDto.getId(), nodeState);
|
workflowRuntimeNodeService.updateInput(runtimeNodeDto.getId(), nodeState);
|
||||||
for (NodeIOData input : nodeState.getInputs()) {
|
List<NodeIOData> nodeIODataList = nodeState.getInputs();
|
||||||
|
// if (!wfNode.getWorkflowComponentId().equals(1L)) {
|
||||||
|
// String inputConfig = wfNode.getInputConfig();
|
||||||
|
// WfNodeInputConfig nodeInputConfig = NodeInputConfigTypeHandler.fillNodeInputConfig(inputConfig);
|
||||||
|
// List<WfNodeParamRef> refInputs = nodeInputConfig.getRefInputs();
|
||||||
|
// Set<String> nameSet = CollStreamUtil.toSet(refInputs, WfNodeParamRef::getNodeParamName);
|
||||||
|
// if (CollUtil.isNotEmpty(nameSet)) {
|
||||||
|
// nodeIODataList = nodeIODataList.stream().filter(item -> nameSet.contains(item.getName()))
|
||||||
|
// .collect(Collectors.toList());
|
||||||
|
// } else {
|
||||||
|
// nodeIODataList = nodeIODataList.stream().filter(item -> item.getName().contains("input"))
|
||||||
|
// .collect(Collectors.toList());
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
for (NodeIOData input : nodeIODataList) {
|
||||||
|
String inputConfig = wfNode.getInputConfig();
|
||||||
|
WfNodeInputConfig nodeInputConfig = NodeInputConfigTypeHandler.fillNodeInputConfig(inputConfig);
|
||||||
|
List<WfNodeParamRef> refInputs = nodeInputConfig.getRefInputs();
|
||||||
|
if (CollUtil.isNotEmpty(refInputs) && "input".equals(input.getName())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
SSEEmitterHelper.parseAndSendPartialMsg(sseEmitter, "[NODE_INPUT_" + wfNode.getUuid() + "]", JsonUtil.toJson(input));
|
SSEEmitterHelper.parseAndSendPartialMsg(sseEmitter, "[NODE_INPUT_" + wfNode.getUuid() + "]", JsonUtil.toJson(input));
|
||||||
}
|
}
|
||||||
}, (is) -> {
|
}, (is) -> {
|
||||||
workflowRuntimeNodeService.updateOutput(runtimeNodeDto.getId(), nodeState);
|
workflowRuntimeNodeService.updateOutput(runtimeNodeDto.getId(), nodeState);
|
||||||
|
|
||||||
//并行节点内部的节点执行结束后,需要主动向客户端发送输出结果
|
//并行节点内部的节点执行结束后,需要主动向客户端发送输出结果
|
||||||
String nodeUuid = wfNode.getUuid();
|
String nodeUuid = wfNode.getUuid();
|
||||||
List<NodeIOData> nodeOutputs = nodeState.getOutputs();
|
List<NodeIOData> nodeOutputs = nodeState.getOutputs();
|
||||||
@@ -229,7 +250,7 @@ public class WorkflowEngine {
|
|||||||
if (out instanceof StreamingOutput<WfNodeState> streamingOutput) {
|
if (out instanceof StreamingOutput<WfNodeState> streamingOutput) {
|
||||||
String node = streamingOutput.node();
|
String node = streamingOutput.node();
|
||||||
String chunk = streamingOutput.chunk();
|
String chunk = streamingOutput.chunk();
|
||||||
log.info("node:{},chunk:{}", node, streamingOutput.chunk());
|
log.info("node:{},chunk:{}", node, chunk);
|
||||||
SSEEmitterHelper.parseAndSendPartialMsg(sseEmitter, "[NODE_CHUNK_" + node + "]", chunk);
|
SSEEmitterHelper.parseAndSendPartialMsg(sseEmitter, "[NODE_CHUNK_" + node + "]", chunk);
|
||||||
} else {
|
} else {
|
||||||
AbstractWfNode abstractWfNode = wfState.getCompletedNodes().stream()
|
AbstractWfNode abstractWfNode = wfState.getCompletedNodes().stream()
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package org.ruoyi.workflow.workflow;
|
package org.ruoyi.workflow.workflow;
|
||||||
|
|
||||||
|
import cn.hutool.core.collection.CollStreamUtil;
|
||||||
|
import cn.hutool.core.collection.CollUtil;
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import dev.langchain4j.data.message.UserMessage;
|
import dev.langchain4j.data.message.UserMessage;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
@@ -10,16 +12,19 @@ import org.ruoyi.chat.factory.ChatServiceFactory;
|
|||||||
import org.ruoyi.chat.service.chat.IChatService;
|
import org.ruoyi.chat.service.chat.IChatService;
|
||||||
import org.ruoyi.common.chat.entity.chat.Message;
|
import org.ruoyi.common.chat.entity.chat.Message;
|
||||||
import org.ruoyi.common.chat.request.ChatRequest;
|
import org.ruoyi.common.chat.request.ChatRequest;
|
||||||
|
import org.ruoyi.workflow.base.NodeInputConfigTypeHandler;
|
||||||
import org.ruoyi.workflow.entity.WorkflowNode;
|
import org.ruoyi.workflow.entity.WorkflowNode;
|
||||||
import org.ruoyi.workflow.enums.WfIODataTypeEnum;
|
import org.ruoyi.workflow.enums.WfIODataTypeEnum;
|
||||||
import org.ruoyi.workflow.util.JsonUtil;
|
import org.ruoyi.workflow.util.JsonUtil;
|
||||||
import org.ruoyi.workflow.workflow.data.NodeIOData;
|
import org.ruoyi.workflow.workflow.data.NodeIOData;
|
||||||
import org.ruoyi.workflow.workflow.data.NodeIODataContent;
|
import org.ruoyi.workflow.workflow.data.NodeIODataContent;
|
||||||
|
import org.ruoyi.workflow.workflow.def.WfNodeParamRef;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
import static org.ruoyi.workflow.cosntant.AdiConstant.WorkflowConstant.DEFAULT_OUTPUT_PARAM_NAME;
|
import static org.ruoyi.workflow.cosntant.AdiConstant.WorkflowConstant.DEFAULT_OUTPUT_PARAM_NAME;
|
||||||
|
|
||||||
@@ -65,12 +70,12 @@ public class WorkflowUtil {
|
|||||||
return String.valueOf(tip);
|
return String.valueOf(tip);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void streamingInvokeLLM(WfState wfState, WfNodeState state, WorkflowNode node, String modelPlatform,
|
public void streamingInvokeLLM(WfState wfState, WfNodeState state, WorkflowNode node, String category,
|
||||||
String modelName, List<UserMessage> msgs) {
|
String modelName, List<UserMessage> systemMessage) {
|
||||||
log.info("stream invoke, modelPlatform: {}, modelName: {}", modelPlatform, modelName);
|
log.info("stream invoke, category: {}, modelName: {}", category, modelName);
|
||||||
|
|
||||||
// 根据 modelPlatform 获取对应的 ChatService(不使用计费代理,工作流场景单独计费)
|
// 根据 category 获取对应的 ChatService(不使用计费代理,工作流场景单独计费)
|
||||||
IChatService chatService = chatServiceFactory.getOriginalService(modelPlatform);
|
IChatService chatService = chatServiceFactory.getOriginalService(category);
|
||||||
|
|
||||||
StreamingChatGenerator<AgentState> streamingGenerator = StreamingChatGenerator.builder()
|
StreamingChatGenerator<AgentState> streamingGenerator = StreamingChatGenerator.builder()
|
||||||
.mapResult(response -> {
|
.mapResult(response -> {
|
||||||
@@ -85,19 +90,73 @@ public class WorkflowUtil {
|
|||||||
.build();
|
.build();
|
||||||
|
|
||||||
// 构建 ruoyi-ai 的 ChatRequest
|
// 构建 ruoyi-ai 的 ChatRequest
|
||||||
|
List<Message> messages = new ArrayList<>();
|
||||||
|
|
||||||
|
addUserMessage(node, state.getInputs(), messages);
|
||||||
|
|
||||||
|
addSystemMessage(systemMessage, messages);
|
||||||
|
|
||||||
ChatRequest chatRequest = new ChatRequest();
|
ChatRequest chatRequest = new ChatRequest();
|
||||||
chatRequest.setModel(modelName);
|
chatRequest.setModel(modelName);
|
||||||
|
|
||||||
List<Message> messages = new ArrayList<>();
|
|
||||||
for (UserMessage userMsg : msgs) {
|
|
||||||
Message message = new Message();
|
|
||||||
message.setContent(userMsg.singleText());
|
|
||||||
message.setRole("user");
|
|
||||||
messages.add(message);
|
|
||||||
}
|
|
||||||
chatRequest.setMessages(messages);
|
chatRequest.setMessages(messages);
|
||||||
|
|
||||||
// 使用工作流专用方法
|
// 使用工作流专用方法
|
||||||
chatService.chat(chatRequest, streamingGenerator.handler());
|
chatService.chat(chatRequest, streamingGenerator.handler());
|
||||||
wfState.getNodeToStreamingGenerator().put(node.getUuid(), streamingGenerator);
|
wfState.getNodeToStreamingGenerator().put(node.getUuid(), streamingGenerator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 添加用户信息
|
||||||
|
*
|
||||||
|
* @param node
|
||||||
|
* @param messages
|
||||||
|
*/
|
||||||
|
private void addUserMessage(WorkflowNode node, List<NodeIOData> userMessage, List<Message> messages) {
|
||||||
|
if (CollUtil.isEmpty(userMessage)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
WfNodeInputConfig nodeInputConfig = NodeInputConfigTypeHandler.fillNodeInputConfig(node.getInputConfig());
|
||||||
|
|
||||||
|
List<WfNodeParamRef> refInputs = nodeInputConfig.getRefInputs();
|
||||||
|
|
||||||
|
Set<String> nameSet = CollStreamUtil.toSet(refInputs, WfNodeParamRef::getName);
|
||||||
|
|
||||||
|
userMessage.stream().filter(item -> nameSet.contains(item.getName()))
|
||||||
|
.map(item -> getMessage("role", item.getContent().getValue())).forEach(messages::add);
|
||||||
|
|
||||||
|
if (CollUtil.isNotEmpty(messages)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
userMessage.stream().filter(item -> "input".equals(item.getName()))
|
||||||
|
.map(item -> getMessage("role", item.getContent().getValue())).forEach(messages::add);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 组装message对象
|
||||||
|
*
|
||||||
|
* @param role
|
||||||
|
* @param value
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private Message getMessage(String role, Object value) {
|
||||||
|
Message message = new Message();
|
||||||
|
message.setContent(String.valueOf(value));
|
||||||
|
message.setRole(role);
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 添加系统信息
|
||||||
|
*
|
||||||
|
* @param systemMessage
|
||||||
|
* @param messages
|
||||||
|
*/
|
||||||
|
private void addSystemMessage(List<UserMessage> systemMessage, List<Message> messages) {
|
||||||
|
if (CollUtil.isEmpty(systemMessage)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
systemMessage.stream().map(userMsg -> getMessage("system", userMsg.singleText())).forEach(messages::add);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,5 +18,6 @@ public class WfNodeParamRef implements Serializable {
|
|||||||
private String nodeUuid;
|
private String nodeUuid;
|
||||||
@JsonProperty("node_param_name")
|
@JsonProperty("node_param_name")
|
||||||
private String nodeParamName;
|
private String nodeParamName;
|
||||||
|
|
||||||
private String name;
|
private String name;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package org.ruoyi.workflow.workflow.node;
|
package org.ruoyi.workflow.workflow.node;
|
||||||
|
|
||||||
|
import cn.hutool.core.collection.CollUtil;
|
||||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
import jakarta.validation.ConstraintViolation;
|
import jakarta.validation.ConstraintViolation;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|||||||
@@ -45,8 +45,8 @@ public class LLMAnswerNode extends AbstractWfNode {
|
|||||||
WorkflowUtil workflowUtil = SpringUtil.getBean(WorkflowUtil.class);
|
WorkflowUtil workflowUtil = SpringUtil.getBean(WorkflowUtil.class);
|
||||||
String modelName = nodeConfigObj.getModelName();
|
String modelName = nodeConfigObj.getModelName();
|
||||||
String category = nodeConfigObj.getCategory();
|
String category = nodeConfigObj.getCategory();
|
||||||
List<UserMessage> userMessageList = List.of(UserMessage.from(prompt));
|
List<UserMessage> systemMessage = List.of(UserMessage.from(prompt));
|
||||||
workflowUtil.streamingInvokeLLM(wfState, state, node, category, modelName, userMessageList);
|
workflowUtil.streamingInvokeLLM(wfState, state, node, category, modelName, systemMessage);
|
||||||
return new NodeProcessResult();
|
return new NodeProcessResult();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,16 @@
|
|||||||
package org.ruoyi.chat.service.chat;
|
package org.ruoyi.chat.service.chat;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.data.message.ChatMessage;
|
||||||
|
import dev.langchain4j.data.message.SystemMessage;
|
||||||
|
import dev.langchain4j.data.message.UserMessage;
|
||||||
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
||||||
import org.ruoyi.common.chat.request.ChatRequest;
|
import org.ruoyi.common.chat.request.ChatRequest;
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 对话Service接口
|
* 对话Service接口
|
||||||
*
|
*
|
||||||
@@ -23,12 +30,27 @@ public interface IChatService {
|
|||||||
* 工作流场景:支持 langchain4j 的 StreamingChatResponseHandler
|
* 工作流场景:支持 langchain4j 的 StreamingChatResponseHandler
|
||||||
*
|
*
|
||||||
* @param chatRequest ruoyi-ai 的请求对象
|
* @param chatRequest ruoyi-ai 的请求对象
|
||||||
* @param handler langchain4j 的流式响应处理器
|
* @param handler langchain4j 的流式响应处理器
|
||||||
*/
|
*/
|
||||||
default void chat(ChatRequest chatRequest, StreamingChatResponseHandler handler) {
|
default void chat(ChatRequest chatRequest, StreamingChatResponseHandler handler) {
|
||||||
throw new UnsupportedOperationException("此服务暂不支持工作流场景");
|
throw new UnsupportedOperationException("此服务暂不支持工作流场景");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
default dev.langchain4j.model.chat.request.ChatRequest convertToLangchainRequest(ChatRequest request) {
|
||||||
|
List<ChatMessage> messages = new ArrayList<>();
|
||||||
|
for (org.ruoyi.common.chat.entity.chat.Message msg : request.getMessages()) {
|
||||||
|
// 简单转换,您可以根据实际需求调整
|
||||||
|
if ("user".equals(msg.getRole())) {
|
||||||
|
messages.add(UserMessage.from(msg.getContent().toString()));
|
||||||
|
} else if ("system".equals(msg.getRole())) {
|
||||||
|
messages.add(SystemMessage.from(msg.getContent().toString()));
|
||||||
|
} else if ("assistant".equals(msg.getRole())) {
|
||||||
|
messages.add(AiMessage.from(msg.getContent().toString()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dev.langchain4j.model.chat.request.ChatRequest.builder().messages(messages).build();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取此服务支持的模型类别
|
* 获取此服务支持的模型类别
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -92,37 +92,19 @@ public class DeepSeekChatImpl implements IChatService {
|
|||||||
.modelName(chatModelVo.getModelName())
|
.modelName(chatModelVo.getModelName())
|
||||||
.logRequests(true)
|
.logRequests(true)
|
||||||
.logResponses(true)
|
.logResponses(true)
|
||||||
.temperature(0.8)
|
.temperature(0.7)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式
|
// 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式
|
||||||
dev.langchain4j.model.chat.request.ChatRequest langchainRequest = convertToLangchainRequest(request);
|
dev.langchain4j.model.chat.request.ChatRequest chatRequest = convertToLangchainRequest(request);
|
||||||
chatModel.chat(langchainRequest, handler);
|
chatModel.chat(chatRequest, handler);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("workflow deepseek请求失败:{}", e.getMessage(), e);
|
log.error("workflow deepseek请求失败:{}", e.getMessage(), e);
|
||||||
throw new RuntimeException("DeepSeek workflow chat failed: " + e.getMessage(), e);
|
throw new RuntimeException("DeepSeek workflow chat failed: " + e.getMessage(), e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 转换请求格式
|
|
||||||
*/
|
|
||||||
private dev.langchain4j.model.chat.request.ChatRequest convertToLangchainRequest(ChatRequest request) {
|
|
||||||
List<ChatMessage> messages = new ArrayList<>();
|
|
||||||
for (org.ruoyi.common.chat.entity.chat.Message msg : request.getMessages()) {
|
|
||||||
// 简单转换,您可以根据实际需求调整
|
|
||||||
if ("user".equals(msg.getRole())) {
|
|
||||||
messages.add(UserMessage.from(msg.getContent().toString()));
|
|
||||||
} else if ("system".equals(msg.getRole())) {
|
|
||||||
messages.add(SystemMessage.from(msg.getContent().toString()));
|
|
||||||
} else if ("assistant".equals(msg.getRole())) {
|
|
||||||
messages.add(AiMessage.from(msg.getContent().toString()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dev.langchain4j.model.chat.request.ChatRequest.builder().messages(messages).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getCategory() {
|
public String getCategory() {
|
||||||
return ChatModeType.DEEPSEEK.getCode();
|
return ChatModeType.DEEPSEEK.getCode();
|
||||||
|
|||||||
@@ -1,10 +1,6 @@
|
|||||||
package org.ruoyi.chat.service.chat.impl;
|
package org.ruoyi.chat.service.chat.impl;
|
||||||
|
|
||||||
import dev.langchain4j.community.model.dashscope.QwenStreamingChatModel;
|
import dev.langchain4j.community.model.dashscope.QwenStreamingChatModel;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
|
||||||
import dev.langchain4j.data.message.ChatMessage;
|
|
||||||
import dev.langchain4j.data.message.SystemMessage;
|
|
||||||
import dev.langchain4j.data.message.UserMessage;
|
|
||||||
import dev.langchain4j.model.chat.StreamingChatModel;
|
import dev.langchain4j.model.chat.StreamingChatModel;
|
||||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||||
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
||||||
@@ -20,9 +16,6 @@ import org.springframework.beans.factory.annotation.Autowired;
|
|||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 阿里通义千问
|
* 阿里通义千问
|
||||||
@@ -92,33 +85,14 @@ public class QianWenAiChatServiceImpl implements IChatService {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式
|
// 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式
|
||||||
dev.langchain4j.model.chat.request.ChatRequest langchainRequest = convertToLangchainRequest(request);
|
dev.langchain4j.model.chat.request.ChatRequest chatRequest = convertToLangchainRequest(request);
|
||||||
model.chat(langchainRequest, handler);
|
model.chat(chatRequest, handler);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("workflow 千问请求失败:{}", e.getMessage(), e);
|
log.error("workflow 千问请求失败:{}", e.getMessage(), e);
|
||||||
throw new RuntimeException("QianWen workflow chat failed: " + e.getMessage(), e);
|
throw new RuntimeException("QianWen workflow chat failed: " + e.getMessage(), e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 转换请求格式
|
|
||||||
*/
|
|
||||||
private dev.langchain4j.model.chat.request.ChatRequest convertToLangchainRequest(ChatRequest request) {
|
|
||||||
List<ChatMessage> messages = new ArrayList<>();
|
|
||||||
for (org.ruoyi.common.chat.entity.chat.Message msg : request.getMessages()) {
|
|
||||||
if ("user".equals(msg.getRole())) {
|
|
||||||
messages.add(UserMessage.from(msg.getContent().toString()));
|
|
||||||
} else if ("system".equals(msg.getRole())) {
|
|
||||||
messages.add(SystemMessage.from(msg.getContent().toString()));
|
|
||||||
} else if ("assistant".equals(msg.getRole())) {
|
|
||||||
messages.add(AiMessage.from(msg.getContent().toString()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dev.langchain4j.model.chat.request.ChatRequest.builder()
|
|
||||||
.messages(messages)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getCategory() {
|
public String getCategory() {
|
||||||
return ChatModeType.QIANWEN.getCode();
|
return ChatModeType.QIANWEN.getCode();
|
||||||
|
|||||||
Reference in New Issue
Block a user