mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-03-19 23:53:42 +08:00
feat: 优化通过知识库获取模型配置逻辑,修改为通过模型id查找模型配置,避免多供应商同模型映射错误。
This commit is contained in:
@@ -83,6 +83,11 @@ public class KnowledgeInfo extends BaseEntity {
|
||||
*/
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型id
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
|
||||
@@ -92,7 +92,11 @@ public class KnowledgeInfoBo extends BaseEntity {
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
@NotBlank(message = "向量模型不能为空", groups = { AddGroup.class, EditGroup.class })
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
private String embeddingModelName;
|
||||
|
||||
|
||||
|
||||
@@ -31,7 +31,12 @@ public class QueryVectorBo {
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
* 向量化模型ID
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型ID
|
||||
*/
|
||||
private String embeddingModelName;
|
||||
|
||||
|
||||
@@ -36,6 +36,11 @@ public class StoreEmbeddingBo {
|
||||
*/
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型id
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
|
||||
@@ -101,6 +101,11 @@ public class KnowledgeInfoVo implements Serializable {
|
||||
*/
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型id
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
|
||||
@@ -28,6 +28,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.service.ConfigService;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
import org.ruoyi.embedding.BaseEmbedModelService;
|
||||
import org.ruoyi.embedding.EmbeddingModelFactory;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.*;
|
||||
@@ -48,6 +50,8 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
// private EmbeddingStore<TextSegment> embeddingStore;
|
||||
private WeaviateClient client;
|
||||
|
||||
private final EmbeddingModelFactory embeddingModelFactory;
|
||||
|
||||
|
||||
@Override
|
||||
public void createSchema(String kid, String modelName) {
|
||||
@@ -98,18 +102,16 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
@Override
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
|
||||
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
|
||||
BaseEmbedModelService model = embeddingModelFactory.createModel(storeEmbeddingBo.getEmbeddingModelId());
|
||||
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||
List<String> fidList = storeEmbeddingBo.getFids();
|
||||
String kid = storeEmbeddingBo.getKid();
|
||||
String docId = storeEmbeddingBo.getDocId();
|
||||
log.info("向量存储条数记录: " + chunkList.size());
|
||||
long startTime = System.currentTimeMillis();
|
||||
for (int i = 0; i < chunkList.size(); i++) {
|
||||
String text = chunkList.get(i);
|
||||
String fid = fidList.get(i);
|
||||
Embedding embedding = embeddingModel.embed(text).content();
|
||||
Embedding embedding = model.embed(text).content();
|
||||
Map<String, Object> properties = Map.of(
|
||||
"text", text,
|
||||
"fid",fid,
|
||||
@@ -137,9 +139,8 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
|
||||
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
BaseEmbedModelService model = embeddingModelFactory.createModel(queryVectorBo.getEmbeddingModelId());
|
||||
Embedding queryEmbedding = model.embed(queryVectorBo.getQuery()).content();
|
||||
float[] vector = queryEmbedding.vector();
|
||||
List<String> vectorStrings = new ArrayList<>();
|
||||
for (float v : vector) {
|
||||
@@ -246,28 +247,4 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
log.error("删除失败: {}", result.getError());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取向量模型
|
||||
*/
|
||||
@SneakyThrows
|
||||
public EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
|
||||
EmbeddingModel embeddingModel;
|
||||
if ("quentinz/bge-large-zh-v1.5".equals(modelName)) {
|
||||
embeddingModel = OllamaEmbeddingModel.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
} else if ("baai/bge-m3".equals(modelName)) {
|
||||
embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
} else {
|
||||
throw new ServiceException("未找到对应向量化模型!");
|
||||
}
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -290,7 +290,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
.eq(KnowledgeInfo::getId, kid));
|
||||
|
||||
// 通过向量模型查询模型信息
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
|
||||
ChatModelVo chatModelVo = chatModelService.queryById(knowledgeInfoVo.getEmbeddingModelId());
|
||||
|
||||
StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
|
||||
storeEmbeddingBo.setKid(kid);
|
||||
@@ -298,7 +298,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
storeEmbeddingBo.setFids(fids);
|
||||
storeEmbeddingBo.setChunkList(chunkList);
|
||||
storeEmbeddingBo.setVectorModelName(knowledgeInfoVo.getVectorModelName());
|
||||
storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
|
||||
storeEmbeddingBo.setEmbeddingModelId(knowledgeInfoVo.getEmbeddingModelId());
|
||||
storeEmbeddingBo.setApiKey(chatModelVo.getApiKey());
|
||||
storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost());
|
||||
vectorStoreService.storeEmbeddings(storeEmbeddingBo);
|
||||
|
||||
13
script/sql/update/2025-10-4-多供应商嵌入模型集成.sql
Normal file
13
script/sql/update/2025-10-4-多供应商嵌入模型集成.sql
Normal file
@@ -0,0 +1,13 @@
|
||||
-- 为 chat_model 表添加 provider_name 字段
|
||||
-- 变更日期: 2025-10-04
|
||||
-- 负责人: Robust_H
|
||||
-- 说明: 嵌入模型供应商 (用于实现动态选择嵌入模型实现类)
|
||||
ALTER TABLE `ruoyi-ai`.chat_model
|
||||
ADD COLUMN `provider_name` varchar(20) DEFAULT NULL COMMENT '模型供应商' AFTER `model_name`;
|
||||
|
||||
-- 修改 knowledge_info 中的 ‘embedding_model_name’ 为 ‘embedding_model_id’
|
||||
-- 变更日期: 2025-10-04
|
||||
-- 负责人: Robust_H
|
||||
-- 说明: 用于区分多个供应商实现同一嵌入模型的情况
|
||||
ALTER TABLE `ruoyi-ai`.knowledge_info
|
||||
ADD COLUMN `embedding_model_id` bigint DEFAULT NULL COMMENT '模型id' AFTER `embedding_model_name`;
|
||||
Reference in New Issue
Block a user