mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-03-13 20:53:42 +08:00
feat: 优化通过知识库获取模型配置逻辑,修改为通过模型id查找模型配置,避免多供应商同模型映射错误。
This commit is contained in:
@@ -83,6 +83,11 @@ public class KnowledgeInfo extends BaseEntity {
|
||||
*/
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型id
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
|
||||
@@ -92,7 +92,11 @@ public class KnowledgeInfoBo extends BaseEntity {
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
@NotBlank(message = "向量模型不能为空", groups = { AddGroup.class, EditGroup.class })
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
private String embeddingModelName;
|
||||
|
||||
|
||||
|
||||
@@ -31,7 +31,12 @@ public class QueryVectorBo {
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
* 向量化模型ID
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型ID
|
||||
*/
|
||||
private String embeddingModelName;
|
||||
|
||||
|
||||
@@ -36,6 +36,11 @@ public class StoreEmbeddingBo {
|
||||
*/
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型id
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
|
||||
@@ -101,6 +101,11 @@ public class KnowledgeInfoVo implements Serializable {
|
||||
*/
|
||||
private String vectorModelName;
|
||||
|
||||
/**
|
||||
* 向量化模型id
|
||||
*/
|
||||
private Long embeddingModelId;
|
||||
|
||||
/**
|
||||
* 向量化模型名称
|
||||
*/
|
||||
|
||||
@@ -28,6 +28,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.service.ConfigService;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
import org.ruoyi.embedding.BaseEmbedModelService;
|
||||
import org.ruoyi.embedding.EmbeddingModelFactory;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.*;
|
||||
@@ -48,6 +50,8 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
// private EmbeddingStore<TextSegment> embeddingStore;
|
||||
private WeaviateClient client;
|
||||
|
||||
private final EmbeddingModelFactory embeddingModelFactory;
|
||||
|
||||
|
||||
@Override
|
||||
public void createSchema(String kid, String modelName) {
|
||||
@@ -98,18 +102,16 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
@Override
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
|
||||
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
|
||||
BaseEmbedModelService model = embeddingModelFactory.createModel(storeEmbeddingBo.getEmbeddingModelId());
|
||||
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||
List<String> fidList = storeEmbeddingBo.getFids();
|
||||
String kid = storeEmbeddingBo.getKid();
|
||||
String docId = storeEmbeddingBo.getDocId();
|
||||
log.info("向量存储条数记录: " + chunkList.size());
|
||||
long startTime = System.currentTimeMillis();
|
||||
for (int i = 0; i < chunkList.size(); i++) {
|
||||
String text = chunkList.get(i);
|
||||
String fid = fidList.get(i);
|
||||
Embedding embedding = embeddingModel.embed(text).content();
|
||||
Embedding embedding = model.embed(text).content();
|
||||
Map<String, Object> properties = Map.of(
|
||||
"text", text,
|
||||
"fid",fid,
|
||||
@@ -137,9 +139,8 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
|
||||
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
BaseEmbedModelService model = embeddingModelFactory.createModel(queryVectorBo.getEmbeddingModelId());
|
||||
Embedding queryEmbedding = model.embed(queryVectorBo.getQuery()).content();
|
||||
float[] vector = queryEmbedding.vector();
|
||||
List<String> vectorStrings = new ArrayList<>();
|
||||
for (float v : vector) {
|
||||
@@ -246,28 +247,4 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
log.error("删除失败: {}", result.getError());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取向量模型
|
||||
*/
|
||||
@SneakyThrows
|
||||
public EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
|
||||
EmbeddingModel embeddingModel;
|
||||
if ("quentinz/bge-large-zh-v1.5".equals(modelName)) {
|
||||
embeddingModel = OllamaEmbeddingModel.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
} else if ("baai/bge-m3".equals(modelName)) {
|
||||
embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
} else {
|
||||
throw new ServiceException("未找到对应向量化模型!");
|
||||
}
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user