feat(embedding): 添加模型维度支持并重构向量存储策略

- 在ChatModelVo中添加dimension字段用于存储模型维度
- 重构EmbeddingModelFactory以支持按模型名称和维度创建实例
- 修改向量存储策略接口参数顺序并统一维度处理
- 为OpenAI和ZhiPuAI嵌入提供者添加维度配置支持
- 优化知识库服务中模型选择逻辑,添加回退机制
This commit is contained in:
Yzm
2025-10-17 19:55:24 +08:00
parent 9d4a0e0b36
commit e242a67c74
11 changed files with 64 additions and 69 deletions

View File

@@ -70,6 +70,11 @@ public class ChatModelVo implements Serializable {
@ExcelProperty(value = "是否显示")
private String modelShow;
/**
* 模型维度
*/
private Integer dimension;
/**
* 系统提示词
*/

View File

@@ -27,20 +27,23 @@ public class EmbeddingModelFactory {
private final IChatModelService chatModelService;
// 模型缓存使用ConcurrentHashMap保证线程安全
private final Map<Long, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
private final Map<String, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
/**
* 创建嵌入模型实例
* 如果模型已存在于缓存中,则直接返回;否则创建新的实例
*
* @param embeddingModelId 嵌入模型的唯一标识ID
* @return BaseEmbedModelService 嵌入模型服务实例
* @param embeddingModelName 嵌入模型名称
* @param dimension 模型维度大小
*/
public BaseEmbedModelService createModel(Long embeddingModelId) {
return modelCache.computeIfAbsent(embeddingModelId, id -> {
ChatModelVo modelConfig = chatModelService.queryById(id);
public BaseEmbedModelService createModel(String embeddingModelName, Integer dimension) {
return modelCache.computeIfAbsent(embeddingModelName, name -> {
ChatModelVo modelConfig = chatModelService.selectModelByName(embeddingModelName);
if (modelConfig == null) {
throw new IllegalArgumentException("未找到模型配置,ID=" + id);
throw new IllegalArgumentException("未找到模型配置,name=" + name);
}
if (modelConfig.getDimension() != null) {
modelConfig.setDimension(dimension);
}
return createModelInstance(modelConfig.getProviderName(), modelConfig);
});
@@ -49,22 +52,22 @@ public class EmbeddingModelFactory {
/**
* 检查模型是否支持多模态
*
* @param embeddingModelId 嵌入模型的唯一标识ID
* @param embeddingModelName 嵌入模型名称
* @return boolean 如果模型支持多模态则返回true否则返回false
*/
public boolean isMultimodalModel(Long embeddingModelId) {
return createModel(embeddingModelId) instanceof MultiModalEmbedModelService;
public boolean isMultimodalModel(String embeddingModelName) {
return createModel(embeddingModelName, null) instanceof MultiModalEmbedModelService;
}
/**
* 创建多模态嵌入模型实例
*
* @param tenantId 租户ID
* @param embeddingModelName 嵌入模型名称
* @return MultiModalEmbedModelService 多模态嵌入模型服务实例
* @throws IllegalArgumentException 当模型不支持多模态时抛出
*/
public MultiModalEmbedModelService createMultimodalModel(Long tenantId) {
BaseEmbedModelService model = createModel(tenantId);
public MultiModalEmbedModelService createMultimodalModel(String embeddingModelName) {
BaseEmbedModelService model = createModel(embeddingModelName, null);
if (model instanceof MultiModalEmbedModelService) {
return (MultiModalEmbedModelService) model;
}

View File

@@ -30,6 +30,7 @@ public class OllamaEmbeddingProvider implements BaseEmbedModelService {
return Set.of(ModalityType.TEXT);
}
// ollama不能设置embedding维度使用milvus时请注意创建向量表时需要先设定维度大小
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
return OllamaEmbeddingModel.builder()

View File

@@ -37,6 +37,7 @@ public class OpenAiEmbeddingProvider implements BaseEmbedModelService {
.baseUrl(chatModelVo.getApiHost())
.apiKey(chatModelVo.getApiKey())
.modelName(chatModelVo.getModelName())
.dimensions(chatModelVo.getDimension())
.build()
.embedAll(textSegments);
}

View File

@@ -37,6 +37,7 @@ public class ZhiPuAiEmbeddingProvider implements BaseEmbedModelService {
.baseUrl(chatModelVo.getApiHost())
.apiKey(chatModelVo.getApiKey())
.model(chatModelVo.getModelName())
.dimensions(chatModelVo.getDimension())
.build()
.embedAll(textSegments);
}

View File

@@ -16,7 +16,7 @@ public interface VectorStoreService {
List<String> getQueryVector(QueryVectorBo queryVectorBo);
void createSchema(String vectorModelName, String kid);
void createSchema(String kid, String embeddingModelName);
void removeById(String id,String modelName) throws ServiceException;

View File

@@ -32,10 +32,9 @@ public class VectorStoreServiceImpl implements VectorStoreService {
}
@Override
public void createSchema(String vectorModelName, String kid) {
log.info("创建向量库schema: vectorModelName={}, kid={}, modelName={}", vectorModelName, kid);
public void createSchema(String kid, String modelName) {
VectorStoreService strategy = getCurrentStrategy();
strategy.createSchema(vectorModelName, kid);
strategy.createSchema(kid, modelName);
}
@Override

View File

@@ -10,6 +10,7 @@ import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.service.VectorStoreService;
import org.ruoyi.embedding.EmbeddingModelFactory;
/**
* 向量库策略抽象基类
@@ -23,34 +24,14 @@ public abstract class AbstractVectorStoreStrategy implements VectorStoreService
protected final VectorStoreProperties vectorStoreProperties;
private final EmbeddingModelFactory embeddingModelFactory;
/**
* 获取向量模型
*/
@SneakyThrows
protected EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
EmbeddingModel embeddingModel;
if ("quentinz/bge-large-zh-v1.5".equals(modelName)) {
embeddingModel = OllamaEmbeddingModel.builder()
.baseUrl(baseUrl)
.modelName(modelName)
.build();
} else if ("baai/bge-m3".equals(modelName)) {
embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.modelName(modelName)
.build();
} else if (StringUtils.isNotEmpty(modelName)){
embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.dimensions(2048)
.modelName(modelName)
.build();
} else {
throw new ServiceException("未找到对应向量化模型!");
}
return embeddingModel;
protected EmbeddingModel getEmbeddingModel(String modelName, Integer dimension) {
return embeddingModelFactory.createModel(modelName, dimension);
}
/**

View File

@@ -1,8 +1,8 @@
package org.ruoyi.service.strategy.impl;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
@@ -17,22 +17,25 @@ import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.embedding.EmbeddingModelFactory;
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;
// 新增导入
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.IntStream;
@Slf4j
@Component
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties) {
super(vectorStoreProperties);
private final Integer DIMENSION = 2048;
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
super(vectorStoreProperties, embeddingModelFactory);
}
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
@@ -44,7 +47,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
MilvusEmbeddingStore.builder()
.uri(vectorStoreProperties.getMilvus().getUrl())
.collectionName(collectionName)
.dimension(2048)
.dimension(DIMENSION)
.indexType(IndexType.IVF_FLAT)
.metricType(MetricType.L2)
.autoFlushOnInsert(autoFlushOnInsert)
@@ -57,12 +60,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
}
@Override
public String getVectorStoreType() {
return "milvus";
}
@Override
public void createSchema(String vectorModelName, String kid) {
public void createSchema(String kid, String modelName) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
// 使用缓存获取连接以确保只初始化一次
EmbeddingStore<TextSegment> store = getMilvusStore(collectionName, true);
@@ -71,8 +69,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), DIMENSION);
List<String> chunkList = storeEmbeddingBo.getChunkList();
List<String> fidList = storeEmbeddingBo.getFids();
@@ -104,8 +101,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), DIMENSION);
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
@@ -153,4 +149,9 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
embeddingStore.removeAll(filter);
log.info("Milvus成功删除 fid={} 的所有向量数据", fid);
}
@Override
public String getVectorStoreType() {
return "milvus";
}
}

View File

@@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.embedding.EmbeddingModelFactory;
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
import org.springframework.stereotype.Component;
import java.util.*;
@@ -35,8 +36,8 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
private WeaviateClient client;
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties) {
super(vectorStoreProperties);
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
super(vectorStoreProperties, embeddingModelFactory);
}
@Override
@@ -45,7 +46,7 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
}
@Override
public void createSchema(String vectorModelName, String kid) {
public void createSchema(String kid, String embeddingModelName) {
String protocol = vectorStoreProperties.getWeaviate().getProtocol();
String host = vectorStoreProperties.getWeaviate().getHost();
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
@@ -84,9 +85,8 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
createSchema(storeEmbeddingBo.getVectorStoreName(), storeEmbeddingBo.getKid());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
createSchema(storeEmbeddingBo.getKid(),storeEmbeddingBo.getEmbeddingModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), null);
List<String> chunkList = storeEmbeddingBo.getChunkList();
List<String> fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
@@ -118,9 +118,8 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
createSchema(queryVectorBo.getVectorModelName(), queryVectorBo.getKid());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
createSchema(queryVectorBo.getKid(),queryVectorBo.getEmbeddingModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),null);
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
float[] vector = queryEmbedding.vector();
List<String> vectorStrings = new ArrayList<>();

View File

@@ -9,6 +9,7 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import lombok.RequiredArgsConstructor;
import org.ruoyi.chain.loader.ResourceLoader;
import org.ruoyi.chain.loader.ResourceLoaderFactory;
import org.ruoyi.chat.enums.ChatModeType;
import org.ruoyi.common.core.domain.model.LoginUser;
import org.ruoyi.common.core.utils.MapstructUtils;
import org.ruoyi.common.core.utils.StringUtils;
@@ -237,7 +238,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
}
baseMapper.insert(knowledgeInfo);
if (knowledgeInfo != null) {
vectorStoreService.createSchema(knowledgeInfo.getVectorModelName(),String.valueOf(knowledgeInfo.getId()));
vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()), bo.getEmbeddingModelName());
}
} else {
baseMapper.updateById(knowledgeInfo);
@@ -312,8 +313,11 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
.eq(KnowledgeInfo::getId, kid));
// 通过向量模型查询模型信息
ChatModelVo chatModelVo = chatModelService.queryById(knowledgeInfoVo.getEmbeddingModelId());
ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
// 未查到指定模型时,回退为向量分类最高优先级模型
if (chatModelVo == null) {
chatModelVo = chatModelService.selectModelByCategoryWithHighestPriority(ChatModeType.VECTOR.getCode());
}
StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
storeEmbeddingBo.setKid(kid);
storeEmbeddingBo.setDocId(docId);