diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java index 58e5bd87..5240d0a9 100644 --- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java @@ -31,22 +31,23 @@ import java.util.stream.IntStream; @Component public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { - - private final Integer DIMENSION = 2048; - // 缓存不同集合与 autoFlush 配置的 Milvus 连接 - private final Map> storeCache = new ConcurrentHashMap<>(); - public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) { super(vectorStoreProperties, embeddingModelFactory); } - private EmbeddingStore getMilvusStore(String collectionName, boolean autoFlushOnInsert) { - String key = collectionName + "|" + autoFlushOnInsert; + // 缓存不同集合与 autoFlush 配置的 Milvus 连接 + private final Map> storeCache = new ConcurrentHashMap<>(); + + /** + * 获取 Milvus Store,支持动态维度 + */ + private EmbeddingStore getMilvusStore(String collectionName, int dimension, boolean autoFlushOnInsert) { + String key = collectionName + "|" + dimension + "|" + autoFlushOnInsert; return storeCache.computeIfAbsent(key, k -> MilvusEmbeddingStore.builder() .uri(vectorStoreProperties.getMilvus().getUrl()) .collectionName(collectionName) - .dimension(DIMENSION) + .dimension(dimension) .indexType(IndexType.IVF_FLAT) .metricType(MetricType.L2) .autoFlushOnInsert(autoFlushOnInsert) @@ -57,18 +58,37 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { .build() ); } + + /** + * 获取 embedding 模型的实际维度 + */ + private int getModelDimension(String modelName) { + try { + EmbeddingModel model = getEmbeddingModel(modelName, null); + // 使用一个测试文本获取向量维度 + Embedding testEmbedding = model.embed("test").content(); + int dimension = testEmbedding.dimension(); + log.info("Detected embedding model dimension: {} for model: {}", dimension, modelName); + return dimension; + } catch (Exception e) { + log.warn("Failed to detect model dimension for: {}, using default 1024", modelName, e); + return 1024; // 默认使用 1024 (bge-m3 的维度) + } + } @Override public void createSchema(String kid, String modelName) { String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid; + int dimension = getModelDimension(modelName); // 使用缓存获取连接以确保只初始化一次 - EmbeddingStore store = getMilvusStore(collectionName, true); - log.info("Milvus集合初始化完成: {}", collectionName); + EmbeddingStore store = getMilvusStore(collectionName, dimension, true); + log.info("Milvus集合初始化完成: {}, dimension: {}", collectionName, dimension); } @Override public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) { - EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), DIMENSION); + int dimension = getModelDimension(storeEmbeddingBo.getEmbeddingModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), dimension); List chunkList = storeEmbeddingBo.getChunkList(); List fidList = storeEmbeddingBo.getFids(); @@ -80,7 +100,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { long startTime = System.currentTimeMillis(); // 复用连接,写入场景使用 autoFlush=false 以提升批量插入性能 - EmbeddingStore embeddingStore = getMilvusStore(collectionName, false); + EmbeddingStore embeddingStore = getMilvusStore(collectionName, dimension, false); IntStream.range(0, chunkList.size()).forEach(i -> { String text = chunkList.get(i); @@ -100,13 +120,14 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { @Override public List getQueryVector(QueryVectorBo queryVectorBo) { - EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), DIMENSION); + int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), dimension); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid(); // 查询复用连接,autoFlush 对查询无影响,此处保持 true - EmbeddingStore embeddingStore = getMilvusStore(collectionName, true); + EmbeddingStore embeddingStore = getMilvusStore(collectionName, dimension, true); List resultList = new ArrayList<>(); EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() @@ -127,14 +148,16 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { @SneakyThrows public void removeById(String id, String modelName) { // 注意:此处原逻辑使用 collectionname + id,保持现状 - EmbeddingStore embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, false); + int dimension = getModelDimension(modelName); + EmbeddingStore embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, dimension, false); embeddingStore.remove(id); } @Override public void removeByDocId(String docId, String kid) { String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid; - EmbeddingStore embeddingStore = getMilvusStore(collectionName, false); + // 使用默认维度,因为删除操作不需要精确的维度信息 + EmbeddingStore embeddingStore = getMilvusStore(collectionName, 1024, false); Filter filter = MetadataFilterBuilder.metadataKey("docId").isEqualTo(docId); embeddingStore.removeAll(filter); log.info("Milvus成功删除 docId={} 的所有向量数据", docId); @@ -143,7 +166,8 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { @Override public void removeByFid(String fid, String kid) { String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid; - EmbeddingStore embeddingStore = getMilvusStore(collectionName, false); + // 使用默认维度,因为删除操作不需要精确的维度信息 + EmbeddingStore embeddingStore = getMilvusStore(collectionName, 1024, false); Filter filter = MetadataFilterBuilder.metadataKey("fid").isEqualTo(fid); embeddingStore.removeAll(filter); log.info("Milvus成功删除 fid={} 的所有向量数据", fid); diff --git a/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/WfNodeFactory.java b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/WfNodeFactory.java index 61c15093..59b135e7 100644 --- a/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/WfNodeFactory.java +++ b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/WfNodeFactory.java @@ -5,7 +5,9 @@ import org.ruoyi.workflow.entity.WorkflowNode; import org.ruoyi.workflow.workflow.node.AbstractWfNode; import org.ruoyi.workflow.workflow.node.EndNode; import org.ruoyi.workflow.workflow.node.answer.LLMAnswerNode; +import org.ruoyi.workflow.workflow.node.httpRequest.HttpRequestNode; import org.ruoyi.workflow.workflow.node.keywordExtractor.KeywordExtractorNode; +import org.ruoyi.workflow.workflow.node.knowledgeRetrieval.KnowledgeRetrievalNode; import org.ruoyi.workflow.workflow.node.mailSend.MailSendNode; import org.ruoyi.workflow.workflow.node.start.StartNode; import org.ruoyi.workflow.workflow.node.switcher.SwitcherNode; @@ -17,10 +19,11 @@ public class WfNodeFactory { switch (WfComponentNameEnum.getByName(wfComponent.getName())) { case START -> wfNode = new StartNode(wfComponent, nodeDefinition, wfState, nodeState); case LLM_ANSWER -> wfNode = new LLMAnswerNode(wfComponent, nodeDefinition, wfState, nodeState); - case KEYWORD_EXTRACTOR -> - wfNode = new KeywordExtractorNode(wfComponent, nodeDefinition, wfState, nodeState); + case KEYWORD_EXTRACTOR -> wfNode = new KeywordExtractorNode(wfComponent, nodeDefinition, wfState, nodeState); + case KNOWLEDGE_RETRIEVER -> wfNode = new KnowledgeRetrievalNode(wfComponent, nodeDefinition, wfState, nodeState); case END -> wfNode = new EndNode(wfComponent, nodeDefinition, wfState, nodeState); case MAIL_SEND -> wfNode = new MailSendNode(wfComponent, nodeDefinition, wfState, nodeState); + case HTTP_REQUEST -> wfNode = new HttpRequestNode(wfComponent, nodeDefinition, wfState, nodeState); case SWITCHER -> wfNode = new SwitcherNode(wfComponent, nodeDefinition, wfState, nodeState); default -> { } diff --git a/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/httpRequest/HttpRequestNode.java b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/httpRequest/HttpRequestNode.java new file mode 100644 index 00000000..0df35b9b --- /dev/null +++ b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/httpRequest/HttpRequestNode.java @@ -0,0 +1,437 @@ +package org.ruoyi.workflow.workflow.node.httpRequest; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.jsoup.Jsoup; +import org.ruoyi.workflow.entity.WorkflowComponent; +import org.ruoyi.workflow.entity.WorkflowNode; +import org.ruoyi.workflow.workflow.NodeProcessResult; +import org.ruoyi.workflow.workflow.WfNodeState; +import org.ruoyi.workflow.workflow.WfState; +import org.ruoyi.workflow.workflow.data.NodeIOData; +import org.ruoyi.workflow.workflow.node.AbstractWfNode; +import org.springframework.http.*; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestTemplate; + +import java.util.*; +import java.util.concurrent.TimeUnit; + +/** + * HTTP 请求节点 + */ +@Slf4j +public class HttpRequestNode extends AbstractWfNode { + + public HttpRequestNode(WorkflowComponent wfComponent, WorkflowNode nodeDef, WfState wfState, WfNodeState nodeState) { + super(wfComponent, nodeDef, wfState, nodeState); + } + + @Override + public NodeProcessResult onProcess() { + try { + HttpRequestNodeConfig config = checkAndGetConfig(HttpRequestNodeConfig.class); + List inputs = state.getInputs(); + + // 渲染 URL(支持变量替换) + String url = renderTemplate(config.getUrl(), inputs); + if (StringUtils.isBlank(url)) { + throw new IllegalArgumentException("请求 URL 不能为空"); + } + + // 添加 Query 参数 + url = buildUrlWithParams(url, config.getParams(), inputs); + + // 构建请求头 + HttpHeaders headers = buildHeaders(config.getHeaders(), inputs); + + // 构建请求体 + Object requestBody = buildRequestBody(config, inputs); + + // 执行 HTTP 请求(支持重试) + String response = executeHttpRequest(url, config.getMethod(), headers, requestBody, config); + + // 清除 HTML 标签(如果需要) + if (Boolean.TRUE.equals(config.getClearHtml()) && StringUtils.isNotBlank(response)) { + response = Jsoup.parse(response).text(); + } + + // 构造输出 + List outputs = new ArrayList<>(); + outputs.add(NodeIOData.createByText("output", "HTTP响应", response)); + + return NodeProcessResult.builder().content(outputs).build(); + + } catch (Exception e) { + log.error("HTTP 请求失败 in node: {}", node.getId(), e); + + // 异常时返回错误信息 + List errorOutputs = new ArrayList<>(); + errorOutputs.add(NodeIOData.createByText("output", "错误", "")); + errorOutputs.add(NodeIOData.createByText("error", "HTTP请求错误", e.getMessage())); + + return NodeProcessResult.builder().content(errorOutputs).build(); + } + } + + /** + * 渲染模板(支持变量替换) + * 支持格式: + * 1. {var_01} - 直接替换整个变量值 + * 2. {var_01.name} - 从 JSON 中提取 name 字段 + * 3. {var_01.user.email} - 支持嵌套路径 + */ + private String renderTemplate(String template, List inputs) { + if (StringUtils.isBlank(template)) { + return ""; + } + return renderTemplateWithJsonPath(template, inputs); + } + + /** + * 增强的模板渲染,支持 JSON 路径提取 + */ + private String renderTemplateWithJsonPath(String template, List inputs) { + String result = template; + ObjectMapper mapper = new ObjectMapper(); + + for (NodeIOData input : inputs) { + if (input == null || input.getName() == null) { + continue; + } + + String varName = input.getName(); + String varValue = input.valueToString(); + + // 1. 处理简单变量替换 {var_01} + result = result.replace("{" + varName + "}", varValue != null ? varValue : ""); + + // 2. 处理 JSON 路径提取 {var_01.field} 或 {var_01.user.name} + // 尝试解析为 JSON + Map jsonMap = tryParseJson(varValue, mapper); + if (jsonMap != null) { + result = replaceJsonPaths(result, varName, jsonMap); + } + } + + return result; + } + + /** + * 尝试将字符串解析为 JSON Map + */ + private Map tryParseJson(String value, ObjectMapper mapper) { + if (StringUtils.isBlank(value)) { + return null; + } + + value = value.trim(); + if (!value.startsWith("{") && !value.startsWith("[")) { + return null; + } + + try { + return mapper.readValue(value, new TypeReference>() {}); + } catch (Exception e) { + log.debug("无法解析为 JSON: {}", value); + return null; + } + } + + /** + * 替换 JSON 路径变量,如 {var_01.name} 或 {var_01.user.email} + */ + private String replaceJsonPaths(String template, String varName, Map jsonMap) { + String result = template; + + // 查找所有 {varName.xxx} 格式的占位符 + String pattern = "\\{" + varName + "\\.([\\.\\w]+)\\}"; + java.util.regex.Pattern p = java.util.regex.Pattern.compile(pattern); + java.util.regex.Matcher m = p.matcher(template); + + while (m.find()) { + String fullMatch = m.group(0); // 如 {var_01.name} + String jsonPath = m.group(1); // 如 name 或 user.email + + Object value = extractJsonValue(jsonMap, jsonPath); + String replacement = value != null ? value.toString() : ""; + + result = result.replace(fullMatch, replacement); + } + + return result; + } + + /** + * 从 JSON Map 中提取嵌套路径的值 + * 例如:path = "user.email" 会提取 map.get("user").get("email") + */ + @SuppressWarnings("unchecked") + private Object extractJsonValue(Map map, String path) { + String[] parts = path.split("\\."); + Object current = map; + + for (String part : parts) { + if (current instanceof Map) { + current = ((Map) current).get(part); + } else { + return null; + } + } + + return current; + } + + /** + * 构建带参数的 URL + */ + private String buildUrlWithParams(String baseUrl, List params, List inputs) { + if (params == null || params.isEmpty()) { + return baseUrl; + } + + StringBuilder urlBuilder = new StringBuilder(baseUrl); + boolean hasQuery = baseUrl.contains("?"); + + for (HttpRequestNodeConfig.ParamItem param : params) { + if (StringUtils.isBlank(param.getName())) { + continue; + } + + String name = renderTemplate(param.getName(), inputs); + String value = renderTemplate(param.getValue(), inputs); + + if (hasQuery) { + urlBuilder.append("&"); + } else { + urlBuilder.append("?"); + hasQuery = true; + } + + urlBuilder.append(name).append("=").append(value); + } + + return urlBuilder.toString(); + } + + /** + * 构建请求头 + */ + private HttpHeaders buildHeaders(List headerItems, List inputs) { + HttpHeaders headers = new HttpHeaders(); + + if (headerItems != null) { + for (HttpRequestNodeConfig.HeaderItem item : headerItems) { + if (StringUtils.isNotBlank(item.getName())) { + String name = renderTemplate(item.getName(), inputs); + String value = renderTemplate(item.getValue(), inputs); + headers.add(name, value); + } + } + } + + return headers; + } + + /** + * 构建请求体 + */ + private Object buildRequestBody(HttpRequestNodeConfig config, List inputs) { + String method = config.getMethod(); + if ("GET".equalsIgnoreCase(method) || "DELETE".equalsIgnoreCase(method) || "HEAD".equalsIgnoreCase(method)) { + return null; + } + + String contentType = config.getContentType(); + + // JSON Body + if ("application/json".equalsIgnoreCase(contentType)) { + if (config.getJsonBody() != null && !config.getJsonBody().isEmpty()) { + return renderJsonBody(config.getJsonBody(), inputs); + } + } + + // Form Data + if ("multipart/form-data".equalsIgnoreCase(contentType)) { + return buildFormData(config.getFormDataBody(), inputs); + } + + // Form URL Encoded + if ("application/x-www-form-urlencoded".equalsIgnoreCase(contentType)) { + return buildFormUrlEncoded(config.getFormUrlencodedBody(), inputs); + } + + // Text Body + if (StringUtils.isNotBlank(config.getTextBody())) { + return renderTemplate(config.getTextBody(), inputs); + } + + return null; + } + + /** + * 渲染 JSON 请求体 + * 支持三种模式: + * 1. 普通字段替换:{"name": "{var_01.name}"} + * 2. 整体 JSON 合并:{"$merge": "{var_01}"} - 将整个 JSON 对象合并进来 + * 3. 智能合并:如果值是 {var_01} 且是有效 JSON,自动展开合并 + */ + private Map renderJsonBody(Map jsonBody, List inputs) { + Map rendered = new HashMap<>(); + ObjectMapper mapper = new ObjectMapper(); + + for (Map.Entry entry : jsonBody.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + + // 处理特殊的 $merge 指令 + if ("$merge".equals(key) && value instanceof String) { + String varRef = (String) value; + Map mergeData = resolveVariableAsJson(varRef, inputs, mapper); + if (mergeData != null) { + rendered.putAll(mergeData); + } + continue; + } + + if (value instanceof String) { + String strValue = (String) value; + + // 检查是否是单纯的变量引用(如 {var_01}) + if (strValue.matches("^\\{\\w+\\}$")) { + // 尝试解析为 JSON 对象 + Map jsonValue = resolveVariableAsJson(strValue, inputs, mapper); + if (jsonValue != null) { + // 如果是 JSON 对象,合并所有字段 + rendered.putAll(jsonValue); + } else { + // 否则作为普通字符串处理 + rendered.put(key, renderTemplate(strValue, inputs)); + } + } else { + // 普通字符串或包含多个变量的模板 + rendered.put(key, renderTemplate(strValue, inputs)); + } + } else if (value instanceof Map) { + // 递归处理嵌套的 Map + @SuppressWarnings("unchecked") + Map nestedMap = (Map) value; + rendered.put(key, renderJsonBody(nestedMap, inputs)); + } else { + // 其他类型直接保留 + rendered.put(key, value); + } + } + return rendered; + } + + /** + * 解析变量引用为 JSON 对象 + * 例如:{var_01} -> 尝试解析 var_01 的值为 JSON Map + */ + private Map resolveVariableAsJson(String varRef, List inputs, ObjectMapper mapper) { + // 提取变量名(去掉 {}) + String varName = varRef.replaceAll("[{}]", ""); + + // 查找对应的输入变量 + for (NodeIOData input : inputs) { + if (input != null && varName.equals(input.getName())) { + String varValue = input.valueToString(); + return tryParseJson(varValue, mapper); + } + } + + return null; + } + + /** + * 构建 Form Data + */ + private MultiValueMap buildFormData(List formItems, List inputs) { + MultiValueMap formData = new LinkedMultiValueMap<>(); + if (formItems != null) { + for (HttpRequestNodeConfig.FormItem item : formItems) { + if (StringUtils.isNotBlank(item.getName())) { + String name = renderTemplate(item.getName(), inputs); + String value = renderTemplate(item.getValue(), inputs); + formData.add(name, value); + } + } + } + return formData; + } + + /** + * 构建 Form URL Encoded + */ + private MultiValueMap buildFormUrlEncoded(List formItems, List inputs) { + return buildFormData(formItems, inputs); + } + + /** + * 执行 HTTP 请求(支持重试) + */ + private String executeHttpRequest(String url, String method, HttpHeaders headers, Object body, HttpRequestNodeConfig config) { + RestTemplate restTemplate = createRestTemplate(config.getTimeout()); + + int maxRetries = config.getRetryTimes() != null ? config.getRetryTimes() : 0; + int attempt = 0; + Exception lastException = null; + + while (attempt <= maxRetries) { + try { + // 设置 Content-Type + if (StringUtils.isNotBlank(config.getContentType())) { + headers.setContentType(MediaType.parseMediaType(config.getContentType())); + } + + HttpEntity requestEntity = new HttpEntity<>(body, headers); + HttpMethod httpMethod = HttpMethod.valueOf(method.toUpperCase()); + + ResponseEntity response = restTemplate.exchange(url, httpMethod, requestEntity, String.class); + + if (response.getStatusCode().is2xxSuccessful()) { + return response.getBody(); + } else { + throw new RuntimeException("HTTP 请求失败,状态码: " + response.getStatusCode()); + } + + } catch (Exception e) { + lastException = e; + attempt++; + + if (attempt <= maxRetries) { + log.warn("HTTP 请求失败,正在重试 ({}/{}): {}", attempt, maxRetries, e.getMessage()); + try { + TimeUnit.SECONDS.sleep(1); // 重试前等待 1 秒 + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("重试等待被中断", ie); + } + } + } + } + + throw new RuntimeException("HTTP 请求失败,已重试 " + maxRetries + " 次", lastException); + } + + /** + * 创建 RestTemplate(设置超时) + */ + private RestTemplate createRestTemplate(Integer timeoutSeconds) { + RestTemplate restTemplate = new RestTemplate(); + + // 设置超时时间 + int timeout = (timeoutSeconds != null ? timeoutSeconds : 10) * 1000; + org.springframework.http.client.SimpleClientHttpRequestFactory requestFactory = + new org.springframework.http.client.SimpleClientHttpRequestFactory(); + requestFactory.setConnectTimeout(timeout); + requestFactory.setReadTimeout(timeout); + restTemplate.setRequestFactory(requestFactory); + + return restTemplate; + } +} diff --git a/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/httpRequest/HttpRequestNodeConfig.java b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/httpRequest/HttpRequestNodeConfig.java new file mode 100644 index 00000000..15bcd9a1 --- /dev/null +++ b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/httpRequest/HttpRequestNodeConfig.java @@ -0,0 +1,113 @@ +package org.ruoyi.workflow.workflow.node.httpRequest; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +import java.util.List; +import java.util.Map; + +/** + * HTTP 请求节点配置 + */ +@Data +public class HttpRequestNodeConfig { + + /** + * HTTP 请求方法 + */ + private String method = "GET"; + + /** + * 请求 URL + */ + private String url; + + /** + * Content-Type + */ + @JsonProperty("content_type") + private String contentType = "text/plain"; + + /** + * 请求头列表 + */ + private List headers; + + /** + * Query 参数列表 + */ + private List params; + + /** + * 纯文本请求体 + */ + @JsonProperty("text_body") + private String textBody; + + /** + * JSON 请求体 + */ + @JsonProperty("json_body") + private Map jsonBody; + + /** + * Form Data 请求体 + */ + @JsonProperty("form_data_body") + private List formDataBody; + + /** + * Form URL Encoded 请求体 + */ + @JsonProperty("form_urlencoded_body") + private List formUrlencodedBody; + + /** + * 请求体(通用) + */ + private Map body; + + /** + * 超时时间(秒) + */ + private Integer timeout = 10; + + /** + * 重试次数 + */ + @JsonProperty("retry_times") + private Integer retryTimes = 0; + + /** + * 是否清除 HTML 标签 + */ + @JsonProperty("clear_html") + private Boolean clearHtml = false; + + /** + * 请求头项 + */ + @Data + public static class HeaderItem { + private String name; + private String value; + } + + /** + * Query 参数项 + */ + @Data + public static class ParamItem { + private String name; + private String value; + } + + /** + * Form 表单项 + */ + @Data + public static class FormItem { + private String name; + private String value; + } +} diff --git a/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNode.java b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNode.java new file mode 100644 index 00000000..fcc0a5a3 --- /dev/null +++ b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNode.java @@ -0,0 +1,274 @@ +package org.ruoyi.workflow.workflow.node.knowledgeRetrieval; + +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.ruoyi.workflow.entity.WorkflowComponent; +import org.ruoyi.workflow.entity.WorkflowNode; +import org.ruoyi.workflow.util.SpringUtil; +import org.ruoyi.workflow.workflow.NodeProcessResult; +import org.ruoyi.workflow.workflow.WfNodeState; +import org.ruoyi.workflow.workflow.WfState; +import org.ruoyi.workflow.workflow.WorkflowUtil; +import org.ruoyi.workflow.workflow.data.NodeIOData; +import org.ruoyi.workflow.workflow.node.AbstractWfNode; +import org.ruoyi.service.VectorStoreService; +import org.ruoyi.service.IKnowledgeInfoService; +import org.ruoyi.domain.bo.QueryVectorBo; +import org.ruoyi.domain.vo.KnowledgeInfoVo; + +import java.util.ArrayList; +import java.util.List; + +import static org.ruoyi.workflow.cosntant.AdiConstant.WorkflowConstant.DEFAULT_OUTPUT_PARAM_NAME; + +/** + * 【节点】知识库检索节点 + * 从知识库中检索相关内容 + */ +@Slf4j +public class KnowledgeRetrievalNode extends AbstractWfNode { + + public KnowledgeRetrievalNode(WorkflowComponent wfComponent, WorkflowNode nodeDef, WfState wfState, WfNodeState nodeState) { + super(wfComponent, nodeDef, wfState, nodeState); + } + + /** + * 处理知识库检索 + * nodeConfig 格式: + * { + * "knowledge_id": "kb_123", + * "top_k": 5, + * "similarity_threshold": 0.7, + * "retrieval_mode": "vector", + * "embedding_model": "text-embedding-3-small", + * "return_source": true, + * "prompt": "额外的查询改写提示词" + * } + * + * @return 检索结果 + */ + @Override + public NodeProcessResult onProcess() { + KnowledgeRetrievalNodeConfig config = checkAndGetConfig(KnowledgeRetrievalNodeConfig.class); + + // 验证知识库ID + if (StringUtils.isBlank(config.getKnowledgeId())) { + log.error("Knowledge base ID is required but not provided"); + List outputs = new ArrayList<>(); + outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", "错误:未配置知识库ID")); + return NodeProcessResult.builder().content(outputs).build(); + } + + // 获取查询文本 + String queryText = getFirstInputText(); + if (StringUtils.isBlank(queryText)) { + log.warn("Knowledge retrieval node has no input query, node: {}", state.getUuid()); + // 返回空结果 + List outputs = new ArrayList<>(); + outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", "")); + return NodeProcessResult.builder().content(outputs).build(); + } + + log.info("Knowledge retrieval node config: {}", config); + log.info("Query text: {}", queryText); + + // 如果有自定义提示词,对查询进行改写 + String finalQuery = queryText; + if (StringUtils.isNotBlank(config.getPrompt())) { + finalQuery = rewriteQuery(config, queryText); + log.info("Rewritten query: {}", finalQuery); + } + + // 根据检索模式执行不同的检索策略 + String retrievalResult; + String mode = config.getRetrievalMode() != null ? config.getRetrievalMode().toLowerCase() : "vector"; + + // 目前只支持向量检索,图谱检索需要依赖graph模块 + if ("graph".equals(mode) || "hybrid".equals(mode)) { + log.warn("Graph retrieval mode is not supported in workflow-api module, falling back to vector retrieval"); + } + + retrievalResult = retrieveFromVector(config, finalQuery); + + log.info("Retrieval result length: {}", retrievalResult.length()); + + // 构建输出 + List outputs = new ArrayList<>(); + outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", retrievalResult)); + + // 如果需要返回原始查询 + outputs.add(NodeIOData.createByText("query", "", finalQuery)); + + return NodeProcessResult.builder().content(outputs).build(); + } + + /** + * 使用LLM改写查询 + */ + private String rewriteQuery(KnowledgeRetrievalNodeConfig config, String originalQuery) { + try { + // 构建改写提示词 + String prompt = WorkflowUtil.renderTemplate(config.getPrompt(), state.getInputs()); + prompt = prompt.replace("{query}", originalQuery); + + log.info("Query rewrite prompt: {}", prompt); + + // 调用LLM进行查询改写 + String rewrittenQuery = invokeLLMSync(config, prompt); + + if (StringUtils.isNotBlank(rewrittenQuery)) { + log.info("Query rewritten from '{}' to '{}'", originalQuery, rewrittenQuery); + return rewrittenQuery.trim(); + } + + // 如果改写失败,返回原查询 + return originalQuery; + } catch (Exception e) { + log.error("Failed to rewrite query, using original query", e); + return originalQuery; + } + } + + /** + * 同步调用LLM + * 使用一个临时的流式处理器来收集完整响应 + */ + private String invokeLLMSync(KnowledgeRetrievalNodeConfig config, String prompt) { + try { + // 创建一个StringBuilder来收集LLM响应 + StringBuilder responseBuilder = new StringBuilder(); + Object lock = new Object(); + boolean[] completed = {false}; + + // 创建临时节点状态用于LLM调用 + WfNodeState tempState = new WfNodeState(); + tempState.setUuid(state.getUuid() + "_rewrite"); + List tempInputs = new ArrayList<>(); + tempInputs.add(NodeIOData.createByText("input", "", prompt)); + tempState.setInputs(tempInputs); + + // 创建临时工作流节点定义 + WorkflowNode tempNode = new WorkflowNode(); + tempNode.setUuid(tempState.getUuid()); + tempNode.setInputConfig(node.getInputConfig()); + + // 使用WorkflowUtil调用LLM(流式) + WorkflowUtil workflowUtil = SpringUtil.getBean(WorkflowUtil.class); + List systemMessage = + List.of(dev.langchain4j.data.message.UserMessage.from(prompt)); + + // 调用流式LLM + String category = StringUtils.isNotBlank(config.getCategory()) ? config.getCategory() : "llm"; + String modelName = StringUtils.isNotBlank(config.getModelName()) ? config.getModelName() : "deepseek-chat"; + + workflowUtil.streamingInvokeLLM( + wfState, + tempState, + tempNode, + category, + modelName, + systemMessage + ); + + // 等待LLM响应完成(最多等待30秒) + long startTime = System.currentTimeMillis(); + long timeout = 30000; // 30秒超时 + + while (!completed[0] && (System.currentTimeMillis() - startTime) < timeout) { + synchronized (lock) { + // 检查是否有输出 + if (!tempState.getOutputs().isEmpty()) { + for (NodeIOData output : tempState.getOutputs()) { + if ("output".equals(output.getName())) { + String text = output.valueToString(); + if (StringUtils.isNotBlank(text)) { + responseBuilder.append(text); + completed[0] = true; + break; + } + } + } + } + } + + if (!completed[0]) { + Thread.sleep(100); // 等待100ms后重试 + } + } + + String result = responseBuilder.toString().trim(); + if (StringUtils.isBlank(result)) { + log.warn("LLM sync call returned empty response"); + } + + return result; + } catch (Exception e) { + log.error("Failed to invoke LLM synchronously", e); + return ""; + } + } + + /** + * 从向量库检索 + */ + private String retrieveFromVector(KnowledgeRetrievalNodeConfig config, String query) { + try { + VectorStoreService vectorStoreService = SpringUtil.getBean(VectorStoreService.class); + IKnowledgeInfoService knowledgeInfoService = SpringUtil.getBean(IKnowledgeInfoService.class); + + // 获取知识库信息以获取embedding模型配置 + Long knowledgeId = Long.parseLong(config.getKnowledgeId()); + KnowledgeInfoVo knowledgeInfo = knowledgeInfoService.queryById(knowledgeId); + + if (knowledgeInfo == null) { + log.error("Knowledge base not found: {}", config.getKnowledgeId()); + return "错误:知识库不存在"; + } + + // 构建查询参数 + QueryVectorBo queryBo = new QueryVectorBo(); + queryBo.setKid(config.getKnowledgeId()); + queryBo.setQuery(query); + queryBo.setMaxResults(config.getTopK()); + + // 优先使用配置中的embedding模型,否则使用知识库的默认模型 + String embeddingModel = StringUtils.isNotBlank(config.getEmbeddingModel()) + ? config.getEmbeddingModel() + : knowledgeInfo.getEmbeddingModelName(); + + // 验证embedding模型配置 + if (StringUtils.isBlank(embeddingModel)) { + log.error("Embedding model not configured for knowledge base: {}", config.getKnowledgeId()); + return "错误:知识库未配置向量化模型"; + } + + queryBo.setEmbeddingModelName(embeddingModel); + + log.info("Querying knowledge base: kid={}, query='{}', embedding model: {}, topK: {}, threshold: {}", + config.getKnowledgeId(), query, embeddingModel, config.getTopK(), config.getSimilarityThreshold()); + + // 执行检索 + List results = vectorStoreService.getQueryVector(queryBo); + + log.info("Vector store query completed, results count: {}", results != null ? results.size() : 0); + + if (results == null || results.isEmpty()) { + log.warn("No results found from vector store for knowledge: {}, query: '{}'", config.getKnowledgeId(), query); + return ""; + } + + // 合并结果 + String mergedResult = String.join("\n\n---\n\n", results); + log.info("Retrieved {} documents from vector store", results.size()); + + return mergedResult; + } catch (NumberFormatException e) { + log.error("Invalid knowledge base ID format: {}", config.getKnowledgeId(), e); + return "错误:知识库ID格式无效"; + } catch (Exception e) { + log.error("Failed to retrieve from vector store", e); + return ""; + } + } + +} diff --git a/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNodeConfig.java b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNodeConfig.java new file mode 100644 index 00000000..e8697f7d --- /dev/null +++ b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNodeConfig.java @@ -0,0 +1,111 @@ +package org.ruoyi.workflow.workflow.node.knowledgeRetrieval; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; +import lombok.Data; +import lombok.EqualsAndHashCode; + +/** + * 知识库检索节点配置 + */ +@EqualsAndHashCode +@Data +public class KnowledgeRetrievalNodeConfig { + + /** + * 知识库UUID(主要字段) + */ + @JsonProperty("knowledge_base_uuid") + private String knowledgeBaseUuid; + + /** + * 知识库ID(兼容字段) + */ + @JsonProperty("knowledge_id") + private String knowledgeId; + + /** + * 获取知识库ID(优先使用knowledgeBaseUuid) + */ + public String getKnowledgeId() { + return knowledgeBaseUuid != null ? knowledgeBaseUuid : knowledgeId; + } + + /** + * 检索的最大结果数 + */ + @Min(1) + @Max(100) + @JsonProperty("top_k") + private Integer topK = 5; + + /** + * 检索的最大结果数(兼容字段,前端使用top_n) + */ + @JsonProperty("top_n") + private Integer topN; + + /** + * 获取topK值(优先使用topN) + */ + public Integer getTopK() { + return topN != null ? topN : topK; + } + + /** + * 相似度阈值(0-1之间) + */ + @Min(0) + @Max(1) + @JsonProperty("similarity_threshold") + private Double similarityThreshold = 0.7; + + /** + * 相似度阈值(兼容字段,前端使用score) + */ + @JsonProperty("score") + private Double score; + + /** + * 获取相似度阈值(优先使用score) + */ + public Double getSimilarityThreshold() { + return score != null ? score : similarityThreshold; + } + + /** + * 检索模式:vector(向量检索)、graph(图谱检索)、hybrid(混合检索) + */ + @JsonProperty("retrieval_mode") + private String retrievalMode = "vector"; + + /** + * 模型分类(用于LLM查询改写) + */ + private String category; + + /** + * LLM模型名称(用于查询改写) + */ + @JsonProperty("model_name") + private String modelName; + + /** + * Embedding模型名称(用于向量检索) + */ + @JsonProperty("embedding_model") + private String embeddingModel; + + /** + * 是否返回原文 + */ + @JsonProperty("return_source") + private Boolean returnSource = true; + + /** + * 自定义查询提示词(可选) + * 用于对查询进行预处理或改写 + */ + private String prompt; +} diff --git a/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/switcher/SwitcherNode.java b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/switcher/SwitcherNode.java index 4f5a2d9f..646e79eb 100644 --- a/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/switcher/SwitcherNode.java +++ b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/switcher/SwitcherNode.java @@ -233,16 +233,22 @@ public class SwitcherNode extends AbstractWfNode { */ private String getValueFromInputs(String nodeUuid, String paramName, List inputs) { log.debug("从节点UUID '{}' 搜索参数 '{}'", nodeUuid, paramName); + + String result = null; // 首先尝试从当前输入中查找 log.debug("检查当前输入 (数量: {})", inputs.size()); for (NodeIOData input : inputs) { log.debug(" - 输入: 名称='{}', 值='{}'", input.getName(), input.valueToString()); if (paramName.equals(input.getName())) { - log.info("在当前输入中找到参数 '{}': '{}'", paramName, input.valueToString()); - return input.valueToString(); + result = input.valueToString(); } } + + if (result != null) { + log.info("在当前输入中找到参数 '{}': '{}'", paramName, result); + return result; + } // 如果当前输入中没有,尝试从工作流状态中查找指定节点的输出 if (StringUtils.isNotBlank(nodeUuid)) { @@ -251,14 +257,27 @@ public class SwitcherNode extends AbstractWfNode { for (NodeIOData output : nodeOutputs) { log.debug(" - 输出: 名称='{}', 值='{}'", output.getName(), output.valueToString()); if (paramName.equals(output.getName())) { - log.info("在节点 '{}' 的输出中找到参数 '{}': '{}'", nodeUuid, paramName, output.valueToString()); - return output.valueToString(); + result = output.valueToString(); } } + + if (result != null) { + log.info("在节点 '{}' 的输出中找到参数 '{}': '{}'", nodeUuid, paramName, result); + return result; + } } else { log.debug("节点UUID为空,跳过工作流状态搜索"); } + // 特殊处理:如果找的是 'output' 但没找到,尝试找 'input' + if ("output".equals(paramName)) { + log.debug("未找到参数 'output',尝试查找 'input'"); + String inputValue = getValueFromInputs(nodeUuid, "input", inputs); + if (inputValue != null) { + return inputValue; + } + } + log.warn("在输入或节点 '{}' 的输出中未找到参数 '{}'", nodeUuid, paramName); return null; } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/OllamaServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/OllamaServiceImpl.java index 4048fbd7..2956e555 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/OllamaServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/OllamaServiceImpl.java @@ -1,5 +1,8 @@ package org.ruoyi.chat.service.chat.impl; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.ollama.OllamaStreamingChatModel; import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.models.chat.OllamaChatMessage; import io.github.ollama4j.models.chat.OllamaChatMessageRole; @@ -80,6 +83,30 @@ public class OllamaServiceImpl implements IChatService { return emitter; } + /** + * 工作流场景:支持 langchain4j handler + */ + @Override + public void chat(ChatRequest request, StreamingChatResponseHandler handler) { + log.info("workflow chat, model: {}", request.getModel()); + + ChatModelVo chatModelVo = chatModelService.selectModelByName(request.getModel()); + + StreamingChatModel model = OllamaStreamingChatModel.builder() + .baseUrl(chatModelVo.getApiHost() != null ? chatModelVo.getApiHost() : "http://localhost:11434") + .modelName(chatModelVo.getModelName()) + .build(); + + try { + // 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式 + dev.langchain4j.model.chat.request.ChatRequest chatRequest = convertToLangchainRequest(request); + model.chat(chatRequest, handler); + } catch (Exception e) { + log.error("workflow ollama请求失败:{}", e.getMessage(), e); + throw new RuntimeException("ollama workflow chat failed: " + e.getMessage(), e); + } + } + @Override public String getCategory() { return ChatModeType.OLLAMA.getCode();