perf: 优化‘嵌入模型’工厂,添加缓存机制

This commit is contained in:
Robust_H
2025-10-04 18:27:54 +08:00
parent b47da3f438
commit 2cef4e17dc

View File

@@ -1,6 +1,7 @@
package org.ruoyi.embedding;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatModelService;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
@@ -18,38 +19,50 @@ import java.util.concurrent.ConcurrentHashMap;
*/
@Service
@RequiredArgsConstructor
@Slf4j
public class EmbeddingModelFactory {
private final ApplicationContext applicationContext;
private final IChatModelService iChatModelService;
private final IChatModelService chatModelService;
private final Map<String, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
// 模型缓存,使用ConcurrentHashMap保证线程安全
private final Map<Long, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
/**
* 创建嵌入模型实例
* 如果模型已存在于缓存中,则直接返回;否则创建新的实例
*
* @param embeddingModelId 嵌入模型的唯一标识ID
* @return BaseEmbedModelService 嵌入模型服务实例
*/
public BaseEmbedModelService createModel(Long embeddingModelId) {
ChatModelVo chatModelVo = iChatModelService.queryById(embeddingModelId);
return createModelInstance(chatModelVo.getProviderName(), chatModelVo);
return modelCache.computeIfAbsent(embeddingModelId, id -> {
ChatModelVo modelConfig = chatModelService.queryById(id);
if (modelConfig == null) {
throw new IllegalArgumentException("未找到模型配置ID=" + id);
}
return createModelInstance(modelConfig.getProviderName(), modelConfig);
});
}
private BaseEmbedModelService createModelInstance(String factory, ChatModelVo config) {
try {
BaseEmbedModelService model = applicationContext.getBean(factory, BaseEmbedModelService.class);
// TODO 缓存设置
model.configure(config);
return model;
} catch (NoSuchBeanDefinitionException e) {
throw new IllegalArgumentException("获取不到嵌入模型: " + factory, e);
}
/**
* 检查模型是否支持多模态
*
* @param embeddingModelId 嵌入模型的唯一标识ID
* @return boolean 如果模型支持多模态则返回true否则返回false
*/
public boolean isMultimodalModel(Long embeddingModelId) {
return createModel(embeddingModelId) instanceof MultiModalEmbedModelService;
}
// 检查模型是否支持多模态
public boolean isMultimodalModel(Long tenantId) {
BaseEmbedModelService model = createModel(tenantId);
return model instanceof MultiModalEmbedModelService;
}
// 获取多模态模型(如果支持)
/**
* 创建多模态嵌入模型实例
*
* @param tenantId 租户ID
* @return MultiModalEmbedModelService 多模态嵌入模型服务实例
* @throws IllegalArgumentException 当模型不支持多模态时抛出
*/
public MultiModalEmbedModelService createMultimodalModel(Long tenantId) {
BaseEmbedModelService model = createModel(tenantId);
if (model instanceof MultiModalEmbedModelService) {
@@ -58,13 +71,47 @@ public class EmbeddingModelFactory {
throw new IllegalArgumentException("该模型不支持多模态");
}
public void refreshModel(String tenantId, String factory) {
String cacheKey = tenantId + ":" + factory;
modelCache.remove(cacheKey);
/**
* 刷新模型缓存
* 根据给定的嵌入模型ID从缓存中移除对应的模型
*
* @param embeddingModelId 嵌入模型的唯一标识ID
*/
public void refreshModel(Long embeddingModelId) {
// 从模型缓存中移除指定ID的模型
modelCache.remove(embeddingModelId);
}
/**
* 获取所有支持模型工厂的列表
*
* @return List<String> 支持的模型工厂名称列表
*/
public List<String> getSupportedFactories() {
return new ArrayList<>(applicationContext.getBeansOfType(BaseEmbedModelService.class)
.keySet());
}
/**
* 创建具体的模型实例
* 根据提供的工厂名称和配置信息创建并配置模型实例
*
* @param factory 工厂名称,用于标识模型类型
* @param config 模型配置信息
* @return BaseEmbedModelService 配置好的模型实例
* @throws IllegalArgumentException 当无法获取指定的模型实例时抛出
*/
private BaseEmbedModelService createModelInstance(String factory, ChatModelVo config) {
try {
// 从Spring上下文中获取模型实例
BaseEmbedModelService model = applicationContext.getBean(factory, BaseEmbedModelService.class);
// 配置模型参数
model.configure(config);
log.info("成功创建嵌入模型: factory={}, modelId={}", config.getProviderName(), config.getId());
return model;
} catch (NoSuchBeanDefinitionException e) {
throw new IllegalArgumentException("获取不到嵌入模型: " + factory, e);
}
}
}