feat: 优化通过知识库获取模型配置逻辑,修改为通过模型id查找模型配置,避免多供应商同模型映射错误。

This commit is contained in:
Robust_H
2025-10-09 20:03:34 +08:00
parent 2cef4e17dc
commit 5475776caa
8 changed files with 49 additions and 35 deletions

View File

@@ -83,6 +83,11 @@ public class KnowledgeInfo extends BaseEntity {
*/
private String vectorModelName;
/**
* 向量化模型id
*/
private Long embeddingModelId;
/**
* 向量化模型名称
*/

View File

@@ -92,7 +92,11 @@ public class KnowledgeInfoBo extends BaseEntity {
/**
* 向量化模型名称
*/
@NotBlank(message = "向量模型不能为空", groups = { AddGroup.class, EditGroup.class })
private Long embeddingModelId;
/**
* 向量化模型名称
*/
private String embeddingModelName;

View File

@@ -31,7 +31,12 @@ public class QueryVectorBo {
private String vectorModelName;
/**
* 向量化模型名称
* 向量化模型ID
*/
private Long embeddingModelId;
/**
* 向量化模型ID
*/
private String embeddingModelName;

View File

@@ -36,6 +36,11 @@ public class StoreEmbeddingBo {
*/
private String vectorModelName;
/**
* 向量化模型id
*/
private Long embeddingModelId;
/**
* 向量化模型名称
*/

View File

@@ -101,6 +101,11 @@ public class KnowledgeInfoVo implements Serializable {
*/
private String vectorModelName;
/**
* 向量化模型id
*/
private Long embeddingModelId;
/**
* 向量化模型名称
*/

View File

@@ -28,6 +28,8 @@ import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.embedding.BaseEmbedModelService;
import org.ruoyi.embedding.EmbeddingModelFactory;
import org.ruoyi.service.VectorStoreService;
import org.springframework.stereotype.Service;
import java.util.*;
@@ -48,6 +50,8 @@ public class VectorStoreServiceImpl implements VectorStoreService {
// private EmbeddingStore<TextSegment> embeddingStore;
private WeaviateClient client;
private final EmbeddingModelFactory embeddingModelFactory;
@Override
public void createSchema(String kid, String modelName) {
@@ -98,18 +102,16 @@ public class VectorStoreServiceImpl implements VectorStoreService {
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
BaseEmbedModelService model = embeddingModelFactory.createModel(storeEmbeddingBo.getEmbeddingModelId());
List<String> chunkList = storeEmbeddingBo.getChunkList();
List<String> fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
String docId = storeEmbeddingBo.getDocId();
log.info("向量存储条数记录: " + chunkList.size());
long startTime = System.currentTimeMillis();
for (int i = 0; i < chunkList.size(); i++) {
String text = chunkList.get(i);
String fid = fidList.get(i);
Embedding embedding = embeddingModel.embed(text).content();
Embedding embedding = model.embed(text).content();
Map<String, Object> properties = Map.of(
"text", text,
"fid",fid,
@@ -137,9 +139,8 @@ public class VectorStoreServiceImpl implements VectorStoreService {
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
BaseEmbedModelService model = embeddingModelFactory.createModel(queryVectorBo.getEmbeddingModelId());
Embedding queryEmbedding = model.embed(queryVectorBo.getQuery()).content();
float[] vector = queryEmbedding.vector();
List<String> vectorStrings = new ArrayList<>();
for (float v : vector) {
@@ -246,28 +247,4 @@ public class VectorStoreServiceImpl implements VectorStoreService {
log.error("删除失败: {}", result.getError());
}
}
/**
* 获取向量模型
*/
@SneakyThrows
public EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
EmbeddingModel embeddingModel;
if ("quentinz/bge-large-zh-v1.5".equals(modelName)) {
embeddingModel = OllamaEmbeddingModel.builder()
.baseUrl(baseUrl)
.modelName(modelName)
.build();
} else if ("baai/bge-m3".equals(modelName)) {
embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.modelName(modelName)
.build();
} else {
throw new ServiceException("未找到对应向量化模型!");
}
return embeddingModel;
}
}

View File

@@ -290,7 +290,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
.eq(KnowledgeInfo::getId, kid));
// 通过向量模型查询模型信息
ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
ChatModelVo chatModelVo = chatModelService.queryById(knowledgeInfoVo.getEmbeddingModelId());
StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
storeEmbeddingBo.setKid(kid);
@@ -298,7 +298,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
storeEmbeddingBo.setFids(fids);
storeEmbeddingBo.setChunkList(chunkList);
storeEmbeddingBo.setVectorModelName(knowledgeInfoVo.getVectorModelName());
storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
storeEmbeddingBo.setEmbeddingModelId(knowledgeInfoVo.getEmbeddingModelId());
storeEmbeddingBo.setApiKey(chatModelVo.getApiKey());
storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost());
vectorStoreService.storeEmbeddings(storeEmbeddingBo);

View File

@@ -0,0 +1,13 @@
-- 为 chat_model 表添加 provider_name 字段
-- 变更日期: 2025-10-04
-- 负责人: Robust_H
-- 说明: 嵌入模型供应商 (用于实现动态选择嵌入模型实现类)
ALTER TABLE `ruoyi-ai`.chat_model
ADD COLUMN `provider_name` varchar(20) DEFAULT NULL COMMENT '模型供应商' AFTER `model_name`;
-- 修改 knowledge_info 中的 embedding_model_nameembedding_model_id
-- 变更日期: 2025-10-04
-- 负责人: Robust_H
-- 说明: 用于区分多个供应商实现同一嵌入模型的情况
ALTER TABLE `ruoyi-ai`.knowledge_info
ADD COLUMN `embedding_model_id` bigint DEFAULT NULL COMMENT '模型id' AFTER `embedding_model_name`;