diff --git a/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/VectorStoreProperties.java b/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/VectorStoreProperties.java
index 98f3ddc4..b4bb135d 100644
--- a/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/VectorStoreProperties.java
+++ b/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/VectorStoreProperties.java
@@ -17,7 +17,7 @@ public class VectorStoreProperties {
/**
* 向量库类型
*/
- private String type = "weaviate";
+ private String type;
/**
* Weaviate配置
@@ -34,17 +34,17 @@ public class VectorStoreProperties {
/**
* 协议
*/
- private String protocol = "http";
+ private String protocol;
/**
* 主机地址
*/
- private String host = "localhost:8080";
+ private String host;
/**
* 类名
*/
- private String classname = "Document";
+ private String classname;
}
@Data
@@ -52,11 +52,11 @@ public class VectorStoreProperties {
/**
* 连接URL
*/
- private String url = "http://localhost:19530";
+ private String url;
/**
* 集合名称
*/
- private String collectionname = "knowledge_base";
+ private String collectionname;
}
}
\ No newline at end of file
diff --git a/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/vo/ChatModelVo.java b/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/vo/ChatModelVo.java
index 6a2de3cf..062a378a 100644
--- a/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/vo/ChatModelVo.java
+++ b/ruoyi-modules-api/ruoyi-chat-api/src/main/java/org/ruoyi/domain/vo/ChatModelVo.java
@@ -70,6 +70,11 @@ public class ChatModelVo implements Serializable {
@ExcelProperty(value = "是否显示")
private String modelShow;
+ /**
+ * 模型维度
+ */
+ private Integer dimension;
+
/**
* 系统提示词
*/
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml b/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml
index 90933a7d..11360516 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml
@@ -80,6 +80,12 @@
2.6.4
+
+
+ dev.langchain4j
+ langchain4j-milvus
+
+
dev.langchain4j
langchain4j-open-ai
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/StoreEmbeddingBo.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/StoreEmbeddingBo.java
index 2b87ce05..eedfe4d6 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/StoreEmbeddingBo.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/StoreEmbeddingBo.java
@@ -32,9 +32,9 @@ public class StoreEmbeddingBo {
private List fids;
/**
- * 向量库模型名称
+ * 向量库名称
*/
- private String vectorModelName;
+ private String vectorStoreName;
/**
* 向量化模型id
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/EmbeddingModelFactory.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/EmbeddingModelFactory.java
index cf0c7b60..3acf8a91 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/EmbeddingModelFactory.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/EmbeddingModelFactory.java
@@ -27,20 +27,23 @@ public class EmbeddingModelFactory {
private final IChatModelService chatModelService;
// 模型缓存,使用ConcurrentHashMap保证线程安全
- private final Map modelCache = new ConcurrentHashMap<>();
+ private final Map modelCache = new ConcurrentHashMap<>();
/**
* 创建嵌入模型实例
* 如果模型已存在于缓存中,则直接返回;否则创建新的实例
*
- * @param embeddingModelId 嵌入模型的唯一标识ID
- * @return BaseEmbedModelService 嵌入模型服务实例
+ * @param embeddingModelName 嵌入模型名称
+ * @param dimension 模型维度大小
*/
- public BaseEmbedModelService createModel(Long embeddingModelId) {
- return modelCache.computeIfAbsent(embeddingModelId, id -> {
- ChatModelVo modelConfig = chatModelService.queryById(id);
+ public BaseEmbedModelService createModel(String embeddingModelName, Integer dimension) {
+ return modelCache.computeIfAbsent(embeddingModelName, name -> {
+ ChatModelVo modelConfig = chatModelService.selectModelByName(embeddingModelName);
if (modelConfig == null) {
- throw new IllegalArgumentException("未找到模型配置,ID=" + id);
+ throw new IllegalArgumentException("未找到模型配置,name=" + name);
+ }
+ if (modelConfig.getDimension() != null) {
+ modelConfig.setDimension(dimension);
}
return createModelInstance(modelConfig.getProviderName(), modelConfig);
});
@@ -49,22 +52,22 @@ public class EmbeddingModelFactory {
/**
* 检查模型是否支持多模态
*
- * @param embeddingModelId 嵌入模型的唯一标识ID
+ * @param embeddingModelName 嵌入模型名称
* @return boolean 如果模型支持多模态则返回true,否则返回false
*/
- public boolean isMultimodalModel(Long embeddingModelId) {
- return createModel(embeddingModelId) instanceof MultiModalEmbedModelService;
+ public boolean isMultimodalModel(String embeddingModelName) {
+ return createModel(embeddingModelName, null) instanceof MultiModalEmbedModelService;
}
/**
* 创建多模态嵌入模型实例
*
- * @param tenantId 租户ID
+ * @param embeddingModelName 嵌入模型名称
* @return MultiModalEmbedModelService 多模态嵌入模型服务实例
* @throws IllegalArgumentException 当模型不支持多模态时抛出
*/
- public MultiModalEmbedModelService createMultimodalModel(Long tenantId) {
- BaseEmbedModelService model = createModel(tenantId);
+ public MultiModalEmbedModelService createMultimodalModel(String embeddingModelName) {
+ BaseEmbedModelService model = createModel(embeddingModelName, null);
if (model instanceof MultiModalEmbedModelService) {
return (MultiModalEmbedModelService) model;
}
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/OllamaEmbeddingProvider.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/OllamaEmbeddingProvider.java
index 17ed798c..1a179be7 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/OllamaEmbeddingProvider.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/OllamaEmbeddingProvider.java
@@ -30,6 +30,7 @@ public class OllamaEmbeddingProvider implements BaseEmbedModelService {
return Set.of(ModalityType.TEXT);
}
+ // ollama不能设置embedding维度,使用milvus时请注意!!创建向量表时需要先设定维度大小
@Override
public Response> embedAll(List textSegments) {
return OllamaEmbeddingModel.builder()
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/OpenAiEmbeddingProvider.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/OpenAiEmbeddingProvider.java
index e58bfe46..8a0c9f62 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/OpenAiEmbeddingProvider.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/OpenAiEmbeddingProvider.java
@@ -37,6 +37,7 @@ public class OpenAiEmbeddingProvider implements BaseEmbedModelService {
.baseUrl(chatModelVo.getApiHost())
.apiKey(chatModelVo.getApiKey())
.modelName(chatModelVo.getModelName())
+ .dimensions(chatModelVo.getDimension())
.build()
.embedAll(textSegments);
}
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/ZhiPuAiEmbeddingProvider.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/ZhiPuAiEmbeddingProvider.java
index e221749a..621e47c6 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/ZhiPuAiEmbeddingProvider.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/embedding/impl/ZhiPuAiEmbeddingProvider.java
@@ -37,6 +37,7 @@ public class ZhiPuAiEmbeddingProvider implements BaseEmbedModelService {
.baseUrl(chatModelVo.getApiHost())
.apiKey(chatModelVo.getApiKey())
.model(chatModelVo.getModelName())
+ .dimensions(chatModelVo.getDimension())
.build()
.embedAll(textSegments);
}
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java
index 4e78f6f3..2937fc44 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java
@@ -16,7 +16,7 @@ public interface VectorStoreService {
List getQueryVector(QueryVectorBo queryVectorBo);
- void createSchema(String vectorModelName, String kid,String modelName);
+ void createSchema(String kid, String embeddingModelName);
void removeById(String id,String modelName) throws ServiceException;
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java
index 58b44f25..c9aeeab4 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java
@@ -2,16 +2,13 @@ package org.ruoyi.service.impl;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
-import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.service.VectorStoreService;
-import org.ruoyi.service.strategy.VectorStoreStrategy;
import org.ruoyi.service.strategy.VectorStoreStrategyFactory;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Service;
import java.util.*;
-import java.util.stream.Collectors;
/**
* 向量库管理
@@ -30,22 +27,21 @@ public class VectorStoreServiceImpl implements VectorStoreService {
/**
* 获取当前配置的向量库策略
*/
- private VectorStoreStrategy getCurrentStrategy() {
+ private VectorStoreService getCurrentStrategy() {
return strategyFactory.getStrategy();
}
@Override
- public void createSchema(String vectorModelName, String kid, String modelName) {
- log.info("创建向量库schema: vectorModelName={}, kid={}, modelName={}", vectorModelName, kid, modelName);
- VectorStoreStrategy strategy = getCurrentStrategy();
- strategy.createSchema(vectorModelName, kid, modelName);
+ public void createSchema(String kid, String modelName) {
+ VectorStoreService strategy = getCurrentStrategy();
+ strategy.createSchema(kid, modelName);
}
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
log.info("存储向量数据: kid={}, docId={}, 数据条数={}",
storeEmbeddingBo.getKid(), storeEmbeddingBo.getDocId(), storeEmbeddingBo.getChunkList().size());
- VectorStoreStrategy strategy = getCurrentStrategy();
+ VectorStoreService strategy = getCurrentStrategy();
strategy.storeEmbeddings(storeEmbeddingBo);
}
@@ -53,28 +49,28 @@ public class VectorStoreServiceImpl implements VectorStoreService {
public List getQueryVector(QueryVectorBo queryVectorBo) {
log.info("查询向量数据: kid={}, query={}, maxResults={}",
queryVectorBo.getKid(), queryVectorBo.getQuery(), queryVectorBo.getMaxResults());
- VectorStoreStrategy strategy = getCurrentStrategy();
+ VectorStoreService strategy = getCurrentStrategy();
return strategy.getQueryVector(queryVectorBo);
}
@Override
public void removeById(String id, String modelName) {
log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName);
- VectorStoreStrategy strategy = getCurrentStrategy();
+ VectorStoreService strategy = getCurrentStrategy();
strategy.removeById(id, modelName);
}
@Override
public void removeByDocId(String docId, String kid) {
log.info("根据docId删除向量数据: docId={}, kid={}", docId, kid);
- VectorStoreStrategy strategy = getCurrentStrategy();
+ VectorStoreService strategy = getCurrentStrategy();
strategy.removeByDocId(docId, kid);
}
@Override
public void removeByFid(String fid, String kid) {
log.info("根据fid删除向量数据: fid={}, kid={}", fid, kid);
- VectorStoreStrategy strategy = getCurrentStrategy();
+ VectorStoreService strategy = getCurrentStrategy();
strategy.removeByFid(fid, kid);
}
}
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/AbstractVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/AbstractVectorStoreStrategy.java
index 7fdeb195..d35d9739 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/AbstractVectorStoreStrategy.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/AbstractVectorStoreStrategy.java
@@ -8,40 +8,30 @@ import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
+import org.ruoyi.common.core.utils.StringUtils;
+import org.ruoyi.service.VectorStoreService;
+import org.ruoyi.embedding.EmbeddingModelFactory;
/**
* 向量库策略抽象基类
* 提供公共的方法实现,如embedding模型获取等
*
- * @author ageer
+ * @author Yzm
*/
@Slf4j
@RequiredArgsConstructor
-public abstract class AbstractVectorStoreStrategy implements VectorStoreStrategy {
+public abstract class AbstractVectorStoreStrategy implements VectorStoreService {
protected final VectorStoreProperties vectorStoreProperties;
+ private final EmbeddingModelFactory embeddingModelFactory;
+
/**
* 获取向量模型
*/
@SneakyThrows
- protected EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
- EmbeddingModel embeddingModel;
- if ("quentinz/bge-large-zh-v1.5".equals(modelName)) {
- embeddingModel = OllamaEmbeddingModel.builder()
- .baseUrl(baseUrl)
- .modelName(modelName)
- .build();
- } else if ("baai/bge-m3".equals(modelName)) {
- embeddingModel = OpenAiEmbeddingModel.builder()
- .apiKey(apiKey)
- .baseUrl(baseUrl)
- .modelName(modelName)
- .build();
- } else {
- throw new ServiceException("未找到对应向量化模型!");
- }
- return embeddingModel;
+ protected EmbeddingModel getEmbeddingModel(String modelName, Integer dimension) {
+ return embeddingModelFactory.createModel(modelName, dimension);
}
/**
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategy.java
deleted file mode 100644
index bd93e6fa..00000000
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategy.java
+++ /dev/null
@@ -1,18 +0,0 @@
-package org.ruoyi.service.strategy;
-
-import org.ruoyi.service.VectorStoreService;
-
-/**
- * 向量库策略接口
- * 继承VectorStoreService以避免重复定义相同的方法
- *
- * @author ageer
- */
-public interface VectorStoreStrategy extends VectorStoreService {
-
- /**
- * 获取向量库类型标识
- * @return 向量库类型(如:weaviate, milvus)
- */
- String getVectorStoreType();
-}
\ No newline at end of file
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategyFactory.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategyFactory.java
index fbd5b27b..0bab68cc 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategyFactory.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategyFactory.java
@@ -6,6 +6,7 @@ import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.service.strategy.impl.MilvusVectorStoreStrategy;
import org.ruoyi.service.strategy.impl.WeaviateVectorStoreStrategy;
+import org.ruoyi.service.VectorStoreService;
import org.springframework.stereotype.Component;
import java.util.HashMap;
@@ -15,7 +16,7 @@ import java.util.Map;
* 向量库策略工厂
* 根据配置动态选择向量库实现
*
- * @author ageer
+ * @author Yzm
*/
@Slf4j
@Component
@@ -26,7 +27,7 @@ public class VectorStoreStrategyFactory {
private final WeaviateVectorStoreStrategy weaviateStrategy;
private final MilvusVectorStoreStrategy milvusStrategy;
- private Map strategies;
+ private Map strategies;
@PostConstruct
public void init() {
@@ -39,36 +40,18 @@ public class VectorStoreStrategyFactory {
/**
* 获取当前配置的向量库策略
*/
- public VectorStoreStrategy getStrategy() {
+ public VectorStoreService getStrategy() {
String vectorStoreType = vectorStoreProperties.getType();
if (vectorStoreType == null || vectorStoreType.trim().isEmpty()) {
vectorStoreType = "weaviate"; // 默认使用weaviate
}
-
- VectorStoreStrategy strategy = strategies.get(vectorStoreType.toLowerCase());
+ VectorStoreService strategy = strategies.get(vectorStoreType.toLowerCase());
if (strategy == null) {
log.warn("未找到向量库策略: {}, 使用默认策略: weaviate", vectorStoreType);
strategy = strategies.get("weaviate");
}
-
log.debug("使用向量库策略: {}", vectorStoreType);
return strategy;
}
- /**
- * 根据类型获取向量库策略
- */
- public VectorStoreStrategy getStrategy(String type) {
- if (type == null || type.trim().isEmpty()) {
- return getStrategy();
- }
-
- VectorStoreStrategy strategy = strategies.get(type.toLowerCase());
- if (strategy == null) {
- log.warn("未找到向量库策略: {}, 使用默认策略", type);
- return getStrategy();
- }
-
- return strategy;
- }
}
\ No newline at end of file
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java
index 26a09f1f..8d3d50fa 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java
@@ -1,337 +1,157 @@
package org.ruoyi.service.strategy.impl;
-import org.ruoyi.common.core.exception.ServiceException;
+import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
+import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
-import io.milvus.client.MilvusServiceClient;
-import io.milvus.common.clientenum.ConsistencyLevelEnum;
-import io.milvus.grpc.*;
-import io.milvus.param.*;
-import io.milvus.param.collection.*;
-import io.milvus.param.dml.DeleteParam;
-import io.milvus.param.dml.InsertParam;
-import io.milvus.param.dml.SearchParam;
-import io.milvus.param.index.CreateIndexParam;
-import io.milvus.param.index.DescribeIndexParam;
-import io.milvus.response.DescCollResponseWrapper;
-import io.milvus.response.SearchResultsWrapper;
+import dev.langchain4j.store.embedding.EmbeddingMatch;
+import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
+import dev.langchain4j.store.embedding.EmbeddingStore;
+import dev.langchain4j.store.embedding.filter.Filter;
+import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
+import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
+import io.milvus.param.IndexType;
+import io.milvus.param.MetricType;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
+import org.ruoyi.embedding.EmbeddingModelFactory;
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
import org.springframework.stereotype.Component;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.IntStream;
-/**
- * Milvus向量库策略实现
- *
- * @author ageer
- */
@Slf4j
@Component
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
- private MilvusServiceClient milvusClient;
- public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties) {
- super(vectorStoreProperties);
+ private final Integer DIMENSION = 2048;
+
+ public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
+ super(vectorStoreProperties, embeddingModelFactory);
+ }
+
+ // 缓存不同集合与 autoFlush 配置的 Milvus 连接
+ private final Map> storeCache = new ConcurrentHashMap<>();
+
+ private EmbeddingStore getMilvusStore(String collectionName, boolean autoFlushOnInsert) {
+ String key = collectionName + "|" + autoFlushOnInsert;
+ return storeCache.computeIfAbsent(key, k ->
+ MilvusEmbeddingStore.builder()
+ .uri(vectorStoreProperties.getMilvus().getUrl())
+ .collectionName(collectionName)
+ .dimension(DIMENSION)
+ .indexType(IndexType.IVF_FLAT)
+ .metricType(MetricType.L2)
+ .autoFlushOnInsert(autoFlushOnInsert)
+ .idFieldName("id")
+ .textFieldName("text")
+ .metadataFieldName("metadata")
+ .vectorFieldName("vector")
+ .build()
+ );
}
@Override
- public String getVectorStoreType() {
- return "milvus";
- }
-
- @Override
- public void createSchema(String vectorModelName, String kid, String modelName) {
- String url = vectorStoreProperties.getMilvus().getUrl();
+ public void createSchema(String kid, String modelName) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
-
- // 创建Milvus客户端连接
- ConnectParam connectParam = ConnectParam.newBuilder()
- .withUri(url)
- .build();
- milvusClient = new MilvusServiceClient(connectParam);
-
- // 检查集合是否存在
- HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder()
- .withCollectionName(collectionName)
- .build();
-
- R hasCollectionResponse = milvusClient.hasCollection(hasCollectionParam);
- if (hasCollectionResponse.getStatus() != R.Status.Success.getCode()) {
- log.error("检查集合是否存在失败: {}", hasCollectionResponse.getMessage());
- return;
- }
-
- if (!hasCollectionResponse.getData()) {
- // 创建字段
- List fields = new ArrayList<>();
-
- // ID字段 (主键)
- fields.add(FieldType.newBuilder()
- .withName("id")
- .withDataType(DataType.Int64)
- .withPrimaryKey(true)
- .withAutoID(true)
- .build());
-
- // 文本字段
- fields.add(FieldType.newBuilder()
- .withName("text")
- .withDataType(DataType.VarChar)
- .withMaxLength(65535)
- .build());
-
- // fid字段
- fields.add(FieldType.newBuilder()
- .withName("fid")
- .withDataType(DataType.VarChar)
- .withMaxLength(255)
- .build());
-
- // kid字段
- fields.add(FieldType.newBuilder()
- .withName("kid")
- .withDataType(DataType.VarChar)
- .withMaxLength(255)
- .build());
-
- // docId字段
- fields.add(FieldType.newBuilder()
- .withName("docId")
- .withDataType(DataType.VarChar)
- .withMaxLength(255)
- .build());
-
- // 向量字段
- fields.add(FieldType.newBuilder()
- .withName("vector")
- .withDataType(DataType.FloatVector)
- .withDimension(1024) // 根据实际embedding维度调整
- .build());
-
- // 创建集合
- CreateCollectionParam createCollectionParam = CreateCollectionParam.newBuilder()
- .withCollectionName(collectionName)
- .withDescription("Knowledge base collection for " + kid)
- .withShardsNum(2)
- .withFieldTypes(fields)
- .build();
-
- R createCollectionResponse = milvusClient.createCollection(createCollectionParam);
- if (createCollectionResponse.getStatus() != R.Status.Success.getCode()) {
- log.error("创建集合失败: {}", createCollectionResponse.getMessage());
- return;
- }
-
- // 创建索引
- CreateIndexParam createIndexParam = CreateIndexParam.newBuilder()
- .withCollectionName(collectionName)
- .withFieldName("vector")
- .withIndexType(IndexType.IVF_FLAT)
- .withMetricType(MetricType.L2)
- .withExtraParam("{\"nlist\":1024}")
- .build();
-
- R createIndexResponse = milvusClient.createIndex(createIndexParam);
- if (createIndexResponse.getStatus() != R.Status.Success.getCode()) {
- log.error("创建索引失败: {}", createIndexResponse.getMessage());
- } else {
- log.info("Milvus集合和索引创建成功: {}", collectionName);
- }
- } else {
- log.info("Milvus集合已存在: {}", collectionName);
- }
+ // 使用缓存获取连接以确保只初始化一次
+ EmbeddingStore store = getMilvusStore(collectionName, true);
+ log.info("Milvus集合初始化完成: {}", collectionName);
}
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
- createSchema(storeEmbeddingBo.getVectorModelName(), storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
-
- EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
- storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
-
+ EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), DIMENSION);
+
List chunkList = storeEmbeddingBo.getChunkList();
List fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
String docId = storeEmbeddingBo.getDocId();
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
-
- log.info("Milvus向量存储条数记录: " + chunkList.size());
+
+ log.info("Milvus向量存储条数记录: {}", chunkList.size());
long startTime = System.currentTimeMillis();
-
- // 准备批量插入数据
- List fields = new ArrayList<>();
- List textList = new ArrayList<>();
- List fidListData = new ArrayList<>();
- List kidList = new ArrayList<>();
- List docIdList = new ArrayList<>();
- List> vectorList = new ArrayList<>();
-
- for (int i = 0; i < chunkList.size(); i++) {
+
+ // 复用连接,写入场景使用 autoFlush=false 以提升批量插入性能
+ EmbeddingStore embeddingStore = getMilvusStore(collectionName, false);
+
+ IntStream.range(0, chunkList.size()).forEach(i -> {
String text = chunkList.get(i);
String fid = fidList.get(i);
+ Metadata metadata = new Metadata();
+ metadata.put("fid", fid);
+ metadata.put("kid", kid);
+ metadata.put("docId", docId);
+
+ TextSegment textSegment = TextSegment.from(text, metadata);
Embedding embedding = embeddingModel.embed(text).content();
-
- textList.add(text);
- fidListData.add(fid);
- kidList.add(kid);
- docIdList.add(docId);
-
- List vector = new ArrayList<>();
- for (float f : embedding.vector()) {
- vector.add(f);
- }
- vectorList.add(vector);
- }
-
- // 构建字段数据
- fields.add(new InsertParam.Field("text", textList));
- fields.add(new InsertParam.Field("fid", fidListData));
- fields.add(new InsertParam.Field("kid", kidList));
- fields.add(new InsertParam.Field("docId", docIdList));
- fields.add(new InsertParam.Field("vector", vectorList));
-
- // 执行插入
- InsertParam insertParam = InsertParam.newBuilder()
- .withCollectionName(collectionName)
- .withFields(fields)
- .build();
-
- R insertResponse = milvusClient.insert(insertParam);
- if (insertResponse.getStatus() != R.Status.Success.getCode()) {
- log.error("Milvus向量存储失败: {}", insertResponse.getMessage());
- throw new ServiceException("Milvus向量存储失败");
- } else {
- log.info("Milvus向量存储成功,插入条数: {}", insertResponse.getData().getInsertCnt());
- }
-
+ embeddingStore.add(embedding, textSegment);
+ });
long endTime = System.currentTimeMillis();
- log.info("Milvus向量存储完成消耗时间:" + (endTime - startTime) / 1000 + "秒");
+ log.info("Milvus向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000);
}
@Override
public List getQueryVector(QueryVectorBo queryVectorBo) {
- createSchema(queryVectorBo.getVectorModelName(), queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
-
- EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
- queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
-
+ EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), DIMENSION);
+
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
-
+
+ // 查询复用连接,autoFlush 对查询无影响,此处保持 true
+ EmbeddingStore embeddingStore = getMilvusStore(collectionName, true);
+
List resultList = new ArrayList<>();
-
- // 加载集合到内存
- LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
- .withCollectionName(collectionName)
+ EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
+ .queryEmbedding(queryEmbedding)
+ .maxResults(queryVectorBo.getMaxResults())
.build();
- milvusClient.loadCollection(loadCollectionParam);
-
- // 准备查询向量
- List> searchVectors = new ArrayList<>();
- List queryVector = new ArrayList<>();
- for (float f : queryEmbedding.vector()) {
- queryVector.add(f);
- }
- searchVectors.add(queryVector);
-
- // 构建搜索参数
- SearchParam searchParam = SearchParam.newBuilder()
- .withCollectionName(collectionName)
- .withMetricType(MetricType.L2)
- .withOutFields(Arrays.asList("text", "fid", "kid", "docId"))
- .withTopK(queryVectorBo.getMaxResults())
- .withVectors(searchVectors)
- .withVectorFieldName("vector")
- .withParams("{\"nprobe\":10}")
- .build();
-
- R searchResponse = milvusClient.search(searchParam);
- if (searchResponse.getStatus() != R.Status.Success.getCode()) {
- log.error("Milvus查询失败: {}", searchResponse.getMessage());
- return resultList;
- }
-
- SearchResultsWrapper wrapper = new SearchResultsWrapper(searchResponse.getData().getResults());
-
- // 遍历搜索结果
- for (int i = 0; i < wrapper.getIDScore(0).size(); i++) {
- SearchResultsWrapper.IDScore idScore = wrapper.getIDScore(0).get(i);
-
- // 获取text字段数据
- List> textFieldData = wrapper.getFieldData("text", 0);
- if (textFieldData != null && i < textFieldData.size()) {
- Object textObj = textFieldData.get(i);
- if (textObj != null) {
- resultList.add(textObj.toString());
- log.debug("找到相似文本,ID: {}, 距离: {}, 内容: {}",
- idScore.getLongID(), idScore.getScore(), textObj.toString());
- }
+ List> matches = embeddingStore.search(request).matches();
+ for (EmbeddingMatch match : matches) {
+ TextSegment segment = match.embedded();
+ if (segment != null) {
+ resultList.add(segment.text());
}
}
-
return resultList;
}
@Override
@SneakyThrows
public void removeById(String id, String modelName) {
- String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + id;
-
- // 删除整个集合
- DropCollectionParam dropCollectionParam = DropCollectionParam.newBuilder()
- .withCollectionName(collectionName)
- .build();
-
- R dropResponse = milvusClient.dropCollection(dropCollectionParam);
- if (dropResponse.getStatus() != R.Status.Success.getCode()) {
- log.error("Milvus集合删除失败: {}", dropResponse.getMessage());
- throw new ServiceException("Milvus集合删除失败");
- } else {
- log.info("Milvus集合删除成功: {}", collectionName);
- }
+ // 注意:此处原逻辑使用 collectionname + id,保持现状
+ EmbeddingStore embeddingStore = getMilvusStore(vectorStoreProperties.getMilvus().getCollectionname() + id, false);
+ embeddingStore.remove(id);
}
@Override
public void removeByDocId(String docId, String kid) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
-
- String expr = "docId == \"" + docId + "\"";
- DeleteParam deleteParam = DeleteParam.newBuilder()
- .withCollectionName(collectionName)
- .withExpr(expr)
- .build();
-
- R deleteResponse = milvusClient.delete(deleteParam);
- if (deleteResponse.getStatus() != R.Status.Success.getCode()) {
- log.error("Milvus删除失败: {}", deleteResponse.getMessage());
- throw new ServiceException("Milvus删除失败");
- } else {
- log.info("Milvus成功删除 docId={} 的所有向量数据,删除条数: {}", docId, deleteResponse.getData().getDeleteCnt());
- }
+ EmbeddingStore embeddingStore = getMilvusStore(collectionName, false);
+ Filter filter = MetadataFilterBuilder.metadataKey("docId").isEqualTo(docId);
+ embeddingStore.removeAll(filter);
+ log.info("Milvus成功删除 docId={} 的所有向量数据", docId);
}
@Override
public void removeByFid(String fid, String kid) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
-
- String expr = "fid == \"" + fid + "\"";
- DeleteParam deleteParam = DeleteParam.newBuilder()
- .withCollectionName(collectionName)
- .withExpr(expr)
- .build();
-
- R deleteResponse = milvusClient.delete(deleteParam);
- if (deleteResponse.getStatus() != R.Status.Success.getCode()) {
- log.error("Milvus删除失败: {}", deleteResponse.getMessage());
- throw new ServiceException("Milvus删除失败");
- } else {
- log.info("Milvus成功删除 fid={} 的所有向量数据,删除条数: {}", fid, deleteResponse.getData().getDeleteCnt());
- }
+ EmbeddingStore embeddingStore = getMilvusStore(collectionName, false);
+ Filter filter = MetadataFilterBuilder.metadataKey("fid").isEqualTo(fid);
+ embeddingStore.removeAll(filter);
+ log.info("Milvus成功删除 fid={} 的所有向量数据", fid);
}
-}
\ No newline at end of file
+
+ @Override
+ public String getVectorStoreType() {
+ return "milvus";
+ }
+}
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/WeaviateVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/WeaviateVectorStoreStrategy.java
index 6275d939..3d61f8ac 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/WeaviateVectorStoreStrategy.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/WeaviateVectorStoreStrategy.java
@@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
+import org.ruoyi.embedding.EmbeddingModelFactory;
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
import org.springframework.stereotype.Component;
import java.util.*;
@@ -27,7 +28,7 @@ import java.util.*;
/**
* Weaviate向量库策略实现
*
- * @author ageer
+ * @author Yzm
*/
@Slf4j
@Component
@@ -35,8 +36,8 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
private WeaviateClient client;
- public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties) {
- super(vectorStoreProperties);
+ public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, EmbeddingModelFactory embeddingModelFactory) {
+ super(vectorStoreProperties, embeddingModelFactory);
}
@Override
@@ -45,7 +46,7 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
}
@Override
- public void createSchema(String vectorModelName, String kid, String modelName) {
+ public void createSchema(String kid, String embeddingModelName) {
String protocol = vectorStoreProperties.getWeaviate().getProtocol();
String host = vectorStoreProperties.getWeaviate().getHost();
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
@@ -84,9 +85,8 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
- createSchema(storeEmbeddingBo.getVectorModelName(), storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
- EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
- storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
+ createSchema(storeEmbeddingBo.getKid(),storeEmbeddingBo.getEmbeddingModelName());
+ EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), null);
List chunkList = storeEmbeddingBo.getChunkList();
List fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
@@ -118,9 +118,8 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
@Override
public List getQueryVector(QueryVectorBo queryVectorBo) {
- createSchema(queryVectorBo.getVectorModelName(), queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
- EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
- queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
+ createSchema(queryVectorBo.getKid(),queryVectorBo.getEmbeddingModelName());
+ EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),null);
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
float[] vector = queryEmbedding.vector();
List vectorStrings = new ArrayList<>();
diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
index ab5f4ae9..6a588185 100644
--- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
+++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
@@ -9,6 +9,7 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import lombok.RequiredArgsConstructor;
import org.ruoyi.chain.loader.ResourceLoader;
import org.ruoyi.chain.loader.ResourceLoaderFactory;
+import org.ruoyi.chat.enums.ChatModeType;
import org.ruoyi.common.core.domain.model.LoginUser;
import org.ruoyi.common.core.utils.MapstructUtils;
import org.ruoyi.common.core.utils.StringUtils;
@@ -237,8 +238,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
}
baseMapper.insert(knowledgeInfo);
if (knowledgeInfo != null) {
- vectorStoreService.createSchema(knowledgeInfo.getVectorModelName(),String.valueOf(knowledgeInfo.getId()),
- bo.getVectorModelName());
+ vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()), bo.getEmbeddingModelName());
}
} else {
baseMapper.updateById(knowledgeInfo);
@@ -313,15 +313,18 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
.eq(KnowledgeInfo::getId, kid));
// 通过向量模型查询模型信息
- ChatModelVo chatModelVo = chatModelService.queryById(knowledgeInfoVo.getEmbeddingModelId());
-
+ ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
+ // 未查到指定模型时,回退为向量分类最高优先级模型
+ if (chatModelVo == null) {
+ chatModelVo = chatModelService.selectModelByCategoryWithHighestPriority(ChatModeType.VECTOR.getCode());
+ }
StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
storeEmbeddingBo.setKid(kid);
storeEmbeddingBo.setDocId(docId);
storeEmbeddingBo.setFids(fids);
storeEmbeddingBo.setChunkList(chunkList);
- storeEmbeddingBo.setVectorModelName(knowledgeInfoVo.getVectorModelName());
- storeEmbeddingBo.setEmbeddingModelId(knowledgeInfoVo.getEmbeddingModelId());
+ storeEmbeddingBo.setVectorStoreName(knowledgeInfoVo.getVectorModelName());
+ storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
storeEmbeddingBo.setApiKey(chatModelVo.getApiKey());
storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost());
vectorStoreService.storeEmbeddings(storeEmbeddingBo);