mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-07 08:47:32 +00:00
perf: 优化‘嵌入模型’工厂,添加缓存机制
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package org.ruoyi.embedding;
|
package org.ruoyi.embedding;
|
||||||
|
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.ruoyi.domain.vo.ChatModelVo;
|
import org.ruoyi.domain.vo.ChatModelVo;
|
||||||
import org.ruoyi.service.IChatModelService;
|
import org.ruoyi.service.IChatModelService;
|
||||||
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
|
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
|
||||||
@@ -18,38 +19,50 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||||||
*/
|
*/
|
||||||
@Service
|
@Service
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
|
@Slf4j
|
||||||
public class EmbeddingModelFactory {
|
public class EmbeddingModelFactory {
|
||||||
|
|
||||||
private final ApplicationContext applicationContext;
|
private final ApplicationContext applicationContext;
|
||||||
|
|
||||||
private final IChatModelService iChatModelService;
|
private final IChatModelService chatModelService;
|
||||||
|
|
||||||
private final Map<String, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
|
// 模型缓存,使用ConcurrentHashMap保证线程安全
|
||||||
|
private final Map<Long, BaseEmbedModelService> modelCache = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建嵌入模型实例
|
||||||
|
* 如果模型已存在于缓存中,则直接返回;否则创建新的实例
|
||||||
|
*
|
||||||
|
* @param embeddingModelId 嵌入模型的唯一标识ID
|
||||||
|
* @return BaseEmbedModelService 嵌入模型服务实例
|
||||||
|
*/
|
||||||
public BaseEmbedModelService createModel(Long embeddingModelId) {
|
public BaseEmbedModelService createModel(Long embeddingModelId) {
|
||||||
ChatModelVo chatModelVo = iChatModelService.queryById(embeddingModelId);
|
return modelCache.computeIfAbsent(embeddingModelId, id -> {
|
||||||
|
ChatModelVo modelConfig = chatModelService.queryById(id);
|
||||||
return createModelInstance(chatModelVo.getProviderName(), chatModelVo);
|
if (modelConfig == null) {
|
||||||
|
throw new IllegalArgumentException("未找到模型配置,ID=" + id);
|
||||||
|
}
|
||||||
|
return createModelInstance(modelConfig.getProviderName(), modelConfig);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private BaseEmbedModelService createModelInstance(String factory, ChatModelVo config) {
|
/**
|
||||||
try {
|
* 检查模型是否支持多模态
|
||||||
BaseEmbedModelService model = applicationContext.getBean(factory, BaseEmbedModelService.class);
|
*
|
||||||
// TODO 缓存设置
|
* @param embeddingModelId 嵌入模型的唯一标识ID
|
||||||
model.configure(config);
|
* @return boolean 如果模型支持多模态则返回true,否则返回false
|
||||||
|
*/
|
||||||
return model;
|
public boolean isMultimodalModel(Long embeddingModelId) {
|
||||||
} catch (NoSuchBeanDefinitionException e) {
|
return createModel(embeddingModelId) instanceof MultiModalEmbedModelService;
|
||||||
throw new IllegalArgumentException("获取不到嵌入模型: " + factory, e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查模型是否支持多模态
|
/**
|
||||||
public boolean isMultimodalModel(Long tenantId) {
|
* 创建多模态嵌入模型实例
|
||||||
BaseEmbedModelService model = createModel(tenantId);
|
*
|
||||||
return model instanceof MultiModalEmbedModelService;
|
* @param tenantId 租户ID
|
||||||
}
|
* @return MultiModalEmbedModelService 多模态嵌入模型服务实例
|
||||||
|
* @throws IllegalArgumentException 当模型不支持多模态时抛出
|
||||||
// 获取多模态模型(如果支持)
|
*/
|
||||||
public MultiModalEmbedModelService createMultimodalModel(Long tenantId) {
|
public MultiModalEmbedModelService createMultimodalModel(Long tenantId) {
|
||||||
BaseEmbedModelService model = createModel(tenantId);
|
BaseEmbedModelService model = createModel(tenantId);
|
||||||
if (model instanceof MultiModalEmbedModelService) {
|
if (model instanceof MultiModalEmbedModelService) {
|
||||||
@@ -58,13 +71,47 @@ public class EmbeddingModelFactory {
|
|||||||
throw new IllegalArgumentException("该模型不支持多模态");
|
throw new IllegalArgumentException("该模型不支持多模态");
|
||||||
}
|
}
|
||||||
|
|
||||||
public void refreshModel(String tenantId, String factory) {
|
/**
|
||||||
String cacheKey = tenantId + ":" + factory;
|
* 刷新模型缓存
|
||||||
modelCache.remove(cacheKey);
|
* 根据给定的嵌入模型ID从缓存中移除对应的模型
|
||||||
|
*
|
||||||
|
* @param embeddingModelId 嵌入模型的唯一标识ID
|
||||||
|
*/
|
||||||
|
public void refreshModel(Long embeddingModelId) {
|
||||||
|
// 从模型缓存中移除指定ID的模型
|
||||||
|
modelCache.remove(embeddingModelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有支持模型工厂的列表
|
||||||
|
*
|
||||||
|
* @return List<String> 支持的模型工厂名称列表
|
||||||
|
*/
|
||||||
public List<String> getSupportedFactories() {
|
public List<String> getSupportedFactories() {
|
||||||
return new ArrayList<>(applicationContext.getBeansOfType(BaseEmbedModelService.class)
|
return new ArrayList<>(applicationContext.getBeansOfType(BaseEmbedModelService.class)
|
||||||
.keySet());
|
.keySet());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建具体的模型实例
|
||||||
|
* 根据提供的工厂名称和配置信息创建并配置模型实例
|
||||||
|
*
|
||||||
|
* @param factory 工厂名称,用于标识模型类型
|
||||||
|
* @param config 模型配置信息
|
||||||
|
* @return BaseEmbedModelService 配置好的模型实例
|
||||||
|
* @throws IllegalArgumentException 当无法获取指定的模型实例时抛出
|
||||||
|
*/
|
||||||
|
private BaseEmbedModelService createModelInstance(String factory, ChatModelVo config) {
|
||||||
|
try {
|
||||||
|
// 从Spring上下文中获取模型实例
|
||||||
|
BaseEmbedModelService model = applicationContext.getBean(factory, BaseEmbedModelService.class);
|
||||||
|
// 配置模型参数
|
||||||
|
model.configure(config);
|
||||||
|
log.info("成功创建嵌入模型: factory={}, modelId={}", config.getProviderName(), config.getId());
|
||||||
|
|
||||||
|
return model;
|
||||||
|
} catch (NoSuchBeanDefinitionException e) {
|
||||||
|
throw new IllegalArgumentException("获取不到嵌入模型: " + factory, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user