mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-23 08:43:40 +00:00
feat(rag): 对接用户端用户知识库对话,集成知识库配置应用
This commit is contained in:
14
docs/script/sql/update/updat-0423.sql
Normal file
14
docs/script/sql/update/updat-0423.sql
Normal file
@@ -0,0 +1,14 @@
|
||||
-- 为知识库信息表新增检索配置字段 (剔除了已存在的重排字段)
|
||||
ALTER TABLE knowledge_info
|
||||
ADD COLUMN similarity_threshold DOUBLE DEFAULT 0.5 COMMENT '相似度阈值'
|
||||
AFTER retrieve_limit;
|
||||
|
||||
ALTER TABLE knowledge_info ADD COLUMN enable_hybrid tinyint(1) DEFAULT 0 COMMENT '是否启用混合检索';
|
||||
ALTER TABLE knowledge_info ADD COLUMN hybrid_alpha double DEFAULT 0.5 COMMENT '混合检索权重比例 (0.0=纯向量, 1.0=纯关键词)';
|
||||
|
||||
-- 为知识片段表增加全文索引及关联ID
|
||||
ALTER TABLE knowledge_fragment ADD COLUMN knowledge_id bigint COMMENT '知识库ID';
|
||||
ALTER TABLE knowledge_fragment ADD FULLTEXT INDEX ft_content (content) WITH PARSER ngram;
|
||||
|
||||
-- 为知识库附件表增加解析状态字段
|
||||
ALTER TABLE `knowledge_attach` ADD COLUMN `status` TINYINT DEFAULT 0 COMMENT '解析状态: 0待解析, 1解析中, 2已解析, 3解析失败';
|
||||
@@ -77,4 +77,22 @@ public class QueryVectorBo {
|
||||
*/
|
||||
private Double rerankScoreThreshold;
|
||||
|
||||
// ========== 混合检索与阈值相关参数 ==========
|
||||
|
||||
/**
|
||||
* 相似度阈值 (0.0-1.0)
|
||||
* 应用于向量搜索阶段
|
||||
*/
|
||||
private Double similarityThreshold;
|
||||
|
||||
/**
|
||||
* 是否启用混合检索
|
||||
*/
|
||||
private Boolean enableHybrid = false;
|
||||
|
||||
/**
|
||||
* 混合检索权重 (0.0-1.0)
|
||||
*/
|
||||
private Double hybridAlpha;
|
||||
|
||||
}
|
||||
|
||||
@@ -20,6 +20,11 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.service.tool.ToolProvider;
|
||||
import dev.langchain4j.skills.shell.ShellSkills;
|
||||
import dev.langchain4j.rag.AugmentationRequest;
|
||||
import dev.langchain4j.rag.AugmentationResult;
|
||||
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
|
||||
import dev.langchain4j.rag.RetrievalAugmentor;
|
||||
import dev.langchain4j.rag.query.Metadata;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -55,6 +60,7 @@ import org.ruoyi.service.chat.IChatMessageService;
|
||||
import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore;
|
||||
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
|
||||
import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever;
|
||||
import org.ruoyi.service.vector.VectorStoreService;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
@@ -412,7 +418,6 @@ public class ChatServiceFacade implements IChatService {
|
||||
|
||||
/**
|
||||
* 构建上下文消息列表
|
||||
|
||||
* 消息顺序:历史消息 → 当前用户消息(确保 AI 正确理解对话上下文)
|
||||
*
|
||||
* @param chatRequest 聊天请求
|
||||
@@ -421,7 +426,41 @@ public class ChatServiceFacade implements IChatService {
|
||||
private List<ChatMessage> buildContextMessages(ChatRequest chatRequest) {
|
||||
List<ChatMessage> messages = new ArrayList<>();
|
||||
|
||||
// 从数据库查询历史对话消息(放在前面)
|
||||
// 1. 初始化当前用户消息
|
||||
UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent());
|
||||
|
||||
// 2. 知识库检索增强 (RAG)
|
||||
if (chatRequest.getKnowledgeId() != null) {
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId()));
|
||||
if (knowledgeInfoVo != null) {
|
||||
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
||||
if (chatModel != null) {
|
||||
log.info("执行高级 RAG 流程: kid={}", chatRequest.getKnowledgeId());
|
||||
|
||||
// 构建自定义检索器
|
||||
CustomVectorRetriever retriever = new CustomVectorRetriever(
|
||||
knowledgeRetrievalService, knowledgeInfoVo, chatModel);
|
||||
|
||||
// 构建增强流水线
|
||||
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
|
||||
.contentRetriever(retriever)
|
||||
.build();
|
||||
|
||||
// 执行增强:编织上下文到 UserMessage
|
||||
Metadata metadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>());
|
||||
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
|
||||
AugmentationResult result = augmentor.augment(augmentationRequest);
|
||||
|
||||
ChatMessage augmented = result.chatMessage();
|
||||
if (augmented instanceof UserMessage) {
|
||||
userMessage = (UserMessage) augmented;
|
||||
log.debug("RAG 增强完成,UserMessage 已注入背景知识");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 从数据库查询历史对话消息(放在前面)
|
||||
if (chatRequest.getSessionId() != null) {
|
||||
MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId());
|
||||
if (memory != null) {
|
||||
@@ -433,38 +472,7 @@ public class ChatServiceFacade implements IChatService {
|
||||
}
|
||||
}
|
||||
|
||||
// 从向量库查询相关历史消息(知识库内容作为上下文)
|
||||
if (chatRequest.getKnowledgeId() != null) {
|
||||
// 查询知识库信息
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId()));
|
||||
if (knowledgeInfoVo == null) {
|
||||
log.warn("知识库信息不存在,kid: {}", chatRequest.getKnowledgeId());
|
||||
// 继续添加当前用户消息
|
||||
messages.add(UserMessage.userMessage(chatRequest.getContent()));
|
||||
return messages;
|
||||
}
|
||||
|
||||
// 查询向量模型配置信息
|
||||
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
||||
if (chatModel == null) {
|
||||
log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel());
|
||||
messages.add(UserMessage.userMessage(chatRequest.getContent()));
|
||||
return messages;
|
||||
}
|
||||
|
||||
// 构建向量查询参数
|
||||
QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel);
|
||||
|
||||
// 使用知识库检索服务(支持重排序)
|
||||
List<String> nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo);
|
||||
for (String prompt : nearestList) {
|
||||
// 知识库内容作为系统上下文添加
|
||||
messages.add(new AiMessage(prompt));
|
||||
}
|
||||
}
|
||||
|
||||
// 构建当前用户消息(放在最后)
|
||||
UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent());
|
||||
// 4. 添加经过增强的用户消息(放在最后)
|
||||
messages.add(userMessage);
|
||||
|
||||
return messages;
|
||||
|
||||
@@ -12,25 +12,18 @@ import org.ruoyi.common.core.utils.StringUtils;
|
||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
||||
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
||||
import org.ruoyi.domain.bo.rerank.RerankRequest;
|
||||
import org.ruoyi.domain.bo.rerank.RerankResult;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
import org.ruoyi.factory.RerankModelFactory;
|
||||
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
|
||||
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
||||
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.rerank.RerankModelService;
|
||||
import org.ruoyi.service.vector.VectorStoreService;
|
||||
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
/**
|
||||
* 知识片段Service业务层处理
|
||||
@@ -46,8 +39,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
||||
private final KnowledgeFragmentMapper baseMapper;
|
||||
private final IKnowledgeInfoService knowledgeInfoService;
|
||||
private final IChatModelService chatModelService;
|
||||
private final VectorStoreService vectorStoreService;
|
||||
private final RerankModelFactory rerankModelFactory;
|
||||
private final KnowledgeRetrievalService knowledgeRetrievalService;
|
||||
|
||||
/**
|
||||
* 查询知识片段
|
||||
@@ -87,7 +79,6 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
||||
}
|
||||
|
||||
private LambdaQueryWrapper<KnowledgeFragment> buildQueryWrapper(KnowledgeFragmentBo bo) {
|
||||
Map<String, Object> params = bo.getParams();
|
||||
LambdaQueryWrapper<KnowledgeFragment> lqw = Wrappers.lambdaQuery();
|
||||
lqw.orderByAsc(KnowledgeFragment::getId);
|
||||
lqw.eq(bo.getDocId() != null, KnowledgeFragment::getDocId, bo.getDocId());
|
||||
@@ -149,7 +140,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
||||
}
|
||||
|
||||
/**
|
||||
* 检索测试核心实现
|
||||
* 检索测试核心实现 - 委托给统一的 KnowledgeRetrievalService
|
||||
*/
|
||||
@Override
|
||||
public List<KnowledgeRetrievalVo> retrieval(KnowledgeFragmentBo bo) {
|
||||
@@ -157,7 +148,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
// 1. 获取知识库及模型配置
|
||||
// 1. 获取知识库及模型配置(为了获取 API Key/Host 等模型参数)
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId());
|
||||
if (knowledgeInfoVo == null) {
|
||||
return new ArrayList<>();
|
||||
@@ -169,151 +160,28 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
// 2. 构造向量检索参数
|
||||
// 2. 构造通用的参数对象
|
||||
QueryVectorBo queryVectorBo = new QueryVectorBo();
|
||||
queryVectorBo.setQuery(bo.getQuery());
|
||||
queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId()));
|
||||
queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit());
|
||||
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
||||
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
|
||||
queryVectorBo.setApiKey(chatModel.getApiKey());
|
||||
queryVectorBo.setBaseUrl(chatModel.getApiHost());
|
||||
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
||||
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
|
||||
|
||||
// 3. 执行搜索 (向量搜索 + 关键词搜索)
|
||||
List<KnowledgeRetrievalVo> allResults;
|
||||
// 使用前端传入的实时测试参数,若无则使用知识库默认参数
|
||||
queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit());
|
||||
queryVectorBo.setSimilarityThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getSimilarityThreshold());
|
||||
|
||||
boolean hybridEnabled = Boolean.TRUE.equals(bo.getEnableHybrid()) ||
|
||||
Integer.valueOf(1).equals(knowledgeInfoVo.getEnableHybrid());
|
||||
queryVectorBo.setEnableHybrid(bo.getEnableHybrid() != null ? bo.getEnableHybrid() : Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1));
|
||||
queryVectorBo.setHybridAlpha(bo.getHybridAlpha() != null ? bo.getHybridAlpha() : knowledgeInfoVo.getHybridAlpha());
|
||||
|
||||
if (hybridEnabled) {
|
||||
log.info("执行混合检索: kid={}, query={}", bo.getKnowledgeId(), bo.getQuery());
|
||||
try {
|
||||
// 并行执行向量搜索
|
||||
CompletableFuture<List<KnowledgeRetrievalVo>> vectorFuture = CompletableFuture.supplyAsync(() ->
|
||||
vectorStoreService.search(queryVectorBo));
|
||||
queryVectorBo.setEnableRerank(bo.getEnableRerank() != null ? bo.getEnableRerank() : Objects.equals(knowledgeInfoVo.getEnableRerank(), 1));
|
||||
queryVectorBo.setRerankModelName(StringUtils.isNotBlank(bo.getRerankModel()) ? bo.getRerankModel() : knowledgeInfoVo.getRerankModel());
|
||||
queryVectorBo.setRerankTopN(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRerankTopN());
|
||||
queryVectorBo.setRerankScoreThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getRerankScoreThreshold());
|
||||
|
||||
// 执行关键词搜索 (MySQL)
|
||||
int limit = bo.getTopK() != null ? bo.getTopK() : 50;
|
||||
List<KnowledgeFragmentVo> keywordFragments = baseMapper.searchByKeyword(bo.getKnowledgeId(), bo.getQuery(), limit);
|
||||
List<KnowledgeRetrievalVo> keywordResults = keywordFragments.stream().map(f -> {
|
||||
KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo();
|
||||
vo.setId(f.getId().toString());
|
||||
vo.setContent(f.getContent());
|
||||
vo.setDocId(f.getDocId());
|
||||
vo.setIdx(f.getIdx());
|
||||
vo.setKnowledgeId(f.getKnowledgeId());
|
||||
vo.setScore(10.0); // 初始分,后续由 RRF 重新打分
|
||||
return vo;
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
List<KnowledgeRetrievalVo> vectorResults = vectorFuture.get();
|
||||
log.info("抽取混合结果成功: Vector命中={}条, Keyword命中={}条", vectorResults.size(), keywordResults.size());
|
||||
|
||||
double alpha = bo.getHybridAlpha() != null ? bo.getHybridAlpha() :
|
||||
(knowledgeInfoVo.getHybridAlpha() != null ? knowledgeInfoVo.getHybridAlpha() : 0.5);
|
||||
|
||||
allResults = calculateRRF(vectorResults, keywordResults, alpha);
|
||||
} catch (Exception e) {
|
||||
log.error("混合检索执行或合并失败,已自动降级回退到纯向量检索", e);
|
||||
allResults = vectorStoreService.search(queryVectorBo);
|
||||
}
|
||||
} else {
|
||||
allResults = vectorStoreService.search(queryVectorBo);
|
||||
}
|
||||
|
||||
// 初始化原始排名
|
||||
for (int i = 0; i < allResults.size(); i++) {
|
||||
allResults.get(i).setOriginalIndex(i);
|
||||
}
|
||||
|
||||
// 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型)
|
||||
if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) {
|
||||
log.info("开始重排精排,模型: [{}]", bo.getRerankModel());
|
||||
try {
|
||||
RerankModelService rerankModel = rerankModelFactory.createModel(bo.getRerankModel());
|
||||
|
||||
List<String> contents = allResults.stream()
|
||||
.map(KnowledgeRetrievalVo::getContent)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
RerankRequest rerankRequest = RerankRequest.builder()
|
||||
.query(bo.getQuery())
|
||||
.documents(contents)
|
||||
.topN(contents.size())
|
||||
.returnDocuments(false)
|
||||
.build();
|
||||
|
||||
RerankResult rerankResult = rerankModel.rerank(rerankRequest);
|
||||
|
||||
// 将重排分数写回,并记录原始分数供前端对比
|
||||
for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) {
|
||||
if (doc.getIndex() != null && doc.getIndex() < allResults.size()) {
|
||||
KnowledgeRetrievalVo resultVo = allResults.get(doc.getIndex());
|
||||
resultVo.setRawScore(resultVo.getScore());
|
||||
resultVo.setScore(doc.getRelevanceScore());
|
||||
}
|
||||
}
|
||||
|
||||
// 按重排后的分数从高到低排序
|
||||
allResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
|
||||
log.info("重排精排完成,结果数: {}", allResults.size());
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("重排精排执行失败,已跳过重排步骤: {}", e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 根据阈值过滤
|
||||
double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0;
|
||||
return allResults.stream()
|
||||
.filter(res -> res.getScore() >= threshold)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* RRF (Reciprocal Rank Fusion) 融合算法
|
||||
* 公式: Score = (1-alpha) * (1 / (k + rank_vector)) + alpha * (1 / (k + rank_keyword))
|
||||
*/
|
||||
private List<KnowledgeRetrievalVo> calculateRRF(List<KnowledgeRetrievalVo> vectorList, List<KnowledgeRetrievalVo> keywordList, double alpha) {
|
||||
Map<String, KnowledgeRetrievalVo> allMap = new HashMap<>();
|
||||
Map<String, Double> vectorScores = new HashMap<>();
|
||||
Map<String, Double> keywordScores = new HashMap<>();
|
||||
|
||||
int k = 60; // 常用 RRF 常数
|
||||
|
||||
for (int i = 0; i < vectorList.size(); i++) {
|
||||
KnowledgeRetrievalVo vo = vectorList.get(i);
|
||||
allMap.put(vo.getId(), vo);
|
||||
vectorScores.put(vo.getId(), 1.0 / (k + i + 1));
|
||||
}
|
||||
|
||||
for (int i = 0; i < keywordList.size(); i++) {
|
||||
KnowledgeRetrievalVo vo = keywordList.get(i);
|
||||
if (!allMap.containsKey(vo.getId())) {
|
||||
allMap.put(vo.getId(), vo);
|
||||
}
|
||||
keywordScores.put(vo.getId(), 1.0 / (k + i + 1));
|
||||
}
|
||||
|
||||
// 重新计算得分
|
||||
List<KnowledgeRetrievalVo> fusedResults = new ArrayList<>();
|
||||
for (Map.Entry<String, KnowledgeRetrievalVo> entry : allMap.entrySet()) {
|
||||
String id = entry.getKey();
|
||||
double vScore = vectorScores.getOrDefault(id, 0.0);
|
||||
double kScore = keywordScores.getOrDefault(id, 0.0);
|
||||
|
||||
// 混合分值
|
||||
double finalScore = (1 - alpha) * vScore + alpha * kScore;
|
||||
|
||||
// 分值归一化/缩放:将 RRF 分值放大到 0-1 范围
|
||||
// 理论单路最大得分为 1/61 ≈ 0.016,乘以 60 使其处于相似度常用区间
|
||||
KnowledgeRetrievalVo vo = entry.getValue();
|
||||
vo.setScore(finalScore * 60.0);
|
||||
fusedResults.add(vo);
|
||||
}
|
||||
|
||||
// 按融合分数从高到低排序
|
||||
fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
|
||||
return fusedResults;
|
||||
// 3. 执行统一检索
|
||||
return knowledgeRetrievalService.retrieve(queryVectorBo);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,14 +9,15 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||
import org.ruoyi.service.vector.VectorStoreService;
|
||||
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 自定义向量检索器:适配 LangChain4j ContentRetriever 接口
|
||||
* 桥接现有的 VectorStoreService 获取检索结果
|
||||
* 自定义检索器:适配 LangChain4j ContentRetriever 接口
|
||||
* 桥接统一的 KnowledgeRetrievalService,支持配置化的混合检索、阈值过滤等功能
|
||||
*
|
||||
* @author RobustH
|
||||
*/
|
||||
@@ -24,15 +25,15 @@ import java.util.stream.Collectors;
|
||||
@RequiredArgsConstructor
|
||||
public class CustomVectorRetriever implements ContentRetriever {
|
||||
|
||||
private final VectorStoreService vectorStoreService;
|
||||
private final KnowledgeRetrievalService knowledgeRetrievalService;
|
||||
private final KnowledgeInfoVo knowledgeInfoVo;
|
||||
private final ChatModelVo chatModelVo;
|
||||
|
||||
@Override
|
||||
public List<Content> retrieve(Query query) {
|
||||
log.info("执行自定义向量检索,关键字: {}", query.text());
|
||||
log.info("执行自定义检索,关键字: {}", query.text());
|
||||
|
||||
// 构建内部查询参数
|
||||
// 构建增强后的查询参数
|
||||
QueryVectorBo queryVectorBo = new QueryVectorBo();
|
||||
queryVectorBo.setQuery(query.text());
|
||||
queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId()));
|
||||
@@ -40,11 +41,21 @@ public class CustomVectorRetriever implements ContentRetriever {
|
||||
queryVectorBo.setBaseUrl(chatModelVo.getApiHost());
|
||||
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
|
||||
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
||||
// 如果接入了重排,这里的 retrieveLimit 也就是 MaxResults 应当被放大,后续留给 Aggregator 截断
|
||||
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
|
||||
|
||||
// 执行底层的多种向量库策略检索
|
||||
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
|
||||
// 应用知识库配置参数
|
||||
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
|
||||
queryVectorBo.setSimilarityThreshold(knowledgeInfoVo.getSimilarityThreshold());
|
||||
queryVectorBo.setEnableHybrid(Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1));
|
||||
queryVectorBo.setHybridAlpha(knowledgeInfoVo.getHybridAlpha());
|
||||
|
||||
// 设置重排序参数 (如果 retriever 阶段也想做初步重排,可以在此设置)
|
||||
queryVectorBo.setEnableRerank(Objects.equals(knowledgeInfoVo.getEnableRerank(), 1));
|
||||
queryVectorBo.setRerankModelName(knowledgeInfoVo.getRerankModel());
|
||||
queryVectorBo.setRerankTopN(knowledgeInfoVo.getRerankTopN());
|
||||
queryVectorBo.setRerankScoreThreshold(knowledgeInfoVo.getRerankScoreThreshold());
|
||||
|
||||
// 通过统一服务执行检索
|
||||
List<String> nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo);
|
||||
|
||||
// 将结果包装为标准的 Content 返回
|
||||
return nearestList.stream()
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package org.ruoyi.service.retrieval;
|
||||
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 知识库检索服务接口
|
||||
* 整合粗召回(向量检索)和重排序流程
|
||||
* 整合粗召回(向量检索/关键词检索)和重排序流程
|
||||
*
|
||||
* @author yang
|
||||
* @date 2026-04-19
|
||||
@@ -21,4 +22,13 @@ public interface KnowledgeRetrievalService {
|
||||
* @return 文本内容列表
|
||||
*/
|
||||
List<String> retrieveTexts(QueryVectorBo queryVectorBo);
|
||||
|
||||
/**
|
||||
* 执行知识库检索,返回详细结果对象(包含分数、文档ID等)
|
||||
* 支持混合检索和重排序
|
||||
*
|
||||
* @param queryVectorBo 查询参数
|
||||
* @return 检索结果列表
|
||||
*/
|
||||
List<KnowledgeRetrievalVo> retrieve(QueryVectorBo queryVectorBo);
|
||||
}
|
||||
|
||||
@@ -2,21 +2,26 @@ package org.ruoyi.service.retrieval.impl;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.utils.StringUtils;
|
||||
import org.ruoyi.domain.bo.rerank.RerankRequest;
|
||||
import org.ruoyi.domain.bo.rerank.RerankResult;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
import org.ruoyi.factory.RerankModelFactory;
|
||||
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
|
||||
import org.ruoyi.service.rerank.RerankModelService;
|
||||
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
|
||||
import org.ruoyi.service.vector.VectorStoreService;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 知识库检索服务实现
|
||||
* 整合粗召回(向量检索)和重排序流程
|
||||
* 整合粗召回(向量检索/关键词检索)、RRF融合和重排序流程
|
||||
*
|
||||
* @author yang
|
||||
* @date 2026-04-19
|
||||
@@ -28,6 +33,7 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService
|
||||
|
||||
private final VectorStoreService vectorStoreService;
|
||||
private final RerankModelFactory rerankModelFactory;
|
||||
private final KnowledgeFragmentMapper fragmentMapper;
|
||||
|
||||
/**
|
||||
* 粗召回默认扩大倍数
|
||||
@@ -37,99 +43,214 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService
|
||||
|
||||
@Override
|
||||
public List<String> retrieveTexts(QueryVectorBo queryVectorBo) {
|
||||
List<KnowledgeRetrievalVo> results = retrieve(queryVectorBo);
|
||||
return results.stream()
|
||||
.map(KnowledgeRetrievalVo::getContent)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<KnowledgeRetrievalVo> retrieve(QueryVectorBo queryVectorBo) {
|
||||
log.info("开始知识库检索, kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
|
||||
|
||||
// 1. 粗召回阶段 - 向量检索
|
||||
List<String> coarseResults = coarseRetrieval(queryVectorBo);
|
||||
// 1. 粗召回阶段 (向量检索 + 关键词搜索)
|
||||
List<KnowledgeRetrievalVo> coarseResults = performCoarseRetrieval(queryVectorBo);
|
||||
log.debug("粗召回返回 {} 条结果", coarseResults.size());
|
||||
|
||||
if (coarseResults.isEmpty()) {
|
||||
return coarseResults;
|
||||
}
|
||||
|
||||
// 2. 重排序阶段(可选)
|
||||
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
|
||||
queryVectorBo.getRerankModelName() != null) {
|
||||
return rerank(queryVectorBo, coarseResults);
|
||||
// 2. 初始化原始索引
|
||||
for (int i = 0; i < coarseResults.size(); i++) {
|
||||
coarseResults.get(i).setOriginalIndex(i);
|
||||
}
|
||||
|
||||
return coarseResults;
|
||||
// 3. 重排序阶段 (可选)
|
||||
List<KnowledgeRetrievalVo> finalResults = coarseResults;
|
||||
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
|
||||
StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) {
|
||||
finalResults = performRerank(queryVectorBo, coarseResults);
|
||||
}
|
||||
|
||||
// 4. 应用分值阈值过滤 (重排分值或 RRF 分值)
|
||||
double threshold = queryVectorBo.getRerankScoreThreshold() != null ?
|
||||
queryVectorBo.getRerankScoreThreshold() : 0.0;
|
||||
|
||||
return finalResults.stream()
|
||||
.filter(res -> res.getScore() >= threshold)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* 粗召回阶段 - 向量检索
|
||||
* 粗召回阶段:根据配置执行向量搜索或混合搜索
|
||||
*/
|
||||
private List<String> coarseRetrieval(QueryVectorBo queryVectorBo) {
|
||||
// 如果启用重排序,扩大粗召回数量
|
||||
int originalMaxResults = queryVectorBo.getMaxResults();
|
||||
int expandedResults = originalMaxResults;
|
||||
private List<KnowledgeRetrievalVo> performCoarseRetrieval(QueryVectorBo queryVectorBo) {
|
||||
// 如果启用重排序,适当扩大召回数量
|
||||
int originalMaxResults = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10;
|
||||
int targetMaxResults = originalMaxResults;
|
||||
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
|
||||
queryVectorBo.getRerankModelName() != null) {
|
||||
expandedResults = originalMaxResults * RERANK_EXPANSION_FACTOR;
|
||||
log.debug("启用重排序,粗召回数量从 {} 扩大到 {}", originalMaxResults, expandedResults);
|
||||
StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) {
|
||||
targetMaxResults = originalMaxResults * RERANK_EXPANSION_FACTOR;
|
||||
}
|
||||
|
||||
// 临时修改查询数量
|
||||
queryVectorBo.setMaxResults(expandedResults);
|
||||
// 如果未启用混合检索,直接走向量搜索
|
||||
if (!Boolean.TRUE.equals(queryVectorBo.getEnableHybrid())) {
|
||||
QueryVectorBo vectorQuery = copyOf(queryVectorBo, targetMaxResults);
|
||||
List<KnowledgeRetrievalVo> results = vectorStoreService.search(vectorQuery);
|
||||
|
||||
// 应用基础相似度阈值过滤(如果有)
|
||||
if (queryVectorBo.getSimilarityThreshold() != null) {
|
||||
results = results.stream()
|
||||
.filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// 混合检索逻辑
|
||||
log.info("执行混合检索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
|
||||
try {
|
||||
return vectorStoreService.getQueryVector(queryVectorBo);
|
||||
} finally {
|
||||
// 恢复原始值
|
||||
queryVectorBo.setMaxResults(originalMaxResults);
|
||||
// A. 并行执行向量搜索
|
||||
int finalTargetMaxResults = targetMaxResults;
|
||||
CompletableFuture<List<KnowledgeRetrievalVo>> vectorFuture = CompletableFuture.supplyAsync(() -> {
|
||||
QueryVectorBo vectorQuery = copyOf(queryVectorBo, finalTargetMaxResults);
|
||||
List<KnowledgeRetrievalVo> results = vectorStoreService.search(vectorQuery);
|
||||
// 向量层初步过滤
|
||||
if (queryVectorBo.getSimilarityThreshold() != null) {
|
||||
return results.stream()
|
||||
.filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return results;
|
||||
});
|
||||
|
||||
// B. 并行执行关键词搜索 (MySQL Fulltext)
|
||||
CompletableFuture<List<KnowledgeRetrievalVo>> keywordFuture = CompletableFuture.supplyAsync(() -> {
|
||||
try {
|
||||
Long kid = Long.valueOf(queryVectorBo.getKid());
|
||||
List<KnowledgeFragmentVo> fragments = fragmentMapper.searchByKeyword(kid, queryVectorBo.getQuery(), finalTargetMaxResults);
|
||||
return fragments.stream().map(f -> {
|
||||
KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo();
|
||||
vo.setId(f.getId().toString());
|
||||
vo.setContent(f.getContent());
|
||||
vo.setDocId(f.getDocId());
|
||||
vo.setIdx(f.getIdx());
|
||||
vo.setKnowledgeId(f.getKnowledgeId());
|
||||
vo.setScore(10.0); // RRF 初始占位分
|
||||
return vo;
|
||||
}).collect(Collectors.toList());
|
||||
} catch (Exception e) {
|
||||
log.error("关键词检索失败: {}", e.getMessage());
|
||||
return new ArrayList<>();
|
||||
}
|
||||
});
|
||||
|
||||
List<KnowledgeRetrievalVo> vectorResults = vectorFuture.get();
|
||||
List<KnowledgeRetrievalVo> keywordResults = keywordFuture.get();
|
||||
|
||||
// C. RRF 融合
|
||||
double alpha = queryVectorBo.getHybridAlpha() != null ? queryVectorBo.getHybridAlpha() : 0.5;
|
||||
return calculateRRF(vectorResults, keywordResults, alpha);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("混合检索执行失败,回退到纯向量检索: {}", e.getMessage(), e);
|
||||
return vectorStoreService.search(copyOf(queryVectorBo, targetMaxResults));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 重排序阶段
|
||||
*/
|
||||
private List<String> rerank(QueryVectorBo queryVectorBo, List<String> coarseResults) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
private List<KnowledgeRetrievalVo> performRerank(QueryVectorBo queryVectorBo, List<KnowledgeRetrievalVo> coarseResults) {
|
||||
try {
|
||||
// 1. 通过工厂获取重排序模型
|
||||
RerankModelService rerankModel = rerankModelFactory.createModel(queryVectorBo.getRerankModelName());
|
||||
|
||||
// 2. 构建重排序请求
|
||||
int topN = queryVectorBo.getRerankTopN() != null ?
|
||||
queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults();
|
||||
List<String> contents = coarseResults.stream()
|
||||
.map(KnowledgeRetrievalVo::getContent)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// topN 默认为 maxResults
|
||||
int topN = queryVectorBo.getRerankTopN() != null ? queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults();
|
||||
|
||||
RerankRequest rerankRequest = RerankRequest.builder()
|
||||
.query(queryVectorBo.getQuery())
|
||||
.documents(coarseResults)
|
||||
.documents(contents)
|
||||
.topN(topN)
|
||||
.returnDocuments(true)
|
||||
.build();
|
||||
|
||||
log.info("执行重排序, model={}, documents={}, topN={}",
|
||||
queryVectorBo.getRerankModelName(), coarseResults.size(), topN);
|
||||
|
||||
// 3. 执行重排序
|
||||
RerankResult rerankResult = rerankModel.rerank(rerankRequest);
|
||||
|
||||
// 4. 转换重排序结果
|
||||
List<String> finalResults = new ArrayList<>();
|
||||
// 写回分数并记录原始分
|
||||
for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) {
|
||||
// 应用分数阈值过滤
|
||||
if (queryVectorBo.getRerankScoreThreshold() != null &&
|
||||
doc.getRelevanceScore() < queryVectorBo.getRerankScoreThreshold()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (doc.getDocument() != null) {
|
||||
finalResults.add(doc.getDocument());
|
||||
if (doc.getIndex() != null && doc.getIndex() < coarseResults.size()) {
|
||||
KnowledgeRetrievalVo vo = coarseResults.get(doc.getIndex());
|
||||
vo.setRawScore(vo.getScore());
|
||||
vo.setScore(doc.getRelevanceScore());
|
||||
}
|
||||
}
|
||||
|
||||
long duration = System.currentTimeMillis() - startTime;
|
||||
log.info("重排序完成, 返回 {} 条结果, 耗时 {}ms", finalResults.size(), duration);
|
||||
// 按新分排序
|
||||
coarseResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
|
||||
|
||||
return finalResults;
|
||||
// 截断到 topN
|
||||
return coarseResults.subList(0, Math.min(topN, coarseResults.size()));
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("重排序失败: {}", e.getMessage(), e);
|
||||
// 重排序失败时返回原始粗召回结果(截取到期望数量)
|
||||
int limit = Math.min(queryVectorBo.getMaxResults(), coarseResults.size());
|
||||
return new ArrayList<>(coarseResults.subList(0, limit));
|
||||
log.error("重排序流程失败: {}", e.getMessage());
|
||||
int limit = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10;
|
||||
return coarseResults.subList(0, Math.min(limit, coarseResults.size()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* RRF (Reciprocal Rank Fusion) 融合计算
|
||||
*/
|
||||
private List<KnowledgeRetrievalVo> calculateRRF(List<KnowledgeRetrievalVo> vectorList, List<KnowledgeRetrievalVo> keywordList, double alpha) {
|
||||
Map<String, KnowledgeRetrievalVo> allMap = new LinkedHashMap<>();
|
||||
Map<String, Double> vectorScores = new HashMap<>();
|
||||
Map<String, Double> keywordScores = new HashMap<>();
|
||||
|
||||
int k = 60; // RRF 常数
|
||||
|
||||
for (int i = 0; i < vectorList.size(); i++) {
|
||||
KnowledgeRetrievalVo vo = vectorList.get(i);
|
||||
allMap.put(vo.getId(), vo);
|
||||
vectorScores.put(vo.getId(), 1.0 / (k + i + 1));
|
||||
}
|
||||
|
||||
for (int i = 0; i < keywordList.size(); i++) {
|
||||
KnowledgeRetrievalVo vo = keywordList.get(i);
|
||||
if (!allMap.containsKey(vo.getId())) {
|
||||
allMap.put(vo.getId(), vo);
|
||||
}
|
||||
keywordScores.put(vo.getId(), 1.0 / (k + i + 1));
|
||||
}
|
||||
|
||||
List<KnowledgeRetrievalVo> fusedResults = new ArrayList<>();
|
||||
for (Map.Entry<String, KnowledgeRetrievalVo> entry : allMap.entrySet()) {
|
||||
String id = entry.getKey();
|
||||
double finalScore = (1 - alpha) * vectorScores.getOrDefault(id, 0.0) +
|
||||
alpha * keywordScores.getOrDefault(id, 0.0);
|
||||
|
||||
KnowledgeRetrievalVo vo = entry.getValue();
|
||||
vo.setScore(finalScore * 60.0); // 归一化缩放
|
||||
fusedResults.add(vo);
|
||||
}
|
||||
|
||||
fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
|
||||
return fusedResults;
|
||||
}
|
||||
|
||||
private QueryVectorBo copyOf(QueryVectorBo original, int maxResults) {
|
||||
QueryVectorBo copy = new QueryVectorBo();
|
||||
copy.setQuery(original.getQuery());
|
||||
copy.setKid(original.getKid());
|
||||
copy.setMaxResults(maxResults);
|
||||
copy.setVectorModelName(original.getVectorModelName());
|
||||
copy.setEmbeddingModelName(original.getEmbeddingModelName());
|
||||
copy.setApiKey(original.getApiKey());
|
||||
copy.setBaseUrl(original.getBaseUrl());
|
||||
return copy;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user