本地向量化

This commit is contained in:
jiahao.he@vtradex.com
2025-03-16 20:01:34 +08:00
parent 0e6c1c47d5
commit 4967c3f906
7 changed files with 445 additions and 15 deletions

View File

@@ -0,0 +1,38 @@
package org.ruoyi.common.chat.entity.models;
import lombok.Data;
import java.util.List;
/**
* @program: RUOYIAI
* @ClassName LocalModelsSearchRequest
* @description:
* @author: hejh
* @create: 2025-03-15 17:22
* @Version 1.0
**/
@Data
public class LocalModelsSearchRequest {
private List<String> text;
private String model_name;
private String delimiter;
private int k;
private int block_size;
private int overlap_chars;
// 构造函数、Getter 和 Setter
public LocalModelsSearchRequest(List<String> text, String model_name, String delimiter, int k, int block_size, int overlap_chars) {
this.text = text;
this.model_name = model_name;
this.delimiter = delimiter;
this.k = k;
this.block_size = block_size;
this.overlap_chars = overlap_chars;
}
}

View File

@@ -0,0 +1,20 @@
package org.ruoyi.common.chat.entity.models;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List;
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class LocalModelsSearchResponse {
@JsonProperty("topKEmbeddings")
private List<List<List<Double>>> topKEmbeddings; // 处理三层嵌套数组
// 默认构造函数
public LocalModelsSearchResponse() {}
}

View File

@@ -0,0 +1,198 @@
package org.ruoyi.common.chat.localModels;
import io.micrometer.common.util.StringUtils;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
import org.springframework.stereotype.Service;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Response;
import retrofit2.Retrofit;
import retrofit2.converter.jackson.JacksonConverterFactory;
import java.util.List;
import java.util.concurrent.CountDownLatch;
@Slf4j
@Service
public class LocalModelsofitClient {
private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask 服务的 URL
private static Retrofit retrofit = null;
// 获取 Retrofit 实例
public static Retrofit getRetrofitInstance() {
if (retrofit == null) {
OkHttpClient client = new OkHttpClient.Builder()
.build();
retrofit = new Retrofit.Builder()
.baseUrl(BASE_URL)
.client(client)
.addConverterFactory(JacksonConverterFactory.create()) // 使用 Jackson 处理 JSON 转换
.build();
}
return retrofit;
}
/**
* 向 Flask 服务发送文本向量化请求
*
* @param queries 查询文本列表
* @param modelName 模型名称
* @param delimiter 文本分隔符
* @param topK 返回的结果数
* @param blockSize 文本块大小
* @param overlapChars 重叠字符数
* @return 返回计算得到的 Top K 嵌入向量列表
*/
public static List<List<Double>> getTopKEmbeddings(
List<String> queries,
String modelName,
String delimiter,
int topK,
int blockSize,
int overlapChars) {
modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // 默认模型名称
delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : "."; // 默认分隔符
topK = (topK > 0) ? topK : 3; // 默认返回 3 个结果
blockSize = (blockSize > 0) ? blockSize : 500; // 默认文本块大小为 500
overlapChars = (overlapChars > 0) ? overlapChars : 50; // 默认重叠字符数为 50
// 创建 Retrofit 实例
Retrofit retrofit = getRetrofitInstance();
// 创建 SearchService 接口
SearchService service = retrofit.create(SearchService.class);
// 创建请求对象 LocalModelsSearchRequest
LocalModelsSearchRequest request = new LocalModelsSearchRequest(
queries, // 查询文本列表
modelName, // 模型名称
delimiter, // 文本分隔符
topK, // 返回的结果数
blockSize, // 文本块大小
overlapChars // 重叠字符数
);
final CountDownLatch latch = new CountDownLatch(1); // 创建一个 CountDownLatch
final List<List<Double>>[] topKEmbeddings = new List[]{null}; // 使用数组来存储结果(因为 Java 不支持直接修改 List
// 发起异步请求
service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
@Override
public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
if (response.isSuccessful()) {
LocalModelsSearchResponse searchResponse = response.body();
if (searchResponse != null) {
topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0); // 获取结果
log.info("Successfully retrieved embeddings");
} else {
log.error("Response body is null");
}
} else {
log.error("Request failed. HTTP error code: " + response.code());
}
latch.countDown(); // 请求完成,减少计数
}
@Override
public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
t.printStackTrace();
log.error("Request failed: ", t);
latch.countDown(); // 请求失败,减少计数
}
});
try {
latch.await(); // 等待请求完成
} catch (InterruptedException e) {
e.printStackTrace();
}
return topKEmbeddings[0]; // 返回结果
}
// public static void main(String[] args) {
// // 示例调用
// List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries.");
// String modelName = "msmarco-distilbert-base-tas-b";
// String delimiter = ".";
// int topK = 3;
// int blockSize = 500;
// int overlapChars = 50;
//
// List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars);
//
// // 打印结果
// if (topKEmbeddings != null) {
// System.out.println("Top K embeddings: ");
// for (List<Double> embedding : topKEmbeddings) {
// System.out.println(embedding);
// }
// } else {
// System.out.println("No embeddings returned.");
// }
// }
// public static void main(String[] args) {
// // 创建 Retrofit 实例
// Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance();
//
// // 创建 SearchService 接口
// SearchService service = retrofit.create(SearchService.class);
//
// // 创建请求对象 LocalModelsSearchRequest
// LocalModelsSearchRequest request = new LocalModelsSearchRequest(
// Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // 查询文本列表
// "msmarco-distilbert-base-tas-b", // 模型名称
// ".", // 分隔符
// 3, // 返回的结果数
// 500, // 文本块大小
// 50 // 重叠字符数
// );
//
// // 发起请求
// service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
// @Override
// public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
// if (response.isSuccessful()) {
// LocalModelsSearchResponse searchResponse = response.body();
// System.out.println("Response Body: " + response.body()); // Print the whole response body for debugging
//
// if (searchResponse != null) {
// // If the response is not null, process it.
// // Example: Extract the embeddings and print them
// List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings();
// if (topKEmbeddings != null) {
// // Print the Top K embeddings
//
// } else {
// System.err.println("Top K embeddings are null");
// }
//
// // If there is more information you want to process, handle it here
//
// } else {
// System.err.println("Response body is null");
// }
// } else {
// System.err.println("Request failed. HTTP error code: " + response.code());
// log.error("Failed to retrieve data. HTTP error code: " + response.code());
// }
// }
//
// @Override
// public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
// // 请求失败,打印错误
// t.printStackTrace();
// log.error("Request failed: ", t);
// }
// });
// }
}

View File

@@ -0,0 +1,25 @@
package org.ruoyi.common.chat.localModels;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.POST;
/**
* @program: RUOYIAI
* @ClassName SearchService
* @description: 请求模型
* @author: hejh
* @create: 2025-03-15 17:27
* @Version 1.0
**/
public interface SearchService {
@POST("/vectorize") // 与 Flask 服务中的路由匹配
Call<LocalModelsSearchResponse> vectorize(@Body LocalModelsSearchRequest request);
}

View File

@@ -0,0 +1,92 @@
package org.ruoyi.knowledge.chain.vectorizer;
import jakarta.annotation.Resource;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.chat.config.ChatConfig;
import org.ruoyi.common.chat.localModels.LocalModelsofitClient;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
import org.ruoyi.knowledge.service.IKnowledgeInfoService;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
@Component
@Slf4j
@RequiredArgsConstructor
public class LocalModelsVectorization {
@Resource
private IKnowledgeInfoService knowledgeInfoService;
@Resource
private LocalModelsofitClient localModelsofitClient;
@Getter
private OpenAiStreamClient openAiStreamClient;
private final ChatConfig chatConfig;
/**
* 批量向量化
*
* @param chunkList 文本块列表
* @param kid 知识 ID
* @return 向量化结果
*/
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
logVectorizationRequest(kid, chunkList); // 在向量化开始前记录日志
openAiStreamClient = chatConfig.getOpenAiStreamClient(); // 获取 OpenAi 客户端
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // 查询知识信息
// 调用 localModelsofitClient 获取 Top K 嵌入向量
try {
return localModelsofitClient.getTopKEmbeddings(
chunkList,
knowledgeInfoVo.getVector(),
knowledgeInfoVo.getKnowledgeSeparator(),
knowledgeInfoVo.getRetrieveLimit(),
knowledgeInfoVo.getTextBlockSize(),
knowledgeInfoVo.getOverlapChar()
);
} catch (Exception e) {
log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e);
throw new RuntimeException("Batch vectorization failed", e);
}
}
/**
* 单一文本块向量化
*
* @param chunk 单一文本块
* @param kid 知识 ID
* @return 向量化结果
*/
public List<Double> singleVectorization(String chunk, String kid) {
List<String> chunkList = new ArrayList<>();
chunkList.add(chunk);
// 调用批量向量化方法
List<List<Double>> vectorList = batchVectorization(chunkList, kid);
if (vectorList.isEmpty()) {
log.warn("Vectorization returned empty list for chunk: {}", chunk);
return new ArrayList<>();
}
return vectorList.get(0); // 返回第一个向量
}
/**
* 提供更简洁的日志记录方法
*
* @param kid 知识 ID
* @param chunkList 文本块列表
*/
private void logVectorizationRequest(String kid, List<String> chunkList) {
log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size());
}
}

View File

@@ -18,6 +18,7 @@ import org.springframework.stereotype.Component;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@Component
@Slf4j
@@ -27,6 +28,9 @@ public class OpenAiVectorization implements Vectorization {
@Lazy
@Resource
private IKnowledgeInfoService knowledgeInfoService;
@Lazy
@Resource
private LocalModelsVectorization localModelsVectorization;
@Getter
private OpenAiStreamClient openAiStreamClient;
@@ -35,25 +39,63 @@ public class OpenAiVectorization implements Vectorization {
@Override
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
openAiStreamClient = chatConfig.getOpenAiStreamClient();
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
Embedding embedding = Embedding.builder()
.input(chunkList)
.model(knowledgeInfoVo.getVectorModel())
.build();
EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
List<List<Double>> vectorList = new ArrayList<>();
embeddings.getData().forEach(data -> {
List<BigDecimal> vector = data.getEmbedding();
List<Double> doubleVector = new ArrayList<>();
for (BigDecimal bd : vector) {
doubleVector.add(bd.doubleValue());
}
vectorList.add(doubleVector);
});
// 获取知识库信息
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
// 如果使用本地模型
try {
return localModelsVectorization.batchVectorization(chunkList, kid);
} catch (Exception e) {
log.error("Local models vectorization failed, falling back to OpenAI embeddings", e);
}
// 如果本地模型失败,则调用 OpenAI 服务进行向量化
Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo);
EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
// 处理 OpenAI 返回的嵌入数据
vectorList = processOpenAiEmbeddings(embeddings);
return vectorList;
}
/**
* 构建 Embedding 对象
*/
private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) {
return Embedding.builder()
.input(chunkList)
.model(knowledgeInfoVo.getVectorModel())
.build();
}
/**
* 处理 OpenAI 返回的嵌入数据
*/
private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) {
List<List<Double>> vectorList = new ArrayList<>();
embeddings.getData().forEach(data -> {
List<BigDecimal> vector = data.getEmbedding();
List<Double> doubleVector = convertToDoubleList(vector);
vectorList.add(doubleVector);
});
return vectorList;
}
/**
* 将 BigDecimal 转换为 Double 列表
*/
private List<Double> convertToDoubleList(List<BigDecimal> vector) {
return vector.stream()
.map(BigDecimal::doubleValue)
.collect(Collectors.toList());
}
@Override
public List<Double> singleVectorization(String chunk, String kid) {
List<String> chunkList = new ArrayList<>();

View File

@@ -0,0 +1,15 @@
package org.ruoyi.knowledge.chain.vectorizer;
public enum VectorizationType {
OPENAI, // OpenAI 向量化
LOCAL; // 本地模型向量化
public static VectorizationType fromString(String type) {
for (VectorizationType v : values()) {
if (v.name().equalsIgnoreCase(type)) {
return v;
}
}
throw new IllegalArgumentException("Unknown VectorizationType: " + type);
}
}