feat(rag): 对接用户端用户知识库对话,集成知识库配置应用

This commit is contained in:
RobustH
2026-04-23 00:52:53 +08:00
parent 058a4aee2a
commit b8d16b7669
7 changed files with 301 additions and 251 deletions

View File

@@ -0,0 +1,14 @@
-- 为知识库信息表新增检索配置字段 (剔除了已存在的重排字段)
ALTER TABLE knowledge_info
ADD COLUMN similarity_threshold DOUBLE DEFAULT 0.5 COMMENT '相似度阈值'
AFTER retrieve_limit;
ALTER TABLE knowledge_info ADD COLUMN enable_hybrid tinyint(1) DEFAULT 0 COMMENT '是否启用混合检索';
ALTER TABLE knowledge_info ADD COLUMN hybrid_alpha double DEFAULT 0.5 COMMENT '混合检索权重比例 (0.0=纯向量, 1.0=纯关键词)';
-- 为知识片段表增加全文索引及关联ID
ALTER TABLE knowledge_fragment ADD COLUMN knowledge_id bigint COMMENT '知识库ID';
ALTER TABLE knowledge_fragment ADD FULLTEXT INDEX ft_content (content) WITH PARSER ngram;
-- 为知识库附件表增加解析状态字段
ALTER TABLE `knowledge_attach` ADD COLUMN `status` TINYINT DEFAULT 0 COMMENT '解析状态: 0待解析, 1解析中, 2已解析, 3解析失败';

View File

@@ -77,4 +77,22 @@ public class QueryVectorBo {
*/
private Double rerankScoreThreshold;
// ========== 混合检索与阈值相关参数 ==========
/**
* 相似度阈值 (0.0-1.0)
* 应用于向量搜索阶段
*/
private Double similarityThreshold;
/**
* 是否启用混合检索
*/
private Boolean enableHybrid = false;
/**
* 混合检索权重 (0.0-1.0)
*/
private Double hybridAlpha;
}

View File

@@ -20,6 +20,11 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.skills.shell.ShellSkills;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.query.Metadata;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
@@ -55,6 +60,7 @@ import org.ruoyi.service.chat.IChatMessageService;
import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore;
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever;
import org.ruoyi.service.vector.VectorStoreService;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@@ -412,16 +418,49 @@ public class ChatServiceFacade implements IChatService {
/**
* 构建上下文消息列表
* 消息顺序:历史消息 → 当前用户消息(确保 AI 正确理解对话上下文)
*
* @param chatRequest 聊天请求
* @return 上下文消息列表
*/
private List<ChatMessage> buildContextMessages(ChatRequest chatRequest) {
List<ChatMessage> messages = new ArrayList<>();
List<ChatMessage> messages = new ArrayList<>();
// 从数据库查询历史对话消息(放在前面)
// 1. 初始化当前用户消息
UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent());
// 2. 知识库检索增强 (RAG)
if (chatRequest.getKnowledgeId() != null) {
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId()));
if (knowledgeInfoVo != null) {
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
if (chatModel != null) {
log.info("执行高级 RAG 流程: kid={}", chatRequest.getKnowledgeId());
// 构建自定义检索器
CustomVectorRetriever retriever = new CustomVectorRetriever(
knowledgeRetrievalService, knowledgeInfoVo, chatModel);
// 构建增强流水线
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
.contentRetriever(retriever)
.build();
// 执行增强:编织上下文到 UserMessage
Metadata metadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>());
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
AugmentationResult result = augmentor.augment(augmentationRequest);
ChatMessage augmented = result.chatMessage();
if (augmented instanceof UserMessage) {
userMessage = (UserMessage) augmented;
log.debug("RAG 增强完成UserMessage 已注入背景知识");
}
}
}
}
// 3. 从数据库查询历史对话消息(放在前面)
if (chatRequest.getSessionId() != null) {
MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId());
if (memory != null) {
@@ -433,38 +472,7 @@ public class ChatServiceFacade implements IChatService {
}
}
// 从向量库查询相关历史消息(知识库内容作为上下文
if (chatRequest.getKnowledgeId() != null) {
// 查询知识库信息
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId()));
if (knowledgeInfoVo == null) {
log.warn("知识库信息不存在kid: {}", chatRequest.getKnowledgeId());
// 继续添加当前用户消息
messages.add(UserMessage.userMessage(chatRequest.getContent()));
return messages;
}
// 查询向量模型配置信息
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
if (chatModel == null) {
log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel());
messages.add(UserMessage.userMessage(chatRequest.getContent()));
return messages;
}
// 构建向量查询参数
QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel);
// 使用知识库检索服务(支持重排序)
List<String> nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo);
for (String prompt : nearestList) {
// 知识库内容作为系统上下文添加
messages.add(new AiMessage(prompt));
}
}
// 构建当前用户消息(放在最后)
UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent());
// 4. 添加经过增强的用户消息(放在最后
messages.add(userMessage);
return messages;

View File

@@ -12,25 +12,18 @@ import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.mybatis.core.page.PageQuery;
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
import org.ruoyi.domain.bo.rerank.RerankRequest;
import org.ruoyi.domain.bo.rerank.RerankResult;
import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import org.ruoyi.factory.RerankModelFactory;
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
import org.ruoyi.service.rerank.RerankModelService;
import org.ruoyi.service.vector.VectorStoreService;
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
/**
* 知识片段Service业务层处理
@@ -46,8 +39,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
private final KnowledgeFragmentMapper baseMapper;
private final IKnowledgeInfoService knowledgeInfoService;
private final IChatModelService chatModelService;
private final VectorStoreService vectorStoreService;
private final RerankModelFactory rerankModelFactory;
private final KnowledgeRetrievalService knowledgeRetrievalService;
/**
* 查询知识片段
@@ -87,7 +79,6 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
}
private LambdaQueryWrapper<KnowledgeFragment> buildQueryWrapper(KnowledgeFragmentBo bo) {
Map<String, Object> params = bo.getParams();
LambdaQueryWrapper<KnowledgeFragment> lqw = Wrappers.lambdaQuery();
lqw.orderByAsc(KnowledgeFragment::getId);
lqw.eq(bo.getDocId() != null, KnowledgeFragment::getDocId, bo.getDocId());
@@ -149,7 +140,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
}
/**
* 检索测试核心实现
* 检索测试核心实现 - 委托给统一的 KnowledgeRetrievalService
*/
@Override
public List<KnowledgeRetrievalVo> retrieval(KnowledgeFragmentBo bo) {
@@ -157,7 +148,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
return new ArrayList<>();
}
// 1. 获取知识库及模型配置
// 1. 获取知识库及模型配置(为了获取 API Key/Host 等模型参数)
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId());
if (knowledgeInfoVo == null) {
return new ArrayList<>();
@@ -169,151 +160,28 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
return new ArrayList<>();
}
// 2. 构造向量检索参数
// 2. 构造通用的参数对象
QueryVectorBo queryVectorBo = new QueryVectorBo();
queryVectorBo.setQuery(bo.getQuery());
queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId()));
queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit());
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
queryVectorBo.setApiKey(chatModel.getApiKey());
queryVectorBo.setBaseUrl(chatModel.getApiHost());
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
// 3. 执行搜索 (向量搜索 + 关键词搜索)
List<KnowledgeRetrievalVo> allResults;
// 使用前端传入的实时测试参数,若无则使用知识库默认参数
queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit());
queryVectorBo.setSimilarityThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getSimilarityThreshold());
boolean hybridEnabled = Boolean.TRUE.equals(bo.getEnableHybrid()) ||
Integer.valueOf(1).equals(knowledgeInfoVo.getEnableHybrid());
if (hybridEnabled) {
log.info("执行混合检索: kid={}, query={}", bo.getKnowledgeId(), bo.getQuery());
try {
// 并行执行向量搜索
CompletableFuture<List<KnowledgeRetrievalVo>> vectorFuture = CompletableFuture.supplyAsync(() ->
vectorStoreService.search(queryVectorBo));
// 执行关键词搜索 (MySQL)
int limit = bo.getTopK() != null ? bo.getTopK() : 50;
List<KnowledgeFragmentVo> keywordFragments = baseMapper.searchByKeyword(bo.getKnowledgeId(), bo.getQuery(), limit);
List<KnowledgeRetrievalVo> keywordResults = keywordFragments.stream().map(f -> {
KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo();
vo.setId(f.getId().toString());
vo.setContent(f.getContent());
vo.setDocId(f.getDocId());
vo.setIdx(f.getIdx());
vo.setKnowledgeId(f.getKnowledgeId());
vo.setScore(10.0); // 初始分,后续由 RRF 重新打分
return vo;
}).collect(Collectors.toList());
List<KnowledgeRetrievalVo> vectorResults = vectorFuture.get();
log.info("抽取混合结果成功: Vector命中={}条, Keyword命中={}条", vectorResults.size(), keywordResults.size());
queryVectorBo.setEnableHybrid(bo.getEnableHybrid() != null ? bo.getEnableHybrid() : Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1));
queryVectorBo.setHybridAlpha(bo.getHybridAlpha() != null ? bo.getHybridAlpha() : knowledgeInfoVo.getHybridAlpha());
double alpha = bo.getHybridAlpha() != null ? bo.getHybridAlpha() :
(knowledgeInfoVo.getHybridAlpha() != null ? knowledgeInfoVo.getHybridAlpha() : 0.5);
allResults = calculateRRF(vectorResults, keywordResults, alpha);
} catch (Exception e) {
log.error("混合检索执行或合并失败,已自动降级回退到纯向量检索", e);
allResults = vectorStoreService.search(queryVectorBo);
}
} else {
allResults = vectorStoreService.search(queryVectorBo);
}
queryVectorBo.setEnableRerank(bo.getEnableRerank() != null ? bo.getEnableRerank() : Objects.equals(knowledgeInfoVo.getEnableRerank(), 1));
queryVectorBo.setRerankModelName(StringUtils.isNotBlank(bo.getRerankModel()) ? bo.getRerankModel() : knowledgeInfoVo.getRerankModel());
queryVectorBo.setRerankTopN(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRerankTopN());
queryVectorBo.setRerankScoreThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getRerankScoreThreshold());
// 初始化原始排名
for (int i = 0; i < allResults.size(); i++) {
allResults.get(i).setOriginalIndex(i);
}
// 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型)
if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) {
log.info("开始重排精排,模型: [{}]", bo.getRerankModel());
try {
RerankModelService rerankModel = rerankModelFactory.createModel(bo.getRerankModel());
List<String> contents = allResults.stream()
.map(KnowledgeRetrievalVo::getContent)
.collect(Collectors.toList());
RerankRequest rerankRequest = RerankRequest.builder()
.query(bo.getQuery())
.documents(contents)
.topN(contents.size())
.returnDocuments(false)
.build();
RerankResult rerankResult = rerankModel.rerank(rerankRequest);
// 将重排分数写回,并记录原始分数供前端对比
for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) {
if (doc.getIndex() != null && doc.getIndex() < allResults.size()) {
KnowledgeRetrievalVo resultVo = allResults.get(doc.getIndex());
resultVo.setRawScore(resultVo.getScore());
resultVo.setScore(doc.getRelevanceScore());
}
}
// 按重排后的分数从高到低排序
allResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
log.info("重排精排完成,结果数: {}", allResults.size());
} catch (Exception e) {
log.error("重排精排执行失败,已跳过重排步骤: {}", e.getMessage(), e);
}
}
// 5. 根据阈值过滤
double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0;
return allResults.stream()
.filter(res -> res.getScore() >= threshold)
.collect(Collectors.toList());
}
/**
* RRF (Reciprocal Rank Fusion) 融合算法
* 公式: Score = (1-alpha) * (1 / (k + rank_vector)) + alpha * (1 / (k + rank_keyword))
*/
private List<KnowledgeRetrievalVo> calculateRRF(List<KnowledgeRetrievalVo> vectorList, List<KnowledgeRetrievalVo> keywordList, double alpha) {
Map<String, KnowledgeRetrievalVo> allMap = new HashMap<>();
Map<String, Double> vectorScores = new HashMap<>();
Map<String, Double> keywordScores = new HashMap<>();
int k = 60; // 常用 RRF 常数
for (int i = 0; i < vectorList.size(); i++) {
KnowledgeRetrievalVo vo = vectorList.get(i);
allMap.put(vo.getId(), vo);
vectorScores.put(vo.getId(), 1.0 / (k + i + 1));
}
for (int i = 0; i < keywordList.size(); i++) {
KnowledgeRetrievalVo vo = keywordList.get(i);
if (!allMap.containsKey(vo.getId())) {
allMap.put(vo.getId(), vo);
}
keywordScores.put(vo.getId(), 1.0 / (k + i + 1));
}
// 重新计算得分
List<KnowledgeRetrievalVo> fusedResults = new ArrayList<>();
for (Map.Entry<String, KnowledgeRetrievalVo> entry : allMap.entrySet()) {
String id = entry.getKey();
double vScore = vectorScores.getOrDefault(id, 0.0);
double kScore = keywordScores.getOrDefault(id, 0.0);
// 混合分值
double finalScore = (1 - alpha) * vScore + alpha * kScore;
// 分值归一化/缩放:将 RRF 分值放大到 0-1 范围
// 理论单路最大得分为 1/61 ≈ 0.016,乘以 60 使其处于相似度常用区间
KnowledgeRetrievalVo vo = entry.getValue();
vo.setScore(finalScore * 60.0);
fusedResults.add(vo);
}
// 按融合分数从高到低排序
fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
return fusedResults;
// 3. 执行统一检索
return knowledgeRetrievalService.retrieve(queryVectorBo);
}
}

View File

@@ -9,14 +9,15 @@ import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
import org.ruoyi.service.vector.VectorStoreService;
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* 自定义向量检索器:适配 LangChain4j ContentRetriever 接口
* 桥接现有的 VectorStoreService 获取检索结果
* 自定义检索器:适配 LangChain4j ContentRetriever 接口
* 桥接统一的 KnowledgeRetrievalService支持配置化的混合检索、阈值过滤等功能
*
* @author RobustH
*/
@@ -24,15 +25,15 @@ import java.util.stream.Collectors;
@RequiredArgsConstructor
public class CustomVectorRetriever implements ContentRetriever {
private final VectorStoreService vectorStoreService;
private final KnowledgeRetrievalService knowledgeRetrievalService;
private final KnowledgeInfoVo knowledgeInfoVo;
private final ChatModelVo chatModelVo;
@Override
public List<Content> retrieve(Query query) {
log.info("执行自定义向量检索,关键字: {}", query.text());
log.info("执行自定义检索,关键字: {}", query.text());
// 构建内部查询参数
// 构建增强后的查询参数
QueryVectorBo queryVectorBo = new QueryVectorBo();
queryVectorBo.setQuery(query.text());
queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId()));
@@ -40,11 +41,21 @@ public class CustomVectorRetriever implements ContentRetriever {
queryVectorBo.setBaseUrl(chatModelVo.getApiHost());
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
// 如果接入了重排,这里的 retrieveLimit 也就是 MaxResults 应当被放大,后续留给 Aggregator 截断
// 应用知识库配置参数
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
queryVectorBo.setSimilarityThreshold(knowledgeInfoVo.getSimilarityThreshold());
queryVectorBo.setEnableHybrid(Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1));
queryVectorBo.setHybridAlpha(knowledgeInfoVo.getHybridAlpha());
// 执行底层的多种向量库策略检索
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
// 设置重排序参数 (如果 retriever 阶段也想做初步重排,可以在此设置)
queryVectorBo.setEnableRerank(Objects.equals(knowledgeInfoVo.getEnableRerank(), 1));
queryVectorBo.setRerankModelName(knowledgeInfoVo.getRerankModel());
queryVectorBo.setRerankTopN(knowledgeInfoVo.getRerankTopN());
queryVectorBo.setRerankScoreThreshold(knowledgeInfoVo.getRerankScoreThreshold());
// 通过统一服务执行检索
List<String> nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo);
// 将结果包装为标准的 Content 返回
return nearestList.stream()

View File

@@ -1,12 +1,13 @@
package org.ruoyi.service.retrieval;
import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import java.util.List;
/**
* 知识库检索服务接口
* 整合粗召回(向量检索)和重排序流程
* 整合粗召回(向量检索/关键词检索)和重排序流程
*
* @author yang
* @date 2026-04-19
@@ -21,4 +22,13 @@ public interface KnowledgeRetrievalService {
* @return 文本内容列表
*/
List<String> retrieveTexts(QueryVectorBo queryVectorBo);
/**
* 执行知识库检索返回详细结果对象包含分数、文档ID等
* 支持混合检索和重排序
*
* @param queryVectorBo 查询参数
* @return 检索结果列表
*/
List<KnowledgeRetrievalVo> retrieve(QueryVectorBo queryVectorBo);
}

View File

@@ -2,21 +2,26 @@ package org.ruoyi.service.retrieval.impl;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.domain.bo.rerank.RerankRequest;
import org.ruoyi.domain.bo.rerank.RerankResult;
import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import org.ruoyi.factory.RerankModelFactory;
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
import org.ruoyi.service.rerank.RerankModelService;
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
import org.ruoyi.service.vector.VectorStoreService;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
/**
* 知识库检索服务实现
* 整合粗召回(向量检索和重排序流程
* 整合粗召回(向量检索/关键词检索、RRF融合和重排序流程
*
* @author yang
* @date 2026-04-19
@@ -28,6 +33,7 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService
private final VectorStoreService vectorStoreService;
private final RerankModelFactory rerankModelFactory;
private final KnowledgeFragmentMapper fragmentMapper;
/**
* 粗召回默认扩大倍数
@@ -37,99 +43,214 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService
@Override
public List<String> retrieveTexts(QueryVectorBo queryVectorBo) {
List<KnowledgeRetrievalVo> results = retrieve(queryVectorBo);
return results.stream()
.map(KnowledgeRetrievalVo::getContent)
.collect(Collectors.toList());
}
@Override
public List<KnowledgeRetrievalVo> retrieve(QueryVectorBo queryVectorBo) {
log.info("开始知识库检索, kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
// 1. 粗召回阶段 - 向量检索
List<String> coarseResults = coarseRetrieval(queryVectorBo);
// 1. 粗召回阶段 (向量检索 + 关键词搜索)
List<KnowledgeRetrievalVo> coarseResults = performCoarseRetrieval(queryVectorBo);
log.debug("粗召回返回 {} 条结果", coarseResults.size());
if (coarseResults.isEmpty()) {
return coarseResults;
}
// 2. 重排序阶段(可选)
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
queryVectorBo.getRerankModelName() != null) {
return rerank(queryVectorBo, coarseResults);
// 2. 初始化原始索引
for (int i = 0; i < coarseResults.size(); i++) {
coarseResults.get(i).setOriginalIndex(i);
}
return coarseResults;
// 3. 重排序阶段 (可选)
List<KnowledgeRetrievalVo> finalResults = coarseResults;
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) {
finalResults = performRerank(queryVectorBo, coarseResults);
}
// 4. 应用分值阈值过滤 (重排分值或 RRF 分值)
double threshold = queryVectorBo.getRerankScoreThreshold() != null ?
queryVectorBo.getRerankScoreThreshold() : 0.0;
return finalResults.stream()
.filter(res -> res.getScore() >= threshold)
.collect(Collectors.toList());
}
/**
* 粗召回阶段 - 向量检
* 粗召回阶段:根据配置执行向量搜索或混合搜
*/
private List<String> coarseRetrieval(QueryVectorBo queryVectorBo) {
// 如果启用重排序,扩大召回数量
int originalMaxResults = queryVectorBo.getMaxResults();
int expandedResults = originalMaxResults;
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
queryVectorBo.getRerankModelName() != null) {
expandedResults = originalMaxResults * RERANK_EXPANSION_FACTOR;
log.debug("启用重排序,粗召回数量从 {} 扩大到 {}", originalMaxResults, expandedResults);
private List<KnowledgeRetrievalVo> performCoarseRetrieval(QueryVectorBo queryVectorBo) {
// 如果启用重排序,适当扩大召回数量
int originalMaxResults = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10;
int targetMaxResults = originalMaxResults;
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) {
targetMaxResults = originalMaxResults * RERANK_EXPANSION_FACTOR;
}
// 临时修改查询数量
queryVectorBo.setMaxResults(expandedResults);
// 如果未启用混合检索,直接走向量搜索
if (!Boolean.TRUE.equals(queryVectorBo.getEnableHybrid())) {
QueryVectorBo vectorQuery = copyOf(queryVectorBo, targetMaxResults);
List<KnowledgeRetrievalVo> results = vectorStoreService.search(vectorQuery);
// 应用基础相似度阈值过滤(如果有)
if (queryVectorBo.getSimilarityThreshold() != null) {
results = results.stream()
.filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold())
.collect(Collectors.toList());
}
return results;
}
// 混合检索逻辑
log.info("执行混合检索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
try {
return vectorStoreService.getQueryVector(queryVectorBo);
} finally {
// 恢复原始值
queryVectorBo.setMaxResults(originalMaxResults);
// A. 并行执行向量搜索
int finalTargetMaxResults = targetMaxResults;
CompletableFuture<List<KnowledgeRetrievalVo>> vectorFuture = CompletableFuture.supplyAsync(() -> {
QueryVectorBo vectorQuery = copyOf(queryVectorBo, finalTargetMaxResults);
List<KnowledgeRetrievalVo> results = vectorStoreService.search(vectorQuery);
// 向量层初步过滤
if (queryVectorBo.getSimilarityThreshold() != null) {
return results.stream()
.filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold())
.collect(Collectors.toList());
}
return results;
});
// B. 并行执行关键词搜索 (MySQL Fulltext)
CompletableFuture<List<KnowledgeRetrievalVo>> keywordFuture = CompletableFuture.supplyAsync(() -> {
try {
Long kid = Long.valueOf(queryVectorBo.getKid());
List<KnowledgeFragmentVo> fragments = fragmentMapper.searchByKeyword(kid, queryVectorBo.getQuery(), finalTargetMaxResults);
return fragments.stream().map(f -> {
KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo();
vo.setId(f.getId().toString());
vo.setContent(f.getContent());
vo.setDocId(f.getDocId());
vo.setIdx(f.getIdx());
vo.setKnowledgeId(f.getKnowledgeId());
vo.setScore(10.0); // RRF 初始占位分
return vo;
}).collect(Collectors.toList());
} catch (Exception e) {
log.error("关键词检索失败: {}", e.getMessage());
return new ArrayList<>();
}
});
List<KnowledgeRetrievalVo> vectorResults = vectorFuture.get();
List<KnowledgeRetrievalVo> keywordResults = keywordFuture.get();
// C. RRF 融合
double alpha = queryVectorBo.getHybridAlpha() != null ? queryVectorBo.getHybridAlpha() : 0.5;
return calculateRRF(vectorResults, keywordResults, alpha);
} catch (Exception e) {
log.error("混合检索执行失败,回退到纯向量检索: {}", e.getMessage(), e);
return vectorStoreService.search(copyOf(queryVectorBo, targetMaxResults));
}
}
/**
* 重排序阶段
*/
private List<String> rerank(QueryVectorBo queryVectorBo, List<String> coarseResults) {
long startTime = System.currentTimeMillis();
private List<KnowledgeRetrievalVo> performRerank(QueryVectorBo queryVectorBo, List<KnowledgeRetrievalVo> coarseResults) {
try {
// 1. 通过工厂获取重排序模型
RerankModelService rerankModel = rerankModelFactory.createModel(queryVectorBo.getRerankModelName());
List<String> contents = coarseResults.stream()
.map(KnowledgeRetrievalVo::getContent)
.collect(Collectors.toList());
// 2. 构建重排序请求
int topN = queryVectorBo.getRerankTopN() != null ?
queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults();
// topN 默认为 maxResults
int topN = queryVectorBo.getRerankTopN() != null ? queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults();
RerankRequest rerankRequest = RerankRequest.builder()
.query(queryVectorBo.getQuery())
.documents(coarseResults)
.documents(contents)
.topN(topN)
.returnDocuments(true)
.build();
log.info("执行重排序, model={}, documents={}, topN={}",
queryVectorBo.getRerankModelName(), coarseResults.size(), topN);
// 3. 执行重排序
RerankResult rerankResult = rerankModel.rerank(rerankRequest);
// 4. 转换重排序结果
List<String> finalResults = new ArrayList<>();
// 写回分数并记录原始分
for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) {
// 应用分数阈值过滤
if (queryVectorBo.getRerankScoreThreshold() != null &&
doc.getRelevanceScore() < queryVectorBo.getRerankScoreThreshold()) {
continue;
}
if (doc.getDocument() != null) {
finalResults.add(doc.getDocument());
if (doc.getIndex() != null && doc.getIndex() < coarseResults.size()) {
KnowledgeRetrievalVo vo = coarseResults.get(doc.getIndex());
vo.setRawScore(vo.getScore());
vo.setScore(doc.getRelevanceScore());
}
}
long duration = System.currentTimeMillis() - startTime;
log.info("重排序完成, 返回 {} 条结果, 耗时 {}ms", finalResults.size(), duration);
return finalResults;
// 按新分排序
coarseResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
// 截断到 topN
return coarseResults.subList(0, Math.min(topN, coarseResults.size()));
} catch (Exception e) {
log.error("重排序失败: {}", e.getMessage(), e);
// 重排序失败时返回原始粗召回结果(截取到期望数量)
int limit = Math.min(queryVectorBo.getMaxResults(), coarseResults.size());
return new ArrayList<>(coarseResults.subList(0, limit));
log.error("重排序流程失败: {}", e.getMessage());
int limit = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10;
return coarseResults.subList(0, Math.min(limit, coarseResults.size()));
}
}
/**
* RRF (Reciprocal Rank Fusion) 融合计算
*/
private List<KnowledgeRetrievalVo> calculateRRF(List<KnowledgeRetrievalVo> vectorList, List<KnowledgeRetrievalVo> keywordList, double alpha) {
Map<String, KnowledgeRetrievalVo> allMap = new LinkedHashMap<>();
Map<String, Double> vectorScores = new HashMap<>();
Map<String, Double> keywordScores = new HashMap<>();
int k = 60; // RRF 常数
for (int i = 0; i < vectorList.size(); i++) {
KnowledgeRetrievalVo vo = vectorList.get(i);
allMap.put(vo.getId(), vo);
vectorScores.put(vo.getId(), 1.0 / (k + i + 1));
}
for (int i = 0; i < keywordList.size(); i++) {
KnowledgeRetrievalVo vo = keywordList.get(i);
if (!allMap.containsKey(vo.getId())) {
allMap.put(vo.getId(), vo);
}
keywordScores.put(vo.getId(), 1.0 / (k + i + 1));
}
List<KnowledgeRetrievalVo> fusedResults = new ArrayList<>();
for (Map.Entry<String, KnowledgeRetrievalVo> entry : allMap.entrySet()) {
String id = entry.getKey();
double finalScore = (1 - alpha) * vectorScores.getOrDefault(id, 0.0) +
alpha * keywordScores.getOrDefault(id, 0.0);
KnowledgeRetrievalVo vo = entry.getValue();
vo.setScore(finalScore * 60.0); // 归一化缩放
fusedResults.add(vo);
}
fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
return fusedResults;
}
private QueryVectorBo copyOf(QueryVectorBo original, int maxResults) {
QueryVectorBo copy = new QueryVectorBo();
copy.setQuery(original.getQuery());
copy.setKid(original.getKid());
copy.setMaxResults(maxResults);
copy.setVectorModelName(original.getVectorModelName());
copy.setEmbeddingModelName(original.getEmbeddingModelName());
copy.setApiKey(original.getApiKey());
copy.setBaseUrl(original.getBaseUrl());
return copy;
}
}