mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-03-14 05:03:43 +08:00
feat: 调整知识库模块
This commit is contained in:
@@ -1,26 +0,0 @@
|
||||
# 使用 Python 3.8 slim 作为基础镜像(轻量稳定)
|
||||
FROM python:3.10
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 复制所有文件到容器
|
||||
COPY . /app
|
||||
|
||||
# 更新 pip 并安装 wheel(避免依赖问题)
|
||||
RUN pip install --upgrade pip wheel
|
||||
|
||||
# 使用阿里云 PyPI 镜像安装依赖,并添加 trusted-host
|
||||
RUN pip install -i https://mirrors.aliyun.com/pypi/simple \
|
||||
--trusted-host mirrors.aliyun.com \
|
||||
--no-cache-dir -r requirements.txt
|
||||
|
||||
# 暴露 Flask 端口
|
||||
EXPOSE 5000
|
||||
|
||||
# 设置环境变量
|
||||
ENV FLASK_APP=app.py
|
||||
ENV FLASK_RUN_HOST=0.0.0.0
|
||||
|
||||
# 运行 Flask
|
||||
CMD ["flask", "run", "--host=0.0.0.0"]
|
||||
@@ -1,116 +0,0 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import json
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# 创建一个全局的模型缓存字典
|
||||
model_cache = {}
|
||||
|
||||
# 分割文本块
|
||||
def split_text(text, block_size, overlap_chars, delimiter):
|
||||
chunks = text.split(delimiter)
|
||||
text_blocks = []
|
||||
current_block = ""
|
||||
|
||||
for chunk in chunks:
|
||||
if len(current_block) + len(chunk) + 1 <= block_size:
|
||||
if current_block:
|
||||
current_block += " " + chunk
|
||||
else:
|
||||
current_block = chunk
|
||||
else:
|
||||
text_blocks.append(current_block)
|
||||
current_block = chunk
|
||||
if current_block:
|
||||
text_blocks.append(current_block)
|
||||
|
||||
overlap_blocks = []
|
||||
for i in range(len(text_blocks)):
|
||||
if i > 0:
|
||||
overlap_block = text_blocks[i - 1][-overlap_chars:] + text_blocks[i]
|
||||
overlap_blocks.append(overlap_block)
|
||||
overlap_blocks.append(text_blocks[i])
|
||||
|
||||
return overlap_blocks
|
||||
|
||||
# 文本向量化
|
||||
def vectorize_text_blocks(text_blocks, model):
|
||||
return model.encode(text_blocks)
|
||||
|
||||
# 文本检索
|
||||
def retrieve_top_k(query, knowledge_base, k, block_size, overlap_chars, delimiter, model):
|
||||
# 将知识库拆分为文本块
|
||||
text_blocks = split_text(knowledge_base, block_size, overlap_chars, delimiter)
|
||||
# 向量化文本块
|
||||
knowledge_vectors = vectorize_text_blocks(text_blocks, model)
|
||||
# 向量化查询文本
|
||||
query_vector = model.encode([query]).reshape(1, -1)
|
||||
# 计算相似度
|
||||
similarities = cosine_similarity(query_vector, knowledge_vectors)
|
||||
# 获取相似度最高的 k 个文本块的索引
|
||||
top_k_indices = similarities[0].argsort()[-k:][::-1]
|
||||
|
||||
# 返回文本块和它们的向量
|
||||
top_k_texts = [text_blocks[i] for i in top_k_indices]
|
||||
top_k_embeddings = [knowledge_vectors[i] for i in top_k_indices]
|
||||
|
||||
return top_k_texts, top_k_embeddings
|
||||
|
||||
@app.route('/vectorize', methods=['POST'])
|
||||
def vectorize_text():
|
||||
# 从请求中获取 JSON 数据
|
||||
data = request.json
|
||||
print(f"Received request data: {data}") # 调试输出请求数据
|
||||
|
||||
text_list = data.get("text", [])
|
||||
model_name = data.get("model_name", "msmarco-distilbert-base-tas-b") # 默认模型
|
||||
|
||||
delimiter = data.get("delimiter", "\n") # 默认分隔符
|
||||
k = int(data.get("k", 3)) # 默认检索条数
|
||||
block_size = int(data.get("block_size", 500)) # 默认文本块大小
|
||||
overlap_chars = int(data.get("overlap_chars", 50)) # 默认重叠字符数
|
||||
|
||||
if not text_list:
|
||||
return jsonify({"error": "Text is required."}), 400
|
||||
|
||||
# 检查模型是否已经加载
|
||||
if model_name not in model_cache:
|
||||
try:
|
||||
model = SentenceTransformer(model_name)
|
||||
model_cache[model_name] = model # 缓存模型
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"Failed to load model: {e}"}), 500
|
||||
|
||||
model = model_cache[model_name]
|
||||
|
||||
top_k_texts_all = []
|
||||
top_k_embeddings_all = []
|
||||
|
||||
# 如果只有一个查询文本
|
||||
if len(text_list) == 1:
|
||||
top_k_texts, top_k_embeddings = retrieve_top_k(text_list[0], text_list[0], k, block_size, overlap_chars, delimiter, model)
|
||||
top_k_texts_all.append(top_k_texts)
|
||||
top_k_embeddings_all.append(top_k_embeddings)
|
||||
elif len(text_list) > 1:
|
||||
# 如果多个查询文本,依次处理
|
||||
for query in text_list:
|
||||
top_k_texts, top_k_embeddings = retrieve_top_k(query, text_list[0], k, block_size, overlap_chars, delimiter, model)
|
||||
top_k_texts_all.append(top_k_texts)
|
||||
top_k_embeddings_all.append(top_k_embeddings)
|
||||
|
||||
# 将嵌入向量(ndarray)转换为可序列化的列表
|
||||
top_k_embeddings_all = [[embedding.tolist() for embedding in embeddings] for embeddings in top_k_embeddings_all]
|
||||
|
||||
print(f"Top K texts: {top_k_texts_all}") # 打印检索到的文本
|
||||
print(f"Top K embeddings: {top_k_embeddings_all}") # 打印检索到的向量
|
||||
|
||||
# 返回 JSON 格式的数据
|
||||
return jsonify({
|
||||
|
||||
"topKEmbeddings": top_k_embeddings_all # 返回嵌入向量
|
||||
})
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0", port=5000, debug=True)
|
||||
@@ -1,4 +0,0 @@
|
||||
1.下载镜像
|
||||
docker pull registry.cn-hangzhou.aliyuncs.com/hejh-docker/localmodel:0.1.1
|
||||
2. 启动
|
||||
docker run -p 5000:5000 <imageId>
|
||||
@@ -1,5 +0,0 @@
|
||||
Flask==3.1.0
|
||||
sentence-transformers==3.4.1
|
||||
scikit-learn==1.6.1
|
||||
werkzeug>=3.1
|
||||
|
||||
Reference in New Issue
Block a user