mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-03-13 20:53:42 +08:00
feat(embedding): 添加模型维度支持并重构向量存储策略
- 在ChatModelVo中添加dimension字段用于存储模型维度 - 重构EmbeddingModelFactory以支持按模型名称和维度创建实例 - 修改向量存储策略接口参数顺序并统一维度处理 - 为OpenAI和ZhiPuAI嵌入提供者添加维度配置支持 - 优化知识库服务中模型选择逻辑,添加回退机制
This commit is contained in:
@@ -70,6 +70,11 @@ public class ChatModelVo implements Serializable {
|
||||
@ExcelProperty(value = "是否显示")
|
||||
private String modelShow;
|
||||
|
||||
/**
|
||||
* 模型维度
|
||||
*/
|
||||
private Integer dimension;
|
||||
|
||||
/**
|
||||
* 系统提示词
|
||||
*/
|
||||
|
||||
@@ -27,20 +27,23 @@ public class EmbeddingModelFactory {
|
||||
private final IChatModelService chatModelService;
|
||||
|
||||
// 模型缓存,使用ConcurrentHashMap保证线程安全
|
||||
private final Map<Long, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
|
||||
private final Map<String, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* 创建嵌入模型实例
|
||||
* 如果模型已存在于缓存中,则直接返回;否则创建新的实例
|
||||
*
|
||||
* @param embeddingModelId 嵌入模型的唯一标识ID
|
||||
* @return BaseEmbedModelService 嵌入模型服务实例
|
||||
* @param embeddingModelName 嵌入模型名称
|
||||
* @param dimension 模型维度大小
|
||||
*/
|
||||
public BaseEmbedModelService createModel(Long embeddingModelId) {
|
||||
return modelCache.computeIfAbsent(embeddingModelId, id -> {
|
||||
ChatModelVo modelConfig = chatModelService.queryById(id);
|
||||
public BaseEmbedModelService createModel(String embeddingModelName, Integer dimension) {
|
||||
return modelCache.computeIfAbsent(embeddingModelName, name -> {
|
||||
ChatModelVo modelConfig = chatModelService.selectModelByName(embeddingModelName);
|
||||
if (modelConfig == null) {
|
||||
throw new IllegalArgumentException("未找到模型配置,ID=" + id);
|
||||
throw new IllegalArgumentException("未找到模型配置,name=" + name);
|
||||
}
|
||||
if (modelConfig.getDimension() != null) {
|
||||
modelConfig.setDimension(dimension);
|
||||
}
|
||||
return createModelInstance(modelConfig.getProviderName(), modelConfig);
|
||||
});
|
||||
@@ -49,22 +52,22 @@ public class EmbeddingModelFactory {
|
||||
/**
|
||||
* 检查模型是否支持多模态
|
||||
*
|
||||
* @param embeddingModelId 嵌入模型的唯一标识ID
|
||||
* @param embeddingModelName 嵌入模型名称
|
||||
* @return boolean 如果模型支持多模态则返回true,否则返回false
|
||||
*/
|
||||
public boolean isMultimodalModel(Long embeddingModelId) {
|
||||
return createModel(embeddingModelId) instanceof MultiModalEmbedModelService;
|
||||
public boolean isMultimodalModel(String embeddingModelName) {
|
||||
return createModel(embeddingModelName, null) instanceof MultiModalEmbedModelService;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建多模态嵌入模型实例
|
||||
*
|
||||
* @param tenantId 租户ID
|
||||
* @param embeddingModelName 嵌入模型名称
|
||||
* @return MultiModalEmbedModelService 多模态嵌入模型服务实例
|
||||
* @throws IllegalArgumentException 当模型不支持多模态时抛出
|
||||
*/
|
||||
public MultiModalEmbedModelService createMultimodalModel(Long tenantId) {
|
||||
BaseEmbedModelService model = createModel(tenantId);
|
||||
public MultiModalEmbedModelService createMultimodalModel(String embeddingModelName) {
|
||||
BaseEmbedModelService model = createModel(embeddingModelName, null);
|
||||
if (model instanceof MultiModalEmbedModelService) {
|
||||
return (MultiModalEmbedModelService) model;
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ public class OllamaEmbeddingProvider implements BaseEmbedModelService {
|
||||
return Set.of(ModalityType.TEXT);
|
||||
}
|
||||
|
||||
// ollama不能设置embedding维度,使用milvus时请注意!!创建向量表时需要先设定维度大小
|
||||
@Override
|
||||
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
|
||||
return OllamaEmbeddingModel.builder()
|
||||
|
||||
@@ -37,6 +37,7 @@ public class OpenAiEmbeddingProvider implements BaseEmbedModelService {
|
||||
.baseUrl(chatModelVo.getApiHost())
|
||||
.apiKey(chatModelVo.getApiKey())
|
||||
.modelName(chatModelVo.getModelName())
|
||||
.dimensions(chatModelVo.getDimension())
|
||||
.build()
|
||||
.embedAll(textSegments);
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ public class ZhiPuAiEmbeddingProvider implements BaseEmbedModelService {
|
||||
.baseUrl(chatModelVo.getApiHost())
|
||||
.apiKey(chatModelVo.getApiKey())
|
||||
.model(chatModelVo.getModelName())
|
||||
.dimensions(chatModelVo.getDimension())
|
||||
.build()
|
||||
.embedAll(textSegments);
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ public interface VectorStoreService {
|
||||
|
||||
List<String> getQueryVector(QueryVectorBo queryVectorBo);
|
||||
|
||||
void createSchema(String vectorModelName, String kid);
|
||||
void createSchema(String kid, String embeddingModelName);
|
||||
|
||||
void removeById(String id,String modelName) throws ServiceException;
|
||||
|
||||
|
||||
@@ -32,10 +32,9 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createSchema(String vectorModelName, String kid) {
|
||||
log.info("创建向量库schema: vectorModelName={}, kid={}, modelName={}", vectorModelName, kid);
|
||||
public void createSchema(String kid, String modelName) {
|
||||
VectorStoreService strategy = getCurrentStrategy();
|
||||
strategy.createSchema(vectorModelName, kid);
|
||||
strategy.createSchema(kid, modelName);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -10,6 +10,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.config.VectorStoreProperties;
|
||||
import org.ruoyi.common.core.utils.StringUtils;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.ruoyi.embedding.EmbeddingModelFactory;
|
||||
|
||||
/**
|
||||
* 向量库策略抽象基类
|
||||
@@ -23,34 +24,14 @@ public abstract class AbstractVectorStoreStrategy implements VectorStoreService
|
||||
|
||||
protected final VectorStoreProperties vectorStoreProperties;
|
||||
|
||||
private final EmbeddingModelFactory embeddingModelFactory;
|
||||
|
||||
/**
|
||||
* 获取向量模型
|
||||
*/
|
||||
@SneakyThrows
|
||||
protected EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
|
||||
EmbeddingModel embeddingModel;
|
||||
if ("quentinz/bge-large-zh-v1.5".equals(modelName)) {
|
||||
embeddingModel = OllamaEmbeddingModel.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
} else if ("baai/bge-m3".equals(modelName)) {
|
||||
embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
} else if (StringUtils.isNotEmpty(modelName)){
|
||||
embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.dimensions(2048)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
} else {
|
||||
throw new ServiceException("未找到对应向量化模型!");
|
||||
}
|
||||
return embeddingModel;
|
||||
protected EmbeddingModel getEmbeddingModel(String modelName, Integer dimension) {
|
||||
return embeddingModelFactory.createModel(modelName, dimension);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package org.ruoyi.service.strategy.impl;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
@@ -17,22 +17,25 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.config.VectorStoreProperties;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
import org.ruoyi.embedding.EmbeddingModelFactory;
|
||||
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.IntStream;
|
||||
// 新增导入
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties) {
|
||||
super(vectorStoreProperties);
|
||||
|
||||
private final Integer DIMENSION = 2048;
|
||||
|
||||
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
|
||||
super(vectorStoreProperties, embeddingModelFactory);
|
||||
}
|
||||
|
||||
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
|
||||
@@ -44,7 +47,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
MilvusEmbeddingStore.builder()
|
||||
.uri(vectorStoreProperties.getMilvus().getUrl())
|
||||
.collectionName(collectionName)
|
||||
.dimension(2048)
|
||||
.dimension(DIMENSION)
|
||||
.indexType(IndexType.IVF_FLAT)
|
||||
.metricType(MetricType.L2)
|
||||
.autoFlushOnInsert(autoFlushOnInsert)
|
||||
@@ -57,12 +60,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getVectorStoreType() {
|
||||
return "milvus";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createSchema(String vectorModelName, String kid) {
|
||||
public void createSchema(String kid, String modelName) {
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
|
||||
// 使用缓存获取连接以确保只初始化一次
|
||||
EmbeddingStore<TextSegment> store = getMilvusStore(collectionName, true);
|
||||
@@ -71,8 +69,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
|
||||
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), DIMENSION);
|
||||
|
||||
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||
List<String> fidList = storeEmbeddingBo.getFids();
|
||||
@@ -104,8 +101,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
|
||||
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), DIMENSION);
|
||||
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
|
||||
@@ -153,4 +149,9 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
embeddingStore.removeAll(filter);
|
||||
log.info("Milvus成功删除 fid={} 的所有向量数据", fid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getVectorStoreType() {
|
||||
return "milvus";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.config.VectorStoreProperties;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
import org.ruoyi.embedding.EmbeddingModelFactory;
|
||||
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
|
||||
import org.springframework.stereotype.Component;
|
||||
import java.util.*;
|
||||
@@ -35,8 +36,8 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
private WeaviateClient client;
|
||||
|
||||
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties) {
|
||||
super(vectorStoreProperties);
|
||||
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
|
||||
super(vectorStoreProperties, embeddingModelFactory);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -45,7 +46,7 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createSchema(String vectorModelName, String kid) {
|
||||
public void createSchema(String kid, String embeddingModelName) {
|
||||
String protocol = vectorStoreProperties.getWeaviate().getProtocol();
|
||||
String host = vectorStoreProperties.getWeaviate().getHost();
|
||||
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
|
||||
@@ -84,9 +85,8 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
createSchema(storeEmbeddingBo.getVectorStoreName(), storeEmbeddingBo.getKid());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
|
||||
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
|
||||
createSchema(storeEmbeddingBo.getKid(),storeEmbeddingBo.getEmbeddingModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), null);
|
||||
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||
List<String> fidList = storeEmbeddingBo.getFids();
|
||||
String kid = storeEmbeddingBo.getKid();
|
||||
@@ -118,9 +118,8 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
createSchema(queryVectorBo.getVectorModelName(), queryVectorBo.getKid());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
|
||||
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
|
||||
createSchema(queryVectorBo.getKid(),queryVectorBo.getEmbeddingModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),null);
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
float[] vector = queryEmbedding.vector();
|
||||
List<String> vectorStrings = new ArrayList<>();
|
||||
|
||||
@@ -9,6 +9,7 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.ruoyi.chain.loader.ResourceLoader;
|
||||
import org.ruoyi.chain.loader.ResourceLoaderFactory;
|
||||
import org.ruoyi.chat.enums.ChatModeType;
|
||||
import org.ruoyi.common.core.domain.model.LoginUser;
|
||||
import org.ruoyi.common.core.utils.MapstructUtils;
|
||||
import org.ruoyi.common.core.utils.StringUtils;
|
||||
@@ -237,7 +238,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
}
|
||||
baseMapper.insert(knowledgeInfo);
|
||||
if (knowledgeInfo != null) {
|
||||
vectorStoreService.createSchema(knowledgeInfo.getVectorModelName(),String.valueOf(knowledgeInfo.getId()));
|
||||
vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()), bo.getEmbeddingModelName());
|
||||
}
|
||||
} else {
|
||||
baseMapper.updateById(knowledgeInfo);
|
||||
@@ -312,8 +313,11 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
.eq(KnowledgeInfo::getId, kid));
|
||||
|
||||
// 通过向量模型查询模型信息
|
||||
ChatModelVo chatModelVo = chatModelService.queryById(knowledgeInfoVo.getEmbeddingModelId());
|
||||
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
|
||||
// 未查到指定模型时,回退为向量分类最高优先级模型
|
||||
if (chatModelVo == null) {
|
||||
chatModelVo = chatModelService.selectModelByCategoryWithHighestPriority(ChatModeType.VECTOR.getCode());
|
||||
}
|
||||
StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
|
||||
storeEmbeddingBo.setKid(kid);
|
||||
storeEmbeddingBo.setDocId(docId);
|
||||
|
||||
Reference in New Issue
Block a user