mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-23 16:53:38 +00:00
fix(rag): 修复合并重复,重排模型新增硅基流动供应商
This commit is contained in:
@@ -98,12 +98,19 @@ public class KnowledgeInfoBo extends BaseEntity {
|
|||||||
private Double rerankScoreThreshold;
|
private Double rerankScoreThreshold;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索(0 否 1是)
|
||||||
|
*/
|
||||||
|
private Integer enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 备注
|
* 备注
|
||||||
*/
|
*/
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,6 +98,16 @@ public class KnowledgeInfo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private Double rerankScoreThreshold;
|
private Double rerankScoreThreshold;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索(0 否 1是)
|
||||||
|
*/
|
||||||
|
private Integer enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 备注
|
* 备注
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -118,6 +118,24 @@ public class KnowledgeInfoVo implements Serializable {
|
|||||||
@ExcelProperty(value = "重排序分数阈值")
|
@ExcelProperty(value = "重排序分数阈值")
|
||||||
private Double rerankScoreThreshold;
|
private Double rerankScoreThreshold;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索(0 否 1是)
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "是否启用混合检索")
|
||||||
|
private Integer enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "混合检索权重")
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档数量
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "文档数量")
|
||||||
|
private Integer documentCount;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 备注
|
* 备注
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -1,39 +1,36 @@
|
|||||||
package org.ruoyi.service.knowledge.impl;
|
package org.ruoyi.service.knowledge.impl;
|
||||||
|
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
import dev.langchain4j.model.scoring.ScoringModel;
|
|
||||||
import org.ruoyi.common.core.utils.MapstructUtils;
|
|
||||||
import org.ruoyi.common.core.utils.StringUtils;
|
|
||||||
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
|
||||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
|
||||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
|
||||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
||||||
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||||
|
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
||||||
|
import org.ruoyi.common.core.utils.MapstructUtils;
|
||||||
|
import org.ruoyi.common.core.utils.StringUtils;
|
||||||
|
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||||
|
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
||||||
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
||||||
|
import org.ruoyi.domain.bo.rerank.RerankRequest;
|
||||||
|
import org.ruoyi.domain.bo.rerank.RerankResult;
|
||||||
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||||
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
|
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||||
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.factory.RerankModelFactory;
|
||||||
|
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
|
||||||
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
||||||
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||||
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
import org.ruoyi.service.rerank.RerankModelService;
|
||||||
import org.ruoyi.service.knowledge.rerank.ScoringModelFactory;
|
|
||||||
import org.ruoyi.service.vector.VectorStoreService;
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Collection;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识片段Service业务层处理
|
* 知识片段Service业务层处理
|
||||||
@@ -50,7 +47,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
|||||||
private final IKnowledgeInfoService knowledgeInfoService;
|
private final IKnowledgeInfoService knowledgeInfoService;
|
||||||
private final IChatModelService chatModelService;
|
private final IChatModelService chatModelService;
|
||||||
private final VectorStoreService vectorStoreService;
|
private final VectorStoreService vectorStoreService;
|
||||||
private final ScoringModelFactory scoringModelFactory;
|
private final RerankModelFactory rerankModelFactory;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 查询知识片段
|
* 查询知识片段
|
||||||
@@ -231,37 +228,38 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
|||||||
|
|
||||||
// 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型)
|
// 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型)
|
||||||
if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) {
|
if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) {
|
||||||
log.info("开始重排配置检索测试,传入模型名称: [{}]", bo.getRerankModel());
|
log.info("开始重排精排,模型: [{}]", bo.getRerankModel());
|
||||||
ChatModelVo rerankModelConfig = chatModelService.selectModelByName(bo.getRerankModel());
|
try {
|
||||||
|
RerankModelService rerankModel = rerankModelFactory.createModel(bo.getRerankModel());
|
||||||
if (rerankModelConfig == null) {
|
|
||||||
log.warn("未能找到重排模型配置: [{}]", bo.getRerankModel());
|
|
||||||
} else {
|
|
||||||
ScoringModel scoringModel = scoringModelFactory.createScoringModel(rerankModelConfig);
|
|
||||||
if (scoringModel != null) {
|
|
||||||
log.info("执行重排精排,模型: {}, 供应商: {}", rerankModelConfig.getModelName(), rerankModelConfig.getProviderCode());
|
|
||||||
|
|
||||||
// 将 KnowledgeRetrievalVo 转换为 TextSegment 列表进行重排
|
List<String> contents = allResults.stream()
|
||||||
List<TextSegment> segments = allResults.stream()
|
.map(KnowledgeRetrievalVo::getContent)
|
||||||
.map(res -> TextSegment.from(res.getContent()))
|
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
Response<List<Double>> scoresResponse = scoringModel.scoreAll(segments, bo.getQuery());
|
RerankRequest rerankRequest = RerankRequest.builder()
|
||||||
List<Double> scores = scoresResponse.content();
|
.query(bo.getQuery())
|
||||||
|
.documents(contents)
|
||||||
|
.topN(contents.size())
|
||||||
|
.returnDocuments(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
// 更新分数并重新排序
|
RerankResult rerankResult = rerankModel.rerank(rerankRequest);
|
||||||
for (int i = 0; i < allResults.size(); i++) {
|
|
||||||
KnowledgeRetrievalVo resultVo = allResults.get(i);
|
// 将重排分数写回,并记录原始分数供前端对比
|
||||||
// 保存原始分数供前端展示对比
|
for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) {
|
||||||
|
if (doc.getIndex() != null && doc.getIndex() < allResults.size()) {
|
||||||
|
KnowledgeRetrievalVo resultVo = allResults.get(doc.getIndex());
|
||||||
resultVo.setRawScore(resultVo.getScore());
|
resultVo.setRawScore(resultVo.getScore());
|
||||||
if (i < scores.size()) {
|
resultVo.setScore(doc.getRelevanceScore());
|
||||||
resultVo.setScore(scores.get(i));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 按重排后的分数从高到低排序
|
|
||||||
allResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 按重排后的分数从高到低排序
|
||||||
|
allResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
|
||||||
|
log.info("重排精排完成,结果数: {}", allResults.size());
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("重排精排执行失败,已跳过重排步骤: {}", e.getMessage(), e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,98 +0,0 @@
|
|||||||
package org.ruoyi.service.knowledge.rerank;
|
|
||||||
|
|
||||||
import com.alibaba.dashscope.exception.ApiException;
|
|
||||||
import com.alibaba.dashscope.exception.InputRequiredException;
|
|
||||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
|
||||||
import com.alibaba.dashscope.rerank.TextReRank;
|
|
||||||
import com.alibaba.dashscope.rerank.TextReRankParam;
|
|
||||||
import com.alibaba.dashscope.rerank.TextReRankResult;
|
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
import dev.langchain4j.model.scoring.ScoringModel;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* DashScope 重排模型实现 (GTE-Rerank)
|
|
||||||
* 包装了阿里云 DashScope 的 TextReRank API,使其符合 LangChain4j 的 ScoringModel 标准。
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class DashScopeScoringModel implements ScoringModel {
|
|
||||||
|
|
||||||
private final String apiKey;
|
|
||||||
private final String modelName;
|
|
||||||
private final TextReRank rerank;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
public DashScopeScoringModel(String apiKey, String modelName) {
|
|
||||||
if (isNullOrEmpty(apiKey)) {
|
|
||||||
throw new IllegalArgumentException("DashScope API Key 不能为空");
|
|
||||||
}
|
|
||||||
this.apiKey = apiKey;
|
|
||||||
this.modelName = isNullOrEmpty(modelName) ? "gte-rerank" : modelName;
|
|
||||||
this.rerank = new TextReRank();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
|
|
||||||
if (isNullOrEmpty(segments)) {
|
|
||||||
return Response.from(new ArrayList<>());
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提取文本列表供阿里 SDK 使用
|
|
||||||
List<String> texts = segments.stream()
|
|
||||||
.map(TextSegment::text)
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
try {
|
|
||||||
TextReRankParam param = TextReRankParam.builder()
|
|
||||||
.apiKey(apiKey)
|
|
||||||
.model(modelName)
|
|
||||||
.query(query)
|
|
||||||
.documents(texts)
|
|
||||||
.topN(texts.size())
|
|
||||||
.returnDocuments(false)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TextReRankResult result = rerank.call(param);
|
|
||||||
|
|
||||||
// 初始化分数组,默认值为 0.0
|
|
||||||
Double[] scores = new Double[texts.size()];
|
|
||||||
for (int i = 0; i < texts.size(); i++) {
|
|
||||||
scores[i] = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根据返回结果填充对应的分数值(返回结果中包含原文索引)
|
|
||||||
result.getOutput().getResults().forEach(item -> {
|
|
||||||
if (item.getIndex() != null && item.getIndex() < texts.size()) {
|
|
||||||
scores[item.getIndex()] = item.getRelevanceScore();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
List<Double> scoreList = new ArrayList<>();
|
|
||||||
for (Double s : scores) {
|
|
||||||
scoreList.add(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Response.from(scoreList);
|
|
||||||
|
|
||||||
} catch (ApiException | NoApiKeyException | InputRequiredException e) {
|
|
||||||
log.error("DashScope 重排处理出错: {}", e.getMessage(), e);
|
|
||||||
throw new RuntimeException("调用 DashScope 重排服务失败", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Response<Double> score(TextSegment segment, String query) {
|
|
||||||
List<TextSegment> segments = new ArrayList<>();
|
|
||||||
segments.add(segment);
|
|
||||||
Response<List<Double>> response = scoreAll(segments, query);
|
|
||||||
return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
package org.ruoyi.service.knowledge.rerank;
|
|
||||||
|
|
||||||
import dev.langchain4j.model.scoring.ScoringModel;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 重排模型提供商工厂
|
|
||||||
* 用于将来无缝拓展硅基流动、百炼等支持重排的模型厂商
|
|
||||||
*
|
|
||||||
* @author RobustH
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
@Component
|
|
||||||
public class ScoringModelFactory {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 根据后台传递的模型配置创建具体的重排模型
|
|
||||||
*
|
|
||||||
* @param rerankModelConfig 重排模型的配置 (例如其 providerCode, apiUrl, apiKey 等)
|
|
||||||
* @return 标准的 LangChain4j ScoringModel
|
|
||||||
*/
|
|
||||||
public ScoringModel createScoringModel(ChatModelVo rerankModelConfig) {
|
|
||||||
if (rerankModelConfig == null) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
String providerCode = rerankModelConfig.getProviderCode();
|
|
||||||
log.info("初始化重排模型,供应商代码: {}, 模型名称: {}", providerCode, rerankModelConfig.getModelName());
|
|
||||||
|
|
||||||
try {
|
|
||||||
if ("alibailian".equalsIgnoreCase(providerCode)) {
|
|
||||||
return DashScopeScoringModel.builder()
|
|
||||||
.apiKey(rerankModelConfig.getApiKey())
|
|
||||||
.modelName(rerankModelConfig.getModelName())
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
if ("siliconflow".equalsIgnoreCase(providerCode)) {
|
|
||||||
return SiliconFlowScoringModel.builder()
|
|
||||||
.apiKey(rerankModelConfig.getApiKey())
|
|
||||||
.modelName(rerankModelConfig.getModelName())
|
|
||||||
// 如果后台配置了不同的 API Host,可以在此传递,否则使用默认值
|
|
||||||
.baseUrl(rerankModelConfig.getApiHost())
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("创建重排模型失败: {}", e.getMessage(), e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
package org.ruoyi.service.knowledge.rerank;
|
|
||||||
|
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
import dev.langchain4j.model.scoring.ScoringModel;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import okhttp3.*;
|
|
||||||
import org.ruoyi.common.json.utils.JsonUtils;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* SiliconFlow 重排模型实现
|
|
||||||
* 适配硅基流动的 /v1/rerank 接口
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class SiliconFlowScoringModel implements ScoringModel {
|
|
||||||
|
|
||||||
private final String apiKey;
|
|
||||||
private final String modelName;
|
|
||||||
private final String baseUrl;
|
|
||||||
private final OkHttpClient client;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
public SiliconFlowScoringModel(String apiKey, String modelName, String baseUrl) {
|
|
||||||
if (isNullOrEmpty(apiKey)) {
|
|
||||||
throw new IllegalArgumentException("SiliconFlow API Key 不能为空");
|
|
||||||
}
|
|
||||||
this.apiKey = apiKey;
|
|
||||||
this.modelName = isNullOrEmpty(modelName) ? "BAAI/bge-reranker-v2-m3" : modelName;
|
|
||||||
|
|
||||||
// 鲁棒性处理:自动补全 /rerank 路径
|
|
||||||
String finalUrl = baseUrl;
|
|
||||||
if (isNullOrEmpty(finalUrl)) {
|
|
||||||
finalUrl = "https://api.siliconflow.cn/v1/rerank";
|
|
||||||
} else {
|
|
||||||
// 如果用户只填了基础路径 https://api.siliconflow.cn/v1,自动补全成 https://api.siliconflow.cn/v1/rerank
|
|
||||||
if (finalUrl.endsWith("/v1")) {
|
|
||||||
finalUrl = finalUrl + "/rerank";
|
|
||||||
} else if (!finalUrl.endsWith("/rerank")) {
|
|
||||||
// 如果没有以 /rerank 结尾也不以斜杠结尾,尝试拼接
|
|
||||||
finalUrl = finalUrl.endsWith("/") ? finalUrl + "rerank" : finalUrl + "/rerank";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.baseUrl = finalUrl;
|
|
||||||
log.info("初始化 SiliconFlow 重排模型: URL=[{}], Model=[{}]", this.baseUrl, this.modelName);
|
|
||||||
|
|
||||||
this.client = new OkHttpClient.Builder()
|
|
||||||
.connectTimeout(60, TimeUnit.SECONDS)
|
|
||||||
.readTimeout(60, TimeUnit.SECONDS)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
|
|
||||||
if (isNullOrEmpty(segments)) {
|
|
||||||
return Response.from(new ArrayList<>());
|
|
||||||
}
|
|
||||||
|
|
||||||
List<String> texts = segments.stream()
|
|
||||||
.map(TextSegment::text)
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
RerankRequest requestBody = new RerankRequest();
|
|
||||||
requestBody.setModel(modelName);
|
|
||||||
requestBody.setQuery(query);
|
|
||||||
requestBody.setDocuments(texts);
|
|
||||||
requestBody.setTop_n(texts.size());
|
|
||||||
requestBody.setReturn_documents(false);
|
|
||||||
|
|
||||||
String json = JsonUtils.toJsonString(requestBody);
|
|
||||||
RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8"));
|
|
||||||
|
|
||||||
Request request = new Request.Builder()
|
|
||||||
.url(baseUrl)
|
|
||||||
.header("Authorization", "Bearer " + apiKey)
|
|
||||||
.post(body)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
try (okhttp3.Response response = client.newCall(request).execute()) {
|
|
||||||
if (!response.isSuccessful()) {
|
|
||||||
String errorBody = response.body() != null ? response.body().string() : "unknown error";
|
|
||||||
log.error("SiliconFlow Rerank API 调用失败: code={}, body={}", response.code(), errorBody);
|
|
||||||
throw new RuntimeException("SiliconFlow Rerank API 调用失败: " + response.code());
|
|
||||||
}
|
|
||||||
|
|
||||||
String responseBody = response.body().string();
|
|
||||||
RerankResponse rerankResponse = JsonUtils.parseObject(responseBody, RerankResponse.class);
|
|
||||||
|
|
||||||
if (rerankResponse == null || rerankResponse.getResults() == null) {
|
|
||||||
return Response.from(new ArrayList<>());
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化分数组,默认值为 0.0
|
|
||||||
Double[] scores = new Double[texts.size()];
|
|
||||||
for (int i = 0; i < texts.size(); i++) {
|
|
||||||
scores[i] = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 填充得分
|
|
||||||
rerankResponse.getResults().forEach(item -> {
|
|
||||||
if (item.getIndex() != null && item.getIndex() < texts.size()) {
|
|
||||||
scores[item.getIndex()] = item.getRelevance_score();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
List<Double> scoreList = new ArrayList<>();
|
|
||||||
for (Double s : scores) {
|
|
||||||
scoreList.add(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Response.from(scoreList);
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("SiliconFlow Rerank 网络请求异常", e);
|
|
||||||
throw new RuntimeException("SiliconFlow Rerank 网络请求异常", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Response<Double> score(TextSegment segment, String query) {
|
|
||||||
List<TextSegment> segments = new ArrayList<>();
|
|
||||||
segments.add(segment);
|
|
||||||
Response<List<Double>> response = scoreAll(segments, query);
|
|
||||||
return Response.from(response.content().get(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public static class RerankRequest {
|
|
||||||
private String model;
|
|
||||||
private String query;
|
|
||||||
private List<String> documents;
|
|
||||||
private Integer top_n;
|
|
||||||
private Boolean return_documents;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public static class RerankResponse {
|
|
||||||
private List<RerankResultItem> results;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public static class RerankResultItem {
|
|
||||||
private Integer index;
|
|
||||||
private Double relevance_score;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,174 @@
|
|||||||
|
package org.ruoyi.service.rerank.impl;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import okhttp3.*;
|
||||||
|
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||||
|
import org.ruoyi.domain.bo.rerank.RerankRequest;
|
||||||
|
import org.ruoyi.domain.bo.rerank.RerankResult;
|
||||||
|
import org.ruoyi.service.rerank.RerankModelService;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 硅基流动重排序模型实现
|
||||||
|
* 适配硅基流动的 /v1/rerank 接口
|
||||||
|
*
|
||||||
|
* @author RobustH
|
||||||
|
* @date 2026-04-21
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Component("siliconflowRerank")
|
||||||
|
public class SiliconFlowRerankModelService implements RerankModelService {
|
||||||
|
|
||||||
|
private static final String DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1/rerank";
|
||||||
|
|
||||||
|
private final OkHttpClient okHttpClient;
|
||||||
|
private final ObjectMapper objectMapper = new ObjectMapper()
|
||||||
|
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||||
|
private ChatModelVo chatModelVo;
|
||||||
|
|
||||||
|
public SiliconFlowRerankModelService() {
|
||||||
|
this.okHttpClient = new OkHttpClient.Builder()
|
||||||
|
.connectTimeout(30, TimeUnit.SECONDS)
|
||||||
|
.readTimeout(60, TimeUnit.SECONDS)
|
||||||
|
.writeTimeout(30, TimeUnit.SECONDS)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void configure(ChatModelVo config) {
|
||||||
|
this.chatModelVo = config;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RerankResult rerank(RerankRequest rerankRequest) {
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
|
||||||
|
try {
|
||||||
|
String url = buildUrl();
|
||||||
|
String requestJson = buildRequestJson(rerankRequest);
|
||||||
|
|
||||||
|
RequestBody body = RequestBody.create(requestJson, MediaType.get("application/json"));
|
||||||
|
Request httpRequest = new Request.Builder()
|
||||||
|
.url(url)
|
||||||
|
.addHeader("Authorization", "Bearer " + chatModelVo.getApiKey())
|
||||||
|
.addHeader("Content-Type", "application/json")
|
||||||
|
.post(body)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
log.info("硅基流动重排序请求: model={}, url={}", chatModelVo.getModelName(), url);
|
||||||
|
|
||||||
|
try (Response response = okHttpClient.newCall(httpRequest).execute()) {
|
||||||
|
if (!response.isSuccessful()) {
|
||||||
|
String err = response.body() != null ? response.body().string() : "无错误信息";
|
||||||
|
throw new IllegalArgumentException("硅基流动 Rerank API 调用失败: " + response.code() + " - " + err);
|
||||||
|
}
|
||||||
|
|
||||||
|
ResponseBody responseBody = response.body();
|
||||||
|
if (responseBody == null) {
|
||||||
|
throw new IllegalArgumentException("响应体为空");
|
||||||
|
}
|
||||||
|
|
||||||
|
SiliconFlowRerankResponse rerankResponse = objectMapper.readValue(
|
||||||
|
responseBody.string(), SiliconFlowRerankResponse.class);
|
||||||
|
|
||||||
|
return buildRerankResult(rerankResponse, rerankRequest.getDocuments(),
|
||||||
|
System.currentTimeMillis() - startTime);
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("硅基流动重排序失败: {}", e.getMessage(), e);
|
||||||
|
throw new RuntimeException("重排序服务调用失败: " + e.getMessage(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建请求 URL,鲁棒处理 API Host 末尾路径
|
||||||
|
*/
|
||||||
|
private String buildUrl() {
|
||||||
|
String apiHost = chatModelVo.getApiHost();
|
||||||
|
if (apiHost == null || apiHost.isBlank()) {
|
||||||
|
return DEFAULT_BASE_URL;
|
||||||
|
}
|
||||||
|
if (apiHost.endsWith("/rerank")) {
|
||||||
|
return apiHost;
|
||||||
|
}
|
||||||
|
if (apiHost.endsWith("/v1")) {
|
||||||
|
return apiHost + "/rerank";
|
||||||
|
}
|
||||||
|
return apiHost.endsWith("/") ? apiHost + "rerank" : apiHost + "/rerank";
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建请求体 JSON
|
||||||
|
*/
|
||||||
|
private String buildRequestJson(RerankRequest rerankRequest) throws IOException {
|
||||||
|
SiliconFlowRerankRequest request = new SiliconFlowRerankRequest();
|
||||||
|
request.setModel(chatModelVo.getModelName());
|
||||||
|
request.setQuery(rerankRequest.getQuery());
|
||||||
|
request.setDocuments(rerankRequest.getDocuments());
|
||||||
|
request.setTop_n(rerankRequest.getTopN() != null ? rerankRequest.getTopN() : rerankRequest.getDocuments().size());
|
||||||
|
request.setReturn_documents(rerankRequest.getReturnDocuments() != null ? rerankRequest.getReturnDocuments() : false);
|
||||||
|
return objectMapper.writeValueAsString(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建标准 RerankResult
|
||||||
|
*/
|
||||||
|
private RerankResult buildRerankResult(SiliconFlowRerankResponse response,
|
||||||
|
List<String> originalDocuments, long durationMs) {
|
||||||
|
Double[] scores = new Double[originalDocuments.size()];
|
||||||
|
for (int i = 0; i < scores.length; i++) {
|
||||||
|
scores[i] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<RerankResult.RerankDocument> docs = new ArrayList<>();
|
||||||
|
if (response != null && response.getResults() != null) {
|
||||||
|
response.getResults().forEach(item -> {
|
||||||
|
if (item.getIndex() != null && item.getIndex() < originalDocuments.size()) {
|
||||||
|
scores[item.getIndex()] = item.getRelevance_score();
|
||||||
|
docs.add(RerankResult.RerankDocument.builder()
|
||||||
|
.index(item.getIndex())
|
||||||
|
.relevanceScore(item.getRelevance_score())
|
||||||
|
.document(originalDocuments.get(item.getIndex()))
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return RerankResult.builder()
|
||||||
|
.documents(docs)
|
||||||
|
.totalDocuments(originalDocuments.size())
|
||||||
|
.durationMs(durationMs)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 内部 DTO ====================
|
||||||
|
|
||||||
|
@Data
|
||||||
|
static class SiliconFlowRerankRequest {
|
||||||
|
private String model;
|
||||||
|
private String query;
|
||||||
|
private List<String> documents;
|
||||||
|
private Integer top_n;
|
||||||
|
private Boolean return_documents;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
static class SiliconFlowRerankResponse {
|
||||||
|
private List<SiliconFlowRerankResultItem> results;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
static class SiliconFlowRerankResultItem {
|
||||||
|
private Integer index;
|
||||||
|
private Double relevance_score;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user