From 1208c46cca20c1a795dab0c250a32caecb9c9cd8 Mon Sep 17 00:00:00 2001 From: RobustH <1511209518@qq.com> Date: Tue, 14 Apr 2026 01:40:28 +0800 Subject: [PATCH] =?UTF-8?q?feat(rag):=20=E9=9B=86=E6=88=90=E7=A1=85?= =?UTF-8?q?=E5=9F=BA=E6=B5=81=E5=8A=A8=E3=80=81=E9=98=BF=E9=87=8C=E7=99=BE?= =?UTF-8?q?=E7=82=BC=E9=87=8D=E6=8E=92=E6=A8=A1=E5=9E=8B=E5=B9=B6=E5=85=A8?= =?UTF-8?q?=E6=96=B9=E4=BD=8D=E5=A2=9E=E5=BC=BA=E6=A3=80=E7=B4=A2=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E4=BD=93=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../bo/knowledge/KnowledgeFragmentBo.java | 10 ++ .../domain/bo/knowledge/KnowledgeInfoBo.java | 10 ++ .../entity/knowledge/KnowledgeInfo.java | 10 ++ .../domain/vo/knowledge/KnowledgeInfoVo.java | 13 ++ .../vo/knowledge/KnowledgeRetrievalVo.java | 10 ++ .../service/chat/impl/ChatServiceFacade.java | 28 ++-- .../impl/KnowledgeFragmentServiceImpl.java | 48 +++++- .../rerank/DashScopeScoringModel.java | 98 +++++++++++ .../knowledge/rerank/ScoringModelFactory.java | 23 ++- .../rerank/SiliconFlowScoringModel.java | 155 ++++++++++++++++++ 10 files changed, 389 insertions(+), 16 deletions(-) create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java index e1925028..1508462f 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java @@ -69,4 +69,14 @@ public class KnowledgeFragmentBo extends BaseEntity { */ private Double threshold; + /** + * 是否启用重排 + */ + private Boolean enableRerank; + + /** + * 重排模型名称 + */ + private String rerankModel; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java index 113a2847..8629018a 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java @@ -77,6 +77,16 @@ public class KnowledgeInfoBo extends BaseEntity { */ private String embeddingModel; + /** + * 重排模型 + */ + private String rerankModel; + + /** + * 是否启用重排(0 否 1 是) + */ + private Integer enableRerank; + /** * 备注 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java index a51cf7da..a5211e69 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java @@ -78,6 +78,16 @@ public class KnowledgeInfo extends BaseEntity { */ private String embeddingModel; + /** + * 重排模型 + */ + private String rerankModel; + + /** + * 是否启用重排(0 否 1 是) + */ + private Integer enableRerank; + /** * 备注 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java index 53e136dd..e65444e7 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java @@ -94,6 +94,19 @@ public class KnowledgeInfoVo implements Serializable { @ExcelProperty(value = "向量模型") private String embeddingModel; + /** + * 重排模型 + */ + @ExcelProperty(value = "重排模型") + private String rerankModel; + + /** + * 是否启用重排(0 否 1 是) + */ + @ExcelProperty(value = "是否启用重排", converter = ExcelDictConvert.class) + @ExcelDictFormat(readConverterExp = "0=否,1=是") + private Integer enableRerank; + /** * 备注 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java index daeaae59..95c8e4cf 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java @@ -28,6 +28,16 @@ public class KnowledgeRetrievalVo implements Serializable { */ private Double score; + /** + * 原始检索排名 (重排前) + */ + private Integer originalIndex; + + /** + * 原始检索得分 (重排前) + */ + private Double rawScore; + /** * 来源文档名称 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java index 2332ca84..9ca93d10 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java @@ -377,20 +377,24 @@ public class ChatServiceFacade implements IChatService { CustomVectorRetriever retriever = new CustomVectorRetriever( vectorStoreService, knowledgeInfoVo, chatModel); - // 2. 获取和构建重排模型聚合器(Aggregator) - // 假设已在 KnowledgeInfoVo 等加入 getRerankModelConfig/getRerankModel 等,这里演示通用逻辑 - // 若无重排需求,使用 DefaultContentAggregator 或无 ScoringModel 的聚合器 + // 2. 构建重排聚合器 (Aggregator) ContentAggregator contentAggregator; - // TODO: 一旦实体类实现了重排模型的支持,此处可以从数据库读出: - // ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel()); - ChatModelVo scoringModelConfig = null; // 当前暂无对应配置字段 + if (knowledgeInfoVo.getEnableRerank() != null && knowledgeInfoVo.getEnableRerank() == 1 + && knowledgeInfoVo.getRerankModel() != null) { - ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig); - if (scoringModel != null) { - contentAggregator = ReRankingContentAggregator.builder() - .scoringModel(scoringModel) - // .maxResults(3) 这个数字将来从配置取 - .build(); + ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel()); + ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig); + + if (scoringModel != null) { + contentAggregator = ReRankingContentAggregator.builder() + .scoringModel(scoringModel) + // 默认重排后只留前 5 条,避免上下文过长 + .maxResults(5) + .build(); + log.info("启用重排模型: {}", knowledgeInfoVo.getRerankModel()); + } else { + contentAggregator = new DefaultContentAggregator(); + } } else { contentAggregator = new DefaultContentAggregator(); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java index 9c2477df..68ccf0bf 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java @@ -1,5 +1,8 @@ package org.ruoyi.service.knowledge.impl; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.scoring.ScoringModel; import org.ruoyi.common.core.utils.MapstructUtils; import org.ruoyi.common.core.utils.StringUtils; import org.ruoyi.common.mybatis.core.page.TableDataInfo; @@ -20,6 +23,7 @@ import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.common.chat.service.chat.IChatModelService; +import org.ruoyi.service.knowledge.rerank.ScoringModelFactory; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; import java.util.ArrayList; @@ -44,6 +48,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { private final IKnowledgeInfoService knowledgeInfoService; private final IChatModelService chatModelService; private final VectorStoreService vectorStoreService; + private final ScoringModelFactory scoringModelFactory; /** * 查询知识片段 @@ -178,7 +183,48 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { // 3. 执行物理检索 List allResults = vectorStoreService.search(queryVectorBo); - // 4. 根据阈值过滤 (LangChain4j 结果 score 通常 0-1) + // 初始化原始排名 + for (int i = 0; i < allResults.size(); i++) { + allResults.get(i).setOriginalIndex(i); + } + + // 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型) + if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) { + log.info("开始重排配置检索测试,传入模型名称: [{}]", bo.getRerankModel()); + ChatModelVo rerankModelConfig = chatModelService.selectModelByName(bo.getRerankModel()); + + if (rerankModelConfig == null) { + log.warn("未能找到重排模型配置: [{}]", bo.getRerankModel()); + } else { + ScoringModel scoringModel = scoringModelFactory.createScoringModel(rerankModelConfig); + if (scoringModel != null) { + log.info("执行重排精排,模型: {}, 供应商: {}", rerankModelConfig.getModelName(), rerankModelConfig.getProviderCode()); + + // 将 KnowledgeRetrievalVo 转换为 TextSegment 列表进行重排 + List segments = allResults.stream() + .map(res -> TextSegment.from(res.getContent())) + .collect(Collectors.toList()); + + Response> scoresResponse = scoringModel.scoreAll(segments, bo.getQuery()); + List scores = scoresResponse.content(); + + // 更新分数并重新排序 + for (int i = 0; i < allResults.size(); i++) { + KnowledgeRetrievalVo resultVo = allResults.get(i); + // 保存原始分数供前端展示对比 + resultVo.setRawScore(resultVo.getScore()); + if (i < scores.size()) { + resultVo.setScore(scores.get(i)); + } + } + + // 按重排后的分数从高到低排序 + allResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); + } + } + } + + // 5. 根据阈值过滤 double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0; return allResults.stream() .filter(res -> res.getScore() >= threshold) diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java new file mode 100644 index 00000000..1086df8e --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java @@ -0,0 +1,98 @@ +package org.ruoyi.service.knowledge.rerank; + +import com.alibaba.dashscope.exception.ApiException; +import com.alibaba.dashscope.exception.InputRequiredException; +import com.alibaba.dashscope.exception.NoApiKeyException; +import com.alibaba.dashscope.rerank.TextReRank; +import com.alibaba.dashscope.rerank.TextReRankParam; +import com.alibaba.dashscope.rerank.TextReRankResult; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.scoring.ScoringModel; +import lombok.Builder; +import lombok.extern.slf4j.Slf4j; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.Utils.isNullOrEmpty; + +/** + * DashScope 重排模型实现 (GTE-Rerank) + * 包装了阿里云 DashScope 的 TextReRank API,使其符合 LangChain4j 的 ScoringModel 标准。 + */ +@Slf4j +public class DashScopeScoringModel implements ScoringModel { + + private final String apiKey; + private final String modelName; + private final TextReRank rerank; + + @Builder + public DashScopeScoringModel(String apiKey, String modelName) { + if (isNullOrEmpty(apiKey)) { + throw new IllegalArgumentException("DashScope API Key 不能为空"); + } + this.apiKey = apiKey; + this.modelName = isNullOrEmpty(modelName) ? "gte-rerank" : modelName; + this.rerank = new TextReRank(); + } + + @Override + public Response> scoreAll(List segments, String query) { + if (isNullOrEmpty(segments)) { + return Response.from(new ArrayList<>()); + } + + // 提取文本列表供阿里 SDK 使用 + List texts = segments.stream() + .map(TextSegment::text) + .collect(Collectors.toList()); + + try { + TextReRankParam param = TextReRankParam.builder() + .apiKey(apiKey) + .model(modelName) + .query(query) + .documents(texts) + .topN(texts.size()) + .returnDocuments(false) + .build(); + + TextReRankResult result = rerank.call(param); + + // 初始化分数组,默认值为 0.0 + Double[] scores = new Double[texts.size()]; + for (int i = 0; i < texts.size(); i++) { + scores[i] = 0.0; + } + + // 根据返回结果填充对应的分数值(返回结果中包含原文索引) + result.getOutput().getResults().forEach(item -> { + if (item.getIndex() != null && item.getIndex() < texts.size()) { + scores[item.getIndex()] = item.getRelevanceScore(); + } + }); + + List scoreList = new ArrayList<>(); + for (Double s : scores) { + scoreList.add(s); + } + + return Response.from(scoreList); + + } catch (ApiException | NoApiKeyException | InputRequiredException e) { + log.error("DashScope 重排处理出错: {}", e.getMessage(), e); + throw new RuntimeException("调用 DashScope 重排服务失败", e); + } + } + + @Override + public Response score(TextSegment segment, String query) { + List segments = new ArrayList<>(); + segments.add(segment); + Response> response = scoreAll(segments, query); + return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason()); + } +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java index 5f28b9c2..a3844d66 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java @@ -27,10 +27,27 @@ public class ScoringModelFactory { } String providerCode = rerankModelConfig.getProviderCode(); - log.info("初始化重排模型,供应商代码: {}", providerCode); + log.info("初始化重排模型,供应商代码: {}, 模型名称: {}", providerCode, rerankModelConfig.getModelName()); - // TODO: 在这里通过 switch 或反射具体实例化支持的各种 ScoringModel (例如 CohereScoringModel, DascScope 等) - // 目前返回 null 代表暂时没有加载特定的重排底座,这不会影响流程,Aggregator 会忽略它返回原样结果 + try { + if ("alibailian".equalsIgnoreCase(providerCode)) { + return DashScopeScoringModel.builder() + .apiKey(rerankModelConfig.getApiKey()) + .modelName(rerankModelConfig.getModelName()) + .build(); + } + + if ("siliconflow".equalsIgnoreCase(providerCode)) { + return SiliconFlowScoringModel.builder() + .apiKey(rerankModelConfig.getApiKey()) + .modelName(rerankModelConfig.getModelName()) + // 如果后台配置了不同的 API Host,可以在此传递,否则使用默认值 + .baseUrl(rerankModelConfig.getApiHost()) + .build(); + } + } catch (Exception e) { + log.error("创建重排模型失败: {}", e.getMessage(), e); + } return null; } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java new file mode 100644 index 00000000..ceae578e --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java @@ -0,0 +1,155 @@ +package org.ruoyi.service.knowledge.rerank; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.scoring.ScoringModel; +import lombok.Builder; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import okhttp3.*; +import org.ruoyi.common.json.utils.JsonUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.Utils.isNullOrEmpty; + +/** + * SiliconFlow 重排模型实现 + * 适配硅基流动的 /v1/rerank 接口 + */ +@Slf4j +public class SiliconFlowScoringModel implements ScoringModel { + + private final String apiKey; + private final String modelName; + private final String baseUrl; + private final OkHttpClient client; + + @Builder + public SiliconFlowScoringModel(String apiKey, String modelName, String baseUrl) { + if (isNullOrEmpty(apiKey)) { + throw new IllegalArgumentException("SiliconFlow API Key 不能为空"); + } + this.apiKey = apiKey; + this.modelName = isNullOrEmpty(modelName) ? "BAAI/bge-reranker-v2-m3" : modelName; + + // 鲁棒性处理:自动补全 /rerank 路径 + String finalUrl = baseUrl; + if (isNullOrEmpty(finalUrl)) { + finalUrl = "https://api.siliconflow.cn/v1/rerank"; + } else { + // 如果用户只填了基础路径 https://api.siliconflow.cn/v1,自动补全成 https://api.siliconflow.cn/v1/rerank + if (finalUrl.endsWith("/v1")) { + finalUrl = finalUrl + "/rerank"; + } else if (!finalUrl.endsWith("/rerank")) { + // 如果没有以 /rerank 结尾也不以斜杠结尾,尝试拼接 + finalUrl = finalUrl.endsWith("/") ? finalUrl + "rerank" : finalUrl + "/rerank"; + } + } + this.baseUrl = finalUrl; + log.info("初始化 SiliconFlow 重排模型: URL=[{}], Model=[{}]", this.baseUrl, this.modelName); + + this.client = new OkHttpClient.Builder() + .connectTimeout(60, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .build(); + } + + @Override + public Response> scoreAll(List segments, String query) { + if (isNullOrEmpty(segments)) { + return Response.from(new ArrayList<>()); + } + + List texts = segments.stream() + .map(TextSegment::text) + .collect(Collectors.toList()); + + RerankRequest requestBody = new RerankRequest(); + requestBody.setModel(modelName); + requestBody.setQuery(query); + requestBody.setDocuments(texts); + requestBody.setTop_n(texts.size()); + requestBody.setReturn_documents(false); + + String json = JsonUtils.toJsonString(requestBody); + RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8")); + + Request request = new Request.Builder() + .url(baseUrl) + .header("Authorization", "Bearer " + apiKey) + .post(body) + .build(); + + try (okhttp3.Response response = client.newCall(request).execute()) { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : "unknown error"; + log.error("SiliconFlow Rerank API 调用失败: code={}, body={}", response.code(), errorBody); + throw new RuntimeException("SiliconFlow Rerank API 调用失败: " + response.code()); + } + + String responseBody = response.body().string(); + RerankResponse rerankResponse = JsonUtils.parseObject(responseBody, RerankResponse.class); + + if (rerankResponse == null || rerankResponse.getResults() == null) { + return Response.from(new ArrayList<>()); + } + + // 初始化分数组,默认值为 0.0 + Double[] scores = new Double[texts.size()]; + for (int i = 0; i < texts.size(); i++) { + scores[i] = 0.0; + } + + // 填充得分 + rerankResponse.getResults().forEach(item -> { + if (item.getIndex() != null && item.getIndex() < texts.size()) { + scores[item.getIndex()] = item.getRelevance_score(); + } + }); + + List scoreList = new ArrayList<>(); + for (Double s : scores) { + scoreList.add(s); + } + + return Response.from(scoreList); + + } catch (IOException e) { + log.error("SiliconFlow Rerank 网络请求异常", e); + throw new RuntimeException("SiliconFlow Rerank 网络请求异常", e); + } + } + + @Override + public Response score(TextSegment segment, String query) { + List segments = new ArrayList<>(); + segments.add(segment); + Response> response = scoreAll(segments, query); + return Response.from(response.content().get(0)); + } + + @Data + public static class RerankRequest { + private String model; + private String query; + private List documents; + private Integer top_n; + private Boolean return_documents; + } + + @Data + public static class RerankResponse { + private List results; + } + + @Data + public static class RerankResultItem { + private Integer index; + private Double relevance_score; + } +}