mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-23 16:53:38 +00:00
feat(rag): 集成硅基流动、阿里百炼重排模型并全方位增强检索测试体验
This commit is contained in:
@@ -69,4 +69,14 @@ public class KnowledgeFragmentBo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private Double threshold;
|
private Double threshold;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用重排
|
||||||
|
*/
|
||||||
|
private Boolean enableRerank;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重排模型名称
|
||||||
|
*/
|
||||||
|
private String rerankModel;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,6 +77,16 @@ public class KnowledgeInfoBo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private String embeddingModel;
|
private String embeddingModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重排模型
|
||||||
|
*/
|
||||||
|
private String rerankModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用重排(0 否 1 是)
|
||||||
|
*/
|
||||||
|
private Integer enableRerank;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 备注
|
* 备注
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -78,6 +78,16 @@ public class KnowledgeInfo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private String embeddingModel;
|
private String embeddingModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重排模型
|
||||||
|
*/
|
||||||
|
private String rerankModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用重排(0 否 1 是)
|
||||||
|
*/
|
||||||
|
private Integer enableRerank;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 备注
|
* 备注
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -94,6 +94,19 @@ public class KnowledgeInfoVo implements Serializable {
|
|||||||
@ExcelProperty(value = "向量模型")
|
@ExcelProperty(value = "向量模型")
|
||||||
private String embeddingModel;
|
private String embeddingModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重排模型
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "重排模型")
|
||||||
|
private String rerankModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用重排(0 否 1 是)
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "是否启用重排", converter = ExcelDictConvert.class)
|
||||||
|
@ExcelDictFormat(readConverterExp = "0=否,1=是")
|
||||||
|
private Integer enableRerank;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 备注
|
* 备注
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -28,6 +28,16 @@ public class KnowledgeRetrievalVo implements Serializable {
|
|||||||
*/
|
*/
|
||||||
private Double score;
|
private Double score;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 原始检索排名 (重排前)
|
||||||
|
*/
|
||||||
|
private Integer originalIndex;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 原始检索得分 (重排前)
|
||||||
|
*/
|
||||||
|
private Double rawScore;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 来源文档名称
|
* 来源文档名称
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -377,20 +377,24 @@ public class ChatServiceFacade implements IChatService {
|
|||||||
CustomVectorRetriever retriever = new CustomVectorRetriever(
|
CustomVectorRetriever retriever = new CustomVectorRetriever(
|
||||||
vectorStoreService, knowledgeInfoVo, chatModel);
|
vectorStoreService, knowledgeInfoVo, chatModel);
|
||||||
|
|
||||||
// 2. 获取和构建重排模型聚合器(Aggregator)
|
// 2. 构建重排聚合器 (Aggregator)
|
||||||
// 假设已在 KnowledgeInfoVo 等加入 getRerankModelConfig/getRerankModel 等,这里演示通用逻辑
|
|
||||||
// 若无重排需求,使用 DefaultContentAggregator 或无 ScoringModel 的聚合器
|
|
||||||
ContentAggregator contentAggregator;
|
ContentAggregator contentAggregator;
|
||||||
// TODO: 一旦实体类实现了重排模型的支持,此处可以从数据库读出:
|
if (knowledgeInfoVo.getEnableRerank() != null && knowledgeInfoVo.getEnableRerank() == 1
|
||||||
// ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel());
|
&& knowledgeInfoVo.getRerankModel() != null) {
|
||||||
ChatModelVo scoringModelConfig = null; // 当前暂无对应配置字段
|
|
||||||
|
|
||||||
|
ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel());
|
||||||
ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig);
|
ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig);
|
||||||
|
|
||||||
if (scoringModel != null) {
|
if (scoringModel != null) {
|
||||||
contentAggregator = ReRankingContentAggregator.builder()
|
contentAggregator = ReRankingContentAggregator.builder()
|
||||||
.scoringModel(scoringModel)
|
.scoringModel(scoringModel)
|
||||||
// .maxResults(3) 这个数字将来从配置取
|
// 默认重排后只留前 5 条,避免上下文过长
|
||||||
|
.maxResults(5)
|
||||||
.build();
|
.build();
|
||||||
|
log.info("启用重排模型: {}", knowledgeInfoVo.getRerankModel());
|
||||||
|
} else {
|
||||||
|
contentAggregator = new DefaultContentAggregator();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
contentAggregator = new DefaultContentAggregator();
|
contentAggregator = new DefaultContentAggregator();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
package org.ruoyi.service.knowledge.impl;
|
package org.ruoyi.service.knowledge.impl;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.model.scoring.ScoringModel;
|
||||||
import org.ruoyi.common.core.utils.MapstructUtils;
|
import org.ruoyi.common.core.utils.MapstructUtils;
|
||||||
import org.ruoyi.common.core.utils.StringUtils;
|
import org.ruoyi.common.core.utils.StringUtils;
|
||||||
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
||||||
@@ -20,6 +23,7 @@ import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
|||||||
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
||||||
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||||
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
||||||
|
import org.ruoyi.service.knowledge.rerank.ScoringModelFactory;
|
||||||
import org.ruoyi.service.vector.VectorStoreService;
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -44,6 +48,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
|||||||
private final IKnowledgeInfoService knowledgeInfoService;
|
private final IKnowledgeInfoService knowledgeInfoService;
|
||||||
private final IChatModelService chatModelService;
|
private final IChatModelService chatModelService;
|
||||||
private final VectorStoreService vectorStoreService;
|
private final VectorStoreService vectorStoreService;
|
||||||
|
private final ScoringModelFactory scoringModelFactory;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 查询知识片段
|
* 查询知识片段
|
||||||
@@ -178,7 +183,48 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
|||||||
// 3. 执行物理检索
|
// 3. 执行物理检索
|
||||||
List<KnowledgeRetrievalVo> allResults = vectorStoreService.search(queryVectorBo);
|
List<KnowledgeRetrievalVo> allResults = vectorStoreService.search(queryVectorBo);
|
||||||
|
|
||||||
// 4. 根据阈值过滤 (LangChain4j 结果 score 通常 0-1)
|
// 初始化原始排名
|
||||||
|
for (int i = 0; i < allResults.size(); i++) {
|
||||||
|
allResults.get(i).setOriginalIndex(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型)
|
||||||
|
if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) {
|
||||||
|
log.info("开始重排配置检索测试,传入模型名称: [{}]", bo.getRerankModel());
|
||||||
|
ChatModelVo rerankModelConfig = chatModelService.selectModelByName(bo.getRerankModel());
|
||||||
|
|
||||||
|
if (rerankModelConfig == null) {
|
||||||
|
log.warn("未能找到重排模型配置: [{}]", bo.getRerankModel());
|
||||||
|
} else {
|
||||||
|
ScoringModel scoringModel = scoringModelFactory.createScoringModel(rerankModelConfig);
|
||||||
|
if (scoringModel != null) {
|
||||||
|
log.info("执行重排精排,模型: {}, 供应商: {}", rerankModelConfig.getModelName(), rerankModelConfig.getProviderCode());
|
||||||
|
|
||||||
|
// 将 KnowledgeRetrievalVo 转换为 TextSegment 列表进行重排
|
||||||
|
List<TextSegment> segments = allResults.stream()
|
||||||
|
.map(res -> TextSegment.from(res.getContent()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
Response<List<Double>> scoresResponse = scoringModel.scoreAll(segments, bo.getQuery());
|
||||||
|
List<Double> scores = scoresResponse.content();
|
||||||
|
|
||||||
|
// 更新分数并重新排序
|
||||||
|
for (int i = 0; i < allResults.size(); i++) {
|
||||||
|
KnowledgeRetrievalVo resultVo = allResults.get(i);
|
||||||
|
// 保存原始分数供前端展示对比
|
||||||
|
resultVo.setRawScore(resultVo.getScore());
|
||||||
|
if (i < scores.size()) {
|
||||||
|
resultVo.setScore(scores.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按重排后的分数从高到低排序
|
||||||
|
allResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 根据阈值过滤
|
||||||
double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0;
|
double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0;
|
||||||
return allResults.stream()
|
return allResults.stream()
|
||||||
.filter(res -> res.getScore() >= threshold)
|
.filter(res -> res.getScore() >= threshold)
|
||||||
|
|||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package org.ruoyi.service.knowledge.rerank;
|
||||||
|
|
||||||
|
import com.alibaba.dashscope.exception.ApiException;
|
||||||
|
import com.alibaba.dashscope.exception.InputRequiredException;
|
||||||
|
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||||
|
import com.alibaba.dashscope.rerank.TextReRank;
|
||||||
|
import com.alibaba.dashscope.rerank.TextReRankParam;
|
||||||
|
import com.alibaba.dashscope.rerank.TextReRankResult;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.model.scoring.ScoringModel;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DashScope 重排模型实现 (GTE-Rerank)
|
||||||
|
* 包装了阿里云 DashScope 的 TextReRank API,使其符合 LangChain4j 的 ScoringModel 标准。
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class DashScopeScoringModel implements ScoringModel {
|
||||||
|
|
||||||
|
private final String apiKey;
|
||||||
|
private final String modelName;
|
||||||
|
private final TextReRank rerank;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public DashScopeScoringModel(String apiKey, String modelName) {
|
||||||
|
if (isNullOrEmpty(apiKey)) {
|
||||||
|
throw new IllegalArgumentException("DashScope API Key 不能为空");
|
||||||
|
}
|
||||||
|
this.apiKey = apiKey;
|
||||||
|
this.modelName = isNullOrEmpty(modelName) ? "gte-rerank" : modelName;
|
||||||
|
this.rerank = new TextReRank();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
|
||||||
|
if (isNullOrEmpty(segments)) {
|
||||||
|
return Response.from(new ArrayList<>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取文本列表供阿里 SDK 使用
|
||||||
|
List<String> texts = segments.stream()
|
||||||
|
.map(TextSegment::text)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
try {
|
||||||
|
TextReRankParam param = TextReRankParam.builder()
|
||||||
|
.apiKey(apiKey)
|
||||||
|
.model(modelName)
|
||||||
|
.query(query)
|
||||||
|
.documents(texts)
|
||||||
|
.topN(texts.size())
|
||||||
|
.returnDocuments(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TextReRankResult result = rerank.call(param);
|
||||||
|
|
||||||
|
// 初始化分数组,默认值为 0.0
|
||||||
|
Double[] scores = new Double[texts.size()];
|
||||||
|
for (int i = 0; i < texts.size(); i++) {
|
||||||
|
scores[i] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据返回结果填充对应的分数值(返回结果中包含原文索引)
|
||||||
|
result.getOutput().getResults().forEach(item -> {
|
||||||
|
if (item.getIndex() != null && item.getIndex() < texts.size()) {
|
||||||
|
scores[item.getIndex()] = item.getRelevanceScore();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
List<Double> scoreList = new ArrayList<>();
|
||||||
|
for (Double s : scores) {
|
||||||
|
scoreList.add(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Response.from(scoreList);
|
||||||
|
|
||||||
|
} catch (ApiException | NoApiKeyException | InputRequiredException e) {
|
||||||
|
log.error("DashScope 重排处理出错: {}", e.getMessage(), e);
|
||||||
|
throw new RuntimeException("调用 DashScope 重排服务失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Response<Double> score(TextSegment segment, String query) {
|
||||||
|
List<TextSegment> segments = new ArrayList<>();
|
||||||
|
segments.add(segment);
|
||||||
|
Response<List<Double>> response = scoreAll(segments, query);
|
||||||
|
return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -27,10 +27,27 @@ public class ScoringModelFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
String providerCode = rerankModelConfig.getProviderCode();
|
String providerCode = rerankModelConfig.getProviderCode();
|
||||||
log.info("初始化重排模型,供应商代码: {}", providerCode);
|
log.info("初始化重排模型,供应商代码: {}, 模型名称: {}", providerCode, rerankModelConfig.getModelName());
|
||||||
|
|
||||||
// TODO: 在这里通过 switch 或反射具体实例化支持的各种 ScoringModel (例如 CohereScoringModel, DascScope 等)
|
try {
|
||||||
// 目前返回 null 代表暂时没有加载特定的重排底座,这不会影响流程,Aggregator 会忽略它返回原样结果
|
if ("alibailian".equalsIgnoreCase(providerCode)) {
|
||||||
|
return DashScopeScoringModel.builder()
|
||||||
|
.apiKey(rerankModelConfig.getApiKey())
|
||||||
|
.modelName(rerankModelConfig.getModelName())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
if ("siliconflow".equalsIgnoreCase(providerCode)) {
|
||||||
|
return SiliconFlowScoringModel.builder()
|
||||||
|
.apiKey(rerankModelConfig.getApiKey())
|
||||||
|
.modelName(rerankModelConfig.getModelName())
|
||||||
|
// 如果后台配置了不同的 API Host,可以在此传递,否则使用默认值
|
||||||
|
.baseUrl(rerankModelConfig.getApiHost())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("创建重排模型失败: {}", e.getMessage(), e);
|
||||||
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,155 @@
|
|||||||
|
package org.ruoyi.service.knowledge.rerank;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.model.scoring.ScoringModel;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import okhttp3.*;
|
||||||
|
import org.ruoyi.common.json.utils.JsonUtils;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SiliconFlow 重排模型实现
|
||||||
|
* 适配硅基流动的 /v1/rerank 接口
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class SiliconFlowScoringModel implements ScoringModel {
|
||||||
|
|
||||||
|
private final String apiKey;
|
||||||
|
private final String modelName;
|
||||||
|
private final String baseUrl;
|
||||||
|
private final OkHttpClient client;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public SiliconFlowScoringModel(String apiKey, String modelName, String baseUrl) {
|
||||||
|
if (isNullOrEmpty(apiKey)) {
|
||||||
|
throw new IllegalArgumentException("SiliconFlow API Key 不能为空");
|
||||||
|
}
|
||||||
|
this.apiKey = apiKey;
|
||||||
|
this.modelName = isNullOrEmpty(modelName) ? "BAAI/bge-reranker-v2-m3" : modelName;
|
||||||
|
|
||||||
|
// 鲁棒性处理:自动补全 /rerank 路径
|
||||||
|
String finalUrl = baseUrl;
|
||||||
|
if (isNullOrEmpty(finalUrl)) {
|
||||||
|
finalUrl = "https://api.siliconflow.cn/v1/rerank";
|
||||||
|
} else {
|
||||||
|
// 如果用户只填了基础路径 https://api.siliconflow.cn/v1,自动补全成 https://api.siliconflow.cn/v1/rerank
|
||||||
|
if (finalUrl.endsWith("/v1")) {
|
||||||
|
finalUrl = finalUrl + "/rerank";
|
||||||
|
} else if (!finalUrl.endsWith("/rerank")) {
|
||||||
|
// 如果没有以 /rerank 结尾也不以斜杠结尾,尝试拼接
|
||||||
|
finalUrl = finalUrl.endsWith("/") ? finalUrl + "rerank" : finalUrl + "/rerank";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.baseUrl = finalUrl;
|
||||||
|
log.info("初始化 SiliconFlow 重排模型: URL=[{}], Model=[{}]", this.baseUrl, this.modelName);
|
||||||
|
|
||||||
|
this.client = new OkHttpClient.Builder()
|
||||||
|
.connectTimeout(60, TimeUnit.SECONDS)
|
||||||
|
.readTimeout(60, TimeUnit.SECONDS)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
|
||||||
|
if (isNullOrEmpty(segments)) {
|
||||||
|
return Response.from(new ArrayList<>());
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> texts = segments.stream()
|
||||||
|
.map(TextSegment::text)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
RerankRequest requestBody = new RerankRequest();
|
||||||
|
requestBody.setModel(modelName);
|
||||||
|
requestBody.setQuery(query);
|
||||||
|
requestBody.setDocuments(texts);
|
||||||
|
requestBody.setTop_n(texts.size());
|
||||||
|
requestBody.setReturn_documents(false);
|
||||||
|
|
||||||
|
String json = JsonUtils.toJsonString(requestBody);
|
||||||
|
RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8"));
|
||||||
|
|
||||||
|
Request request = new Request.Builder()
|
||||||
|
.url(baseUrl)
|
||||||
|
.header("Authorization", "Bearer " + apiKey)
|
||||||
|
.post(body)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
try (okhttp3.Response response = client.newCall(request).execute()) {
|
||||||
|
if (!response.isSuccessful()) {
|
||||||
|
String errorBody = response.body() != null ? response.body().string() : "unknown error";
|
||||||
|
log.error("SiliconFlow Rerank API 调用失败: code={}, body={}", response.code(), errorBody);
|
||||||
|
throw new RuntimeException("SiliconFlow Rerank API 调用失败: " + response.code());
|
||||||
|
}
|
||||||
|
|
||||||
|
String responseBody = response.body().string();
|
||||||
|
RerankResponse rerankResponse = JsonUtils.parseObject(responseBody, RerankResponse.class);
|
||||||
|
|
||||||
|
if (rerankResponse == null || rerankResponse.getResults() == null) {
|
||||||
|
return Response.from(new ArrayList<>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化分数组,默认值为 0.0
|
||||||
|
Double[] scores = new Double[texts.size()];
|
||||||
|
for (int i = 0; i < texts.size(); i++) {
|
||||||
|
scores[i] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 填充得分
|
||||||
|
rerankResponse.getResults().forEach(item -> {
|
||||||
|
if (item.getIndex() != null && item.getIndex() < texts.size()) {
|
||||||
|
scores[item.getIndex()] = item.getRelevance_score();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
List<Double> scoreList = new ArrayList<>();
|
||||||
|
for (Double s : scores) {
|
||||||
|
scoreList.add(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Response.from(scoreList);
|
||||||
|
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("SiliconFlow Rerank 网络请求异常", e);
|
||||||
|
throw new RuntimeException("SiliconFlow Rerank 网络请求异常", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Response<Double> score(TextSegment segment, String query) {
|
||||||
|
List<TextSegment> segments = new ArrayList<>();
|
||||||
|
segments.add(segment);
|
||||||
|
Response<List<Double>> response = scoreAll(segments, query);
|
||||||
|
return Response.from(response.content().get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class RerankRequest {
|
||||||
|
private String model;
|
||||||
|
private String query;
|
||||||
|
private List<String> documents;
|
||||||
|
private Integer top_n;
|
||||||
|
private Boolean return_documents;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class RerankResponse {
|
||||||
|
private List<RerankResultItem> results;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class RerankResultItem {
|
||||||
|
private Integer index;
|
||||||
|
private Double relevance_score;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user