diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java index 29bfecc3..21d2410d 100644 --- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java @@ -15,7 +15,7 @@ public interface VectorStoreService { List getQueryVector(QueryVectorBo queryVectorBo); - void createSchema(String kid,String modelName); + void createSchema(String vectorModelName, String kid,String modelName); void removeById(String id,String modelName); diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java index 799ce729..e7b023cc 100644 --- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java @@ -1,40 +1,19 @@ package org.ruoyi.service.impl; -import cn.hutool.json.JSONObject; -import com.google.protobuf.ServiceException; -import dev.langchain4j.data.embedding.Embedding; -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.model.ollama.OllamaEmbeddingModel; -import dev.langchain4j.model.openai.OpenAiEmbeddingModel; -import dev.langchain4j.store.embedding.EmbeddingMatch; -import dev.langchain4j.store.embedding.EmbeddingSearchRequest; -import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore; -import io.weaviate.client.Config; -import io.weaviate.client.WeaviateClient; -import io.weaviate.client.base.Result; -import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter; -import io.weaviate.client.v1.batch.model.BatchDeleteResponse; -import io.weaviate.client.v1.filters.Operator; -import io.weaviate.client.v1.filters.WhereFilter; -import io.weaviate.client.v1.graphql.model.GraphQLResponse; -import io.weaviate.client.v1.schema.model.Property; -import io.weaviate.client.v1.schema.model.Schema; -import io.weaviate.client.v1.schema.model.WeaviateClass; import lombok.RequiredArgsConstructor; -import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.ruoyi.common.core.service.ConfigService; import org.ruoyi.domain.bo.QueryVectorBo; import org.ruoyi.domain.bo.StoreEmbeddingBo; import org.ruoyi.service.VectorStoreService; +import org.ruoyi.service.strategy.VectorStoreStrategy; +import org.ruoyi.service.strategy.VectorStoreStrategyFactory; import org.springframework.stereotype.Service; -import java.util.*; -import java.util.stream.Collectors; + +import java.util.List; /** - * 向量库管理 + * 向量库管理服务实现 - 使用策略模式 * * @author ageer */ @@ -44,230 +23,61 @@ import java.util.stream.Collectors; public class VectorStoreServiceImpl implements VectorStoreService { private final ConfigService configService; + private final VectorStoreStrategyFactory strategyFactory; -// private EmbeddingStore embeddingStore; - private WeaviateClient client; + /** + * 获取当前配置的向量库策略 + */ + private VectorStoreStrategy getCurrentStrategy() { + String vectorStoreType = configService.getConfigValue("vector", "type"); + if (vectorStoreType == null || vectorStoreType.trim().isEmpty()) { + vectorStoreType = "weaviate"; // 默认使用weaviate + } + return strategyFactory.getStrategy(vectorStoreType); + } @Override - public void createSchema(String kid, String modelName) { - String protocol = configService.getConfigValue("weaviate", "protocol"); - String host = configService.getConfigValue("weaviate", "host"); - String className = configService.getConfigValue("weaviate", "classname")+kid; - // 创建 Weaviate 客户端 - client= new WeaviateClient(new Config(protocol, host)); - // 检查类是否存在,如果不存在就创建 schema - Result schemaResult = client.schema().getter().run(); - Schema schema = schemaResult.getResult(); - boolean classExists = false; - for (WeaviateClass weaviateClass : schema.getClasses()) { - if (weaviateClass.getClassName().equals(className)) { - classExists = true; - break; - } - } - if (!classExists) { - // 类不存在,创建 schema - WeaviateClass build = WeaviateClass.builder() - .className(className) - .vectorizer("none") - .properties( - List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(), - Property.builder().name("fid").dataType(Collections.singletonList("text")).build(), - Property.builder().name("kid").dataType(Collections.singletonList("text")).build(), - Property.builder().name("docId").dataType(Collections.singletonList("text")).build()) - ) - .build(); - Result createResult = client.schema().classCreator().withClass(build).run(); - if (createResult.hasErrors()) { - log.error("Schema 创建失败: {}", createResult.getError()); - } else { - log.info("Schema 创建成功: {}", className); - } - } -// embeddingStore = WeaviateEmbeddingStore.builder() -// .scheme(protocol) -// .host(host) -// .objectClass(className) -// .scheme(protocol) -// .avoidDups(true) -// .consistencyLevel("ALL") -// .build(); + public void createSchema(String vectorModelName, String kid, String modelName) { + log.info("创建向量库schema: vectorModelName={}, kid={}, modelName={}", vectorModelName, kid, modelName); + VectorStoreStrategy strategy = getCurrentStrategy(); + strategy.createSchema(vectorModelName, kid, modelName); } @Override public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) { - createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName()); - EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), - storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl()); - List chunkList = storeEmbeddingBo.getChunkList(); - List fidList = storeEmbeddingBo.getFids(); - String kid = storeEmbeddingBo.getKid(); - String docId = storeEmbeddingBo.getDocId(); - log.info("向量存储条数记录: " + chunkList.size()); - long startTime = System.currentTimeMillis(); - for (int i = 0; i < chunkList.size(); i++) { - String text = chunkList.get(i); - String fid = fidList.get(i); - Embedding embedding = embeddingModel.embed(text).content(); - Map properties = Map.of( - "text", text, - "fid",fid, - "kid", kid, - "docId", docId - ); - Float[] vector = toObjectArray(embedding.vector()); - client.data().creator() - .withClassName("LocalKnowledge" + kid) // 注意替换成实际类名 - .withProperties(properties) - .withVector(vector) - .run(); - } - long endTime = System.currentTimeMillis(); - log.info("向量存储完成消耗时间:"+ (endTime-startTime)/1000+"秒"); + log.info("存储向量数据: kid={}, docId={}, 数据条数={}", + storeEmbeddingBo.getKid(), storeEmbeddingBo.getDocId(), storeEmbeddingBo.getChunkList().size()); + VectorStoreStrategy strategy = getCurrentStrategy(); + strategy.storeEmbeddings(storeEmbeddingBo); } - private static Float[] toObjectArray(float[] primitive) { - Float[] result = new Float[primitive.length]; - for (int i = 0; i < primitive.length; i++) { - result[i] = primitive[i]; // 自动装箱 - } - return result; - } @Override public List getQueryVector(QueryVectorBo queryVectorBo) { - createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName()); - EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), - queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl()); - Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); - float[] vector = queryEmbedding.vector(); - List vectorStrings = new ArrayList<>(); - for (float v : vector) { - vectorStrings.add(String.valueOf(v)); - } - String vectorStr = String.join(",", vectorStrings); - String className = configService.getConfigValue("weaviate", "classname") ; - // 构建 GraphQL 查询 - String graphQLQuery = String.format( - "{\n" + - " Get {\n" + - " %s(nearVector: {vector: [%s]} limit: %d) {\n" + - " text\n" + - " fid\n" + - " kid\n" + - " docId\n" + - " _additional {\n" + - " distance\n" + - " id\n" + - " }\n" + - " }\n" + - " }\n" + - "}", - className+ queryVectorBo.getKid(), - vectorStr, - queryVectorBo.getMaxResults() - ); - - Result result = client.graphQL().raw().withQuery(graphQLQuery).run(); - List resultList = new ArrayList<>(); - if (result != null && !result.hasErrors()) { - Object data = result.getResult().getData(); - JSONObject entries = new JSONObject(data); - Map entriesMap = entries.get("Get", Map.class); - cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid()); - if(objects.isEmpty()){ - return resultList; - } - for (Object object : objects) { - Map map = (Map) object; - String content = map.get("text"); - resultList.add( content); - } - return resultList; - } else { - log.error("GraphQL 查询失败: {}", result.getError()); - return resultList; - } + log.info("查询向量数据: kid={}, query={}, maxResults={}", + queryVectorBo.getKid(), queryVectorBo.getQuery(), queryVectorBo.getMaxResults()); + VectorStoreStrategy strategy = getCurrentStrategy(); + return strategy.getQueryVector(queryVectorBo); } @Override - @SneakyThrows public void removeById(String id, String modelName) { - String protocol = configService.getConfigValue("weaviate", "protocol"); - String host = configService.getConfigValue("weaviate", "host"); - String className = configService.getConfigValue("weaviate", "classname"); - String finalClassName = className + id; - WeaviateClient client = new WeaviateClient(new Config(protocol, host)); - Result result = client.schema().classDeleter().withClassName(finalClassName).run(); - if (result.hasErrors()) { - log.error("失败删除向量: " + result.getError()); - throw new ServiceException("失败删除向量数据!"); - } else { - log.info("成功删除向量数据: " + result.getResult()); - } + log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName); + VectorStoreStrategy strategy = getCurrentStrategy(); + strategy.removeById(id, modelName); } @Override public void removeByDocId(String docId, String kid) { - String className = configService.getConfigValue("weaviate", "classname") + kid; - // 构建 Where 条件 - WhereFilter whereFilter = WhereFilter.builder() - .path("docId") - .operator(Operator.Equal) - .valueText(docId) - .build(); - ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter(); - Result result = deleter.withClassName(className) - .withWhere(whereFilter) - .run(); - if (result != null && !result.hasErrors()) { - log.info("成功删除 docId={} 的所有向量数据", docId); - } else { - log.error("删除失败: {}", result.getError()); - } + log.info("根据docId删除向量数据: docId={}, kid={}", docId, kid); + VectorStoreStrategy strategy = getCurrentStrategy(); + strategy.removeByDocId(docId, kid); } @Override public void removeByFid(String fid, String kid) { - String className = configService.getConfigValue("weaviate", "classname") + kid; - // 构建 Where 条件 - WhereFilter whereFilter = WhereFilter.builder() - .path("fid") - .operator(Operator.Equal) - .valueText(fid) - .build(); - ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter(); - Result result = deleter.withClassName(className) - .withWhere(whereFilter) - .run(); - if (result != null && !result.hasErrors()) { - log.info("成功删除 fid={} 的所有向量数据", fid); - } else { - log.error("删除失败: {}", result.getError()); - } + log.info("根据fid删除向量数据: fid={}, kid={}", fid, kid); + VectorStoreStrategy strategy = getCurrentStrategy(); + strategy.removeByFid(fid, kid); } - - /** - * 获取向量模型 - */ - @SneakyThrows - public EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) { - EmbeddingModel embeddingModel; - if ("quentinz/bge-large-zh-v1.5".equals(modelName)) { - embeddingModel = OllamaEmbeddingModel.builder() - .baseUrl(baseUrl) - .modelName(modelName) - .build(); - } else if ("baai/bge-m3".equals(modelName)) { - embeddingModel = OpenAiEmbeddingModel.builder() - .apiKey(apiKey) - .baseUrl(baseUrl) - .modelName(modelName) - .build(); - } else { - throw new ServiceException("未找到对应向量化模型!"); - } - return embeddingModel; - } - } diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/AbstractVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/AbstractVectorStoreStrategy.java new file mode 100644 index 00000000..104714cb --- /dev/null +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/AbstractVectorStoreStrategy.java @@ -0,0 +1,62 @@ +package org.ruoyi.service.strategy; + +import com.google.protobuf.ServiceException; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.ollama.OllamaEmbeddingModel; +import dev.langchain4j.model.openai.OpenAiEmbeddingModel; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.service.ConfigService; + +/** + * 向量库策略抽象基类 + * 提供公共的方法实现,如embedding模型获取等 + * + * @author ageer + */ +@Slf4j +@RequiredArgsConstructor +public abstract class AbstractVectorStoreStrategy implements VectorStoreStrategy { + + protected final ConfigService configService; + + /** + * 获取向量模型 + */ + @SneakyThrows + protected EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) { + EmbeddingModel embeddingModel; + if ("quentinz/bge-large-zh-v1.5".equals(modelName)) { + embeddingModel = OllamaEmbeddingModel.builder() + .baseUrl(baseUrl) + .modelName(modelName) + .build(); + } else if ("baai/bge-m3".equals(modelName)) { + embeddingModel = OpenAiEmbeddingModel.builder() + .apiKey(apiKey) + .baseUrl(baseUrl) + .modelName(modelName) + .build(); + } else { + throw new ServiceException("未找到对应向量化模型!"); + } + return embeddingModel; + } + + /** + * 将float数组转换为Float对象数组 + */ + protected static Float[] toObjectArray(float[] primitive) { + Float[] result = new Float[primitive.length]; + for (int i = 0; i < primitive.length; i++) { + result[i] = primitive[i]; // 自动装箱 + } + return result; + } + + /** + * 获取向量库类型标识 + */ + public abstract String getVectorStoreType(); +} \ No newline at end of file diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategy.java new file mode 100644 index 00000000..bd93e6fa --- /dev/null +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategy.java @@ -0,0 +1,18 @@ +package org.ruoyi.service.strategy; + +import org.ruoyi.service.VectorStoreService; + +/** + * 向量库策略接口 + * 继承VectorStoreService以避免重复定义相同的方法 + * + * @author ageer + */ +public interface VectorStoreStrategy extends VectorStoreService { + + /** + * 获取向量库类型标识 + * @return 向量库类型(如:weaviate, milvus) + */ + String getVectorStoreType(); +} \ No newline at end of file diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategyFactory.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategyFactory.java new file mode 100644 index 00000000..1e606b22 --- /dev/null +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategyFactory.java @@ -0,0 +1,88 @@ +package org.ruoyi.service.strategy; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.service.ConfigService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.stereotype.Component; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * 向量库策略工厂 + * 根据配置动态选择向量库实现 + * + * @author ageer + */ +@Slf4j +@Component +@RequiredArgsConstructor +public class VectorStoreStrategyFactory implements ApplicationContextAware { + + private final ConfigService configService; + private final Map strategyMap = new ConcurrentHashMap<>(); + private ApplicationContext applicationContext; + + @Override + public void setApplicationContext(ApplicationContext applicationContext) { + this.applicationContext = applicationContext; + initStrategies(); + } + + /** + * 初始化所有策略实现 + */ + private void initStrategies() { + Map strategies = applicationContext.getBeansOfType(VectorStoreStrategy.class); + for (VectorStoreStrategy strategy : strategies.values()) { + if (strategy instanceof AbstractVectorStoreStrategy) { + AbstractVectorStoreStrategy abstractStrategy = (AbstractVectorStoreStrategy) strategy; + strategyMap.put(abstractStrategy.getVectorStoreType(), strategy); + log.info("注册向量库策略: {}", abstractStrategy.getVectorStoreType()); + } + } + } + + /** + * 获取当前配置的向量库策略 + */ + public VectorStoreStrategy getStrategy() { + String vectorStoreType = configService.getConfigValue("vector", "store_type"); + if (vectorStoreType == null || vectorStoreType.isEmpty()) { + vectorStoreType = "weaviate"; // 默认使用weaviate + } + + VectorStoreStrategy strategy = strategyMap.get(vectorStoreType); + if (strategy == null) { + log.warn("未找到向量库策略: {}, 使用默认策略: weaviate", vectorStoreType); + strategy = strategyMap.get("weaviate"); + } + + if (strategy == null) { + throw new RuntimeException("未找到可用的向量库策略实现"); + } + + return strategy; + } + + /** + * 根据类型获取特定的向量库策略 + */ + public VectorStoreStrategy getStrategy(String vectorStoreType) { + VectorStoreStrategy strategy = strategyMap.get(vectorStoreType); + if (strategy == null) { + throw new RuntimeException("未找到向量库策略: " + vectorStoreType); + } + return strategy; + } + + /** + * 获取所有可用的向量库类型 + */ + public String[] getAvailableTypes() { + return strategyMap.keySet().toArray(new String[0]); + } +} \ No newline at end of file diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java new file mode 100644 index 00000000..25605629 --- /dev/null +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java @@ -0,0 +1,312 @@ +package org.ruoyi.service.strategy.impl; + +import com.google.protobuf.ServiceException; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.model.embedding.EmbeddingModel; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.service.ConfigService; +import org.ruoyi.domain.bo.QueryVectorBo; +import org.ruoyi.domain.bo.StoreEmbeddingBo; +import org.ruoyi.service.strategy.AbstractVectorStoreStrategy; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Milvus向量库策略实现 + * + * @author ageer + */ +@Slf4j +@Component +public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { + + // Milvus客户端和相关配置 + // private MilvusClient milvusClient; + + public MilvusVectorStoreStrategy(ConfigService configService) { + super(configService); + } + + @Override + public String getVectorStoreType() { + return "milvus"; + } + + @Override + public void createSchema(String vectorModelName, String kid, String modelName) { + log.info("Milvus创建schema: vectorModelName={}, kid={}, modelName={}", vectorModelName, kid, modelName); + + // 1. 获取Milvus配置 + String host = configService.getConfigValue("milvus", "host"); + String port = configService.getConfigValue("milvus", "port"); + String collectionName = configService.getConfigValue("milvus", "collectionname") + kid; + + // 2. 初始化Milvus客户端 + // ConnectParam connectParam = ConnectParam.newBuilder() + // .withHost(host) + // .withPort(Integer.parseInt(port)) + // .build(); + // milvusClient = new MilvusClient(connectParam); + + // 3. 检查集合是否存在,如果不存在则创建 + // HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder() + // .withCollectionName(collectionName) + // .build(); + // R hasCollectionResponse = milvusClient.hasCollection(hasCollectionParam); + // + // if (!hasCollectionResponse.getData()) { + // // 创建集合 + // List fieldsSchema = new ArrayList<>(); + // + // // 主键字段 + // fieldsSchema.add(FieldType.newBuilder() + // .withName("id") + // .withDataType(DataType.Int64) + // .withPrimaryKey(true) + // .withAutoID(true) + // .build()); + // + // // 文本字段 + // fieldsSchema.add(FieldType.newBuilder() + // .withName("text") + // .withDataType(DataType.VarChar) + // .withMaxLength(65535) + // .build()); + // + // // fid字段 + // fieldsSchema.add(FieldType.newBuilder() + // .withName("fid") + // .withDataType(DataType.VarChar) + // .withMaxLength(255) + // .build()); + // + // // kid字段 + // fieldsSchema.add(FieldType.newBuilder() + // .withName("kid") + // .withDataType(DataType.VarChar) + // .withMaxLength(255) + // .build()); + // + // // docId字段 + // fieldsSchema.add(FieldType.newBuilder() + // .withName("docId") + // .withDataType(DataType.VarChar) + // .withMaxLength(255) + // .build()); + // + // // 向量字段 + // fieldsSchema.add(FieldType.newBuilder() + // .withName("vector") + // .withDataType(DataType.FloatVector) + // .withDimension(1536) // 根据实际embedding维度调整 + // .build()); + // + // CreateCollectionParam createCollectionParam = CreateCollectionParam.newBuilder() + // .withCollectionName(collectionName) + // .withDescription("Knowledge base collection for " + kid) + // .withShardsNum(2) + // .withFieldTypes(fieldsSchema) + // .build(); + // + // R createCollectionResponse = milvusClient.createCollection(createCollectionParam); + // if (createCollectionResponse.getStatus() == R.Status.Success.getCode()) { + // log.info("Milvus集合创建成功: {}", collectionName); + // + // // 创建索引 + // IndexParam indexParam = IndexParam.newBuilder() + // .withCollectionName(collectionName) + // .withFieldName("vector") + // .withIndexType(IndexType.IVF_FLAT) + // .withMetricType(MetricType.L2) + // .withExtraParam("{\"nlist\":1024}") + // .build(); + // + // R createIndexResponse = milvusClient.createIndex(indexParam); + // if (createIndexResponse.getStatus() == R.Status.Success.getCode()) { + // log.info("Milvus索引创建成功: {}", collectionName); + // } else { + // log.error("Milvus索引创建失败: {}", createIndexResponse.getMessage()); + // } + // } else { + // log.error("Milvus集合创建失败: {}", createCollectionResponse.getMessage()); + // } + // } + + log.info("Milvus schema创建完成: {}", collectionName); + } + + @Override + public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) { + createSchema(storeEmbeddingBo.getVectorModelName(), storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName()); + + EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), + storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl()); + + List chunkList = storeEmbeddingBo.getChunkList(); + List fidList = storeEmbeddingBo.getFids(); + String kid = storeEmbeddingBo.getKid(); + String docId = storeEmbeddingBo.getDocId(); + String collectionName = configService.getConfigValue("milvus", "collectionname") + kid; + + log.info("Milvus向量存储条数记录: " + chunkList.size()); + long startTime = System.currentTimeMillis(); + + // List fields = new ArrayList<>(); + // List textList = new ArrayList<>(); + // List fidListData = new ArrayList<>(); + // List kidList = new ArrayList<>(); + // List docIdList = new ArrayList<>(); + // List> vectorList = new ArrayList<>(); + + for (int i = 0; i < chunkList.size(); i++) { + String text = chunkList.get(i); + String fid = fidList.get(i); + Embedding embedding = embeddingModel.embed(text).content(); + + // textList.add(text); + // fidListData.add(fid); + // kidList.add(kid); + // docIdList.add(docId); + // + // List vector = new ArrayList<>(); + // for (float f : embedding.vector()) { + // vector.add(f); + // } + // vectorList.add(vector); + } + + // fields.add(new InsertParam.Field("text", textList)); + // fields.add(new InsertParam.Field("fid", fidListData)); + // fields.add(new InsertParam.Field("kid", kidList)); + // fields.add(new InsertParam.Field("docId", docIdList)); + // fields.add(new InsertParam.Field("vector", vectorList)); + // + // InsertParam insertParam = InsertParam.newBuilder() + // .withCollectionName(collectionName) + // .withFields(fields) + // .build(); + // + // R insertResponse = milvusClient.insert(insertParam); + // if (insertResponse.getStatus() == R.Status.Success.getCode()) { + // log.info("Milvus向量存储成功,插入条数: {}", insertResponse.getData().getInsertCnt()); + // } else { + // log.error("Milvus向量存储失败: {}", insertResponse.getMessage()); + // throw new ServiceException("Milvus向量存储失败: " + insertResponse.getMessage()); + // } + + long endTime = System.currentTimeMillis(); + log.info("Milvus向量存储完成消耗时间:" + (endTime - startTime) / 1000 + "秒"); + } + + @Override + public List getQueryVector(QueryVectorBo queryVectorBo) { + createSchema(queryVectorBo.getVectorModelName(), queryVectorBo.getKid(), queryVectorBo.getVectorModelName()); + + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), + queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl()); + + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + String collectionName = configService.getConfigValue("milvus", "collectionname") + queryVectorBo.getKid(); + + List resultList = new ArrayList<>(); + + // List searchOutputFields = List.of("text", "fid", "kid", "docId"); + // List> searchVectors = new ArrayList<>(); + // List queryVector = new ArrayList<>(); + // for (float f : queryEmbedding.vector()) { + // queryVector.add(f); + // } + // searchVectors.add(queryVector); + // + // SearchParam searchParam = SearchParam.newBuilder() + // .withCollectionName(collectionName) + // .withMetricType(MetricType.L2) + // .withOutFields(searchOutputFields) + // .withTopK(queryVectorBo.getMaxResults()) + // .withVectors(searchVectors) + // .withVectorFieldName("vector") + // .withParams("{\"nprobe\":10}") + // .build(); + // + // R searchResponse = milvusClient.search(searchParam); + // if (searchResponse.getStatus() == R.Status.Success.getCode()) { + // SearchResults searchResults = searchResponse.getData(); + // List queryResults = searchResults.getResults(); + // + // for (SearchResults.QueryResult queryResult : queryResults) { + // List rows = queryResult.getRows(); + // for (SearchResults.QueryResult.Row row : rows) { + // String text = (String) row.get("text"); + // resultList.add(text); + // } + // } + // } else { + // log.error("Milvus查询失败: {}", searchResponse.getMessage()); + // } + + return resultList; + } + + @Override + public void removeById(String id, String modelName) { + String collectionName = configService.getConfigValue("milvus", "collectionname") + id; + + // DropCollectionParam dropCollectionParam = DropCollectionParam.newBuilder() + // .withCollectionName(collectionName) + // .build(); + // + // R dropResponse = milvusClient.dropCollection(dropCollectionParam); + // if (dropResponse.getStatus() == R.Status.Success.getCode()) { + // log.info("Milvus集合删除成功: {}", collectionName); + // } else { + // log.error("Milvus集合删除失败: {}", dropResponse.getMessage()); + // throw new ServiceException("Milvus集合删除失败: " + dropResponse.getMessage()); + // } + + log.info("Milvus删除集合: {}", collectionName); + } + + @Override + public void removeByDocId(String docId, String kid) { + String collectionName = configService.getConfigValue("milvus", "collectionname") + kid; + + // String expr = "docId == \"" + docId + "\""; + // DeleteParam deleteParam = DeleteParam.newBuilder() + // .withCollectionName(collectionName) + // .withExpr(expr) + // .build(); + // + // R deleteResponse = milvusClient.delete(deleteParam); + // if (deleteResponse.getStatus() == R.Status.Success.getCode()) { + // log.info("Milvus成功删除 docId={} 的所有向量数据,删除条数: {}", docId, deleteResponse.getData().getDeleteCnt()); + // } else { + // log.error("Milvus删除失败: {}", deleteResponse.getMessage()); + // } + + log.info("Milvus删除docId={}的数据", docId); + } + + @Override + public void removeByFid(String fid, String kid) { + String collectionName = configService.getConfigValue("milvus", "collectionname") + kid; + + // String expr = "fid == \"" + fid + "\""; + // DeleteParam deleteParam = DeleteParam.newBuilder() + // .withCollectionName(collectionName) + // .withExpr(expr) + // .build(); + // + // R deleteResponse = milvusClient.delete(deleteParam); + // if (deleteResponse.getStatus() == R.Status.Success.getCode()) { + // log.info("Milvus成功删除 fid={} 的所有向量数据,删除条数: {}", fid, deleteResponse.getData().getDeleteCnt()); + // } else { + // log.error("Milvus删除失败: {}", deleteResponse.getMessage()); + // } + + log.info("Milvus删除fid={}的数据", fid); + } +} \ No newline at end of file diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/WeaviateVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/WeaviateVectorStoreStrategy.java new file mode 100644 index 00000000..9e3d3aec --- /dev/null +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/WeaviateVectorStoreStrategy.java @@ -0,0 +1,233 @@ +package org.ruoyi.service.strategy.impl; + +import cn.hutool.json.JSONObject; +import com.google.protobuf.ServiceException; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.model.embedding.EmbeddingModel; +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter; +import io.weaviate.client.v1.batch.model.BatchDeleteResponse; +import io.weaviate.client.v1.filters.Operator; +import io.weaviate.client.v1.filters.WhereFilter; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.schema.model.Property; +import io.weaviate.client.v1.schema.model.Schema; +import io.weaviate.client.v1.schema.model.WeaviateClass; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.service.ConfigService; +import org.ruoyi.domain.bo.QueryVectorBo; +import org.ruoyi.domain.bo.StoreEmbeddingBo; +import org.ruoyi.service.strategy.AbstractVectorStoreStrategy; +import org.springframework.stereotype.Component; +import java.util.*; + +/** + * Weaviate向量库策略实现 + * + * @author ageer + */ +@Slf4j +@Component +public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { + + private WeaviateClient client; + + public WeaviateVectorStoreStrategy(ConfigService configService) { + super(configService); + } + + @Override + public String getVectorStoreType() { + return "weaviate"; + } + + @Override + public void createSchema(String vectorModelName, String kid, String modelName) { + String protocol = configService.getConfigValue("weaviate", "protocol"); + String host = configService.getConfigValue("weaviate", "host"); + String className = configService.getConfigValue("weaviate", "classname") + kid; + // 创建 Weaviate 客户端 + client = new WeaviateClient(new Config(protocol, host)); + // 检查类是否存在,如果不存在就创建 schema + Result schemaResult = client.schema().getter().run(); + Schema schema = schemaResult.getResult(); + boolean classExists = false; + for (WeaviateClass weaviateClass : schema.getClasses()) { + if (weaviateClass.getClassName().equals(className)) { + classExists = true; + break; + } + } + if (!classExists) { + // 类不存在,创建 schema + WeaviateClass build = WeaviateClass.builder() + .className(className) + .vectorizer("none") + .properties( + List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(), + Property.builder().name("fid").dataType(Collections.singletonList("text")).build(), + Property.builder().name("kid").dataType(Collections.singletonList("text")).build(), + Property.builder().name("docId").dataType(Collections.singletonList("text")).build()) + ) + .build(); + Result createResult = client.schema().classCreator().withClass(build).run(); + if (createResult.hasErrors()) { + log.error("Schema 创建失败: {}", createResult.getError()); + } else { + log.info("Schema 创建成功: {}", className); + } + } + } + + @Override + public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) { + createSchema(storeEmbeddingBo.getVectorModelName(), storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), + storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl()); + List chunkList = storeEmbeddingBo.getChunkList(); + List fidList = storeEmbeddingBo.getFids(); + String kid = storeEmbeddingBo.getKid(); + String docId = storeEmbeddingBo.getDocId(); + log.info("向量存储条数记录: " + chunkList.size()); + long startTime = System.currentTimeMillis(); + for (int i = 0; i < chunkList.size(); i++) { + String text = chunkList.get(i); + String fid = fidList.get(i); + Embedding embedding = embeddingModel.embed(text).content(); + Map properties = Map.of( + "text", text, + "fid", fid, + "kid", kid, + "docId", docId + ); + Float[] vector = toObjectArray(embedding.vector()); + client.data().creator() + .withClassName("LocalKnowledge" + kid) + .withProperties(properties) + .withVector(vector) + .run(); + } + long endTime = System.currentTimeMillis(); + log.info("向量存储完成消耗时间:" + (endTime - startTime) / 1000 + "秒"); + } + + + + @Override + public List getQueryVector(QueryVectorBo queryVectorBo) { + createSchema(queryVectorBo.getVectorModelName(), queryVectorBo.getKid(), queryVectorBo.getVectorModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), + queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl()); + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + float[] vector = queryEmbedding.vector(); + List vectorStrings = new ArrayList<>(); + for (float v : vector) { + vectorStrings.add(String.valueOf(v)); + } + String vectorStr = String.join(",", vectorStrings); + String className = configService.getConfigValue("weaviate", "classname"); + + // 构建 GraphQL 查询 + String graphQLQuery = String.format( + "{\n" + + " Get {\n" + + " %s(nearVector: {vector: [%s]} limit: %d) {\n" + + " text\n" + + " fid\n" + + " kid\n" + + " docId\n" + + " _additional {\n" + + " distance\n" + + " id\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + className + queryVectorBo.getKid(), + vectorStr, + queryVectorBo.getMaxResults() + ); + + Result result = client.graphQL().raw().withQuery(graphQLQuery).run(); + List resultList = new ArrayList<>(); + if (result != null && !result.hasErrors()) { + Object data = result.getResult().getData(); + JSONObject entries = new JSONObject(data); + Map entriesMap = entries.get("Get", Map.class); + cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid()); + if (objects.isEmpty()) { + return resultList; + } + for (Object object : objects) { + Map map = (Map) object; + String content = map.get("text"); + resultList.add(content); + } + return resultList; + } else { + log.error("GraphQL 查询失败: {}", result.getError()); + return resultList; + } + } + + @Override + @SneakyThrows + public void removeById(String id, String modelName) { + String protocol = configService.getConfigValue("weaviate", "protocol"); + String host = configService.getConfigValue("weaviate", "host"); + String className = configService.getConfigValue("weaviate", "classname"); + String finalClassName = className + id; + WeaviateClient client = new WeaviateClient(new Config(protocol, host)); + Result result = client.schema().classDeleter().withClassName(finalClassName).run(); + if (result.hasErrors()) { + log.error("失败删除向量: " + result.getError()); + throw new ServiceException("失败删除向量数据!"); + } else { + log.info("成功删除向量数据: " + result.getResult()); + } + } + + @Override + public void removeByDocId(String docId, String kid) { + String className = configService.getConfigValue("weaviate", "classname") + kid; + // 构建 Where 条件 + WhereFilter whereFilter = WhereFilter.builder() + .path("docId") + .operator(Operator.Equal) + .valueText(docId) + .build(); + ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter(); + Result result = deleter.withClassName(className) + .withWhere(whereFilter) + .run(); + if (result != null && !result.hasErrors()) { + log.info("成功删除 docId={} 的所有向量数据", docId); + } else { + log.error("删除失败: {}", result.getError()); + } + } + + @Override + public void removeByFid(String fid, String kid) { + String className = configService.getConfigValue("weaviate", "classname") + kid; + // 构建 Where 条件 + WhereFilter whereFilter = WhereFilter.builder() + .path("fid") + .operator(Operator.Equal) + .valueText(fid) + .build(); + ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter(); + Result result = deleter.withClassName(className) + .withWhere(whereFilter) + .run(); + if (result != null && !result.hasErrors()) { + log.info("成功删除 fid={} 的所有向量数据", fid); + } else { + log.error("删除失败: {}", result.getError()); + } + } + +} \ No newline at end of file diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java index a5be768b..23148c44 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java @@ -216,7 +216,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { } baseMapper.insert(knowledgeInfo); if (knowledgeInfo != null) { - vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()), + vectorStoreService.createSchema(knowledgeInfo.getVectorModelName(),String.valueOf(knowledgeInfo.getId()), bo.getVectorModelName()); } } else { @@ -257,6 +257,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { knowledgeAttach.setDocType(fileName.substring(fileName.lastIndexOf(".") + 1)); String content = ""; ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(knowledgeAttach.getDocType()); + // 文档分段入库 List fids = new ArrayList<>(); try { content = resourceLoader.getContent(file.getInputStream()); @@ -264,6 +265,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { List knowledgeFragmentList = new ArrayList<>(); if (CollUtil.isNotEmpty(chunkList)) { for (int i = 0; i < chunkList.size(); i++) { + // 生成知识片段ID String fid = RandomUtil.randomString(10); fids.add(fid); KnowledgeFragment knowledgeFragment = new KnowledgeFragment();