feat(rag): 集成硅基流动、阿里百炼重排模型并全方位增强检索测试体验

This commit is contained in:
RobustH
2026-04-14 01:40:28 +08:00
parent 06a63c377e
commit 1208c46cca
10 changed files with 389 additions and 16 deletions

View File

@@ -69,4 +69,14 @@ public class KnowledgeFragmentBo extends BaseEntity {
*/ */
private Double threshold; private Double threshold;
/**
* 是否启用重排
*/
private Boolean enableRerank;
/**
* 重排模型名称
*/
private String rerankModel;
} }

View File

@@ -77,6 +77,16 @@ public class KnowledgeInfoBo extends BaseEntity {
*/ */
private String embeddingModel; private String embeddingModel;
/**
* 重排模型
*/
private String rerankModel;
/**
* 是否启用重排0 否 1 是)
*/
private Integer enableRerank;
/** /**
* 备注 * 备注
*/ */

View File

@@ -78,6 +78,16 @@ public class KnowledgeInfo extends BaseEntity {
*/ */
private String embeddingModel; private String embeddingModel;
/**
* 重排模型
*/
private String rerankModel;
/**
* 是否启用重排0 否 1 是)
*/
private Integer enableRerank;
/** /**
* 备注 * 备注
*/ */

View File

@@ -94,6 +94,19 @@ public class KnowledgeInfoVo implements Serializable {
@ExcelProperty(value = "向量模型") @ExcelProperty(value = "向量模型")
private String embeddingModel; private String embeddingModel;
/**
* 重排模型
*/
@ExcelProperty(value = "重排模型")
private String rerankModel;
/**
* 是否启用重排0 否 1 是)
*/
@ExcelProperty(value = "是否启用重排", converter = ExcelDictConvert.class)
@ExcelDictFormat(readConverterExp = "0=否,1=是")
private Integer enableRerank;
/** /**
* 备注 * 备注
*/ */

View File

@@ -28,6 +28,16 @@ public class KnowledgeRetrievalVo implements Serializable {
*/ */
private Double score; private Double score;
/**
* 原始检索排名 (重排前)
*/
private Integer originalIndex;
/**
* 原始检索得分 (重排前)
*/
private Double rawScore;
/** /**
* 来源文档名称 * 来源文档名称
*/ */

View File

@@ -377,20 +377,24 @@ public class ChatServiceFacade implements IChatService {
CustomVectorRetriever retriever = new CustomVectorRetriever( CustomVectorRetriever retriever = new CustomVectorRetriever(
vectorStoreService, knowledgeInfoVo, chatModel); vectorStoreService, knowledgeInfoVo, chatModel);
// 2. 获取和构建重排模型聚合器Aggregator // 2. 构建重排聚合器 (Aggregator)
// 假设已在 KnowledgeInfoVo 等加入 getRerankModelConfig/getRerankModel 等,这里演示通用逻辑
// 若无重排需求,使用 DefaultContentAggregator 或无 ScoringModel 的聚合器
ContentAggregator contentAggregator; ContentAggregator contentAggregator;
// TODO: 一旦实体类实现了重排模型的支持,此处可以从数据库读出: if (knowledgeInfoVo.getEnableRerank() != null && knowledgeInfoVo.getEnableRerank() == 1
// ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel()); && knowledgeInfoVo.getRerankModel() != null) {
ChatModelVo scoringModelConfig = null; // 当前暂无对应配置字段
ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel());
ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig); ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig);
if (scoringModel != null) { if (scoringModel != null) {
contentAggregator = ReRankingContentAggregator.builder() contentAggregator = ReRankingContentAggregator.builder()
.scoringModel(scoringModel) .scoringModel(scoringModel)
// .maxResults(3) 这个数字将来从配置取 // 默认重排后只留前 5 条,避免上下文过长
.maxResults(5)
.build(); .build();
log.info("启用重排模型: {}", knowledgeInfoVo.getRerankModel());
} else {
contentAggregator = new DefaultContentAggregator();
}
} else { } else {
contentAggregator = new DefaultContentAggregator(); contentAggregator = new DefaultContentAggregator();
} }

View File

@@ -1,5 +1,8 @@
package org.ruoyi.service.knowledge.impl; package org.ruoyi.service.knowledge.impl;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import org.ruoyi.common.core.utils.MapstructUtils; import org.ruoyi.common.core.utils.MapstructUtils;
import org.ruoyi.common.core.utils.StringUtils; import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.common.mybatis.core.page.TableDataInfo;
@@ -20,6 +23,7 @@ import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.service.knowledge.IKnowledgeInfoService;
import org.ruoyi.common.chat.service.chat.IChatModelService; import org.ruoyi.common.chat.service.chat.IChatModelService;
import org.ruoyi.service.knowledge.rerank.ScoringModelFactory;
import org.ruoyi.service.vector.VectorStoreService; import org.ruoyi.service.vector.VectorStoreService;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
@@ -44,6 +48,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
private final IKnowledgeInfoService knowledgeInfoService; private final IKnowledgeInfoService knowledgeInfoService;
private final IChatModelService chatModelService; private final IChatModelService chatModelService;
private final VectorStoreService vectorStoreService; private final VectorStoreService vectorStoreService;
private final ScoringModelFactory scoringModelFactory;
/** /**
* 查询知识片段 * 查询知识片段
@@ -178,7 +183,48 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
// 3. 执行物理检索 // 3. 执行物理检索
List<KnowledgeRetrievalVo> allResults = vectorStoreService.search(queryVectorBo); List<KnowledgeRetrievalVo> allResults = vectorStoreService.search(queryVectorBo);
// 4. 根据阈值过滤 (LangChain4j 结果 score 通常 0-1) // 初始化原始排名
for (int i = 0; i < allResults.size(); i++) {
allResults.get(i).setOriginalIndex(i);
}
// 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型)
if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) {
log.info("开始重排配置检索测试,传入模型名称: [{}]", bo.getRerankModel());
ChatModelVo rerankModelConfig = chatModelService.selectModelByName(bo.getRerankModel());
if (rerankModelConfig == null) {
log.warn("未能找到重排模型配置: [{}]", bo.getRerankModel());
} else {
ScoringModel scoringModel = scoringModelFactory.createScoringModel(rerankModelConfig);
if (scoringModel != null) {
log.info("执行重排精排,模型: {}, 供应商: {}", rerankModelConfig.getModelName(), rerankModelConfig.getProviderCode());
// 将 KnowledgeRetrievalVo 转换为 TextSegment 列表进行重排
List<TextSegment> segments = allResults.stream()
.map(res -> TextSegment.from(res.getContent()))
.collect(Collectors.toList());
Response<List<Double>> scoresResponse = scoringModel.scoreAll(segments, bo.getQuery());
List<Double> scores = scoresResponse.content();
// 更新分数并重新排序
for (int i = 0; i < allResults.size(); i++) {
KnowledgeRetrievalVo resultVo = allResults.get(i);
// 保存原始分数供前端展示对比
resultVo.setRawScore(resultVo.getScore());
if (i < scores.size()) {
resultVo.setScore(scores.get(i));
}
}
// 按重排后的分数从高到低排序
allResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
}
}
}
// 5. 根据阈值过滤
double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0; double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0;
return allResults.stream() return allResults.stream()
.filter(res -> res.getScore() >= threshold) .filter(res -> res.getScore() >= threshold)

View File

@@ -0,0 +1,98 @@
package org.ruoyi.service.knowledge.rerank;
import com.alibaba.dashscope.exception.ApiException;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.rerank.TextReRank;
import com.alibaba.dashscope.rerank.TextReRankParam;
import com.alibaba.dashscope.rerank.TextReRankResult;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
/**
* DashScope 重排模型实现 (GTE-Rerank)
* 包装了阿里云 DashScope 的 TextReRank API使其符合 LangChain4j 的 ScoringModel 标准。
*/
@Slf4j
public class DashScopeScoringModel implements ScoringModel {
private final String apiKey;
private final String modelName;
private final TextReRank rerank;
@Builder
public DashScopeScoringModel(String apiKey, String modelName) {
if (isNullOrEmpty(apiKey)) {
throw new IllegalArgumentException("DashScope API Key 不能为空");
}
this.apiKey = apiKey;
this.modelName = isNullOrEmpty(modelName) ? "gte-rerank" : modelName;
this.rerank = new TextReRank();
}
@Override
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
if (isNullOrEmpty(segments)) {
return Response.from(new ArrayList<>());
}
// 提取文本列表供阿里 SDK 使用
List<String> texts = segments.stream()
.map(TextSegment::text)
.collect(Collectors.toList());
try {
TextReRankParam param = TextReRankParam.builder()
.apiKey(apiKey)
.model(modelName)
.query(query)
.documents(texts)
.topN(texts.size())
.returnDocuments(false)
.build();
TextReRankResult result = rerank.call(param);
// 初始化分数组,默认值为 0.0
Double[] scores = new Double[texts.size()];
for (int i = 0; i < texts.size(); i++) {
scores[i] = 0.0;
}
// 根据返回结果填充对应的分数值(返回结果中包含原文索引)
result.getOutput().getResults().forEach(item -> {
if (item.getIndex() != null && item.getIndex() < texts.size()) {
scores[item.getIndex()] = item.getRelevanceScore();
}
});
List<Double> scoreList = new ArrayList<>();
for (Double s : scores) {
scoreList.add(s);
}
return Response.from(scoreList);
} catch (ApiException | NoApiKeyException | InputRequiredException e) {
log.error("DashScope 重排处理出错: {}", e.getMessage(), e);
throw new RuntimeException("调用 DashScope 重排服务失败", e);
}
}
@Override
public Response<Double> score(TextSegment segment, String query) {
List<TextSegment> segments = new ArrayList<>();
segments.add(segment);
Response<List<Double>> response = scoreAll(segments, query);
return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason());
}
}

View File

@@ -27,10 +27,27 @@ public class ScoringModelFactory {
} }
String providerCode = rerankModelConfig.getProviderCode(); String providerCode = rerankModelConfig.getProviderCode();
log.info("初始化重排模型,供应商代码: {}", providerCode); log.info("初始化重排模型,供应商代码: {}, 模型名称: {}", providerCode, rerankModelConfig.getModelName());
// TODO: 在这里通过 switch 或反射具体实例化支持的各种 ScoringModel (例如 CohereScoringModel, DascScope 等) try {
// 目前返回 null 代表暂时没有加载特定的重排底座这不会影响流程Aggregator 会忽略它返回原样结果 if ("alibailian".equalsIgnoreCase(providerCode)) {
return DashScopeScoringModel.builder()
.apiKey(rerankModelConfig.getApiKey())
.modelName(rerankModelConfig.getModelName())
.build();
}
if ("siliconflow".equalsIgnoreCase(providerCode)) {
return SiliconFlowScoringModel.builder()
.apiKey(rerankModelConfig.getApiKey())
.modelName(rerankModelConfig.getModelName())
// 如果后台配置了不同的 API Host可以在此传递否则使用默认值
.baseUrl(rerankModelConfig.getApiHost())
.build();
}
} catch (Exception e) {
log.error("创建重排模型失败: {}", e.getMessage(), e);
}
return null; return null;
} }

View File

@@ -0,0 +1,155 @@
package org.ruoyi.service.knowledge.rerank;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.ruoyi.common.json.utils.JsonUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
/**
* SiliconFlow 重排模型实现
* 适配硅基流动的 /v1/rerank 接口
*/
@Slf4j
public class SiliconFlowScoringModel implements ScoringModel {
private final String apiKey;
private final String modelName;
private final String baseUrl;
private final OkHttpClient client;
@Builder
public SiliconFlowScoringModel(String apiKey, String modelName, String baseUrl) {
if (isNullOrEmpty(apiKey)) {
throw new IllegalArgumentException("SiliconFlow API Key 不能为空");
}
this.apiKey = apiKey;
this.modelName = isNullOrEmpty(modelName) ? "BAAI/bge-reranker-v2-m3" : modelName;
// 鲁棒性处理:自动补全 /rerank 路径
String finalUrl = baseUrl;
if (isNullOrEmpty(finalUrl)) {
finalUrl = "https://api.siliconflow.cn/v1/rerank";
} else {
// 如果用户只填了基础路径 https://api.siliconflow.cn/v1自动补全成 https://api.siliconflow.cn/v1/rerank
if (finalUrl.endsWith("/v1")) {
finalUrl = finalUrl + "/rerank";
} else if (!finalUrl.endsWith("/rerank")) {
// 如果没有以 /rerank 结尾也不以斜杠结尾,尝试拼接
finalUrl = finalUrl.endsWith("/") ? finalUrl + "rerank" : finalUrl + "/rerank";
}
}
this.baseUrl = finalUrl;
log.info("初始化 SiliconFlow 重排模型: URL=[{}], Model=[{}]", this.baseUrl, this.modelName);
this.client = new OkHttpClient.Builder()
.connectTimeout(60, TimeUnit.SECONDS)
.readTimeout(60, TimeUnit.SECONDS)
.build();
}
@Override
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
if (isNullOrEmpty(segments)) {
return Response.from(new ArrayList<>());
}
List<String> texts = segments.stream()
.map(TextSegment::text)
.collect(Collectors.toList());
RerankRequest requestBody = new RerankRequest();
requestBody.setModel(modelName);
requestBody.setQuery(query);
requestBody.setDocuments(texts);
requestBody.setTop_n(texts.size());
requestBody.setReturn_documents(false);
String json = JsonUtils.toJsonString(requestBody);
RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8"));
Request request = new Request.Builder()
.url(baseUrl)
.header("Authorization", "Bearer " + apiKey)
.post(body)
.build();
try (okhttp3.Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) {
String errorBody = response.body() != null ? response.body().string() : "unknown error";
log.error("SiliconFlow Rerank API 调用失败: code={}, body={}", response.code(), errorBody);
throw new RuntimeException("SiliconFlow Rerank API 调用失败: " + response.code());
}
String responseBody = response.body().string();
RerankResponse rerankResponse = JsonUtils.parseObject(responseBody, RerankResponse.class);
if (rerankResponse == null || rerankResponse.getResults() == null) {
return Response.from(new ArrayList<>());
}
// 初始化分数组,默认值为 0.0
Double[] scores = new Double[texts.size()];
for (int i = 0; i < texts.size(); i++) {
scores[i] = 0.0;
}
// 填充得分
rerankResponse.getResults().forEach(item -> {
if (item.getIndex() != null && item.getIndex() < texts.size()) {
scores[item.getIndex()] = item.getRelevance_score();
}
});
List<Double> scoreList = new ArrayList<>();
for (Double s : scores) {
scoreList.add(s);
}
return Response.from(scoreList);
} catch (IOException e) {
log.error("SiliconFlow Rerank 网络请求异常", e);
throw new RuntimeException("SiliconFlow Rerank 网络请求异常", e);
}
}
@Override
public Response<Double> score(TextSegment segment, String query) {
List<TextSegment> segments = new ArrayList<>();
segments.add(segment);
Response<List<Double>> response = scoreAll(segments, query);
return Response.from(response.content().get(0));
}
@Data
public static class RerankRequest {
private String model;
private String query;
private List<String> documents;
private Integer top_n;
private Boolean return_documents;
}
@Data
public static class RerankResponse {
private List<RerankResultItem> results;
}
@Data
public static class RerankResultItem {
private Integer index;
private Double relevance_score;
}
}