mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-03-27 19:43:42 +08:00
@@ -0,0 +1,38 @@
|
||||
package org.ruoyi.common.chat.entity.models;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @program: RUOYIAI
|
||||
* @ClassName LocalModelsSearchRequest
|
||||
* @description:
|
||||
* @author: hejh
|
||||
* @create: 2025-03-15 17:22
|
||||
* @Version 1.0
|
||||
**/
|
||||
@Data
|
||||
public class LocalModelsSearchRequest {
|
||||
|
||||
private List<String> text;
|
||||
private String model_name;
|
||||
private String delimiter;
|
||||
private int k;
|
||||
private int block_size;
|
||||
private int overlap_chars;
|
||||
|
||||
// 构造函数、Getter 和 Setter
|
||||
public LocalModelsSearchRequest(List<String> text, String model_name, String delimiter, int k, int block_size, int overlap_chars) {
|
||||
this.text = text;
|
||||
this.model_name = model_name;
|
||||
this.delimiter = delimiter;
|
||||
this.k = k;
|
||||
this.block_size = block_size;
|
||||
this.overlap_chars = overlap_chars;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package org.ruoyi.common.chat.entity.models;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public class LocalModelsSearchResponse {
|
||||
@JsonProperty("topKEmbeddings")
|
||||
|
||||
private List<List<List<Double>>> topKEmbeddings; // 处理三层嵌套数组
|
||||
|
||||
// 默认构造函数
|
||||
public LocalModelsSearchResponse() {}
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package org.ruoyi.common.chat.localModels;
|
||||
|
||||
import io.micrometer.common.util.StringUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.OkHttpClient;
|
||||
import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
|
||||
import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
|
||||
import org.springframework.stereotype.Service;
|
||||
import retrofit2.Call;
|
||||
import retrofit2.Callback;
|
||||
import retrofit2.Response;
|
||||
import retrofit2.Retrofit;
|
||||
import retrofit2.converter.jackson.JacksonConverterFactory;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LocalModelsofitClient {
|
||||
private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask 服务的 URL
|
||||
private static Retrofit retrofit = null;
|
||||
|
||||
// 获取 Retrofit 实例
|
||||
public static Retrofit getRetrofitInstance() {
|
||||
if (retrofit == null) {
|
||||
OkHttpClient client = new OkHttpClient.Builder()
|
||||
.build();
|
||||
|
||||
retrofit = new Retrofit.Builder()
|
||||
.baseUrl(BASE_URL)
|
||||
.client(client)
|
||||
.addConverterFactory(JacksonConverterFactory.create()) // 使用 Jackson 处理 JSON 转换
|
||||
.build();
|
||||
}
|
||||
return retrofit;
|
||||
}
|
||||
|
||||
/**
|
||||
* 向 Flask 服务发送文本向量化请求
|
||||
*
|
||||
* @param queries 查询文本列表
|
||||
* @param modelName 模型名称
|
||||
* @param delimiter 文本分隔符
|
||||
* @param topK 返回的结果数
|
||||
* @param blockSize 文本块大小
|
||||
* @param overlapChars 重叠字符数
|
||||
* @return 返回计算得到的 Top K 嵌入向量列表
|
||||
*/
|
||||
|
||||
public static List<List<Double>> getTopKEmbeddings(
|
||||
List<String> queries,
|
||||
String modelName,
|
||||
String delimiter,
|
||||
int topK,
|
||||
int blockSize,
|
||||
int overlapChars) {
|
||||
|
||||
modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // 默认模型名称
|
||||
delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : "."; // 默认分隔符
|
||||
topK = (topK > 0) ? topK : 3; // 默认返回 3 个结果
|
||||
blockSize = (blockSize > 0) ? blockSize : 500; // 默认文本块大小为 500
|
||||
overlapChars = (overlapChars > 0) ? overlapChars : 50; // 默认重叠字符数为 50
|
||||
|
||||
// 创建 Retrofit 实例
|
||||
Retrofit retrofit = getRetrofitInstance();
|
||||
|
||||
// 创建 SearchService 接口
|
||||
SearchService service = retrofit.create(SearchService.class);
|
||||
|
||||
// 创建请求对象 LocalModelsSearchRequest
|
||||
LocalModelsSearchRequest request = new LocalModelsSearchRequest(
|
||||
queries, // 查询文本列表
|
||||
modelName, // 模型名称
|
||||
delimiter, // 文本分隔符
|
||||
topK, // 返回的结果数
|
||||
blockSize, // 文本块大小
|
||||
overlapChars // 重叠字符数
|
||||
);
|
||||
|
||||
final CountDownLatch latch = new CountDownLatch(1); // 创建一个 CountDownLatch
|
||||
final List<List<Double>>[] topKEmbeddings = new List[]{null}; // 使用数组来存储结果(因为 Java 不支持直接修改 List)
|
||||
|
||||
// 发起异步请求
|
||||
service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
|
||||
@Override
|
||||
public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
|
||||
if (response.isSuccessful()) {
|
||||
LocalModelsSearchResponse searchResponse = response.body();
|
||||
if (searchResponse != null) {
|
||||
topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0); // 获取结果
|
||||
log.info("Successfully retrieved embeddings");
|
||||
} else {
|
||||
log.error("Response body is null");
|
||||
}
|
||||
} else {
|
||||
log.error("Request failed. HTTP error code: " + response.code());
|
||||
}
|
||||
latch.countDown(); // 请求完成,减少计数
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
|
||||
t.printStackTrace();
|
||||
log.error("Request failed: ", t);
|
||||
latch.countDown(); // 请求失败,减少计数
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
latch.await(); // 等待请求完成
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
return topKEmbeddings[0]; // 返回结果
|
||||
}
|
||||
|
||||
// public static void main(String[] args) {
|
||||
// // 示例调用
|
||||
// List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries.");
|
||||
// String modelName = "msmarco-distilbert-base-tas-b";
|
||||
// String delimiter = ".";
|
||||
// int topK = 3;
|
||||
// int blockSize = 500;
|
||||
// int overlapChars = 50;
|
||||
//
|
||||
// List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars);
|
||||
//
|
||||
// // 打印结果
|
||||
// if (topKEmbeddings != null) {
|
||||
// System.out.println("Top K embeddings: ");
|
||||
// for (List<Double> embedding : topKEmbeddings) {
|
||||
// System.out.println(embedding);
|
||||
// }
|
||||
// } else {
|
||||
// System.out.println("No embeddings returned.");
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
// public static void main(String[] args) {
|
||||
// // 创建 Retrofit 实例
|
||||
// Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance();
|
||||
//
|
||||
// // 创建 SearchService 接口
|
||||
// SearchService service = retrofit.create(SearchService.class);
|
||||
//
|
||||
// // 创建请求对象 LocalModelsSearchRequest
|
||||
// LocalModelsSearchRequest request = new LocalModelsSearchRequest(
|
||||
// Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // 查询文本列表
|
||||
// "msmarco-distilbert-base-tas-b", // 模型名称
|
||||
// ".", // 分隔符
|
||||
// 3, // 返回的结果数
|
||||
// 500, // 文本块大小
|
||||
// 50 // 重叠字符数
|
||||
// );
|
||||
//
|
||||
// // 发起请求
|
||||
// service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
|
||||
// @Override
|
||||
// public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
|
||||
// if (response.isSuccessful()) {
|
||||
// LocalModelsSearchResponse searchResponse = response.body();
|
||||
// System.out.println("Response Body: " + response.body()); // Print the whole response body for debugging
|
||||
//
|
||||
// if (searchResponse != null) {
|
||||
// // If the response is not null, process it.
|
||||
// // Example: Extract the embeddings and print them
|
||||
// List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings();
|
||||
// if (topKEmbeddings != null) {
|
||||
// // Print the Top K embeddings
|
||||
//
|
||||
// } else {
|
||||
// System.err.println("Top K embeddings are null");
|
||||
// }
|
||||
//
|
||||
// // If there is more information you want to process, handle it here
|
||||
//
|
||||
// } else {
|
||||
// System.err.println("Response body is null");
|
||||
// }
|
||||
// } else {
|
||||
// System.err.println("Request failed. HTTP error code: " + response.code());
|
||||
// log.error("Failed to retrieve data. HTTP error code: " + response.code());
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// @Override
|
||||
// public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
|
||||
// // 请求失败,打印错误
|
||||
// t.printStackTrace();
|
||||
// log.error("Request failed: ", t);
|
||||
// }
|
||||
// });
|
||||
// }
|
||||
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package org.ruoyi.common.chat.localModels;
|
||||
|
||||
|
||||
|
||||
import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
|
||||
import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
|
||||
import retrofit2.Call;
|
||||
import retrofit2.http.Body;
|
||||
import retrofit2.http.POST;
|
||||
/**
|
||||
* @program: RUOYIAI
|
||||
* @ClassName SearchService
|
||||
* @description: 请求模型
|
||||
* @author: hejh
|
||||
* @create: 2025-03-15 17:27
|
||||
* @Version 1.0
|
||||
**/
|
||||
|
||||
|
||||
public interface SearchService {
|
||||
@POST("/vectorize") // 与 Flask 服务中的路由匹配
|
||||
Call<LocalModelsSearchResponse> vectorize(@Body LocalModelsSearchRequest request);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package org.ruoyi.knowledge.chain.vectorizer;
|
||||
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.chat.config.ChatConfig;
|
||||
import org.ruoyi.common.chat.localModels.LocalModelsofitClient;
|
||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.knowledge.service.IKnowledgeInfoService;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class LocalModelsVectorization {
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
@Resource
|
||||
private LocalModelsofitClient localModelsofitClient;
|
||||
|
||||
@Getter
|
||||
private OpenAiStreamClient openAiStreamClient;
|
||||
|
||||
private final ChatConfig chatConfig;
|
||||
|
||||
/**
|
||||
* 批量向量化
|
||||
*
|
||||
* @param chunkList 文本块列表
|
||||
* @param kid 知识 ID
|
||||
* @return 向量化结果
|
||||
*/
|
||||
|
||||
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
|
||||
logVectorizationRequest(kid, chunkList); // 在向量化开始前记录日志
|
||||
openAiStreamClient = chatConfig.getOpenAiStreamClient(); // 获取 OpenAi 客户端
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // 查询知识信息
|
||||
// 调用 localModelsofitClient 获取 Top K 嵌入向量
|
||||
try {
|
||||
return localModelsofitClient.getTopKEmbeddings(
|
||||
chunkList,
|
||||
knowledgeInfoVo.getVector(),
|
||||
knowledgeInfoVo.getKnowledgeSeparator(),
|
||||
knowledgeInfoVo.getRetrieveLimit(),
|
||||
knowledgeInfoVo.getTextBlockSize(),
|
||||
knowledgeInfoVo.getOverlapChar()
|
||||
);
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e);
|
||||
throw new RuntimeException("Batch vectorization failed", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 单一文本块向量化
|
||||
*
|
||||
* @param chunk 单一文本块
|
||||
* @param kid 知识 ID
|
||||
* @return 向量化结果
|
||||
*/
|
||||
|
||||
public List<Double> singleVectorization(String chunk, String kid) {
|
||||
List<String> chunkList = new ArrayList<>();
|
||||
chunkList.add(chunk);
|
||||
|
||||
// 调用批量向量化方法
|
||||
List<List<Double>> vectorList = batchVectorization(chunkList, kid);
|
||||
|
||||
if (vectorList.isEmpty()) {
|
||||
log.warn("Vectorization returned empty list for chunk: {}", chunk);
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
return vectorList.get(0); // 返回第一个向量
|
||||
}
|
||||
|
||||
/**
|
||||
* 提供更简洁的日志记录方法
|
||||
*
|
||||
* @param kid 知识 ID
|
||||
* @param chunkList 文本块列表
|
||||
*/
|
||||
private void logVectorizationRequest(String kid, List<String> chunkList) {
|
||||
log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size());
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import org.springframework.stereotype.Component;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@@ -27,6 +28,9 @@ public class OpenAiVectorization implements Vectorization {
|
||||
@Lazy
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
@Lazy
|
||||
@Resource
|
||||
private LocalModelsVectorization localModelsVectorization;
|
||||
|
||||
@Getter
|
||||
private OpenAiStreamClient openAiStreamClient;
|
||||
@@ -35,25 +39,63 @@ public class OpenAiVectorization implements Vectorization {
|
||||
|
||||
@Override
|
||||
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
|
||||
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
Embedding embedding = Embedding.builder()
|
||||
.input(chunkList)
|
||||
.model(knowledgeInfoVo.getVectorModel())
|
||||
.build();
|
||||
EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
|
||||
List<List<Double>> vectorList = new ArrayList<>();
|
||||
embeddings.getData().forEach(data -> {
|
||||
List<BigDecimal> vector = data.getEmbedding();
|
||||
List<Double> doubleVector = new ArrayList<>();
|
||||
for (BigDecimal bd : vector) {
|
||||
doubleVector.add(bd.doubleValue());
|
||||
}
|
||||
vectorList.add(doubleVector);
|
||||
});
|
||||
|
||||
// 获取知识库信息
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
|
||||
// 如果使用本地模型
|
||||
try {
|
||||
return localModelsVectorization.batchVectorization(chunkList, kid);
|
||||
} catch (Exception e) {
|
||||
log.error("Local models vectorization failed, falling back to OpenAI embeddings", e);
|
||||
}
|
||||
|
||||
// 如果本地模型失败,则调用 OpenAI 服务进行向量化
|
||||
Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo);
|
||||
EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
|
||||
|
||||
// 处理 OpenAI 返回的嵌入数据
|
||||
vectorList = processOpenAiEmbeddings(embeddings);
|
||||
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Embedding 对象
|
||||
*/
|
||||
private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) {
|
||||
return Embedding.builder()
|
||||
.input(chunkList)
|
||||
.model(knowledgeInfoVo.getVectorModel())
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 OpenAI 返回的嵌入数据
|
||||
*/
|
||||
private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) {
|
||||
List<List<Double>> vectorList = new ArrayList<>();
|
||||
|
||||
embeddings.getData().forEach(data -> {
|
||||
List<BigDecimal> vector = data.getEmbedding();
|
||||
List<Double> doubleVector = convertToDoubleList(vector);
|
||||
vectorList.add(doubleVector);
|
||||
});
|
||||
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 BigDecimal 转换为 Double 列表
|
||||
*/
|
||||
private List<Double> convertToDoubleList(List<BigDecimal> vector) {
|
||||
return vector.stream()
|
||||
.map(BigDecimal::doubleValue)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<Double> singleVectorization(String chunk, String kid) {
|
||||
List<String> chunkList = new ArrayList<>();
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package org.ruoyi.knowledge.chain.vectorizer;
|
||||
|
||||
public enum VectorizationType {
|
||||
OPENAI, // OpenAI 向量化
|
||||
LOCAL; // 本地模型向量化
|
||||
|
||||
public static VectorizationType fromString(String type) {
|
||||
for (VectorizationType v : values()) {
|
||||
if (v.name().equalsIgnoreCase(type)) {
|
||||
return v;
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException("Unknown VectorizationType: " + type);
|
||||
}
|
||||
}
|
||||
21
script/docker/localModels/Dockerfile
Normal file
21
script/docker/localModels/Dockerfile
Normal file
@@ -0,0 +1,21 @@
|
||||
# 使用官方 Python 作为基础镜像
|
||||
FROM python:3.8-slim
|
||||
|
||||
# 设置工作目录为 /app
|
||||
WORKDIR /app
|
||||
|
||||
# 复制当前目录下的所有文件到 Docker 容器的 /app 目录
|
||||
COPY . /app
|
||||
|
||||
# 安装应用依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 暴露 Flask 应用使用的端口
|
||||
EXPOSE 5000
|
||||
|
||||
# 设置环境变量
|
||||
ENV FLASK_APP=app.py
|
||||
ENV FLASK_RUN_HOST=0.0.0.0
|
||||
|
||||
# 启动 Flask 应用
|
||||
CMD ["flask", "run", "--host=0.0.0.0"]
|
||||
116
script/docker/localModels/app.py
Normal file
116
script/docker/localModels/app.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import json
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# 创建一个全局的模型缓存字典
|
||||
model_cache = {}
|
||||
|
||||
# 分割文本块
|
||||
def split_text(text, block_size, overlap_chars, delimiter):
|
||||
chunks = text.split(delimiter)
|
||||
text_blocks = []
|
||||
current_block = ""
|
||||
|
||||
for chunk in chunks:
|
||||
if len(current_block) + len(chunk) + 1 <= block_size:
|
||||
if current_block:
|
||||
current_block += " " + chunk
|
||||
else:
|
||||
current_block = chunk
|
||||
else:
|
||||
text_blocks.append(current_block)
|
||||
current_block = chunk
|
||||
if current_block:
|
||||
text_blocks.append(current_block)
|
||||
|
||||
overlap_blocks = []
|
||||
for i in range(len(text_blocks)):
|
||||
if i > 0:
|
||||
overlap_block = text_blocks[i - 1][-overlap_chars:] + text_blocks[i]
|
||||
overlap_blocks.append(overlap_block)
|
||||
overlap_blocks.append(text_blocks[i])
|
||||
|
||||
return overlap_blocks
|
||||
|
||||
# 文本向量化
|
||||
def vectorize_text_blocks(text_blocks, model):
|
||||
return model.encode(text_blocks)
|
||||
|
||||
# 文本检索
|
||||
def retrieve_top_k(query, knowledge_base, k, block_size, overlap_chars, delimiter, model):
|
||||
# 将知识库拆分为文本块
|
||||
text_blocks = split_text(knowledge_base, block_size, overlap_chars, delimiter)
|
||||
# 向量化文本块
|
||||
knowledge_vectors = vectorize_text_blocks(text_blocks, model)
|
||||
# 向量化查询文本
|
||||
query_vector = model.encode([query]).reshape(1, -1)
|
||||
# 计算相似度
|
||||
similarities = cosine_similarity(query_vector, knowledge_vectors)
|
||||
# 获取相似度最高的 k 个文本块的索引
|
||||
top_k_indices = similarities[0].argsort()[-k:][::-1]
|
||||
|
||||
# 返回文本块和它们的向量
|
||||
top_k_texts = [text_blocks[i] for i in top_k_indices]
|
||||
top_k_embeddings = [knowledge_vectors[i] for i in top_k_indices]
|
||||
|
||||
return top_k_texts, top_k_embeddings
|
||||
|
||||
@app.route('/vectorize', methods=['POST'])
|
||||
def vectorize_text():
|
||||
# 从请求中获取 JSON 数据
|
||||
data = request.json
|
||||
print(f"Received request data: {data}") # 调试输出请求数据
|
||||
|
||||
text_list = data.get("text", [])
|
||||
model_name = data.get("model_name", "msmarco-distilbert-base-tas-b") # 默认模型
|
||||
|
||||
delimiter = data.get("delimiter", "\n") # 默认分隔符
|
||||
k = int(data.get("k", 3)) # 默认检索条数
|
||||
block_size = int(data.get("block_size", 500)) # 默认文本块大小
|
||||
overlap_chars = int(data.get("overlap_chars", 50)) # 默认重叠字符数
|
||||
|
||||
if not text_list:
|
||||
return jsonify({"error": "Text is required."}), 400
|
||||
|
||||
# 检查模型是否已经加载
|
||||
if model_name not in model_cache:
|
||||
try:
|
||||
model = SentenceTransformer(model_name)
|
||||
model_cache[model_name] = model # 缓存模型
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"Failed to load model: {e}"}), 500
|
||||
|
||||
model = model_cache[model_name]
|
||||
|
||||
top_k_texts_all = []
|
||||
top_k_embeddings_all = []
|
||||
|
||||
# 如果只有一个查询文本
|
||||
if len(text_list) == 1:
|
||||
top_k_texts, top_k_embeddings = retrieve_top_k(text_list[0], text_list[0], k, block_size, overlap_chars, delimiter, model)
|
||||
top_k_texts_all.append(top_k_texts)
|
||||
top_k_embeddings_all.append(top_k_embeddings)
|
||||
elif len(text_list) > 1:
|
||||
# 如果多个查询文本,依次处理
|
||||
for query in text_list:
|
||||
top_k_texts, top_k_embeddings = retrieve_top_k(query, text_list[0], k, block_size, overlap_chars, delimiter, model)
|
||||
top_k_texts_all.append(top_k_texts)
|
||||
top_k_embeddings_all.append(top_k_embeddings)
|
||||
|
||||
# 将嵌入向量(ndarray)转换为可序列化的列表
|
||||
top_k_embeddings_all = [[embedding.tolist() for embedding in embeddings] for embeddings in top_k_embeddings_all]
|
||||
|
||||
print(f"Top K texts: {top_k_texts_all}") # 打印检索到的文本
|
||||
print(f"Top K embeddings: {top_k_embeddings_all}") # 打印检索到的向量
|
||||
|
||||
# 返回 JSON 格式的数据
|
||||
return jsonify({
|
||||
|
||||
"topKEmbeddings": top_k_embeddings_all # 返回嵌入向量
|
||||
})
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0", port=5000, debug=True)
|
||||
3
script/docker/localModels/requirements.txt
Normal file
3
script/docker/localModels/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
Flask==2.0.3
|
||||
sentence-transformers==2.2.0
|
||||
scikit-learn==0.24.2
|
||||
Reference in New Issue
Block a user