Merge pull request #241 from stageluo/main

新增知识库、http分支工作流节点
This commit is contained in:
ageerle
2025-12-12 11:46:29 +08:00
committed by GitHub
8 changed files with 1031 additions and 23 deletions

View File

@@ -31,22 +31,23 @@ import java.util.stream.IntStream;
@Component
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
private final Integer DIMENSION = 2048;
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
private final Map<String, EmbeddingStore<TextSegment>> storeCache = new ConcurrentHashMap<>();
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
super(vectorStoreProperties, embeddingModelFactory);
}
private EmbeddingStore<TextSegment> getMilvusStore(String collectionName, boolean autoFlushOnInsert) {
String key = collectionName + "|" + autoFlushOnInsert;
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
private final Map<String, EmbeddingStore<TextSegment>> storeCache = new ConcurrentHashMap<>();
/**
* 获取 Milvus Store支持动态维度
*/
private EmbeddingStore<TextSegment> getMilvusStore(String collectionName, int dimension, boolean autoFlushOnInsert) {
String key = collectionName + "|" + dimension + "|" + autoFlushOnInsert;
return storeCache.computeIfAbsent(key, k ->
MilvusEmbeddingStore.builder()
.uri(vectorStoreProperties.getMilvus().getUrl())
.collectionName(collectionName)
.dimension(DIMENSION)
.dimension(dimension)
.indexType(IndexType.IVF_FLAT)
.metricType(MetricType.L2)
.autoFlushOnInsert(autoFlushOnInsert)
@@ -58,17 +59,36 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
);
}
/**
* 获取 embedding 模型的实际维度
*/
private int getModelDimension(String modelName) {
try {
EmbeddingModel model = getEmbeddingModel(modelName, null);
// 使用一个测试文本获取向量维度
Embedding testEmbedding = model.embed("test").content();
int dimension = testEmbedding.dimension();
log.info("Detected embedding model dimension: {} for model: {}", dimension, modelName);
return dimension;
} catch (Exception e) {
log.warn("Failed to detect model dimension for: {}, using default 1024", modelName, e);
return 1024; // 默认使用 1024 (bge-m3 的维度)
}
}
@Override
public void createSchema(String kid, String modelName) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
int dimension = getModelDimension(modelName);
// 使用缓存获取连接以确保只初始化一次
EmbeddingStore<TextSegment> store = getMilvusStore(collectionName, true);
log.info("Milvus集合初始化完成: {}", collectionName);
EmbeddingStore<TextSegment> store = getMilvusStore(collectionName, dimension, true);
log.info("Milvus集合初始化完成: {}, dimension: {}", collectionName, dimension);
}
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), DIMENSION);
int dimension = getModelDimension(storeEmbeddingBo.getEmbeddingModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), dimension);
List<String> chunkList = storeEmbeddingBo.getChunkList();
List<String> fidList = storeEmbeddingBo.getFids();
@@ -80,7 +100,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
long startTime = System.currentTimeMillis();
// 复用连接,写入场景使用 autoFlush=false 以提升批量插入性能
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, dimension, false);
IntStream.range(0, chunkList.size()).forEach(i -> {
String text = chunkList.get(i);
@@ -100,13 +120,14 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), DIMENSION);
int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), dimension);
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
// 查询复用连接autoFlush 对查询无影响,此处保持 true
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, true);
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, dimension, true);
List<String> resultList = new ArrayList<>();
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
@@ -127,14 +148,16 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
@SneakyThrows
public void removeById(String id, String modelName) {
// 注意:此处原逻辑使用 collectionname + id保持现状
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, false);
int dimension = getModelDimension(modelName);
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, dimension, false);
embeddingStore.remove(id);
}
@Override
public void removeByDocId(String docId, String kid) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
// 使用默认维度,因为删除操作不需要精确的维度信息
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, 1024, false);
Filter filter = MetadataFilterBuilder.metadataKey("docId").isEqualTo(docId);
embeddingStore.removeAll(filter);
log.info("Milvus成功删除 docId={} 的所有向量数据", docId);
@@ -143,7 +166,8 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
@Override
public void removeByFid(String fid, String kid) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, false);
// 使用默认维度,因为删除操作不需要精确的维度信息
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, 1024, false);
Filter filter = MetadataFilterBuilder.metadataKey("fid").isEqualTo(fid);
embeddingStore.removeAll(filter);
log.info("Milvus成功删除 fid={} 的所有向量数据", fid);

View File

@@ -5,7 +5,9 @@ import org.ruoyi.workflow.entity.WorkflowNode;
import org.ruoyi.workflow.workflow.node.AbstractWfNode;
import org.ruoyi.workflow.workflow.node.EndNode;
import org.ruoyi.workflow.workflow.node.answer.LLMAnswerNode;
import org.ruoyi.workflow.workflow.node.httpRequest.HttpRequestNode;
import org.ruoyi.workflow.workflow.node.keywordExtractor.KeywordExtractorNode;
import org.ruoyi.workflow.workflow.node.knowledgeRetrieval.KnowledgeRetrievalNode;
import org.ruoyi.workflow.workflow.node.mailSend.MailSendNode;
import org.ruoyi.workflow.workflow.node.start.StartNode;
import org.ruoyi.workflow.workflow.node.switcher.SwitcherNode;
@@ -17,10 +19,11 @@ public class WfNodeFactory {
switch (WfComponentNameEnum.getByName(wfComponent.getName())) {
case START -> wfNode = new StartNode(wfComponent, nodeDefinition, wfState, nodeState);
case LLM_ANSWER -> wfNode = new LLMAnswerNode(wfComponent, nodeDefinition, wfState, nodeState);
case KEYWORD_EXTRACTOR ->
wfNode = new KeywordExtractorNode(wfComponent, nodeDefinition, wfState, nodeState);
case KEYWORD_EXTRACTOR -> wfNode = new KeywordExtractorNode(wfComponent, nodeDefinition, wfState, nodeState);
case KNOWLEDGE_RETRIEVER -> wfNode = new KnowledgeRetrievalNode(wfComponent, nodeDefinition, wfState, nodeState);
case END -> wfNode = new EndNode(wfComponent, nodeDefinition, wfState, nodeState);
case MAIL_SEND -> wfNode = new MailSendNode(wfComponent, nodeDefinition, wfState, nodeState);
case HTTP_REQUEST -> wfNode = new HttpRequestNode(wfComponent, nodeDefinition, wfState, nodeState);
case SWITCHER -> wfNode = new SwitcherNode(wfComponent, nodeDefinition, wfState, nodeState);
default -> {
}

View File

@@ -0,0 +1,437 @@
package org.ruoyi.workflow.workflow.node.httpRequest;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jsoup.Jsoup;
import org.ruoyi.workflow.entity.WorkflowComponent;
import org.ruoyi.workflow.entity.WorkflowNode;
import org.ruoyi.workflow.workflow.NodeProcessResult;
import org.ruoyi.workflow.workflow.WfNodeState;
import org.ruoyi.workflow.workflow.WfState;
import org.ruoyi.workflow.workflow.data.NodeIOData;
import org.ruoyi.workflow.workflow.node.AbstractWfNode;
import org.springframework.http.*;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
import java.util.*;
import java.util.concurrent.TimeUnit;
/**
* HTTP 请求节点
*/
@Slf4j
public class HttpRequestNode extends AbstractWfNode {
public HttpRequestNode(WorkflowComponent wfComponent, WorkflowNode nodeDef, WfState wfState, WfNodeState nodeState) {
super(wfComponent, nodeDef, wfState, nodeState);
}
@Override
public NodeProcessResult onProcess() {
try {
HttpRequestNodeConfig config = checkAndGetConfig(HttpRequestNodeConfig.class);
List<NodeIOData> inputs = state.getInputs();
// 渲染 URL支持变量替换
String url = renderTemplate(config.getUrl(), inputs);
if (StringUtils.isBlank(url)) {
throw new IllegalArgumentException("请求 URL 不能为空");
}
// 添加 Query 参数
url = buildUrlWithParams(url, config.getParams(), inputs);
// 构建请求头
HttpHeaders headers = buildHeaders(config.getHeaders(), inputs);
// 构建请求体
Object requestBody = buildRequestBody(config, inputs);
// 执行 HTTP 请求(支持重试)
String response = executeHttpRequest(url, config.getMethod(), headers, requestBody, config);
// 清除 HTML 标签(如果需要)
if (Boolean.TRUE.equals(config.getClearHtml()) && StringUtils.isNotBlank(response)) {
response = Jsoup.parse(response).text();
}
// 构造输出
List<NodeIOData> outputs = new ArrayList<>();
outputs.add(NodeIOData.createByText("output", "HTTP响应", response));
return NodeProcessResult.builder().content(outputs).build();
} catch (Exception e) {
log.error("HTTP 请求失败 in node: {}", node.getId(), e);
// 异常时返回错误信息
List<NodeIOData> errorOutputs = new ArrayList<>();
errorOutputs.add(NodeIOData.createByText("output", "错误", ""));
errorOutputs.add(NodeIOData.createByText("error", "HTTP请求错误", e.getMessage()));
return NodeProcessResult.builder().content(errorOutputs).build();
}
}
/**
* 渲染模板(支持变量替换)
* 支持格式:
* 1. {var_01} - 直接替换整个变量值
* 2. {var_01.name} - 从 JSON 中提取 name 字段
* 3. {var_01.user.email} - 支持嵌套路径
*/
private String renderTemplate(String template, List<NodeIOData> inputs) {
if (StringUtils.isBlank(template)) {
return "";
}
return renderTemplateWithJsonPath(template, inputs);
}
/**
* 增强的模板渲染,支持 JSON 路径提取
*/
private String renderTemplateWithJsonPath(String template, List<NodeIOData> inputs) {
String result = template;
ObjectMapper mapper = new ObjectMapper();
for (NodeIOData input : inputs) {
if (input == null || input.getName() == null) {
continue;
}
String varName = input.getName();
String varValue = input.valueToString();
// 1. 处理简单变量替换 {var_01}
result = result.replace("{" + varName + "}", varValue != null ? varValue : "");
// 2. 处理 JSON 路径提取 {var_01.field} 或 {var_01.user.name}
// 尝试解析为 JSON
Map<String, Object> jsonMap = tryParseJson(varValue, mapper);
if (jsonMap != null) {
result = replaceJsonPaths(result, varName, jsonMap);
}
}
return result;
}
/**
* 尝试将字符串解析为 JSON Map
*/
private Map<String, Object> tryParseJson(String value, ObjectMapper mapper) {
if (StringUtils.isBlank(value)) {
return null;
}
value = value.trim();
if (!value.startsWith("{") && !value.startsWith("[")) {
return null;
}
try {
return mapper.readValue(value, new TypeReference<Map<String, Object>>() {});
} catch (Exception e) {
log.debug("无法解析为 JSON: {}", value);
return null;
}
}
/**
* 替换 JSON 路径变量,如 {var_01.name} 或 {var_01.user.email}
*/
private String replaceJsonPaths(String template, String varName, Map<String, Object> jsonMap) {
String result = template;
// 查找所有 {varName.xxx} 格式的占位符
String pattern = "\\{" + varName + "\\.([\\.\\w]+)\\}";
java.util.regex.Pattern p = java.util.regex.Pattern.compile(pattern);
java.util.regex.Matcher m = p.matcher(template);
while (m.find()) {
String fullMatch = m.group(0); // 如 {var_01.name}
String jsonPath = m.group(1); // 如 name 或 user.email
Object value = extractJsonValue(jsonMap, jsonPath);
String replacement = value != null ? value.toString() : "";
result = result.replace(fullMatch, replacement);
}
return result;
}
/**
* 从 JSON Map 中提取嵌套路径的值
* 例如path = "user.email" 会提取 map.get("user").get("email")
*/
@SuppressWarnings("unchecked")
private Object extractJsonValue(Map<String, Object> map, String path) {
String[] parts = path.split("\\.");
Object current = map;
for (String part : parts) {
if (current instanceof Map) {
current = ((Map<String, Object>) current).get(part);
} else {
return null;
}
}
return current;
}
/**
* 构建带参数的 URL
*/
private String buildUrlWithParams(String baseUrl, List<HttpRequestNodeConfig.ParamItem> params, List<NodeIOData> inputs) {
if (params == null || params.isEmpty()) {
return baseUrl;
}
StringBuilder urlBuilder = new StringBuilder(baseUrl);
boolean hasQuery = baseUrl.contains("?");
for (HttpRequestNodeConfig.ParamItem param : params) {
if (StringUtils.isBlank(param.getName())) {
continue;
}
String name = renderTemplate(param.getName(), inputs);
String value = renderTemplate(param.getValue(), inputs);
if (hasQuery) {
urlBuilder.append("&");
} else {
urlBuilder.append("?");
hasQuery = true;
}
urlBuilder.append(name).append("=").append(value);
}
return urlBuilder.toString();
}
/**
* 构建请求头
*/
private HttpHeaders buildHeaders(List<HttpRequestNodeConfig.HeaderItem> headerItems, List<NodeIOData> inputs) {
HttpHeaders headers = new HttpHeaders();
if (headerItems != null) {
for (HttpRequestNodeConfig.HeaderItem item : headerItems) {
if (StringUtils.isNotBlank(item.getName())) {
String name = renderTemplate(item.getName(), inputs);
String value = renderTemplate(item.getValue(), inputs);
headers.add(name, value);
}
}
}
return headers;
}
/**
* 构建请求体
*/
private Object buildRequestBody(HttpRequestNodeConfig config, List<NodeIOData> inputs) {
String method = config.getMethod();
if ("GET".equalsIgnoreCase(method) || "DELETE".equalsIgnoreCase(method) || "HEAD".equalsIgnoreCase(method)) {
return null;
}
String contentType = config.getContentType();
// JSON Body
if ("application/json".equalsIgnoreCase(contentType)) {
if (config.getJsonBody() != null && !config.getJsonBody().isEmpty()) {
return renderJsonBody(config.getJsonBody(), inputs);
}
}
// Form Data
if ("multipart/form-data".equalsIgnoreCase(contentType)) {
return buildFormData(config.getFormDataBody(), inputs);
}
// Form URL Encoded
if ("application/x-www-form-urlencoded".equalsIgnoreCase(contentType)) {
return buildFormUrlEncoded(config.getFormUrlencodedBody(), inputs);
}
// Text Body
if (StringUtils.isNotBlank(config.getTextBody())) {
return renderTemplate(config.getTextBody(), inputs);
}
return null;
}
/**
* 渲染 JSON 请求体
* 支持三种模式:
* 1. 普通字段替换:{"name": "{var_01.name}"}
* 2. 整体 JSON 合并:{"$merge": "{var_01}"} - 将整个 JSON 对象合并进来
* 3. 智能合并:如果值是 {var_01} 且是有效 JSON自动展开合并
*/
private Map<String, Object> renderJsonBody(Map<String, Object> jsonBody, List<NodeIOData> inputs) {
Map<String, Object> rendered = new HashMap<>();
ObjectMapper mapper = new ObjectMapper();
for (Map.Entry<String, Object> entry : jsonBody.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
// 处理特殊的 $merge 指令
if ("$merge".equals(key) && value instanceof String) {
String varRef = (String) value;
Map<String, Object> mergeData = resolveVariableAsJson(varRef, inputs, mapper);
if (mergeData != null) {
rendered.putAll(mergeData);
}
continue;
}
if (value instanceof String) {
String strValue = (String) value;
// 检查是否是单纯的变量引用(如 {var_01}
if (strValue.matches("^\\{\\w+\\}$")) {
// 尝试解析为 JSON 对象
Map<String, Object> jsonValue = resolveVariableAsJson(strValue, inputs, mapper);
if (jsonValue != null) {
// 如果是 JSON 对象,合并所有字段
rendered.putAll(jsonValue);
} else {
// 否则作为普通字符串处理
rendered.put(key, renderTemplate(strValue, inputs));
}
} else {
// 普通字符串或包含多个变量的模板
rendered.put(key, renderTemplate(strValue, inputs));
}
} else if (value instanceof Map) {
// 递归处理嵌套的 Map
@SuppressWarnings("unchecked")
Map<String, Object> nestedMap = (Map<String, Object>) value;
rendered.put(key, renderJsonBody(nestedMap, inputs));
} else {
// 其他类型直接保留
rendered.put(key, value);
}
}
return rendered;
}
/**
* 解析变量引用为 JSON 对象
* 例如:{var_01} -> 尝试解析 var_01 的值为 JSON Map
*/
private Map<String, Object> resolveVariableAsJson(String varRef, List<NodeIOData> inputs, ObjectMapper mapper) {
// 提取变量名(去掉 {}
String varName = varRef.replaceAll("[{}]", "");
// 查找对应的输入变量
for (NodeIOData input : inputs) {
if (input != null && varName.equals(input.getName())) {
String varValue = input.valueToString();
return tryParseJson(varValue, mapper);
}
}
return null;
}
/**
* 构建 Form Data
*/
private MultiValueMap<String, String> buildFormData(List<HttpRequestNodeConfig.FormItem> formItems, List<NodeIOData> inputs) {
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
if (formItems != null) {
for (HttpRequestNodeConfig.FormItem item : formItems) {
if (StringUtils.isNotBlank(item.getName())) {
String name = renderTemplate(item.getName(), inputs);
String value = renderTemplate(item.getValue(), inputs);
formData.add(name, value);
}
}
}
return formData;
}
/**
* 构建 Form URL Encoded
*/
private MultiValueMap<String, String> buildFormUrlEncoded(List<HttpRequestNodeConfig.FormItem> formItems, List<NodeIOData> inputs) {
return buildFormData(formItems, inputs);
}
/**
* 执行 HTTP 请求(支持重试)
*/
private String executeHttpRequest(String url, String method, HttpHeaders headers, Object body, HttpRequestNodeConfig config) {
RestTemplate restTemplate = createRestTemplate(config.getTimeout());
int maxRetries = config.getRetryTimes() != null ? config.getRetryTimes() : 0;
int attempt = 0;
Exception lastException = null;
while (attempt <= maxRetries) {
try {
// 设置 Content-Type
if (StringUtils.isNotBlank(config.getContentType())) {
headers.setContentType(MediaType.parseMediaType(config.getContentType()));
}
HttpEntity<?> requestEntity = new HttpEntity<>(body, headers);
HttpMethod httpMethod = HttpMethod.valueOf(method.toUpperCase());
ResponseEntity<String> response = restTemplate.exchange(url, httpMethod, requestEntity, String.class);
if (response.getStatusCode().is2xxSuccessful()) {
return response.getBody();
} else {
throw new RuntimeException("HTTP 请求失败,状态码: " + response.getStatusCode());
}
} catch (Exception e) {
lastException = e;
attempt++;
if (attempt <= maxRetries) {
log.warn("HTTP 请求失败,正在重试 ({}/{}): {}", attempt, maxRetries, e.getMessage());
try {
TimeUnit.SECONDS.sleep(1); // 重试前等待 1 秒
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("重试等待被中断", ie);
}
}
}
}
throw new RuntimeException("HTTP 请求失败,已重试 " + maxRetries + "", lastException);
}
/**
* 创建 RestTemplate设置超时
*/
private RestTemplate createRestTemplate(Integer timeoutSeconds) {
RestTemplate restTemplate = new RestTemplate();
// 设置超时时间
int timeout = (timeoutSeconds != null ? timeoutSeconds : 10) * 1000;
org.springframework.http.client.SimpleClientHttpRequestFactory requestFactory =
new org.springframework.http.client.SimpleClientHttpRequestFactory();
requestFactory.setConnectTimeout(timeout);
requestFactory.setReadTimeout(timeout);
restTemplate.setRequestFactory(requestFactory);
return restTemplate;
}
}

View File

@@ -0,0 +1,113 @@
package org.ruoyi.workflow.workflow.node.httpRequest;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List;
import java.util.Map;
/**
* HTTP 请求节点配置
*/
@Data
public class HttpRequestNodeConfig {
/**
* HTTP 请求方法
*/
private String method = "GET";
/**
* 请求 URL
*/
private String url;
/**
* Content-Type
*/
@JsonProperty("content_type")
private String contentType = "text/plain";
/**
* 请求头列表
*/
private List<HeaderItem> headers;
/**
* Query 参数列表
*/
private List<ParamItem> params;
/**
* 纯文本请求体
*/
@JsonProperty("text_body")
private String textBody;
/**
* JSON 请求体
*/
@JsonProperty("json_body")
private Map<String, Object> jsonBody;
/**
* Form Data 请求体
*/
@JsonProperty("form_data_body")
private List<FormItem> formDataBody;
/**
* Form URL Encoded 请求体
*/
@JsonProperty("form_urlencoded_body")
private List<FormItem> formUrlencodedBody;
/**
* 请求体(通用)
*/
private Map<String, Object> body;
/**
* 超时时间(秒)
*/
private Integer timeout = 10;
/**
* 重试次数
*/
@JsonProperty("retry_times")
private Integer retryTimes = 0;
/**
* 是否清除 HTML 标签
*/
@JsonProperty("clear_html")
private Boolean clearHtml = false;
/**
* 请求头项
*/
@Data
public static class HeaderItem {
private String name;
private String value;
}
/**
* Query 参数项
*/
@Data
public static class ParamItem {
private String name;
private String value;
}
/**
* Form 表单项
*/
@Data
public static class FormItem {
private String name;
private String value;
}
}

View File

@@ -0,0 +1,274 @@
package org.ruoyi.workflow.workflow.node.knowledgeRetrieval;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.ruoyi.workflow.entity.WorkflowComponent;
import org.ruoyi.workflow.entity.WorkflowNode;
import org.ruoyi.workflow.util.SpringUtil;
import org.ruoyi.workflow.workflow.NodeProcessResult;
import org.ruoyi.workflow.workflow.WfNodeState;
import org.ruoyi.workflow.workflow.WfState;
import org.ruoyi.workflow.workflow.WorkflowUtil;
import org.ruoyi.workflow.workflow.data.NodeIOData;
import org.ruoyi.workflow.workflow.node.AbstractWfNode;
import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.IKnowledgeInfoService;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.vo.KnowledgeInfoVo;
import java.util.ArrayList;
import java.util.List;
import static org.ruoyi.workflow.cosntant.AdiConstant.WorkflowConstant.DEFAULT_OUTPUT_PARAM_NAME;
/**
* 【节点】知识库检索节点
* 从知识库中检索相关内容
*/
@Slf4j
public class KnowledgeRetrievalNode extends AbstractWfNode {
public KnowledgeRetrievalNode(WorkflowComponent wfComponent, WorkflowNode nodeDef, WfState wfState, WfNodeState nodeState) {
super(wfComponent, nodeDef, wfState, nodeState);
}
/**
* 处理知识库检索
* nodeConfig 格式:
* {
* "knowledge_id": "kb_123",
* "top_k": 5,
* "similarity_threshold": 0.7,
* "retrieval_mode": "vector",
* "embedding_model": "text-embedding-3-small",
* "return_source": true,
* "prompt": "额外的查询改写提示词"
* }
*
* @return 检索结果
*/
@Override
public NodeProcessResult onProcess() {
KnowledgeRetrievalNodeConfig config = checkAndGetConfig(KnowledgeRetrievalNodeConfig.class);
// 验证知识库ID
if (StringUtils.isBlank(config.getKnowledgeId())) {
log.error("Knowledge base ID is required but not provided");
List<NodeIOData> outputs = new ArrayList<>();
outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", "错误未配置知识库ID"));
return NodeProcessResult.builder().content(outputs).build();
}
// 获取查询文本
String queryText = getFirstInputText();
if (StringUtils.isBlank(queryText)) {
log.warn("Knowledge retrieval node has no input query, node: {}", state.getUuid());
// 返回空结果
List<NodeIOData> outputs = new ArrayList<>();
outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", ""));
return NodeProcessResult.builder().content(outputs).build();
}
log.info("Knowledge retrieval node config: {}", config);
log.info("Query text: {}", queryText);
// 如果有自定义提示词,对查询进行改写
String finalQuery = queryText;
if (StringUtils.isNotBlank(config.getPrompt())) {
finalQuery = rewriteQuery(config, queryText);
log.info("Rewritten query: {}", finalQuery);
}
// 根据检索模式执行不同的检索策略
String retrievalResult;
String mode = config.getRetrievalMode() != null ? config.getRetrievalMode().toLowerCase() : "vector";
// 目前只支持向量检索图谱检索需要依赖graph模块
if ("graph".equals(mode) || "hybrid".equals(mode)) {
log.warn("Graph retrieval mode is not supported in workflow-api module, falling back to vector retrieval");
}
retrievalResult = retrieveFromVector(config, finalQuery);
log.info("Retrieval result length: {}", retrievalResult.length());
// 构建输出
List<NodeIOData> outputs = new ArrayList<>();
outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", retrievalResult));
// 如果需要返回原始查询
outputs.add(NodeIOData.createByText("query", "", finalQuery));
return NodeProcessResult.builder().content(outputs).build();
}
/**
* 使用LLM改写查询
*/
private String rewriteQuery(KnowledgeRetrievalNodeConfig config, String originalQuery) {
try {
// 构建改写提示词
String prompt = WorkflowUtil.renderTemplate(config.getPrompt(), state.getInputs());
prompt = prompt.replace("{query}", originalQuery);
log.info("Query rewrite prompt: {}", prompt);
// 调用LLM进行查询改写
String rewrittenQuery = invokeLLMSync(config, prompt);
if (StringUtils.isNotBlank(rewrittenQuery)) {
log.info("Query rewritten from '{}' to '{}'", originalQuery, rewrittenQuery);
return rewrittenQuery.trim();
}
// 如果改写失败,返回原查询
return originalQuery;
} catch (Exception e) {
log.error("Failed to rewrite query, using original query", e);
return originalQuery;
}
}
/**
* 同步调用LLM
* 使用一个临时的流式处理器来收集完整响应
*/
private String invokeLLMSync(KnowledgeRetrievalNodeConfig config, String prompt) {
try {
// 创建一个StringBuilder来收集LLM响应
StringBuilder responseBuilder = new StringBuilder();
Object lock = new Object();
boolean[] completed = {false};
// 创建临时节点状态用于LLM调用
WfNodeState tempState = new WfNodeState();
tempState.setUuid(state.getUuid() + "_rewrite");
List<NodeIOData> tempInputs = new ArrayList<>();
tempInputs.add(NodeIOData.createByText("input", "", prompt));
tempState.setInputs(tempInputs);
// 创建临时工作流节点定义
WorkflowNode tempNode = new WorkflowNode();
tempNode.setUuid(tempState.getUuid());
tempNode.setInputConfig(node.getInputConfig());
// 使用WorkflowUtil调用LLM流式
WorkflowUtil workflowUtil = SpringUtil.getBean(WorkflowUtil.class);
List<dev.langchain4j.data.message.UserMessage> systemMessage =
List.of(dev.langchain4j.data.message.UserMessage.from(prompt));
// 调用流式LLM
String category = StringUtils.isNotBlank(config.getCategory()) ? config.getCategory() : "llm";
String modelName = StringUtils.isNotBlank(config.getModelName()) ? config.getModelName() : "deepseek-chat";
workflowUtil.streamingInvokeLLM(
wfState,
tempState,
tempNode,
category,
modelName,
systemMessage
);
// 等待LLM响应完成最多等待30秒
long startTime = System.currentTimeMillis();
long timeout = 30000; // 30秒超时
while (!completed[0] && (System.currentTimeMillis() - startTime) < timeout) {
synchronized (lock) {
// 检查是否有输出
if (!tempState.getOutputs().isEmpty()) {
for (NodeIOData output : tempState.getOutputs()) {
if ("output".equals(output.getName())) {
String text = output.valueToString();
if (StringUtils.isNotBlank(text)) {
responseBuilder.append(text);
completed[0] = true;
break;
}
}
}
}
}
if (!completed[0]) {
Thread.sleep(100); // 等待100ms后重试
}
}
String result = responseBuilder.toString().trim();
if (StringUtils.isBlank(result)) {
log.warn("LLM sync call returned empty response");
}
return result;
} catch (Exception e) {
log.error("Failed to invoke LLM synchronously", e);
return "";
}
}
/**
* 从向量库检索
*/
private String retrieveFromVector(KnowledgeRetrievalNodeConfig config, String query) {
try {
VectorStoreService vectorStoreService = SpringUtil.getBean(VectorStoreService.class);
IKnowledgeInfoService knowledgeInfoService = SpringUtil.getBean(IKnowledgeInfoService.class);
// 获取知识库信息以获取embedding模型配置
Long knowledgeId = Long.parseLong(config.getKnowledgeId());
KnowledgeInfoVo knowledgeInfo = knowledgeInfoService.queryById(knowledgeId);
if (knowledgeInfo == null) {
log.error("Knowledge base not found: {}", config.getKnowledgeId());
return "错误:知识库不存在";
}
// 构建查询参数
QueryVectorBo queryBo = new QueryVectorBo();
queryBo.setKid(config.getKnowledgeId());
queryBo.setQuery(query);
queryBo.setMaxResults(config.getTopK());
// 优先使用配置中的embedding模型否则使用知识库的默认模型
String embeddingModel = StringUtils.isNotBlank(config.getEmbeddingModel())
? config.getEmbeddingModel()
: knowledgeInfo.getEmbeddingModelName();
// 验证embedding模型配置
if (StringUtils.isBlank(embeddingModel)) {
log.error("Embedding model not configured for knowledge base: {}", config.getKnowledgeId());
return "错误:知识库未配置向量化模型";
}
queryBo.setEmbeddingModelName(embeddingModel);
log.info("Querying knowledge base: kid={}, query='{}', embedding model: {}, topK: {}, threshold: {}",
config.getKnowledgeId(), query, embeddingModel, config.getTopK(), config.getSimilarityThreshold());
// 执行检索
List<String> results = vectorStoreService.getQueryVector(queryBo);
log.info("Vector store query completed, results count: {}", results != null ? results.size() : 0);
if (results == null || results.isEmpty()) {
log.warn("No results found from vector store for knowledge: {}, query: '{}'", config.getKnowledgeId(), query);
return "";
}
// 合并结果
String mergedResult = String.join("\n\n---\n\n", results);
log.info("Retrieved {} documents from vector store", results.size());
return mergedResult;
} catch (NumberFormatException e) {
log.error("Invalid knowledge base ID format: {}", config.getKnowledgeId(), e);
return "错误知识库ID格式无效";
} catch (Exception e) {
log.error("Failed to retrieve from vector store", e);
return "";
}
}
}

View File

@@ -0,0 +1,111 @@
package org.ruoyi.workflow.workflow.node.knowledgeRetrieval;
import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import lombok.Data;
import lombok.EqualsAndHashCode;
/**
* 知识库检索节点配置
*/
@EqualsAndHashCode
@Data
public class KnowledgeRetrievalNodeConfig {
/**
* 知识库UUID主要字段
*/
@JsonProperty("knowledge_base_uuid")
private String knowledgeBaseUuid;
/**
* 知识库ID兼容字段
*/
@JsonProperty("knowledge_id")
private String knowledgeId;
/**
* 获取知识库ID优先使用knowledgeBaseUuid
*/
public String getKnowledgeId() {
return knowledgeBaseUuid != null ? knowledgeBaseUuid : knowledgeId;
}
/**
* 检索的最大结果数
*/
@Min(1)
@Max(100)
@JsonProperty("top_k")
private Integer topK = 5;
/**
* 检索的最大结果数兼容字段前端使用top_n
*/
@JsonProperty("top_n")
private Integer topN;
/**
* 获取topK值优先使用topN
*/
public Integer getTopK() {
return topN != null ? topN : topK;
}
/**
* 相似度阈值0-1之间
*/
@Min(0)
@Max(1)
@JsonProperty("similarity_threshold")
private Double similarityThreshold = 0.7;
/**
* 相似度阈值兼容字段前端使用score
*/
@JsonProperty("score")
private Double score;
/**
* 获取相似度阈值优先使用score
*/
public Double getSimilarityThreshold() {
return score != null ? score : similarityThreshold;
}
/**
* 检索模式vector向量检索、graph图谱检索、hybrid混合检索
*/
@JsonProperty("retrieval_mode")
private String retrievalMode = "vector";
/**
* 模型分类用于LLM查询改写
*/
private String category;
/**
* LLM模型名称用于查询改写
*/
@JsonProperty("model_name")
private String modelName;
/**
* Embedding模型名称用于向量检索
*/
@JsonProperty("embedding_model")
private String embeddingModel;
/**
* 是否返回原文
*/
@JsonProperty("return_source")
private Boolean returnSource = true;
/**
* 自定义查询提示词(可选)
* 用于对查询进行预处理或改写
*/
private String prompt;
}

View File

@@ -234,16 +234,22 @@ public class SwitcherNode extends AbstractWfNode {
private String getValueFromInputs(String nodeUuid, String paramName, List<NodeIOData> inputs) {
log.debug("从节点UUID '{}' 搜索参数 '{}'", nodeUuid, paramName);
String result = null;
// 首先尝试从当前输入中查找
log.debug("检查当前输入 (数量: {})", inputs.size());
for (NodeIOData input : inputs) {
log.debug(" - 输入: 名称='{}', 值='{}'", input.getName(), input.valueToString());
if (paramName.equals(input.getName())) {
log.info("在当前输入中找到参数 '{}': '{}'", paramName, input.valueToString());
return input.valueToString();
result = input.valueToString();
}
}
if (result != null) {
log.info("在当前输入中找到参数 '{}': '{}'", paramName, result);
return result;
}
// 如果当前输入中没有,尝试从工作流状态中查找指定节点的输出
if (StringUtils.isNotBlank(nodeUuid)) {
List<NodeIOData> nodeOutputs = wfState.getIOByNodeUuid(nodeUuid);
@@ -251,14 +257,27 @@ public class SwitcherNode extends AbstractWfNode {
for (NodeIOData output : nodeOutputs) {
log.debug(" - 输出: 名称='{}', 值='{}'", output.getName(), output.valueToString());
if (paramName.equals(output.getName())) {
log.info("在节点 '{}' 的输出中找到参数 '{}': '{}'", nodeUuid, paramName, output.valueToString());
return output.valueToString();
result = output.valueToString();
}
}
if (result != null) {
log.info("在节点 '{}' 的输出中找到参数 '{}': '{}'", nodeUuid, paramName, result);
return result;
}
} else {
log.debug("节点UUID为空跳过工作流状态搜索");
}
// 特殊处理:如果找的是 'output' 但没找到,尝试找 'input'
if ("output".equals(paramName)) {
log.debug("未找到参数 'output',尝试查找 'input'");
String inputValue = getValueFromInputs(nodeUuid, "input", inputs);
if (inputValue != null) {
return inputValue;
}
}
log.warn("在输入或节点 '{}' 的输出中未找到参数 '{}'", nodeUuid, paramName);
return null;
}

View File

@@ -1,5 +1,8 @@
package org.ruoyi.chat.service.chat.impl;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.ollama.OllamaStreamingChatModel;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.models.chat.OllamaChatMessage;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
@@ -80,6 +83,30 @@ public class OllamaServiceImpl implements IChatService {
return emitter;
}
/**
* 工作流场景:支持 langchain4j handler
*/
@Override
public void chat(ChatRequest request, StreamingChatResponseHandler handler) {
log.info("workflow chat, model: {}", request.getModel());
ChatModelVo chatModelVo = chatModelService.selectModelByName(request.getModel());
StreamingChatModel model = OllamaStreamingChatModel.builder()
.baseUrl(chatModelVo.getApiHost() != null ? chatModelVo.getApiHost() : "http://localhost:11434")
.modelName(chatModelVo.getModelName())
.build();
try {
// 将 ruoyi-ai 的 ChatRequest 转换为 langchain4j 的格式
dev.langchain4j.model.chat.request.ChatRequest chatRequest = convertToLangchainRequest(request);
model.chat(chatRequest, handler);
} catch (Exception e) {
log.error("workflow ollama请求失败{}", e.getMessage(), e);
throw new RuntimeException("ollama workflow chat failed: " + e.getMessage(), e);
}
}
@Override
public String getCategory() {
return ChatModeType.OLLAMA.getCode();