From 4967c3f906b32184aa2e617fea034900bb564aa2 Mon Sep 17 00:00:00 2001 From: "jiahao.he@vtradex.com" <794629435@qq.com> Date: Sun, 16 Mar 2025 20:01:34 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=AC=E5=9C=B0=E5=90=91=E9=87=8F=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../models/LocalModelsSearchRequest.java | 38 ++++ .../models/LocalModelsSearchResponse.java | 20 ++ .../localModels/LocalModelsofitClient.java | 198 ++++++++++++++++++ .../chat/localModels/SearchService.java | 25 +++ .../vectorizer/LocalModelsVectorization.java | 92 ++++++++ .../chain/vectorizer/OpenAiVectorization.java | 72 +++++-- .../chain/vectorizer/VectorizationType.java | 15 ++ 7 files changed, 445 insertions(+), 15 deletions(-) create mode 100644 ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java create mode 100644 ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java create mode 100644 ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java create mode 100644 ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java create mode 100644 ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java create mode 100644 ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java new file mode 100644 index 00000000..4ca71bab --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java @@ -0,0 +1,38 @@ +package org.ruoyi.common.chat.entity.models; + +import lombok.Data; + +import java.util.List; + +/** + * @program: RUOYIAI + * @ClassName LocalModelsSearchRequest + * @description: + * @author: hejh + * @create: 2025-03-15 17:22 + * @Version 1.0 + **/ +@Data +public class LocalModelsSearchRequest { + + private List text; + private String model_name; + private String delimiter; + private int k; + private int block_size; + private int overlap_chars; + + // 构造函数、Getter 和 Setter + public LocalModelsSearchRequest(List text, String model_name, String delimiter, int k, int block_size, int overlap_chars) { + this.text = text; + this.model_name = model_name; + this.delimiter = delimiter; + this.k = k; + this.block_size = block_size; + this.overlap_chars = overlap_chars; + } + + +} + + diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java new file mode 100644 index 00000000..12025d5c --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java @@ -0,0 +1,20 @@ +package org.ruoyi.common.chat.entity.models; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +import java.util.List; + +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class LocalModelsSearchResponse { + @JsonProperty("topKEmbeddings") + + private List>> topKEmbeddings; // 处理三层嵌套数组 + + // 默认构造函数 + public LocalModelsSearchResponse() {} + + + +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java new file mode 100644 index 00000000..606a7c25 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java @@ -0,0 +1,198 @@ +package org.ruoyi.common.chat.localModels; + +import io.micrometer.common.util.StringUtils; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; +import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest; +import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse; +import org.springframework.stereotype.Service; +import retrofit2.Call; +import retrofit2.Callback; +import retrofit2.Response; +import retrofit2.Retrofit; +import retrofit2.converter.jackson.JacksonConverterFactory; + +import java.util.List; +import java.util.concurrent.CountDownLatch; + +@Slf4j +@Service +public class LocalModelsofitClient { + private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask 服务的 URL + private static Retrofit retrofit = null; + + // 获取 Retrofit 实例 + public static Retrofit getRetrofitInstance() { + if (retrofit == null) { + OkHttpClient client = new OkHttpClient.Builder() + .build(); + + retrofit = new Retrofit.Builder() + .baseUrl(BASE_URL) + .client(client) + .addConverterFactory(JacksonConverterFactory.create()) // 使用 Jackson 处理 JSON 转换 + .build(); + } + return retrofit; + } + + /** + * 向 Flask 服务发送文本向量化请求 + * + * @param queries 查询文本列表 + * @param modelName 模型名称 + * @param delimiter 文本分隔符 + * @param topK 返回的结果数 + * @param blockSize 文本块大小 + * @param overlapChars 重叠字符数 + * @return 返回计算得到的 Top K 嵌入向量列表 + */ + + public static List> getTopKEmbeddings( + List queries, + String modelName, + String delimiter, + int topK, + int blockSize, + int overlapChars) { + + modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // 默认模型名称 + delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : "."; // 默认分隔符 + topK = (topK > 0) ? topK : 3; // 默认返回 3 个结果 + blockSize = (blockSize > 0) ? blockSize : 500; // 默认文本块大小为 500 + overlapChars = (overlapChars > 0) ? overlapChars : 50; // 默认重叠字符数为 50 + + // 创建 Retrofit 实例 + Retrofit retrofit = getRetrofitInstance(); + + // 创建 SearchService 接口 + SearchService service = retrofit.create(SearchService.class); + + // 创建请求对象 LocalModelsSearchRequest + LocalModelsSearchRequest request = new LocalModelsSearchRequest( + queries, // 查询文本列表 + modelName, // 模型名称 + delimiter, // 文本分隔符 + topK, // 返回的结果数 + blockSize, // 文本块大小 + overlapChars // 重叠字符数 + ); + + final CountDownLatch latch = new CountDownLatch(1); // 创建一个 CountDownLatch + final List>[] topKEmbeddings = new List[]{null}; // 使用数组来存储结果(因为 Java 不支持直接修改 List) + + // 发起异步请求 + service.vectorize(request).enqueue(new Callback() { + @Override + public void onResponse(Call call, Response response) { + if (response.isSuccessful()) { + LocalModelsSearchResponse searchResponse = response.body(); + if (searchResponse != null) { + topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0); // 获取结果 + log.info("Successfully retrieved embeddings"); + } else { + log.error("Response body is null"); + } + } else { + log.error("Request failed. HTTP error code: " + response.code()); + } + latch.countDown(); // 请求完成,减少计数 + } + + @Override + public void onFailure(Call call, Throwable t) { + t.printStackTrace(); + log.error("Request failed: ", t); + latch.countDown(); // 请求失败,减少计数 + } + }); + + try { + latch.await(); // 等待请求完成 + } catch (InterruptedException e) { + e.printStackTrace(); + } + + return topKEmbeddings[0]; // 返回结果 + } + +// public static void main(String[] args) { +// // 示例调用 +// List queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries."); +// String modelName = "msmarco-distilbert-base-tas-b"; +// String delimiter = "."; +// int topK = 3; +// int blockSize = 500; +// int overlapChars = 50; +// +// List> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars); +// +// // 打印结果 +// if (topKEmbeddings != null) { +// System.out.println("Top K embeddings: "); +// for (List embedding : topKEmbeddings) { +// System.out.println(embedding); +// } +// } else { +// System.out.println("No embeddings returned."); +// } +// } + + +// public static void main(String[] args) { +// // 创建 Retrofit 实例 +// Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance(); +// +// // 创建 SearchService 接口 +// SearchService service = retrofit.create(SearchService.class); +// +// // 创建请求对象 LocalModelsSearchRequest +// LocalModelsSearchRequest request = new LocalModelsSearchRequest( +// Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // 查询文本列表 +// "msmarco-distilbert-base-tas-b", // 模型名称 +// ".", // 分隔符 +// 3, // 返回的结果数 +// 500, // 文本块大小 +// 50 // 重叠字符数 +// ); +// +// // 发起请求 +// service.vectorize(request).enqueue(new Callback() { +// @Override +// public void onResponse(Call call, Response response) { +// if (response.isSuccessful()) { +// LocalModelsSearchResponse searchResponse = response.body(); +// System.out.println("Response Body: " + response.body()); // Print the whole response body for debugging +// +// if (searchResponse != null) { +// // If the response is not null, process it. +// // Example: Extract the embeddings and print them +// List>> topKEmbeddings = searchResponse.getTopKEmbeddings(); +// if (topKEmbeddings != null) { +// // Print the Top K embeddings +// +// } else { +// System.err.println("Top K embeddings are null"); +// } +// +// // If there is more information you want to process, handle it here +// +// } else { +// System.err.println("Response body is null"); +// } +// } else { +// System.err.println("Request failed. HTTP error code: " + response.code()); +// log.error("Failed to retrieve data. HTTP error code: " + response.code()); +// } +// } +// +// @Override +// public void onFailure(Call call, Throwable t) { +// // 请求失败,打印错误 +// t.printStackTrace(); +// log.error("Request failed: ", t); +// } +// }); +// } + +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java new file mode 100644 index 00000000..3fa131e5 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java @@ -0,0 +1,25 @@ +package org.ruoyi.common.chat.localModels; + + + +import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest; +import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse; +import retrofit2.Call; +import retrofit2.http.Body; +import retrofit2.http.POST; +/** + * @program: RUOYIAI + * @ClassName SearchService + * @description: 请求模型 + * @author: hejh + * @create: 2025-03-15 17:27 + * @Version 1.0 + **/ + + +public interface SearchService { + @POST("/vectorize") // 与 Flask 服务中的路由匹配 + Call vectorize(@Body LocalModelsSearchRequest request); +} + + diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java new file mode 100644 index 00000000..d7dff252 --- /dev/null +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java @@ -0,0 +1,92 @@ +package org.ruoyi.knowledge.chain.vectorizer; + +import jakarta.annotation.Resource; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.chat.config.ChatConfig; +import org.ruoyi.common.chat.localModels.LocalModelsofitClient; +import org.ruoyi.common.chat.openai.OpenAiStreamClient; +import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo; +import org.ruoyi.knowledge.service.IKnowledgeInfoService; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.List; + +@Component +@Slf4j +@RequiredArgsConstructor +public class LocalModelsVectorization { + @Resource + private IKnowledgeInfoService knowledgeInfoService; + + @Resource + private LocalModelsofitClient localModelsofitClient; + + @Getter + private OpenAiStreamClient openAiStreamClient; + + private final ChatConfig chatConfig; + + /** + * 批量向量化 + * + * @param chunkList 文本块列表 + * @param kid 知识 ID + * @return 向量化结果 + */ + + public List> batchVectorization(List chunkList, String kid) { + logVectorizationRequest(kid, chunkList); // 在向量化开始前记录日志 + openAiStreamClient = chatConfig.getOpenAiStreamClient(); // 获取 OpenAi 客户端 + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // 查询知识信息 + // 调用 localModelsofitClient 获取 Top K 嵌入向量 + try { + return localModelsofitClient.getTopKEmbeddings( + chunkList, + knowledgeInfoVo.getVector(), + knowledgeInfoVo.getKnowledgeSeparator(), + knowledgeInfoVo.getRetrieveLimit(), + knowledgeInfoVo.getTextBlockSize(), + knowledgeInfoVo.getOverlapChar() + ); + } catch (Exception e) { + log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e); + throw new RuntimeException("Batch vectorization failed", e); + } + } + + /** + * 单一文本块向量化 + * + * @param chunk 单一文本块 + * @param kid 知识 ID + * @return 向量化结果 + */ + + public List singleVectorization(String chunk, String kid) { + List chunkList = new ArrayList<>(); + chunkList.add(chunk); + + // 调用批量向量化方法 + List> vectorList = batchVectorization(chunkList, kid); + + if (vectorList.isEmpty()) { + log.warn("Vectorization returned empty list for chunk: {}", chunk); + return new ArrayList<>(); + } + + return vectorList.get(0); // 返回第一个向量 + } + + /** + * 提供更简洁的日志记录方法 + * + * @param kid 知识 ID + * @param chunkList 文本块列表 + */ + private void logVectorizationRequest(String kid, List chunkList) { + log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size()); + } +} diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java index 0f2d0ba5..764c2c16 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java @@ -18,6 +18,7 @@ import org.springframework.stereotype.Component; import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; @Component @Slf4j @@ -27,6 +28,9 @@ public class OpenAiVectorization implements Vectorization { @Lazy @Resource private IKnowledgeInfoService knowledgeInfoService; + @Lazy + @Resource + private LocalModelsVectorization localModelsVectorization; @Getter private OpenAiStreamClient openAiStreamClient; @@ -35,25 +39,63 @@ public class OpenAiVectorization implements Vectorization { @Override public List> batchVectorization(List chunkList, String kid) { - openAiStreamClient = chatConfig.getOpenAiStreamClient(); - KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); - Embedding embedding = Embedding.builder() - .input(chunkList) - .model(knowledgeInfoVo.getVectorModel()) - .build(); - EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding); List> vectorList = new ArrayList<>(); - embeddings.getData().forEach(data -> { - List vector = data.getEmbedding(); - List doubleVector = new ArrayList<>(); - for (BigDecimal bd : vector) { - doubleVector.add(bd.doubleValue()); - } - vectorList.add(doubleVector); - }); + + // 获取知识库信息 + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); + + // 如果使用本地模型 + try { + return localModelsVectorization.batchVectorization(chunkList, kid); + } catch (Exception e) { + log.error("Local models vectorization failed, falling back to OpenAI embeddings", e); + } + + // 如果本地模型失败,则调用 OpenAI 服务进行向量化 + Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo); + EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding); + + // 处理 OpenAI 返回的嵌入数据 + vectorList = processOpenAiEmbeddings(embeddings); + return vectorList; } + /** + * 构建 Embedding 对象 + */ + private Embedding buildEmbedding(List chunkList, KnowledgeInfoVo knowledgeInfoVo) { + return Embedding.builder() + .input(chunkList) + .model(knowledgeInfoVo.getVectorModel()) + .build(); + } + + /** + * 处理 OpenAI 返回的嵌入数据 + */ + private List> processOpenAiEmbeddings(EmbeddingResponse embeddings) { + List> vectorList = new ArrayList<>(); + + embeddings.getData().forEach(data -> { + List vector = data.getEmbedding(); + List doubleVector = convertToDoubleList(vector); + vectorList.add(doubleVector); + }); + + return vectorList; + } + + /** + * 将 BigDecimal 转换为 Double 列表 + */ + private List convertToDoubleList(List vector) { + return vector.stream() + .map(BigDecimal::doubleValue) + .collect(Collectors.toList()); + } + + @Override public List singleVectorization(String chunk, String kid) { List chunkList = new ArrayList<>(); diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java new file mode 100644 index 00000000..a9d370d5 --- /dev/null +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java @@ -0,0 +1,15 @@ +package org.ruoyi.knowledge.chain.vectorizer; + +public enum VectorizationType { + OPENAI, // OpenAI 向量化 + LOCAL; // 本地模型向量化 + + public static VectorizationType fromString(String type) { + for (VectorizationType v : values()) { + if (v.name().equalsIgnoreCase(type)) { + return v; + } + } + throw new IllegalArgumentException("Unknown VectorizationType: " + type); + } +}