mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-23 08:43:40 +00:00
feat(rag): 知识库检索测试新增混合检索
This commit is contained in:
@@ -79,4 +79,14 @@ public class KnowledgeFragmentBo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private String rerankModel;
|
private String rerankModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索
|
||||||
|
*/
|
||||||
|
private Boolean enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,5 +92,15 @@ public class KnowledgeInfoBo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索(0 否 1 是)
|
||||||
|
*/
|
||||||
|
private Integer enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重比例 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,5 +47,10 @@ public class KnowledgeFragment extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库ID
|
||||||
|
*/
|
||||||
|
private Long knowledgeId;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -93,5 +93,15 @@ public class KnowledgeInfo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索(0 否 1 是)
|
||||||
|
*/
|
||||||
|
private Integer enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重比例 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ public class KnowledgeFragmentVo implements Serializable {
|
|||||||
* 片段索引下标
|
* 片段索引下标
|
||||||
*/
|
*/
|
||||||
@ExcelProperty(value = "片段索引下标")
|
@ExcelProperty(value = "片段索引下标")
|
||||||
private Long idx;
|
private Integer idx;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 文档内容
|
* 文档内容
|
||||||
@@ -53,5 +53,10 @@ public class KnowledgeFragmentVo implements Serializable {
|
|||||||
@ExcelProperty(value = "备注")
|
@ExcelProperty(value = "备注")
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库ID
|
||||||
|
*/
|
||||||
|
private Long knowledgeId;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -113,6 +113,19 @@ public class KnowledgeInfoVo implements Serializable {
|
|||||||
@ExcelProperty(value = "备注")
|
@ExcelProperty(value = "备注")
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索(0 否 1 是)
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "是否启用混合检索", converter = ExcelDictConvert.class)
|
||||||
|
@ExcelDictFormat(readConverterExp = "0=否,1=是")
|
||||||
|
private Integer enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重比例 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "混合检索权重比例")
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 文档数(统计字段,非数据库列)
|
* 文档数(统计字段,非数据库列)
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package org.ruoyi.domain.vo.knowledge;
|
package org.ruoyi.domain.vo.knowledge;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.io.Serial;
|
import java.io.Serial;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
@@ -13,11 +15,33 @@ import java.io.Serializable;
|
|||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
public class KnowledgeRetrievalVo implements Serializable {
|
public class KnowledgeRetrievalVo implements Serializable {
|
||||||
|
|
||||||
@Serial
|
@Serial
|
||||||
private static final long serialVersionUID = 1L;
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 片段ID
|
||||||
|
*/
|
||||||
|
private String id;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档ID
|
||||||
|
*/
|
||||||
|
private String docId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库ID
|
||||||
|
*/
|
||||||
|
private Long knowledgeId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分片索引
|
||||||
|
*/
|
||||||
|
private Integer idx;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 片段内容
|
* 片段内容
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -33,4 +33,13 @@ public interface KnowledgeFragmentMapper extends BaseMapperPlus<KnowledgeFragmen
|
|||||||
"GROUP BY doc_id" +
|
"GROUP BY doc_id" +
|
||||||
"</script>")
|
"</script>")
|
||||||
List<DocFragmentCountVo> selectFragmentCountByDocIds(@Param("docIds") List<String> docIds);
|
List<DocFragmentCountVo> selectFragmentCountByDocIds(@Param("docIds") List<String> docIds);
|
||||||
|
@Select("<script>" +
|
||||||
|
"SELECT id, doc_id AS docId, content, idx, knowledge_id AS knowledgeId " +
|
||||||
|
"FROM knowledge_fragment " +
|
||||||
|
"WHERE knowledge_id = #{knowledgeId} " +
|
||||||
|
"AND MATCH (content) AGAINST (#{query} IN NATURAL LANGUAGE MODE) " +
|
||||||
|
"ORDER BY MATCH (content) AGAINST (#{query} IN NATURAL LANGUAGE MODE) DESC " +
|
||||||
|
"LIMIT #{limit}" +
|
||||||
|
"</script>")
|
||||||
|
List<KnowledgeFragmentVo> searchByKeyword(@Param("knowledgeId") Long knowledgeId, @Param("query") String query, @Param("limit") Integer limit);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService {
|
|||||||
String fid = RandomUtil.randomString(10);
|
String fid = RandomUtil.randomString(10);
|
||||||
fids.add(fid);
|
fids.add(fid);
|
||||||
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
|
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
|
||||||
|
knowledgeFragment.setKnowledgeId(knowledgeId);
|
||||||
knowledgeFragment.setDocId(docId);
|
knowledgeFragment.setDocId(docId);
|
||||||
knowledgeFragment.setIdx(i);
|
knowledgeFragment.setIdx(i);
|
||||||
knowledgeFragment.setContent(chunkList.get(i));
|
knowledgeFragment.setContent(chunkList.get(i));
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ import org.ruoyi.service.vector.VectorStoreService;
|
|||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -180,8 +182,47 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
|||||||
queryVectorBo.setApiKey(chatModel.getApiKey());
|
queryVectorBo.setApiKey(chatModel.getApiKey());
|
||||||
queryVectorBo.setBaseUrl(chatModel.getApiHost());
|
queryVectorBo.setBaseUrl(chatModel.getApiHost());
|
||||||
|
|
||||||
// 3. 执行物理检索
|
// 3. 执行搜索 (向量搜索 + 关键词搜索)
|
||||||
List<KnowledgeRetrievalVo> allResults = vectorStoreService.search(queryVectorBo);
|
List<KnowledgeRetrievalVo> allResults;
|
||||||
|
|
||||||
|
boolean hybridEnabled = Boolean.TRUE.equals(bo.getEnableHybrid()) ||
|
||||||
|
Integer.valueOf(1).equals(knowledgeInfoVo.getEnableHybrid());
|
||||||
|
|
||||||
|
if (hybridEnabled) {
|
||||||
|
log.info("执行混合检索: kid={}, query={}", bo.getKnowledgeId(), bo.getQuery());
|
||||||
|
try {
|
||||||
|
// 并行执行向量搜索
|
||||||
|
CompletableFuture<List<KnowledgeRetrievalVo>> vectorFuture = CompletableFuture.supplyAsync(() ->
|
||||||
|
vectorStoreService.search(queryVectorBo));
|
||||||
|
|
||||||
|
// 执行关键词搜索 (MySQL)
|
||||||
|
int limit = bo.getTopK() != null ? bo.getTopK() : 50;
|
||||||
|
List<KnowledgeFragmentVo> keywordFragments = baseMapper.searchByKeyword(bo.getKnowledgeId(), bo.getQuery(), limit);
|
||||||
|
List<KnowledgeRetrievalVo> keywordResults = keywordFragments.stream().map(f -> {
|
||||||
|
KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo();
|
||||||
|
vo.setId(f.getId().toString());
|
||||||
|
vo.setContent(f.getContent());
|
||||||
|
vo.setDocId(f.getDocId());
|
||||||
|
vo.setIdx(f.getIdx());
|
||||||
|
vo.setKnowledgeId(f.getKnowledgeId());
|
||||||
|
vo.setScore(10.0); // 初始分,后续由 RRF 重新打分
|
||||||
|
return vo;
|
||||||
|
}).collect(Collectors.toList());
|
||||||
|
|
||||||
|
List<KnowledgeRetrievalVo> vectorResults = vectorFuture.get();
|
||||||
|
log.info("抽取混合结果成功: Vector命中={}条, Keyword命中={}条", vectorResults.size(), keywordResults.size());
|
||||||
|
|
||||||
|
double alpha = bo.getHybridAlpha() != null ? bo.getHybridAlpha() :
|
||||||
|
(knowledgeInfoVo.getHybridAlpha() != null ? knowledgeInfoVo.getHybridAlpha() : 0.5);
|
||||||
|
|
||||||
|
allResults = calculateRRF(vectorResults, keywordResults, alpha);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("混合检索执行或合并失败,已自动降级回退到纯向量检索", e);
|
||||||
|
allResults = vectorStoreService.search(queryVectorBo);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
allResults = vectorStoreService.search(queryVectorBo);
|
||||||
|
}
|
||||||
|
|
||||||
// 初始化原始排名
|
// 初始化原始排名
|
||||||
for (int i = 0; i < allResults.size(); i++) {
|
for (int i = 0; i < allResults.size(); i++) {
|
||||||
@@ -230,4 +271,51 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
|||||||
.filter(res -> res.getScore() >= threshold)
|
.filter(res -> res.getScore() >= threshold)
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RRF (Reciprocal Rank Fusion) 融合算法
|
||||||
|
* 公式: Score = (1-alpha) * (1 / (k + rank_vector)) + alpha * (1 / (k + rank_keyword))
|
||||||
|
*/
|
||||||
|
private List<KnowledgeRetrievalVo> calculateRRF(List<KnowledgeRetrievalVo> vectorList, List<KnowledgeRetrievalVo> keywordList, double alpha) {
|
||||||
|
Map<String, KnowledgeRetrievalVo> allMap = new HashMap<>();
|
||||||
|
Map<String, Double> vectorScores = new HashMap<>();
|
||||||
|
Map<String, Double> keywordScores = new HashMap<>();
|
||||||
|
|
||||||
|
int k = 60; // 常用 RRF 常数
|
||||||
|
|
||||||
|
for (int i = 0; i < vectorList.size(); i++) {
|
||||||
|
KnowledgeRetrievalVo vo = vectorList.get(i);
|
||||||
|
allMap.put(vo.getId(), vo);
|
||||||
|
vectorScores.put(vo.getId(), 1.0 / (k + i + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < keywordList.size(); i++) {
|
||||||
|
KnowledgeRetrievalVo vo = keywordList.get(i);
|
||||||
|
if (!allMap.containsKey(vo.getId())) {
|
||||||
|
allMap.put(vo.getId(), vo);
|
||||||
|
}
|
||||||
|
keywordScores.put(vo.getId(), 1.0 / (k + i + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重新计算得分
|
||||||
|
List<KnowledgeRetrievalVo> fusedResults = new ArrayList<>();
|
||||||
|
for (Map.Entry<String, KnowledgeRetrievalVo> entry : allMap.entrySet()) {
|
||||||
|
String id = entry.getKey();
|
||||||
|
double vScore = vectorScores.getOrDefault(id, 0.0);
|
||||||
|
double kScore = keywordScores.getOrDefault(id, 0.0);
|
||||||
|
|
||||||
|
// 混合分值
|
||||||
|
double finalScore = (1 - alpha) * vScore + alpha * kScore;
|
||||||
|
|
||||||
|
// 分值归一化/缩放:将 RRF 分值放大到 0-1 范围
|
||||||
|
// 理论单路最大得分为 1/61 ≈ 0.016,乘以 60 使其处于相似度常用区间
|
||||||
|
KnowledgeRetrievalVo vo = entry.getValue();
|
||||||
|
vo.setScore(finalScore * 60.0);
|
||||||
|
fusedResults.add(vo);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按融合分数从高到低排序
|
||||||
|
fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
|
||||||
|
return fusedResults;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user