perf: 优化‘嵌入模型’工厂,添加缓存机制

This commit is contained in:
Robust_H
2025-10-04 18:27:54 +08:00
parent b47da3f438
commit 2cef4e17dc

View File

@@ -1,6 +1,7 @@
package org.ruoyi.embedding; package org.ruoyi.embedding;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.domain.vo.ChatModelVo; import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatModelService; import org.ruoyi.service.IChatModelService;
import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.NoSuchBeanDefinitionException;
@@ -18,38 +19,50 @@ import java.util.concurrent.ConcurrentHashMap;
*/ */
@Service @Service
@RequiredArgsConstructor @RequiredArgsConstructor
@Slf4j
public class EmbeddingModelFactory { public class EmbeddingModelFactory {
private final ApplicationContext applicationContext; private final ApplicationContext applicationContext;
private final IChatModelService iChatModelService; private final IChatModelService chatModelService;
private final Map<String, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>(); // 模型缓存,使用ConcurrentHashMap保证线程安全
private final Map<Long, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
/**
* 创建嵌入模型实例
* 如果模型已存在于缓存中,则直接返回;否则创建新的实例
*
* @param embeddingModelId 嵌入模型的唯一标识ID
* @return BaseEmbedModelService 嵌入模型服务实例
*/
public BaseEmbedModelService createModel(Long embeddingModelId) { public BaseEmbedModelService createModel(Long embeddingModelId) {
ChatModelVo chatModelVo = iChatModelService.queryById(embeddingModelId); return modelCache.computeIfAbsent(embeddingModelId, id -> {
ChatModelVo modelConfig = chatModelService.queryById(id);
return createModelInstance(chatModelVo.getProviderName(), chatModelVo); if (modelConfig == null) {
throw new IllegalArgumentException("未找到模型配置ID=" + id);
}
return createModelInstance(modelConfig.getProviderName(), modelConfig);
});
} }
private BaseEmbedModelService createModelInstance(String factory, ChatModelVo config) { /**
try { * 检查模型是否支持多模态
BaseEmbedModelService model = applicationContext.getBean(factory, BaseEmbedModelService.class); *
// TODO 缓存设置 * @param embeddingModelId 嵌入模型的唯一标识ID
model.configure(config); * @return boolean 如果模型支持多模态则返回true否则返回false
*/
return model; public boolean isMultimodalModel(Long embeddingModelId) {
} catch (NoSuchBeanDefinitionException e) { return createModel(embeddingModelId) instanceof MultiModalEmbedModelService;
throw new IllegalArgumentException("获取不到嵌入模型: " + factory, e);
}
} }
// 检查模型是否支持多模态 /**
public boolean isMultimodalModel(Long tenantId) { * 创建多模态嵌入模型实例
BaseEmbedModelService model = createModel(tenantId); *
return model instanceof MultiModalEmbedModelService; * @param tenantId 租户ID
} * @return MultiModalEmbedModelService 多模态嵌入模型服务实例
* @throws IllegalArgumentException 当模型不支持多模态时抛出
// 获取多模态模型(如果支持) */
public MultiModalEmbedModelService createMultimodalModel(Long tenantId) { public MultiModalEmbedModelService createMultimodalModel(Long tenantId) {
BaseEmbedModelService model = createModel(tenantId); BaseEmbedModelService model = createModel(tenantId);
if (model instanceof MultiModalEmbedModelService) { if (model instanceof MultiModalEmbedModelService) {
@@ -58,13 +71,47 @@ public class EmbeddingModelFactory {
throw new IllegalArgumentException("该模型不支持多模态"); throw new IllegalArgumentException("该模型不支持多模态");
} }
public void refreshModel(String tenantId, String factory) { /**
String cacheKey = tenantId + ":" + factory; * 刷新模型缓存
modelCache.remove(cacheKey); * 根据给定的嵌入模型ID从缓存中移除对应的模型
*
* @param embeddingModelId 嵌入模型的唯一标识ID
*/
public void refreshModel(Long embeddingModelId) {
// 从模型缓存中移除指定ID的模型
modelCache.remove(embeddingModelId);
} }
/**
* 获取所有支持模型工厂的列表
*
* @return List<String> 支持的模型工厂名称列表
*/
public List<String> getSupportedFactories() { public List<String> getSupportedFactories() {
return new ArrayList<>(applicationContext.getBeansOfType(BaseEmbedModelService.class) return new ArrayList<>(applicationContext.getBeansOfType(BaseEmbedModelService.class)
.keySet()); .keySet());
} }
/**
* 创建具体的模型实例
* 根据提供的工厂名称和配置信息创建并配置模型实例
*
* @param factory 工厂名称,用于标识模型类型
* @param config 模型配置信息
* @return BaseEmbedModelService 配置好的模型实例
* @throws IllegalArgumentException 当无法获取指定的模型实例时抛出
*/
private BaseEmbedModelService createModelInstance(String factory, ChatModelVo config) {
try {
// 从Spring上下文中获取模型实例
BaseEmbedModelService model = applicationContext.getBean(factory, BaseEmbedModelService.class);
// 配置模型参数
model.configure(config);
log.info("成功创建嵌入模型: factory={}, modelId={}", config.getProviderName(), config.getId());
return model;
} catch (NoSuchBeanDefinitionException e) {
throw new IllegalArgumentException("获取不到嵌入模型: " + factory, e);
}
}
} }