fix(rag): 修复合并重复,重排模型新增硅基流动供应商

This commit is contained in:
RobustH
2026-04-21 22:41:00 +08:00
parent e7f53fd55f
commit 1b50c7f9f1
8 changed files with 254 additions and 354 deletions

View File

@@ -98,12 +98,19 @@ public class KnowledgeInfoBo extends BaseEntity {
private Double rerankScoreThreshold;
/**
* 是否启用混合检索0 否 1是
*/
private Integer enableHybrid;
/**
* 混合检索权重 (0.0-1.0)
*/
private Double hybridAlpha;
/**
* 备注
*/
private String remark;
}

View File

@@ -98,6 +98,16 @@ public class KnowledgeInfo extends BaseEntity {
*/
private Double rerankScoreThreshold;
/**
* 是否启用混合检索0 否 1是
*/
private Integer enableHybrid;
/**
* 混合检索权重 (0.0-1.0)
*/
private Double hybridAlpha;
/**
* 备注
*/

View File

@@ -118,6 +118,24 @@ public class KnowledgeInfoVo implements Serializable {
@ExcelProperty(value = "重排序分数阈值")
private Double rerankScoreThreshold;
/**
* 是否启用混合检索0 否 1是
*/
@ExcelProperty(value = "是否启用混合检索")
private Integer enableHybrid;
/**
* 混合检索权重 (0.0-1.0)
*/
@ExcelProperty(value = "混合检索权重")
private Double hybridAlpha;
/**
* 文档数量
*/
@ExcelProperty(value = "文档数量")
private Integer documentCount;
/**
* 备注
*/

View File

@@ -1,39 +1,36 @@
package org.ruoyi.service.knowledge.impl;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import org.ruoyi.common.core.utils.MapstructUtils;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
import org.ruoyi.common.mybatis.core.page.PageQuery;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
import org.ruoyi.common.chat.service.chat.IChatModelService;
import org.ruoyi.common.core.utils.MapstructUtils;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.mybatis.core.page.PageQuery;
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
import org.ruoyi.domain.bo.rerank.RerankRequest;
import org.ruoyi.domain.bo.rerank.RerankResult;
import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import org.ruoyi.factory.RerankModelFactory;
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
import org.ruoyi.common.chat.service.chat.IChatModelService;
import org.ruoyi.service.knowledge.rerank.ScoringModelFactory;
import org.ruoyi.service.rerank.RerankModelService;
import org.ruoyi.service.vector.VectorStoreService;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.stream.Collectors;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.List;
import java.util.Map;
import java.util.Collection;
/**
* 知识片段Service业务层处理
@@ -50,7 +47,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
private final IKnowledgeInfoService knowledgeInfoService;
private final IChatModelService chatModelService;
private final VectorStoreService vectorStoreService;
private final ScoringModelFactory scoringModelFactory;
private final RerankModelFactory rerankModelFactory;
/**
* 查询知识片段
@@ -231,37 +228,38 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
// 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型)
if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) {
log.info("开始重排配置检索测试,传入模型名称: [{}]", bo.getRerankModel());
ChatModelVo rerankModelConfig = chatModelService.selectModelByName(bo.getRerankModel());
log.info("开始重排精排,模型: [{}]", bo.getRerankModel());
try {
RerankModelService rerankModel = rerankModelFactory.createModel(bo.getRerankModel());
if (rerankModelConfig == null) {
log.warn("未能找到重排模型配置: [{}]", bo.getRerankModel());
} else {
ScoringModel scoringModel = scoringModelFactory.createScoringModel(rerankModelConfig);
if (scoringModel != null) {
log.info("执行重排精排,模型: {}, 供应商: {}", rerankModelConfig.getModelName(), rerankModelConfig.getProviderCode());
// 将 KnowledgeRetrievalVo 转换为 TextSegment 列表进行重排
List<TextSegment> segments = allResults.stream()
.map(res -> TextSegment.from(res.getContent()))
List<String> contents = allResults.stream()
.map(KnowledgeRetrievalVo::getContent)
.collect(Collectors.toList());
Response<List<Double>> scoresResponse = scoringModel.scoreAll(segments, bo.getQuery());
List<Double> scores = scoresResponse.content();
RerankRequest rerankRequest = RerankRequest.builder()
.query(bo.getQuery())
.documents(contents)
.topN(contents.size())
.returnDocuments(false)
.build();
// 更新分数并重新排序
for (int i = 0; i < allResults.size(); i++) {
KnowledgeRetrievalVo resultVo = allResults.get(i);
// 保存原始分数供前端展示对比
RerankResult rerankResult = rerankModel.rerank(rerankRequest);
// 将重排分数写回,并记录原始分数供前端对比
for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) {
if (doc.getIndex() != null && doc.getIndex() < allResults.size()) {
KnowledgeRetrievalVo resultVo = allResults.get(doc.getIndex());
resultVo.setRawScore(resultVo.getScore());
if (i < scores.size()) {
resultVo.setScore(scores.get(i));
}
resultVo.setScore(doc.getRelevanceScore());
}
// 按重排后的分数从高到低排序
allResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
}
// 按重排后的分数从高到低排序
allResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
log.info("重排精排完成,结果数: {}", allResults.size());
} catch (Exception e) {
log.error("重排精排执行失败,已跳过重排步骤: {}", e.getMessage(), e);
}
}

View File

@@ -1,98 +0,0 @@
package org.ruoyi.service.knowledge.rerank;
import com.alibaba.dashscope.exception.ApiException;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.rerank.TextReRank;
import com.alibaba.dashscope.rerank.TextReRankParam;
import com.alibaba.dashscope.rerank.TextReRankResult;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
/**
* DashScope 重排模型实现 (GTE-Rerank)
* 包装了阿里云 DashScope 的 TextReRank API使其符合 LangChain4j 的 ScoringModel 标准。
*/
@Slf4j
public class DashScopeScoringModel implements ScoringModel {
private final String apiKey;
private final String modelName;
private final TextReRank rerank;
@Builder
public DashScopeScoringModel(String apiKey, String modelName) {
if (isNullOrEmpty(apiKey)) {
throw new IllegalArgumentException("DashScope API Key 不能为空");
}
this.apiKey = apiKey;
this.modelName = isNullOrEmpty(modelName) ? "gte-rerank" : modelName;
this.rerank = new TextReRank();
}
@Override
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
if (isNullOrEmpty(segments)) {
return Response.from(new ArrayList<>());
}
// 提取文本列表供阿里 SDK 使用
List<String> texts = segments.stream()
.map(TextSegment::text)
.collect(Collectors.toList());
try {
TextReRankParam param = TextReRankParam.builder()
.apiKey(apiKey)
.model(modelName)
.query(query)
.documents(texts)
.topN(texts.size())
.returnDocuments(false)
.build();
TextReRankResult result = rerank.call(param);
// 初始化分数组,默认值为 0.0
Double[] scores = new Double[texts.size()];
for (int i = 0; i < texts.size(); i++) {
scores[i] = 0.0;
}
// 根据返回结果填充对应的分数值(返回结果中包含原文索引)
result.getOutput().getResults().forEach(item -> {
if (item.getIndex() != null && item.getIndex() < texts.size()) {
scores[item.getIndex()] = item.getRelevanceScore();
}
});
List<Double> scoreList = new ArrayList<>();
for (Double s : scores) {
scoreList.add(s);
}
return Response.from(scoreList);
} catch (ApiException | NoApiKeyException | InputRequiredException e) {
log.error("DashScope 重排处理出错: {}", e.getMessage(), e);
throw new RuntimeException("调用 DashScope 重排服务失败", e);
}
}
@Override
public Response<Double> score(TextSegment segment, String query) {
List<TextSegment> segments = new ArrayList<>();
segments.add(segment);
Response<List<Double>> response = scoreAll(segments, query);
return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason());
}
}

View File

@@ -1,54 +0,0 @@
package org.ruoyi.service.knowledge.rerank;
import dev.langchain4j.model.scoring.ScoringModel;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
import org.springframework.stereotype.Component;
/**
* 重排模型提供商工厂
* 用于将来无缝拓展硅基流动、百炼等支持重排的模型厂商
*
* @author RobustH
*/
@Slf4j
@Component
public class ScoringModelFactory {
/**
* 根据后台传递的模型配置创建具体的重排模型
*
* @param rerankModelConfig 重排模型的配置 (例如其 providerCode, apiUrl, apiKey 等)
* @return 标准的 LangChain4j ScoringModel
*/
public ScoringModel createScoringModel(ChatModelVo rerankModelConfig) {
if (rerankModelConfig == null) {
return null;
}
String providerCode = rerankModelConfig.getProviderCode();
log.info("初始化重排模型,供应商代码: {}, 模型名称: {}", providerCode, rerankModelConfig.getModelName());
try {
if ("alibailian".equalsIgnoreCase(providerCode)) {
return DashScopeScoringModel.builder()
.apiKey(rerankModelConfig.getApiKey())
.modelName(rerankModelConfig.getModelName())
.build();
}
if ("siliconflow".equalsIgnoreCase(providerCode)) {
return SiliconFlowScoringModel.builder()
.apiKey(rerankModelConfig.getApiKey())
.modelName(rerankModelConfig.getModelName())
// 如果后台配置了不同的 API Host可以在此传递否则使用默认值
.baseUrl(rerankModelConfig.getApiHost())
.build();
}
} catch (Exception e) {
log.error("创建重排模型失败: {}", e.getMessage(), e);
}
return null;
}
}

View File

@@ -1,155 +0,0 @@
package org.ruoyi.service.knowledge.rerank;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.ruoyi.common.json.utils.JsonUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
/**
* SiliconFlow 重排模型实现
* 适配硅基流动的 /v1/rerank 接口
*/
@Slf4j
public class SiliconFlowScoringModel implements ScoringModel {
private final String apiKey;
private final String modelName;
private final String baseUrl;
private final OkHttpClient client;
@Builder
public SiliconFlowScoringModel(String apiKey, String modelName, String baseUrl) {
if (isNullOrEmpty(apiKey)) {
throw new IllegalArgumentException("SiliconFlow API Key 不能为空");
}
this.apiKey = apiKey;
this.modelName = isNullOrEmpty(modelName) ? "BAAI/bge-reranker-v2-m3" : modelName;
// 鲁棒性处理:自动补全 /rerank 路径
String finalUrl = baseUrl;
if (isNullOrEmpty(finalUrl)) {
finalUrl = "https://api.siliconflow.cn/v1/rerank";
} else {
// 如果用户只填了基础路径 https://api.siliconflow.cn/v1自动补全成 https://api.siliconflow.cn/v1/rerank
if (finalUrl.endsWith("/v1")) {
finalUrl = finalUrl + "/rerank";
} else if (!finalUrl.endsWith("/rerank")) {
// 如果没有以 /rerank 结尾也不以斜杠结尾,尝试拼接
finalUrl = finalUrl.endsWith("/") ? finalUrl + "rerank" : finalUrl + "/rerank";
}
}
this.baseUrl = finalUrl;
log.info("初始化 SiliconFlow 重排模型: URL=[{}], Model=[{}]", this.baseUrl, this.modelName);
this.client = new OkHttpClient.Builder()
.connectTimeout(60, TimeUnit.SECONDS)
.readTimeout(60, TimeUnit.SECONDS)
.build();
}
@Override
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
if (isNullOrEmpty(segments)) {
return Response.from(new ArrayList<>());
}
List<String> texts = segments.stream()
.map(TextSegment::text)
.collect(Collectors.toList());
RerankRequest requestBody = new RerankRequest();
requestBody.setModel(modelName);
requestBody.setQuery(query);
requestBody.setDocuments(texts);
requestBody.setTop_n(texts.size());
requestBody.setReturn_documents(false);
String json = JsonUtils.toJsonString(requestBody);
RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8"));
Request request = new Request.Builder()
.url(baseUrl)
.header("Authorization", "Bearer " + apiKey)
.post(body)
.build();
try (okhttp3.Response response = client.newCall(request).execute()) {
if (!response.isSuccessful()) {
String errorBody = response.body() != null ? response.body().string() : "unknown error";
log.error("SiliconFlow Rerank API 调用失败: code={}, body={}", response.code(), errorBody);
throw new RuntimeException("SiliconFlow Rerank API 调用失败: " + response.code());
}
String responseBody = response.body().string();
RerankResponse rerankResponse = JsonUtils.parseObject(responseBody, RerankResponse.class);
if (rerankResponse == null || rerankResponse.getResults() == null) {
return Response.from(new ArrayList<>());
}
// 初始化分数组,默认值为 0.0
Double[] scores = new Double[texts.size()];
for (int i = 0; i < texts.size(); i++) {
scores[i] = 0.0;
}
// 填充得分
rerankResponse.getResults().forEach(item -> {
if (item.getIndex() != null && item.getIndex() < texts.size()) {
scores[item.getIndex()] = item.getRelevance_score();
}
});
List<Double> scoreList = new ArrayList<>();
for (Double s : scores) {
scoreList.add(s);
}
return Response.from(scoreList);
} catch (IOException e) {
log.error("SiliconFlow Rerank 网络请求异常", e);
throw new RuntimeException("SiliconFlow Rerank 网络请求异常", e);
}
}
@Override
public Response<Double> score(TextSegment segment, String query) {
List<TextSegment> segments = new ArrayList<>();
segments.add(segment);
Response<List<Double>> response = scoreAll(segments, query);
return Response.from(response.content().get(0));
}
@Data
public static class RerankRequest {
private String model;
private String query;
private List<String> documents;
private Integer top_n;
private Boolean return_documents;
}
@Data
public static class RerankResponse {
private List<RerankResultItem> results;
}
@Data
public static class RerankResultItem {
private Integer index;
private Double relevance_score;
}
}

View File

@@ -0,0 +1,174 @@
package org.ruoyi.service.rerank.impl;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
import org.ruoyi.domain.bo.rerank.RerankRequest;
import org.ruoyi.domain.bo.rerank.RerankResult;
import org.ruoyi.service.rerank.RerankModelService;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
/**
* 硅基流动重排序模型实现
* 适配硅基流动的 /v1/rerank 接口
*
* @author RobustH
* @date 2026-04-21
*/
@Slf4j
@Component("siliconflowRerank")
public class SiliconFlowRerankModelService implements RerankModelService {
private static final String DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1/rerank";
private final OkHttpClient okHttpClient;
private final ObjectMapper objectMapper = new ObjectMapper()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
private ChatModelVo chatModelVo;
public SiliconFlowRerankModelService() {
this.okHttpClient = new OkHttpClient.Builder()
.connectTimeout(30, TimeUnit.SECONDS)
.readTimeout(60, TimeUnit.SECONDS)
.writeTimeout(30, TimeUnit.SECONDS)
.build();
}
@Override
public void configure(ChatModelVo config) {
this.chatModelVo = config;
}
@Override
public RerankResult rerank(RerankRequest rerankRequest) {
long startTime = System.currentTimeMillis();
try {
String url = buildUrl();
String requestJson = buildRequestJson(rerankRequest);
RequestBody body = RequestBody.create(requestJson, MediaType.get("application/json"));
Request httpRequest = new Request.Builder()
.url(url)
.addHeader("Authorization", "Bearer " + chatModelVo.getApiKey())
.addHeader("Content-Type", "application/json")
.post(body)
.build();
log.info("硅基流动重排序请求: model={}, url={}", chatModelVo.getModelName(), url);
try (Response response = okHttpClient.newCall(httpRequest).execute()) {
if (!response.isSuccessful()) {
String err = response.body() != null ? response.body().string() : "无错误信息";
throw new IllegalArgumentException("硅基流动 Rerank API 调用失败: " + response.code() + " - " + err);
}
ResponseBody responseBody = response.body();
if (responseBody == null) {
throw new IllegalArgumentException("响应体为空");
}
SiliconFlowRerankResponse rerankResponse = objectMapper.readValue(
responseBody.string(), SiliconFlowRerankResponse.class);
return buildRerankResult(rerankResponse, rerankRequest.getDocuments(),
System.currentTimeMillis() - startTime);
}
} catch (Exception e) {
log.error("硅基流动重排序失败: {}", e.getMessage(), e);
throw new RuntimeException("重排序服务调用失败: " + e.getMessage(), e);
}
}
/**
* 构建请求 URL鲁棒处理 API Host 末尾路径
*/
private String buildUrl() {
String apiHost = chatModelVo.getApiHost();
if (apiHost == null || apiHost.isBlank()) {
return DEFAULT_BASE_URL;
}
if (apiHost.endsWith("/rerank")) {
return apiHost;
}
if (apiHost.endsWith("/v1")) {
return apiHost + "/rerank";
}
return apiHost.endsWith("/") ? apiHost + "rerank" : apiHost + "/rerank";
}
/**
* 构建请求体 JSON
*/
private String buildRequestJson(RerankRequest rerankRequest) throws IOException {
SiliconFlowRerankRequest request = new SiliconFlowRerankRequest();
request.setModel(chatModelVo.getModelName());
request.setQuery(rerankRequest.getQuery());
request.setDocuments(rerankRequest.getDocuments());
request.setTop_n(rerankRequest.getTopN() != null ? rerankRequest.getTopN() : rerankRequest.getDocuments().size());
request.setReturn_documents(rerankRequest.getReturnDocuments() != null ? rerankRequest.getReturnDocuments() : false);
return objectMapper.writeValueAsString(request);
}
/**
* 构建标准 RerankResult
*/
private RerankResult buildRerankResult(SiliconFlowRerankResponse response,
List<String> originalDocuments, long durationMs) {
Double[] scores = new Double[originalDocuments.size()];
for (int i = 0; i < scores.length; i++) {
scores[i] = 0.0;
}
List<RerankResult.RerankDocument> docs = new ArrayList<>();
if (response != null && response.getResults() != null) {
response.getResults().forEach(item -> {
if (item.getIndex() != null && item.getIndex() < originalDocuments.size()) {
scores[item.getIndex()] = item.getRelevance_score();
docs.add(RerankResult.RerankDocument.builder()
.index(item.getIndex())
.relevanceScore(item.getRelevance_score())
.document(originalDocuments.get(item.getIndex()))
.build());
}
});
}
return RerankResult.builder()
.documents(docs)
.totalDocuments(originalDocuments.size())
.durationMs(durationMs)
.build();
}
// ==================== 内部 DTO ====================
@Data
static class SiliconFlowRerankRequest {
private String model;
private String query;
private List<String> documents;
private Integer top_n;
private Boolean return_documents;
}
@Data
static class SiliconFlowRerankResponse {
private List<SiliconFlowRerankResultItem> results;
}
@Data
static class SiliconFlowRerankResultItem {
private Integer index;
private Double relevance_score;
}
}