From c22c5eac7f6a079c5747502021fd49e69efcc011 Mon Sep 17 00:00:00 2001 From: stageluo <979175267@qq.com> Date: Tue, 25 Nov 2025 09:26:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E5=88=86=E6=94=AF=E5=B7=A5=E4=BD=9C=E6=B5=81=E8=8A=82=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../impl/MilvusVectorStoreStrategy.java | 53 +++- .../KnowledgeRetrievalNode.java | 274 ++++++++++++++++++ .../KnowledgeRetrievalNodeConfig.java | 111 +++++++ 3 files changed, 423 insertions(+), 15 deletions(-) create mode 100644 ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNode.java create mode 100644 ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNodeConfig.java diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java index 8d3d50fa..5240d0a9 100644 --- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java @@ -31,9 +31,6 @@ import java.util.stream.IntStream; @Component public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { - - private final Integer DIMENSION = 2048; - public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) { super(vectorStoreProperties, embeddingModelFactory); } @@ -41,13 +38,16 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { // 缓存不同集合与 autoFlush 配置的 Milvus 连接 private final Map> storeCache = new ConcurrentHashMap<>(); - private EmbeddingStore getMilvusStore(String collectionName, boolean autoFlushOnInsert) { - String key = collectionName + "|" + autoFlushOnInsert; + /** + * 获取 Milvus Store,支持动态维度 + */ + private EmbeddingStore getMilvusStore(String collectionName, int dimension, boolean autoFlushOnInsert) { + String key = collectionName + "|" + dimension + "|" + autoFlushOnInsert; return storeCache.computeIfAbsent(key, k -> MilvusEmbeddingStore.builder() .uri(vectorStoreProperties.getMilvus().getUrl()) .collectionName(collectionName) - .dimension(DIMENSION) + .dimension(dimension) .indexType(IndexType.IVF_FLAT) .metricType(MetricType.L2) .autoFlushOnInsert(autoFlushOnInsert) @@ -58,18 +58,37 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { .build() ); } + + /** + * 获取 embedding 模型的实际维度 + */ + private int getModelDimension(String modelName) { + try { + EmbeddingModel model = getEmbeddingModel(modelName, null); + // 使用一个测试文本获取向量维度 + Embedding testEmbedding = model.embed("test").content(); + int dimension = testEmbedding.dimension(); + log.info("Detected embedding model dimension: {} for model: {}", dimension, modelName); + return dimension; + } catch (Exception e) { + log.warn("Failed to detect model dimension for: {}, using default 1024", modelName, e); + return 1024; // 默认使用 1024 (bge-m3 的维度) + } + } @Override public void createSchema(String kid, String modelName) { String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid; + int dimension = getModelDimension(modelName); // 使用缓存获取连接以确保只初始化一次 - EmbeddingStore store = getMilvusStore(collectionName, true); - log.info("Milvus集合初始化完成: {}", collectionName); + EmbeddingStore store = getMilvusStore(collectionName, dimension, true); + log.info("Milvus集合初始化完成: {}, dimension: {}", collectionName, dimension); } @Override public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) { - EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), DIMENSION); + int dimension = getModelDimension(storeEmbeddingBo.getEmbeddingModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), dimension); List chunkList = storeEmbeddingBo.getChunkList(); List fidList = storeEmbeddingBo.getFids(); @@ -81,7 +100,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { long startTime = System.currentTimeMillis(); // 复用连接,写入场景使用 autoFlush=false 以提升批量插入性能 - EmbeddingStore embeddingStore = getMilvusStore(collectionName, false); + EmbeddingStore embeddingStore = getMilvusStore(collectionName, dimension, false); IntStream.range(0, chunkList.size()).forEach(i -> { String text = chunkList.get(i); @@ -101,13 +120,14 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { @Override public List getQueryVector(QueryVectorBo queryVectorBo) { - EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), DIMENSION); + int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), dimension); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid(); // 查询复用连接,autoFlush 对查询无影响,此处保持 true - EmbeddingStore embeddingStore = getMilvusStore(collectionName, true); + EmbeddingStore embeddingStore = getMilvusStore(collectionName, dimension, true); List resultList = new ArrayList<>(); EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() @@ -128,14 +148,16 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { @SneakyThrows public void removeById(String id, String modelName) { // 注意:此处原逻辑使用 collectionname + id,保持现状 - EmbeddingStore embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, false); + int dimension = getModelDimension(modelName); + EmbeddingStore embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, dimension, false); embeddingStore.remove(id); } @Override public void removeByDocId(String docId, String kid) { String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid; - EmbeddingStore embeddingStore = getMilvusStore(collectionName, false); + // 使用默认维度,因为删除操作不需要精确的维度信息 + EmbeddingStore embeddingStore = getMilvusStore(collectionName, 1024, false); Filter filter = MetadataFilterBuilder.metadataKey("docId").isEqualTo(docId); embeddingStore.removeAll(filter); log.info("Milvus成功删除 docId={} 的所有向量数据", docId); @@ -144,7 +166,8 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { @Override public void removeByFid(String fid, String kid) { String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid; - EmbeddingStore embeddingStore = getMilvusStore(collectionName, false); + // 使用默认维度,因为删除操作不需要精确的维度信息 + EmbeddingStore embeddingStore = getMilvusStore(collectionName, 1024, false); Filter filter = MetadataFilterBuilder.metadataKey("fid").isEqualTo(fid); embeddingStore.removeAll(filter); log.info("Milvus成功删除 fid={} 的所有向量数据", fid); diff --git a/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNode.java b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNode.java new file mode 100644 index 00000000..fcc0a5a3 --- /dev/null +++ b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNode.java @@ -0,0 +1,274 @@ +package org.ruoyi.workflow.workflow.node.knowledgeRetrieval; + +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.ruoyi.workflow.entity.WorkflowComponent; +import org.ruoyi.workflow.entity.WorkflowNode; +import org.ruoyi.workflow.util.SpringUtil; +import org.ruoyi.workflow.workflow.NodeProcessResult; +import org.ruoyi.workflow.workflow.WfNodeState; +import org.ruoyi.workflow.workflow.WfState; +import org.ruoyi.workflow.workflow.WorkflowUtil; +import org.ruoyi.workflow.workflow.data.NodeIOData; +import org.ruoyi.workflow.workflow.node.AbstractWfNode; +import org.ruoyi.service.VectorStoreService; +import org.ruoyi.service.IKnowledgeInfoService; +import org.ruoyi.domain.bo.QueryVectorBo; +import org.ruoyi.domain.vo.KnowledgeInfoVo; + +import java.util.ArrayList; +import java.util.List; + +import static org.ruoyi.workflow.cosntant.AdiConstant.WorkflowConstant.DEFAULT_OUTPUT_PARAM_NAME; + +/** + * 【节点】知识库检索节点 + * 从知识库中检索相关内容 + */ +@Slf4j +public class KnowledgeRetrievalNode extends AbstractWfNode { + + public KnowledgeRetrievalNode(WorkflowComponent wfComponent, WorkflowNode nodeDef, WfState wfState, WfNodeState nodeState) { + super(wfComponent, nodeDef, wfState, nodeState); + } + + /** + * 处理知识库检索 + * nodeConfig 格式: + * { + * "knowledge_id": "kb_123", + * "top_k": 5, + * "similarity_threshold": 0.7, + * "retrieval_mode": "vector", + * "embedding_model": "text-embedding-3-small", + * "return_source": true, + * "prompt": "额外的查询改写提示词" + * } + * + * @return 检索结果 + */ + @Override + public NodeProcessResult onProcess() { + KnowledgeRetrievalNodeConfig config = checkAndGetConfig(KnowledgeRetrievalNodeConfig.class); + + // 验证知识库ID + if (StringUtils.isBlank(config.getKnowledgeId())) { + log.error("Knowledge base ID is required but not provided"); + List outputs = new ArrayList<>(); + outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", "错误:未配置知识库ID")); + return NodeProcessResult.builder().content(outputs).build(); + } + + // 获取查询文本 + String queryText = getFirstInputText(); + if (StringUtils.isBlank(queryText)) { + log.warn("Knowledge retrieval node has no input query, node: {}", state.getUuid()); + // 返回空结果 + List outputs = new ArrayList<>(); + outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", "")); + return NodeProcessResult.builder().content(outputs).build(); + } + + log.info("Knowledge retrieval node config: {}", config); + log.info("Query text: {}", queryText); + + // 如果有自定义提示词,对查询进行改写 + String finalQuery = queryText; + if (StringUtils.isNotBlank(config.getPrompt())) { + finalQuery = rewriteQuery(config, queryText); + log.info("Rewritten query: {}", finalQuery); + } + + // 根据检索模式执行不同的检索策略 + String retrievalResult; + String mode = config.getRetrievalMode() != null ? config.getRetrievalMode().toLowerCase() : "vector"; + + // 目前只支持向量检索,图谱检索需要依赖graph模块 + if ("graph".equals(mode) || "hybrid".equals(mode)) { + log.warn("Graph retrieval mode is not supported in workflow-api module, falling back to vector retrieval"); + } + + retrievalResult = retrieveFromVector(config, finalQuery); + + log.info("Retrieval result length: {}", retrievalResult.length()); + + // 构建输出 + List outputs = new ArrayList<>(); + outputs.add(NodeIOData.createByText(DEFAULT_OUTPUT_PARAM_NAME, "", retrievalResult)); + + // 如果需要返回原始查询 + outputs.add(NodeIOData.createByText("query", "", finalQuery)); + + return NodeProcessResult.builder().content(outputs).build(); + } + + /** + * 使用LLM改写查询 + */ + private String rewriteQuery(KnowledgeRetrievalNodeConfig config, String originalQuery) { + try { + // 构建改写提示词 + String prompt = WorkflowUtil.renderTemplate(config.getPrompt(), state.getInputs()); + prompt = prompt.replace("{query}", originalQuery); + + log.info("Query rewrite prompt: {}", prompt); + + // 调用LLM进行查询改写 + String rewrittenQuery = invokeLLMSync(config, prompt); + + if (StringUtils.isNotBlank(rewrittenQuery)) { + log.info("Query rewritten from '{}' to '{}'", originalQuery, rewrittenQuery); + return rewrittenQuery.trim(); + } + + // 如果改写失败,返回原查询 + return originalQuery; + } catch (Exception e) { + log.error("Failed to rewrite query, using original query", e); + return originalQuery; + } + } + + /** + * 同步调用LLM + * 使用一个临时的流式处理器来收集完整响应 + */ + private String invokeLLMSync(KnowledgeRetrievalNodeConfig config, String prompt) { + try { + // 创建一个StringBuilder来收集LLM响应 + StringBuilder responseBuilder = new StringBuilder(); + Object lock = new Object(); + boolean[] completed = {false}; + + // 创建临时节点状态用于LLM调用 + WfNodeState tempState = new WfNodeState(); + tempState.setUuid(state.getUuid() + "_rewrite"); + List tempInputs = new ArrayList<>(); + tempInputs.add(NodeIOData.createByText("input", "", prompt)); + tempState.setInputs(tempInputs); + + // 创建临时工作流节点定义 + WorkflowNode tempNode = new WorkflowNode(); + tempNode.setUuid(tempState.getUuid()); + tempNode.setInputConfig(node.getInputConfig()); + + // 使用WorkflowUtil调用LLM(流式) + WorkflowUtil workflowUtil = SpringUtil.getBean(WorkflowUtil.class); + List systemMessage = + List.of(dev.langchain4j.data.message.UserMessage.from(prompt)); + + // 调用流式LLM + String category = StringUtils.isNotBlank(config.getCategory()) ? config.getCategory() : "llm"; + String modelName = StringUtils.isNotBlank(config.getModelName()) ? config.getModelName() : "deepseek-chat"; + + workflowUtil.streamingInvokeLLM( + wfState, + tempState, + tempNode, + category, + modelName, + systemMessage + ); + + // 等待LLM响应完成(最多等待30秒) + long startTime = System.currentTimeMillis(); + long timeout = 30000; // 30秒超时 + + while (!completed[0] && (System.currentTimeMillis() - startTime) < timeout) { + synchronized (lock) { + // 检查是否有输出 + if (!tempState.getOutputs().isEmpty()) { + for (NodeIOData output : tempState.getOutputs()) { + if ("output".equals(output.getName())) { + String text = output.valueToString(); + if (StringUtils.isNotBlank(text)) { + responseBuilder.append(text); + completed[0] = true; + break; + } + } + } + } + } + + if (!completed[0]) { + Thread.sleep(100); // 等待100ms后重试 + } + } + + String result = responseBuilder.toString().trim(); + if (StringUtils.isBlank(result)) { + log.warn("LLM sync call returned empty response"); + } + + return result; + } catch (Exception e) { + log.error("Failed to invoke LLM synchronously", e); + return ""; + } + } + + /** + * 从向量库检索 + */ + private String retrieveFromVector(KnowledgeRetrievalNodeConfig config, String query) { + try { + VectorStoreService vectorStoreService = SpringUtil.getBean(VectorStoreService.class); + IKnowledgeInfoService knowledgeInfoService = SpringUtil.getBean(IKnowledgeInfoService.class); + + // 获取知识库信息以获取embedding模型配置 + Long knowledgeId = Long.parseLong(config.getKnowledgeId()); + KnowledgeInfoVo knowledgeInfo = knowledgeInfoService.queryById(knowledgeId); + + if (knowledgeInfo == null) { + log.error("Knowledge base not found: {}", config.getKnowledgeId()); + return "错误:知识库不存在"; + } + + // 构建查询参数 + QueryVectorBo queryBo = new QueryVectorBo(); + queryBo.setKid(config.getKnowledgeId()); + queryBo.setQuery(query); + queryBo.setMaxResults(config.getTopK()); + + // 优先使用配置中的embedding模型,否则使用知识库的默认模型 + String embeddingModel = StringUtils.isNotBlank(config.getEmbeddingModel()) + ? config.getEmbeddingModel() + : knowledgeInfo.getEmbeddingModelName(); + + // 验证embedding模型配置 + if (StringUtils.isBlank(embeddingModel)) { + log.error("Embedding model not configured for knowledge base: {}", config.getKnowledgeId()); + return "错误:知识库未配置向量化模型"; + } + + queryBo.setEmbeddingModelName(embeddingModel); + + log.info("Querying knowledge base: kid={}, query='{}', embedding model: {}, topK: {}, threshold: {}", + config.getKnowledgeId(), query, embeddingModel, config.getTopK(), config.getSimilarityThreshold()); + + // 执行检索 + List results = vectorStoreService.getQueryVector(queryBo); + + log.info("Vector store query completed, results count: {}", results != null ? results.size() : 0); + + if (results == null || results.isEmpty()) { + log.warn("No results found from vector store for knowledge: {}, query: '{}'", config.getKnowledgeId(), query); + return ""; + } + + // 合并结果 + String mergedResult = String.join("\n\n---\n\n", results); + log.info("Retrieved {} documents from vector store", results.size()); + + return mergedResult; + } catch (NumberFormatException e) { + log.error("Invalid knowledge base ID format: {}", config.getKnowledgeId(), e); + return "错误:知识库ID格式无效"; + } catch (Exception e) { + log.error("Failed to retrieve from vector store", e); + return ""; + } + } + +} diff --git a/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNodeConfig.java b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNodeConfig.java new file mode 100644 index 00000000..e8697f7d --- /dev/null +++ b/ruoyi-modules-api/ruoyi-workflow-api/src/main/java/org/ruoyi/workflow/workflow/node/knowledgeRetrieval/KnowledgeRetrievalNodeConfig.java @@ -0,0 +1,111 @@ +package org.ruoyi.workflow.workflow.node.knowledgeRetrieval; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; +import lombok.Data; +import lombok.EqualsAndHashCode; + +/** + * 知识库检索节点配置 + */ +@EqualsAndHashCode +@Data +public class KnowledgeRetrievalNodeConfig { + + /** + * 知识库UUID(主要字段) + */ + @JsonProperty("knowledge_base_uuid") + private String knowledgeBaseUuid; + + /** + * 知识库ID(兼容字段) + */ + @JsonProperty("knowledge_id") + private String knowledgeId; + + /** + * 获取知识库ID(优先使用knowledgeBaseUuid) + */ + public String getKnowledgeId() { + return knowledgeBaseUuid != null ? knowledgeBaseUuid : knowledgeId; + } + + /** + * 检索的最大结果数 + */ + @Min(1) + @Max(100) + @JsonProperty("top_k") + private Integer topK = 5; + + /** + * 检索的最大结果数(兼容字段,前端使用top_n) + */ + @JsonProperty("top_n") + private Integer topN; + + /** + * 获取topK值(优先使用topN) + */ + public Integer getTopK() { + return topN != null ? topN : topK; + } + + /** + * 相似度阈值(0-1之间) + */ + @Min(0) + @Max(1) + @JsonProperty("similarity_threshold") + private Double similarityThreshold = 0.7; + + /** + * 相似度阈值(兼容字段,前端使用score) + */ + @JsonProperty("score") + private Double score; + + /** + * 获取相似度阈值(优先使用score) + */ + public Double getSimilarityThreshold() { + return score != null ? score : similarityThreshold; + } + + /** + * 检索模式:vector(向量检索)、graph(图谱检索)、hybrid(混合检索) + */ + @JsonProperty("retrieval_mode") + private String retrievalMode = "vector"; + + /** + * 模型分类(用于LLM查询改写) + */ + private String category; + + /** + * LLM模型名称(用于查询改写) + */ + @JsonProperty("model_name") + private String modelName; + + /** + * Embedding模型名称(用于向量检索) + */ + @JsonProperty("embedding_model") + private String embeddingModel; + + /** + * 是否返回原文 + */ + @JsonProperty("return_source") + private Boolean returnSource = true; + + /** + * 自定义查询提示词(可选) + * 用于对查询进行预处理或改写 + */ + private String prompt; +}