feat(rag): 知识库检索测试新增混合检索

This commit is contained in:
RobustH
2026-04-14 23:18:29 +08:00
parent 1208c46cca
commit ccbf5c9520
10 changed files with 178 additions and 3 deletions

View File

@@ -79,4 +79,14 @@ public class KnowledgeFragmentBo extends BaseEntity {
*/ */
private String rerankModel; private String rerankModel;
/**
* 是否启用混合检索
*/
private Boolean enableHybrid;
/**
* 混合检索权重 (0.0-1.0)
*/
private Double hybridAlpha;
} }

View File

@@ -92,5 +92,15 @@ public class KnowledgeInfoBo extends BaseEntity {
*/ */
private String remark; private String remark;
/**
* 是否启用混合检索0 否 1 是)
*/
private Integer enableHybrid;
/**
* 混合检索权重比例 (0.0-1.0)
*/
private Double hybridAlpha;
} }

View File

@@ -47,5 +47,10 @@ public class KnowledgeFragment extends BaseEntity {
*/ */
private String remark; private String remark;
/**
* 知识库ID
*/
private Long knowledgeId;
} }

View File

@@ -93,5 +93,15 @@ public class KnowledgeInfo extends BaseEntity {
*/ */
private String remark; private String remark;
/**
* 是否启用混合检索0 否 1 是)
*/
private Integer enableHybrid;
/**
* 混合检索权重比例 (0.0-1.0)
*/
private Double hybridAlpha;
} }

View File

@@ -39,7 +39,7 @@ public class KnowledgeFragmentVo implements Serializable {
* 片段索引下标 * 片段索引下标
*/ */
@ExcelProperty(value = "片段索引下标") @ExcelProperty(value = "片段索引下标")
private Long idx; private Integer idx;
/** /**
* 文档内容 * 文档内容
@@ -53,5 +53,10 @@ public class KnowledgeFragmentVo implements Serializable {
@ExcelProperty(value = "备注") @ExcelProperty(value = "备注")
private String remark; private String remark;
/**
* 知识库ID
*/
private Long knowledgeId;
} }

View File

@@ -113,6 +113,19 @@ public class KnowledgeInfoVo implements Serializable {
@ExcelProperty(value = "备注") @ExcelProperty(value = "备注")
private String remark; private String remark;
/**
* 是否启用混合检索0 否 1 是)
*/
@ExcelProperty(value = "是否启用混合检索", converter = ExcelDictConvert.class)
@ExcelDictFormat(readConverterExp = "0=否,1=是")
private Integer enableHybrid;
/**
* 混合检索权重比例 (0.0-1.0)
*/
@ExcelProperty(value = "混合检索权重比例")
private Double hybridAlpha;
/** /**
* 文档数(统计字段,非数据库列) * 文档数(统计字段,非数据库列)
*/ */

View File

@@ -1,7 +1,9 @@
package org.ruoyi.domain.vo.knowledge; package org.ruoyi.domain.vo.knowledge;
import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serial; import java.io.Serial;
import java.io.Serializable; import java.io.Serializable;
@@ -13,11 +15,33 @@ import java.io.Serializable;
*/ */
@Data @Data
@Builder @Builder
@NoArgsConstructor
@AllArgsConstructor
public class KnowledgeRetrievalVo implements Serializable { public class KnowledgeRetrievalVo implements Serializable {
@Serial @Serial
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
/**
* 片段ID
*/
private String id;
/**
* 文档ID
*/
private String docId;
/**
* 知识库ID
*/
private Long knowledgeId;
/**
* 分片索引
*/
private Integer idx;
/** /**
* 片段内容 * 片段内容
*/ */

View File

@@ -33,4 +33,13 @@ public interface KnowledgeFragmentMapper extends BaseMapperPlus<KnowledgeFragmen
"GROUP BY doc_id" + "GROUP BY doc_id" +
"</script>") "</script>")
List<DocFragmentCountVo> selectFragmentCountByDocIds(@Param("docIds") List<String> docIds); List<DocFragmentCountVo> selectFragmentCountByDocIds(@Param("docIds") List<String> docIds);
@Select("<script>" +
"SELECT id, doc_id AS docId, content, idx, knowledge_id AS knowledgeId " +
"FROM knowledge_fragment " +
"WHERE knowledge_id = #{knowledgeId} " +
"AND MATCH (content) AGAINST (#{query} IN NATURAL LANGUAGE MODE) " +
"ORDER BY MATCH (content) AGAINST (#{query} IN NATURAL LANGUAGE MODE) DESC " +
"LIMIT #{limit}" +
"</script>")
List<KnowledgeFragmentVo> searchByKeyword(@Param("knowledgeId") Long knowledgeId, @Param("query") String query, @Param("limit") Integer limit);
} }

View File

@@ -187,6 +187,7 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService {
String fid = RandomUtil.randomString(10); String fid = RandomUtil.randomString(10);
fids.add(fid); fids.add(fid);
KnowledgeFragment knowledgeFragment = new KnowledgeFragment(); KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
knowledgeFragment.setKnowledgeId(knowledgeId);
knowledgeFragment.setDocId(docId); knowledgeFragment.setDocId(docId);
knowledgeFragment.setIdx(i); knowledgeFragment.setIdx(i);
knowledgeFragment.setContent(chunkList.get(i)); knowledgeFragment.setContent(chunkList.get(i));

View File

@@ -28,6 +28,8 @@ import org.ruoyi.service.vector.VectorStoreService;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -180,8 +182,47 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
queryVectorBo.setApiKey(chatModel.getApiKey()); queryVectorBo.setApiKey(chatModel.getApiKey());
queryVectorBo.setBaseUrl(chatModel.getApiHost()); queryVectorBo.setBaseUrl(chatModel.getApiHost());
// 3. 执行物理检索 // 3. 执行搜索 (向量搜索 + 关键词搜索)
List<KnowledgeRetrievalVo> allResults = vectorStoreService.search(queryVectorBo); List<KnowledgeRetrievalVo> allResults;
boolean hybridEnabled = Boolean.TRUE.equals(bo.getEnableHybrid()) ||
Integer.valueOf(1).equals(knowledgeInfoVo.getEnableHybrid());
if (hybridEnabled) {
log.info("执行混合检索: kid={}, query={}", bo.getKnowledgeId(), bo.getQuery());
try {
// 并行执行向量搜索
CompletableFuture<List<KnowledgeRetrievalVo>> vectorFuture = CompletableFuture.supplyAsync(() ->
vectorStoreService.search(queryVectorBo));
// 执行关键词搜索 (MySQL)
int limit = bo.getTopK() != null ? bo.getTopK() : 50;
List<KnowledgeFragmentVo> keywordFragments = baseMapper.searchByKeyword(bo.getKnowledgeId(), bo.getQuery(), limit);
List<KnowledgeRetrievalVo> keywordResults = keywordFragments.stream().map(f -> {
KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo();
vo.setId(f.getId().toString());
vo.setContent(f.getContent());
vo.setDocId(f.getDocId());
vo.setIdx(f.getIdx());
vo.setKnowledgeId(f.getKnowledgeId());
vo.setScore(10.0); // 初始分,后续由 RRF 重新打分
return vo;
}).collect(Collectors.toList());
List<KnowledgeRetrievalVo> vectorResults = vectorFuture.get();
log.info("抽取混合结果成功: Vector命中={}条, Keyword命中={}条", vectorResults.size(), keywordResults.size());
double alpha = bo.getHybridAlpha() != null ? bo.getHybridAlpha() :
(knowledgeInfoVo.getHybridAlpha() != null ? knowledgeInfoVo.getHybridAlpha() : 0.5);
allResults = calculateRRF(vectorResults, keywordResults, alpha);
} catch (Exception e) {
log.error("混合检索执行或合并失败,已自动降级回退到纯向量检索", e);
allResults = vectorStoreService.search(queryVectorBo);
}
} else {
allResults = vectorStoreService.search(queryVectorBo);
}
// 初始化原始排名 // 初始化原始排名
for (int i = 0; i < allResults.size(); i++) { for (int i = 0; i < allResults.size(); i++) {
@@ -230,4 +271,51 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
.filter(res -> res.getScore() >= threshold) .filter(res -> res.getScore() >= threshold)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
/**
* RRF (Reciprocal Rank Fusion) 融合算法
* 公式: Score = (1-alpha) * (1 / (k + rank_vector)) + alpha * (1 / (k + rank_keyword))
*/
private List<KnowledgeRetrievalVo> calculateRRF(List<KnowledgeRetrievalVo> vectorList, List<KnowledgeRetrievalVo> keywordList, double alpha) {
Map<String, KnowledgeRetrievalVo> allMap = new HashMap<>();
Map<String, Double> vectorScores = new HashMap<>();
Map<String, Double> keywordScores = new HashMap<>();
int k = 60; // 常用 RRF 常数
for (int i = 0; i < vectorList.size(); i++) {
KnowledgeRetrievalVo vo = vectorList.get(i);
allMap.put(vo.getId(), vo);
vectorScores.put(vo.getId(), 1.0 / (k + i + 1));
}
for (int i = 0; i < keywordList.size(); i++) {
KnowledgeRetrievalVo vo = keywordList.get(i);
if (!allMap.containsKey(vo.getId())) {
allMap.put(vo.getId(), vo);
}
keywordScores.put(vo.getId(), 1.0 / (k + i + 1));
}
// 重新计算得分
List<KnowledgeRetrievalVo> fusedResults = new ArrayList<>();
for (Map.Entry<String, KnowledgeRetrievalVo> entry : allMap.entrySet()) {
String id = entry.getKey();
double vScore = vectorScores.getOrDefault(id, 0.0);
double kScore = keywordScores.getOrDefault(id, 0.0);
// 混合分值
double finalScore = (1 - alpha) * vScore + alpha * kScore;
// 分值归一化/缩放:将 RRF 分值放大到 0-1 范围
// 理论单路最大得分为 1/61 ≈ 0.016,乘以 60 使其处于相似度常用区间
KnowledgeRetrievalVo vo = entry.getValue();
vo.setScore(finalScore * 60.0);
fusedResults.add(vo);
}
// 按融合分数从高到低排序
fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
return fusedResults;
}
} }