feat: 优化通过知识库获取模型配置逻辑,修改为通过模型id查找模型配置,避免多供应商同模型映射错误。

This commit is contained in:
Robust_H
2025-10-09 20:03:34 +08:00
parent 2cef4e17dc
commit 5475776caa
8 changed files with 49 additions and 35 deletions

View File

@@ -83,6 +83,11 @@ public class KnowledgeInfo extends BaseEntity {
*/
private String vectorModelName;
/**
* 向量化模型id
*/
private Long embeddingModelId;
/**
* 向量化模型名称
*/

View File

@@ -92,7 +92,11 @@ public class KnowledgeInfoBo extends BaseEntity {
/**
* 向量化模型名称
*/
@NotBlank(message = "向量模型不能为空", groups = { AddGroup.class, EditGroup.class })
private Long embeddingModelId;
/**
* 向量化模型名称
*/
private String embeddingModelName;

View File

@@ -31,7 +31,12 @@ public class QueryVectorBo {
private String vectorModelName;
/**
* 向量化模型名称
* 向量化模型ID
*/
private Long embeddingModelId;
/**
* 向量化模型ID
*/
private String embeddingModelName;

View File

@@ -36,6 +36,11 @@ public class StoreEmbeddingBo {
*/
private String vectorModelName;
/**
* 向量化模型id
*/
private Long embeddingModelId;
/**
* 向量化模型名称
*/

View File

@@ -101,6 +101,11 @@ public class KnowledgeInfoVo implements Serializable {
*/
private String vectorModelName;
/**
* 向量化模型id
*/
private Long embeddingModelId;
/**
* 向量化模型名称
*/

View File

@@ -28,6 +28,8 @@ import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.embedding.BaseEmbedModelService;
import org.ruoyi.embedding.EmbeddingModelFactory;
import org.ruoyi.service.VectorStoreService;
import org.springframework.stereotype.Service;
import java.util.*;
@@ -48,6 +50,8 @@ public class VectorStoreServiceImpl implements VectorStoreService {
// private EmbeddingStore<TextSegment> embeddingStore;
private WeaviateClient client;
private final EmbeddingModelFactory embeddingModelFactory;
@Override
public void createSchema(String kid, String modelName) {
@@ -98,18 +102,16 @@ public class VectorStoreServiceImpl implements VectorStoreService {
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
BaseEmbedModelService model = embeddingModelFactory.createModel(storeEmbeddingBo.getEmbeddingModelId());
List<String> chunkList = storeEmbeddingBo.getChunkList();
List<String> fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
String docId = storeEmbeddingBo.getDocId();
log.info("向量存储条数记录: " + chunkList.size());
long startTime = System.currentTimeMillis();
for (int i = 0; i < chunkList.size(); i++) {
String text = chunkList.get(i);
String fid = fidList.get(i);
Embedding embedding = embeddingModel.embed(text).content();
Embedding embedding = model.embed(text).content();
Map<String, Object> properties = Map.of(
"text", text,
"fid",fid,
@@ -137,9 +139,8 @@ public class VectorStoreServiceImpl implements VectorStoreService {
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
BaseEmbedModelService model = embeddingModelFactory.createModel(queryVectorBo.getEmbeddingModelId());
Embedding queryEmbedding = model.embed(queryVectorBo.getQuery()).content();
float[] vector = queryEmbedding.vector();
List<String> vectorStrings = new ArrayList<>();
for (float v : vector) {
@@ -246,28 +247,4 @@ public class VectorStoreServiceImpl implements VectorStoreService {
log.error("删除失败: {}", result.getError());
}
}
/**
* 获取向量模型
*/
@SneakyThrows
public EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
EmbeddingModel embeddingModel;
if ("quentinz/bge-large-zh-v1.5".equals(modelName)) {
embeddingModel = OllamaEmbeddingModel.builder()
.baseUrl(baseUrl)
.modelName(modelName)
.build();
} else if ("baai/bge-m3".equals(modelName)) {
embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.modelName(modelName)
.build();
} else {
throw new ServiceException("未找到对应向量化模型!");
}
return embeddingModel;
}
}