From b8d16b7669b3c575e651ff4d4ba6dc4e79802ca0 Mon Sep 17 00:00:00 2001 From: RobustH <1511209518@qq.com> Date: Thu, 23 Apr 2026 00:52:53 +0800 Subject: [PATCH] =?UTF-8?q?feat(rag):=20=E5=AF=B9=E6=8E=A5=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E7=AB=AF=E7=94=A8=E6=88=B7=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=EF=BC=8C=E9=9B=86=E6=88=90=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E5=BA=93=E9=85=8D=E7=BD=AE=E5=BA=94=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/script/sql/update/updat-0423.sql | 14 ++ .../ruoyi/domain/bo/vector/QueryVectorBo.java | 18 ++ .../service/chat/impl/ChatServiceFacade.java | 78 +++--- .../impl/KnowledgeFragmentServiceImpl.java | 168 ++----------- .../retriever/CustomVectorRetriever.java | 29 ++- .../retrieval/KnowledgeRetrievalService.java | 12 +- .../impl/KnowledgeRetrievalServiceImpl.java | 233 +++++++++++++----- 7 files changed, 301 insertions(+), 251 deletions(-) create mode 100644 docs/script/sql/update/updat-0423.sql diff --git a/docs/script/sql/update/updat-0423.sql b/docs/script/sql/update/updat-0423.sql new file mode 100644 index 00000000..4ed14430 --- /dev/null +++ b/docs/script/sql/update/updat-0423.sql @@ -0,0 +1,14 @@ +-- 为知识库信息表新增检索配置字段 (剔除了已存在的重排字段) +ALTER TABLE knowledge_info + ADD COLUMN similarity_threshold DOUBLE DEFAULT 0.5 COMMENT '相似度阈值' +AFTER retrieve_limit; + +ALTER TABLE knowledge_info ADD COLUMN enable_hybrid tinyint(1) DEFAULT 0 COMMENT '是否启用混合检索'; +ALTER TABLE knowledge_info ADD COLUMN hybrid_alpha double DEFAULT 0.5 COMMENT '混合检索权重比例 (0.0=纯向量, 1.0=纯关键词)'; + +-- 为知识片段表增加全文索引及关联ID +ALTER TABLE knowledge_fragment ADD COLUMN knowledge_id bigint COMMENT '知识库ID'; +ALTER TABLE knowledge_fragment ADD FULLTEXT INDEX ft_content (content) WITH PARSER ngram; + +-- 为知识库附件表增加解析状态字段 +ALTER TABLE `knowledge_attach` ADD COLUMN `status` TINYINT DEFAULT 0 COMMENT '解析状态: 0待解析, 1解析中, 2已解析, 3解析失败'; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java index bb5634a3..6f0b9352 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java @@ -77,4 +77,22 @@ public class QueryVectorBo { */ private Double rerankScoreThreshold; + // ========== 混合检索与阈值相关参数 ========== + + /** + * 相似度阈值 (0.0-1.0) + * 应用于向量搜索阶段 + */ + private Double similarityThreshold; + + /** + * 是否启用混合检索 + */ + private Boolean enableHybrid = false; + + /** + * 混合检索权重 (0.0-1.0) + */ + private Double hybridAlpha; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java index 3bd0876e..16e0750d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java @@ -20,6 +20,11 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.service.tool.ToolProvider; import dev.langchain4j.skills.shell.ShellSkills; +import dev.langchain4j.rag.AugmentationRequest; +import dev.langchain4j.rag.AugmentationResult; +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.query.Metadata; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; @@ -55,6 +60,7 @@ import org.ruoyi.service.chat.IChatMessageService; import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.service.retrieval.KnowledgeRetrievalService; +import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -412,16 +418,49 @@ public class ChatServiceFacade implements IChatService { /** * 构建上下文消息列表 - * 消息顺序:历史消息 → 当前用户消息(确保 AI 正确理解对话上下文) * * @param chatRequest 聊天请求 * @return 上下文消息列表 */ private List buildContextMessages(ChatRequest chatRequest) { - List messages = new ArrayList<>(); + List messages = new ArrayList<>(); - // 从数据库查询历史对话消息(放在前面) + // 1. 初始化当前用户消息 + UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent()); + + // 2. 知识库检索增强 (RAG) + if (chatRequest.getKnowledgeId() != null) { + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId())); + if (knowledgeInfoVo != null) { + ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); + if (chatModel != null) { + log.info("执行高级 RAG 流程: kid={}", chatRequest.getKnowledgeId()); + + // 构建自定义检索器 + CustomVectorRetriever retriever = new CustomVectorRetriever( + knowledgeRetrievalService, knowledgeInfoVo, chatModel); + + // 构建增强流水线 + RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder() + .contentRetriever(retriever) + .build(); + + // 执行增强:编织上下文到 UserMessage + Metadata metadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>()); + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); + AugmentationResult result = augmentor.augment(augmentationRequest); + + ChatMessage augmented = result.chatMessage(); + if (augmented instanceof UserMessage) { + userMessage = (UserMessage) augmented; + log.debug("RAG 增强完成,UserMessage 已注入背景知识"); + } + } + } + } + + // 3. 从数据库查询历史对话消息(放在前面) if (chatRequest.getSessionId() != null) { MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId()); if (memory != null) { @@ -433,38 +472,7 @@ public class ChatServiceFacade implements IChatService { } } - // 从向量库查询相关历史消息(知识库内容作为上下文) - if (chatRequest.getKnowledgeId() != null) { - // 查询知识库信息 - KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId())); - if (knowledgeInfoVo == null) { - log.warn("知识库信息不存在,kid: {}", chatRequest.getKnowledgeId()); - // 继续添加当前用户消息 - messages.add(UserMessage.userMessage(chatRequest.getContent())); - return messages; - } - - // 查询向量模型配置信息 - ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); - if (chatModel == null) { - log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel()); - messages.add(UserMessage.userMessage(chatRequest.getContent())); - return messages; - } - - // 构建向量查询参数 - QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel); - - // 使用知识库检索服务(支持重排序) - List nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo); - for (String prompt : nearestList) { - // 知识库内容作为系统上下文添加 - messages.add(new AiMessage(prompt)); - } - } - - // 构建当前用户消息(放在最后) - UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent()); + // 4. 添加经过增强的用户消息(放在最后) messages.add(userMessage); return messages; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java index 8ca9e647..da17f9b9 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java @@ -12,25 +12,18 @@ import org.ruoyi.common.core.utils.StringUtils; import org.ruoyi.common.mybatis.core.page.PageQuery; import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; -import org.ruoyi.domain.bo.rerank.RerankRequest; -import org.ruoyi.domain.bo.rerank.RerankResult; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; -import org.ruoyi.factory.RerankModelFactory; import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.ruoyi.service.knowledge.IKnowledgeInfoService; -import org.ruoyi.service.rerank.RerankModelService; -import org.ruoyi.service.vector.VectorStoreService; +import org.ruoyi.service.retrieval.KnowledgeRetrievalService; import org.springframework.stereotype.Service; import java.util.*; -import java.util.concurrent.CompletableFuture; -import java.util.stream.Collectors; - /** * 知识片段Service业务层处理 @@ -46,8 +39,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { private final KnowledgeFragmentMapper baseMapper; private final IKnowledgeInfoService knowledgeInfoService; private final IChatModelService chatModelService; - private final VectorStoreService vectorStoreService; - private final RerankModelFactory rerankModelFactory; + private final KnowledgeRetrievalService knowledgeRetrievalService; /** * 查询知识片段 @@ -87,7 +79,6 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { } private LambdaQueryWrapper buildQueryWrapper(KnowledgeFragmentBo bo) { - Map params = bo.getParams(); LambdaQueryWrapper lqw = Wrappers.lambdaQuery(); lqw.orderByAsc(KnowledgeFragment::getId); lqw.eq(bo.getDocId() != null, KnowledgeFragment::getDocId, bo.getDocId()); @@ -149,7 +140,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { } /** - * 检索测试核心实现 + * 检索测试核心实现 - 委托给统一的 KnowledgeRetrievalService */ @Override public List retrieval(KnowledgeFragmentBo bo) { @@ -157,7 +148,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { return new ArrayList<>(); } - // 1. 获取知识库及模型配置 + // 1. 获取知识库及模型配置(为了获取 API Key/Host 等模型参数) KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId()); if (knowledgeInfoVo == null) { return new ArrayList<>(); @@ -169,151 +160,28 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { return new ArrayList<>(); } - // 2. 构造向量检索参数 + // 2. 构造通用的参数对象 QueryVectorBo queryVectorBo = new QueryVectorBo(); queryVectorBo.setQuery(bo.getQuery()); queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId())); - queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit()); - queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); - queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); queryVectorBo.setApiKey(chatModel.getApiKey()); queryVectorBo.setBaseUrl(chatModel.getApiHost()); + queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); + queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); - // 3. 执行搜索 (向量搜索 + 关键词搜索) - List allResults; + // 使用前端传入的实时测试参数,若无则使用知识库默认参数 + queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit()); + queryVectorBo.setSimilarityThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getSimilarityThreshold()); - boolean hybridEnabled = Boolean.TRUE.equals(bo.getEnableHybrid()) || - Integer.valueOf(1).equals(knowledgeInfoVo.getEnableHybrid()); - - if (hybridEnabled) { - log.info("执行混合检索: kid={}, query={}", bo.getKnowledgeId(), bo.getQuery()); - try { - // 并行执行向量搜索 - CompletableFuture> vectorFuture = CompletableFuture.supplyAsync(() -> - vectorStoreService.search(queryVectorBo)); - - // 执行关键词搜索 (MySQL) - int limit = bo.getTopK() != null ? bo.getTopK() : 50; - List keywordFragments = baseMapper.searchByKeyword(bo.getKnowledgeId(), bo.getQuery(), limit); - List keywordResults = keywordFragments.stream().map(f -> { - KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo(); - vo.setId(f.getId().toString()); - vo.setContent(f.getContent()); - vo.setDocId(f.getDocId()); - vo.setIdx(f.getIdx()); - vo.setKnowledgeId(f.getKnowledgeId()); - vo.setScore(10.0); // 初始分,后续由 RRF 重新打分 - return vo; - }).collect(Collectors.toList()); - - List vectorResults = vectorFuture.get(); - log.info("抽取混合结果成功: Vector命中={}条, Keyword命中={}条", vectorResults.size(), keywordResults.size()); + queryVectorBo.setEnableHybrid(bo.getEnableHybrid() != null ? bo.getEnableHybrid() : Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1)); + queryVectorBo.setHybridAlpha(bo.getHybridAlpha() != null ? bo.getHybridAlpha() : knowledgeInfoVo.getHybridAlpha()); - double alpha = bo.getHybridAlpha() != null ? bo.getHybridAlpha() : - (knowledgeInfoVo.getHybridAlpha() != null ? knowledgeInfoVo.getHybridAlpha() : 0.5); - - allResults = calculateRRF(vectorResults, keywordResults, alpha); - } catch (Exception e) { - log.error("混合检索执行或合并失败,已自动降级回退到纯向量检索", e); - allResults = vectorStoreService.search(queryVectorBo); - } - } else { - allResults = vectorStoreService.search(queryVectorBo); - } + queryVectorBo.setEnableRerank(bo.getEnableRerank() != null ? bo.getEnableRerank() : Objects.equals(knowledgeInfoVo.getEnableRerank(), 1)); + queryVectorBo.setRerankModelName(StringUtils.isNotBlank(bo.getRerankModel()) ? bo.getRerankModel() : knowledgeInfoVo.getRerankModel()); + queryVectorBo.setRerankTopN(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRerankTopN()); + queryVectorBo.setRerankScoreThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getRerankScoreThreshold()); - // 初始化原始排名 - for (int i = 0; i < allResults.size(); i++) { - allResults.get(i).setOriginalIndex(i); - } - - // 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型) - if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) { - log.info("开始重排精排,模型: [{}]", bo.getRerankModel()); - try { - RerankModelService rerankModel = rerankModelFactory.createModel(bo.getRerankModel()); - - List contents = allResults.stream() - .map(KnowledgeRetrievalVo::getContent) - .collect(Collectors.toList()); - - RerankRequest rerankRequest = RerankRequest.builder() - .query(bo.getQuery()) - .documents(contents) - .topN(contents.size()) - .returnDocuments(false) - .build(); - - RerankResult rerankResult = rerankModel.rerank(rerankRequest); - - // 将重排分数写回,并记录原始分数供前端对比 - for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) { - if (doc.getIndex() != null && doc.getIndex() < allResults.size()) { - KnowledgeRetrievalVo resultVo = allResults.get(doc.getIndex()); - resultVo.setRawScore(resultVo.getScore()); - resultVo.setScore(doc.getRelevanceScore()); - } - } - - // 按重排后的分数从高到低排序 - allResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); - log.info("重排精排完成,结果数: {}", allResults.size()); - - } catch (Exception e) { - log.error("重排精排执行失败,已跳过重排步骤: {}", e.getMessage(), e); - } - } - - // 5. 根据阈值过滤 - double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0; - return allResults.stream() - .filter(res -> res.getScore() >= threshold) - .collect(Collectors.toList()); - } - - /** - * RRF (Reciprocal Rank Fusion) 融合算法 - * 公式: Score = (1-alpha) * (1 / (k + rank_vector)) + alpha * (1 / (k + rank_keyword)) - */ - private List calculateRRF(List vectorList, List keywordList, double alpha) { - Map allMap = new HashMap<>(); - Map vectorScores = new HashMap<>(); - Map keywordScores = new HashMap<>(); - - int k = 60; // 常用 RRF 常数 - - for (int i = 0; i < vectorList.size(); i++) { - KnowledgeRetrievalVo vo = vectorList.get(i); - allMap.put(vo.getId(), vo); - vectorScores.put(vo.getId(), 1.0 / (k + i + 1)); - } - - for (int i = 0; i < keywordList.size(); i++) { - KnowledgeRetrievalVo vo = keywordList.get(i); - if (!allMap.containsKey(vo.getId())) { - allMap.put(vo.getId(), vo); - } - keywordScores.put(vo.getId(), 1.0 / (k + i + 1)); - } - - // 重新计算得分 - List fusedResults = new ArrayList<>(); - for (Map.Entry entry : allMap.entrySet()) { - String id = entry.getKey(); - double vScore = vectorScores.getOrDefault(id, 0.0); - double kScore = keywordScores.getOrDefault(id, 0.0); - - // 混合分值 - double finalScore = (1 - alpha) * vScore + alpha * kScore; - - // 分值归一化/缩放:将 RRF 分值放大到 0-1 范围 - // 理论单路最大得分为 1/61 ≈ 0.016,乘以 60 使其处于相似度常用区间 - KnowledgeRetrievalVo vo = entry.getValue(); - vo.setScore(finalScore * 60.0); - fusedResults.add(vo); - } - - // 按融合分数从高到低排序 - fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); - return fusedResults; + // 3. 执行统一检索 + return knowledgeRetrievalService.retrieve(queryVectorBo); } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java index 6d876710..f79206bc 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java @@ -9,14 +9,15 @@ import lombok.extern.slf4j.Slf4j; import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; -import org.ruoyi.service.vector.VectorStoreService; +import org.ruoyi.service.retrieval.KnowledgeRetrievalService; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; /** - * 自定义向量检索器:适配 LangChain4j ContentRetriever 接口 - * 桥接现有的 VectorStoreService 获取检索结果 + * 自定义检索器:适配 LangChain4j ContentRetriever 接口 + * 桥接统一的 KnowledgeRetrievalService,支持配置化的混合检索、阈值过滤等功能 * * @author RobustH */ @@ -24,15 +25,15 @@ import java.util.stream.Collectors; @RequiredArgsConstructor public class CustomVectorRetriever implements ContentRetriever { - private final VectorStoreService vectorStoreService; + private final KnowledgeRetrievalService knowledgeRetrievalService; private final KnowledgeInfoVo knowledgeInfoVo; private final ChatModelVo chatModelVo; @Override public List retrieve(Query query) { - log.info("执行自定义向量检索,关键字: {}", query.text()); + log.info("执行自定义检索,关键字: {}", query.text()); - // 构建内部查询参数 + // 构建增强后的查询参数 QueryVectorBo queryVectorBo = new QueryVectorBo(); queryVectorBo.setQuery(query.text()); queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId())); @@ -40,11 +41,21 @@ public class CustomVectorRetriever implements ContentRetriever { queryVectorBo.setBaseUrl(chatModelVo.getApiHost()); queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); - // 如果接入了重排,这里的 retrieveLimit 也就是 MaxResults 应当被放大,后续留给 Aggregator 截断 + + // 应用知识库配置参数 queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit()); + queryVectorBo.setSimilarityThreshold(knowledgeInfoVo.getSimilarityThreshold()); + queryVectorBo.setEnableHybrid(Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1)); + queryVectorBo.setHybridAlpha(knowledgeInfoVo.getHybridAlpha()); - // 执行底层的多种向量库策略检索 - List nearestList = vectorStoreService.getQueryVector(queryVectorBo); + // 设置重排序参数 (如果 retriever 阶段也想做初步重排,可以在此设置) + queryVectorBo.setEnableRerank(Objects.equals(knowledgeInfoVo.getEnableRerank(), 1)); + queryVectorBo.setRerankModelName(knowledgeInfoVo.getRerankModel()); + queryVectorBo.setRerankTopN(knowledgeInfoVo.getRerankTopN()); + queryVectorBo.setRerankScoreThreshold(knowledgeInfoVo.getRerankScoreThreshold()); + + // 通过统一服务执行检索 + List nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo); // 将结果包装为标准的 Content 返回 return nearestList.stream() diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java index 9c42dd1a..3e0a6cab 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java @@ -1,12 +1,13 @@ package org.ruoyi.service.retrieval; import org.ruoyi.domain.bo.vector.QueryVectorBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import java.util.List; /** * 知识库检索服务接口 - * 整合粗召回(向量检索)和重排序流程 + * 整合粗召回(向量检索/关键词检索)和重排序流程 * * @author yang * @date 2026-04-19 @@ -21,4 +22,13 @@ public interface KnowledgeRetrievalService { * @return 文本内容列表 */ List retrieveTexts(QueryVectorBo queryVectorBo); + + /** + * 执行知识库检索,返回详细结果对象(包含分数、文档ID等) + * 支持混合检索和重排序 + * + * @param queryVectorBo 查询参数 + * @return 检索结果列表 + */ + List retrieve(QueryVectorBo queryVectorBo); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java index 42f6cf68..b6841ef1 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java @@ -2,21 +2,26 @@ package org.ruoyi.service.retrieval.impl; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.utils.StringUtils; import org.ruoyi.domain.bo.rerank.RerankRequest; import org.ruoyi.domain.bo.rerank.RerankResult; import org.ruoyi.domain.bo.vector.QueryVectorBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.RerankModelFactory; +import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; import org.ruoyi.service.rerank.RerankModelService; import org.ruoyi.service.retrieval.KnowledgeRetrievalService; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; -import java.util.ArrayList; -import java.util.List; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; /** * 知识库检索服务实现 - * 整合粗召回(向量检索)和重排序流程 + * 整合粗召回(向量检索/关键词检索)、RRF融合和重排序流程 * * @author yang * @date 2026-04-19 @@ -28,6 +33,7 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService private final VectorStoreService vectorStoreService; private final RerankModelFactory rerankModelFactory; + private final KnowledgeFragmentMapper fragmentMapper; /** * 粗召回默认扩大倍数 @@ -37,99 +43,214 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService @Override public List retrieveTexts(QueryVectorBo queryVectorBo) { + List results = retrieve(queryVectorBo); + return results.stream() + .map(KnowledgeRetrievalVo::getContent) + .collect(Collectors.toList()); + } + + @Override + public List retrieve(QueryVectorBo queryVectorBo) { log.info("开始知识库检索, kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery()); - // 1. 粗召回阶段 - 向量检索 - List coarseResults = coarseRetrieval(queryVectorBo); + // 1. 粗召回阶段 (向量检索 + 关键词搜索) + List coarseResults = performCoarseRetrieval(queryVectorBo); log.debug("粗召回返回 {} 条结果", coarseResults.size()); if (coarseResults.isEmpty()) { return coarseResults; } - // 2. 重排序阶段(可选) - if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && - queryVectorBo.getRerankModelName() != null) { - return rerank(queryVectorBo, coarseResults); + // 2. 初始化原始索引 + for (int i = 0; i < coarseResults.size(); i++) { + coarseResults.get(i).setOriginalIndex(i); } - return coarseResults; + // 3. 重排序阶段 (可选) + List finalResults = coarseResults; + if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && + StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) { + finalResults = performRerank(queryVectorBo, coarseResults); + } + + // 4. 应用分值阈值过滤 (重排分值或 RRF 分值) + double threshold = queryVectorBo.getRerankScoreThreshold() != null ? + queryVectorBo.getRerankScoreThreshold() : 0.0; + + return finalResults.stream() + .filter(res -> res.getScore() >= threshold) + .collect(Collectors.toList()); } /** - * 粗召回阶段 - 向量检索 + * 粗召回阶段:根据配置执行向量搜索或混合搜索 */ - private List coarseRetrieval(QueryVectorBo queryVectorBo) { - // 如果启用重排序,扩大粗召回数量 - int originalMaxResults = queryVectorBo.getMaxResults(); - int expandedResults = originalMaxResults; - if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && - queryVectorBo.getRerankModelName() != null) { - expandedResults = originalMaxResults * RERANK_EXPANSION_FACTOR; - log.debug("启用重排序,粗召回数量从 {} 扩大到 {}", originalMaxResults, expandedResults); + private List performCoarseRetrieval(QueryVectorBo queryVectorBo) { + // 如果启用重排序,适当扩大召回数量 + int originalMaxResults = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10; + int targetMaxResults = originalMaxResults; + if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && + StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) { + targetMaxResults = originalMaxResults * RERANK_EXPANSION_FACTOR; } - // 临时修改查询数量 - queryVectorBo.setMaxResults(expandedResults); + // 如果未启用混合检索,直接走向量搜索 + if (!Boolean.TRUE.equals(queryVectorBo.getEnableHybrid())) { + QueryVectorBo vectorQuery = copyOf(queryVectorBo, targetMaxResults); + List results = vectorStoreService.search(vectorQuery); + + // 应用基础相似度阈值过滤(如果有) + if (queryVectorBo.getSimilarityThreshold() != null) { + results = results.stream() + .filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold()) + .collect(Collectors.toList()); + } + return results; + } + + // 混合检索逻辑 + log.info("执行混合检索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery()); try { - return vectorStoreService.getQueryVector(queryVectorBo); - } finally { - // 恢复原始值 - queryVectorBo.setMaxResults(originalMaxResults); + // A. 并行执行向量搜索 + int finalTargetMaxResults = targetMaxResults; + CompletableFuture> vectorFuture = CompletableFuture.supplyAsync(() -> { + QueryVectorBo vectorQuery = copyOf(queryVectorBo, finalTargetMaxResults); + List results = vectorStoreService.search(vectorQuery); + // 向量层初步过滤 + if (queryVectorBo.getSimilarityThreshold() != null) { + return results.stream() + .filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold()) + .collect(Collectors.toList()); + } + return results; + }); + + // B. 并行执行关键词搜索 (MySQL Fulltext) + CompletableFuture> keywordFuture = CompletableFuture.supplyAsync(() -> { + try { + Long kid = Long.valueOf(queryVectorBo.getKid()); + List fragments = fragmentMapper.searchByKeyword(kid, queryVectorBo.getQuery(), finalTargetMaxResults); + return fragments.stream().map(f -> { + KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo(); + vo.setId(f.getId().toString()); + vo.setContent(f.getContent()); + vo.setDocId(f.getDocId()); + vo.setIdx(f.getIdx()); + vo.setKnowledgeId(f.getKnowledgeId()); + vo.setScore(10.0); // RRF 初始占位分 + return vo; + }).collect(Collectors.toList()); + } catch (Exception e) { + log.error("关键词检索失败: {}", e.getMessage()); + return new ArrayList<>(); + } + }); + + List vectorResults = vectorFuture.get(); + List keywordResults = keywordFuture.get(); + + // C. RRF 融合 + double alpha = queryVectorBo.getHybridAlpha() != null ? queryVectorBo.getHybridAlpha() : 0.5; + return calculateRRF(vectorResults, keywordResults, alpha); + + } catch (Exception e) { + log.error("混合检索执行失败,回退到纯向量检索: {}", e.getMessage(), e); + return vectorStoreService.search(copyOf(queryVectorBo, targetMaxResults)); } } /** * 重排序阶段 */ - private List rerank(QueryVectorBo queryVectorBo, List coarseResults) { - long startTime = System.currentTimeMillis(); - + private List performRerank(QueryVectorBo queryVectorBo, List coarseResults) { try { - // 1. 通过工厂获取重排序模型 RerankModelService rerankModel = rerankModelFactory.createModel(queryVectorBo.getRerankModelName()); + + List contents = coarseResults.stream() + .map(KnowledgeRetrievalVo::getContent) + .collect(Collectors.toList()); - // 2. 构建重排序请求 - int topN = queryVectorBo.getRerankTopN() != null ? - queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults(); + // topN 默认为 maxResults + int topN = queryVectorBo.getRerankTopN() != null ? queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults(); RerankRequest rerankRequest = RerankRequest.builder() .query(queryVectorBo.getQuery()) - .documents(coarseResults) + .documents(contents) .topN(topN) - .returnDocuments(true) .build(); - log.info("执行重排序, model={}, documents={}, topN={}", - queryVectorBo.getRerankModelName(), coarseResults.size(), topN); - - // 3. 执行重排序 RerankResult rerankResult = rerankModel.rerank(rerankRequest); - // 4. 转换重排序结果 - List finalResults = new ArrayList<>(); + // 写回分数并记录原始分 for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) { - // 应用分数阈值过滤 - if (queryVectorBo.getRerankScoreThreshold() != null && - doc.getRelevanceScore() < queryVectorBo.getRerankScoreThreshold()) { - continue; - } - - if (doc.getDocument() != null) { - finalResults.add(doc.getDocument()); + if (doc.getIndex() != null && doc.getIndex() < coarseResults.size()) { + KnowledgeRetrievalVo vo = coarseResults.get(doc.getIndex()); + vo.setRawScore(vo.getScore()); + vo.setScore(doc.getRelevanceScore()); } } - long duration = System.currentTimeMillis() - startTime; - log.info("重排序完成, 返回 {} 条结果, 耗时 {}ms", finalResults.size(), duration); - - return finalResults; + // 按新分排序 + coarseResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); + + // 截断到 topN + return coarseResults.subList(0, Math.min(topN, coarseResults.size())); } catch (Exception e) { - log.error("重排序失败: {}", e.getMessage(), e); - // 重排序失败时返回原始粗召回结果(截取到期望数量) - int limit = Math.min(queryVectorBo.getMaxResults(), coarseResults.size()); - return new ArrayList<>(coarseResults.subList(0, limit)); + log.error("重排序流程失败: {}", e.getMessage()); + int limit = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10; + return coarseResults.subList(0, Math.min(limit, coarseResults.size())); } } + + /** + * RRF (Reciprocal Rank Fusion) 融合计算 + */ + private List calculateRRF(List vectorList, List keywordList, double alpha) { + Map allMap = new LinkedHashMap<>(); + Map vectorScores = new HashMap<>(); + Map keywordScores = new HashMap<>(); + + int k = 60; // RRF 常数 + + for (int i = 0; i < vectorList.size(); i++) { + KnowledgeRetrievalVo vo = vectorList.get(i); + allMap.put(vo.getId(), vo); + vectorScores.put(vo.getId(), 1.0 / (k + i + 1)); + } + + for (int i = 0; i < keywordList.size(); i++) { + KnowledgeRetrievalVo vo = keywordList.get(i); + if (!allMap.containsKey(vo.getId())) { + allMap.put(vo.getId(), vo); + } + keywordScores.put(vo.getId(), 1.0 / (k + i + 1)); + } + + List fusedResults = new ArrayList<>(); + for (Map.Entry entry : allMap.entrySet()) { + String id = entry.getKey(); + double finalScore = (1 - alpha) * vectorScores.getOrDefault(id, 0.0) + + alpha * keywordScores.getOrDefault(id, 0.0); + + KnowledgeRetrievalVo vo = entry.getValue(); + vo.setScore(finalScore * 60.0); // 归一化缩放 + fusedResults.add(vo); + } + + fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); + return fusedResults; + } + + private QueryVectorBo copyOf(QueryVectorBo original, int maxResults) { + QueryVectorBo copy = new QueryVectorBo(); + copy.setQuery(original.getQuery()); + copy.setKid(original.getKid()); + copy.setMaxResults(maxResults); + copy.setVectorModelName(original.getVectorModelName()); + copy.setEmbeddingModelName(original.getEmbeddingModelName()); + copy.setApiKey(original.getApiKey()); + copy.setBaseUrl(original.getBaseUrl()); + return copy; + } }