diff --git a/ruoyi-admin/src/main/resources/application.yml b/ruoyi-admin/src/main/resources/application.yml index 6d5e6d2f..2bb9f120 100644 --- a/ruoyi-admin/src/main/resources/application.yml +++ b/ruoyi-admin/src/main/resources/application.yml @@ -328,3 +328,17 @@ spring: servers-configuration: classpath:mcp-server.json request-timeout: 300s +--- # 向量库配置 +vector-store: + # 向量存储类型 (weaviate/milvus) + type: weaviate + # Weaviate配置 + weaviate: + protocol: http + host: 127.0.0.1:6038 + classname: LocalKnowledge + # Milvus配置 + milvus: + url: http://localhost:19530 + collectionname: LocalKnowledge + diff --git a/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/VectorStoreProperties.java b/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/VectorStoreProperties.java new file mode 100644 index 00000000..98f3ddc4 --- /dev/null +++ b/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/VectorStoreProperties.java @@ -0,0 +1,62 @@ +package org.ruoyi.common.core.config; + +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; + +/** + * 向量库配置属性 + * + * @author ageer + */ +@Data +@Component +@ConfigurationProperties(prefix = "vector-store") +public class VectorStoreProperties { + + /** + * 向量库类型 + */ + private String type = "weaviate"; + + /** + * Weaviate配置 + */ + private Weaviate weaviate = new Weaviate(); + + /** + * Milvus配置 + */ + private Milvus milvus = new Milvus(); + + @Data + public static class Weaviate { + /** + * 协议 + */ + private String protocol = "http"; + + /** + * 主机地址 + */ + private String host = "localhost:8080"; + + /** + * 类名 + */ + private String classname = "Document"; + } + + @Data + public static class Milvus { + /** + * 连接URL + */ + private String url = "http://localhost:19530"; + + /** + * 集合名称 + */ + private String collectionname = "knowledge_base"; + } +} \ No newline at end of file diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml b/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml index 0394b6d1..90933a7d 100644 --- a/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml +++ b/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml @@ -74,6 +74,12 @@ 1.19.6 + + io.milvus + milvus-sdk-java + 2.6.4 + + dev.langchain4j langchain4j-open-ai diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java index 29bfecc3..4e78f6f3 100644 --- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java @@ -1,5 +1,6 @@ package org.ruoyi.service; +import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.domain.bo.QueryVectorBo; import org.ruoyi.domain.bo.StoreEmbeddingBo; @@ -11,15 +12,15 @@ import java.util.List; */ public interface VectorStoreService { - void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo); + void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) throws ServiceException; List getQueryVector(QueryVectorBo queryVectorBo); - void createSchema(String kid,String modelName); + void createSchema(String vectorModelName, String kid,String modelName); - void removeById(String id,String modelName); + void removeById(String id,String modelName) throws ServiceException; - void removeByDocId(String docId, String kid); + void removeByDocId(String docId, String kid) throws ServiceException; - void removeByFid(String fid, String kid); + void removeByFid(String fid, String kid) throws ServiceException; } diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java index d81905ae..58b44f25 100644 --- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java @@ -1,36 +1,14 @@ package org.ruoyi.service.impl; -import cn.hutool.json.JSONObject; -import com.google.protobuf.ServiceException; -import dev.langchain4j.data.embedding.Embedding; -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.model.ollama.OllamaEmbeddingModel; -import dev.langchain4j.model.openai.OpenAiEmbeddingModel; -import dev.langchain4j.store.embedding.EmbeddingMatch; -import dev.langchain4j.store.embedding.EmbeddingSearchRequest; -import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore; -import io.weaviate.client.Config; -import io.weaviate.client.WeaviateClient; -import io.weaviate.client.base.Result; -import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter; -import io.weaviate.client.v1.batch.model.BatchDeleteResponse; -import io.weaviate.client.v1.filters.Operator; -import io.weaviate.client.v1.filters.WhereFilter; -import io.weaviate.client.v1.graphql.model.GraphQLResponse; -import io.weaviate.client.v1.schema.model.Property; -import io.weaviate.client.v1.schema.model.Schema; -import io.weaviate.client.v1.schema.model.WeaviateClass; import lombok.RequiredArgsConstructor; -import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.ruoyi.common.core.service.ConfigService; import org.ruoyi.domain.bo.QueryVectorBo; import org.ruoyi.domain.bo.StoreEmbeddingBo; -import org.ruoyi.embedding.BaseEmbedModelService; -import org.ruoyi.embedding.EmbeddingModelFactory; import org.ruoyi.service.VectorStoreService; +import org.ruoyi.service.strategy.VectorStoreStrategy; +import org.ruoyi.service.strategy.VectorStoreStrategyFactory; +import org.springframework.context.annotation.Primary; import org.springframework.stereotype.Service; import java.util.*; import java.util.stream.Collectors; @@ -41,210 +19,62 @@ import java.util.stream.Collectors; * @author ageer */ @Service +@Primary @Slf4j @RequiredArgsConstructor public class VectorStoreServiceImpl implements VectorStoreService { - private final ConfigService configService; + private final VectorStoreStrategyFactory strategyFactory; -// private EmbeddingStore embeddingStore; - private WeaviateClient client; - - private final EmbeddingModelFactory embeddingModelFactory; + /** + * 获取当前配置的向量库策略 + */ + private VectorStoreStrategy getCurrentStrategy() { + return strategyFactory.getStrategy(); + } @Override - public void createSchema(String kid, String modelName) { - String protocol = configService.getConfigValue("weaviate", "protocol"); - String host = configService.getConfigValue("weaviate", "host"); - String className = configService.getConfigValue("weaviate", "classname")+kid; - // 创建 Weaviate 客户端 - client= new WeaviateClient(new Config(protocol, host)); - // 检查类是否存在,如果不存在就创建 schema - Result schemaResult = client.schema().getter().run(); - Schema schema = schemaResult.getResult(); - boolean classExists = false; - for (WeaviateClass weaviateClass : schema.getClasses()) { - if (weaviateClass.getClassName().equals(className)) { - classExists = true; - break; - } - } - if (!classExists) { - // 类不存在,创建 schema - WeaviateClass build = WeaviateClass.builder() - .className(className) - .vectorizer("none") - .properties( - List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(), - Property.builder().name("fid").dataType(Collections.singletonList("text")).build(), - Property.builder().name("kid").dataType(Collections.singletonList("text")).build(), - Property.builder().name("docId").dataType(Collections.singletonList("text")).build()) - ) - .build(); - Result createResult = client.schema().classCreator().withClass(build).run(); - if (createResult.hasErrors()) { - log.error("Schema 创建失败: {}", createResult.getError()); - } else { - log.info("Schema 创建成功: {}", className); - } - } -// embeddingStore = WeaviateEmbeddingStore.builder() -// .scheme(protocol) -// .host(host) -// .objectClass(className) -// .scheme(protocol) -// .avoidDups(true) -// .consistencyLevel("ALL") -// .build(); + public void createSchema(String vectorModelName, String kid, String modelName) { + log.info("创建向量库schema: vectorModelName={}, kid={}, modelName={}", vectorModelName, kid, modelName); + VectorStoreStrategy strategy = getCurrentStrategy(); + strategy.createSchema(vectorModelName, kid, modelName); } @Override public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) { - createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName()); - BaseEmbedModelService model = embeddingModelFactory.createModel(storeEmbeddingBo.getEmbeddingModelId()); - List chunkList = storeEmbeddingBo.getChunkList(); - List fidList = storeEmbeddingBo.getFids(); - String kid = storeEmbeddingBo.getKid(); - String docId = storeEmbeddingBo.getDocId(); - long startTime = System.currentTimeMillis(); - for (int i = 0; i < chunkList.size(); i++) { - String text = chunkList.get(i); - String fid = fidList.get(i); - Embedding embedding = model.embed(text).content(); - Map properties = Map.of( - "text", text, - "fid",fid, - "kid", kid, - "docId", docId - ); - Float[] vector = toObjectArray(embedding.vector()); - client.data().creator() - .withClassName("LocalKnowledge" + kid) // 注意替换成实际类名 - .withProperties(properties) - .withVector(vector) - .run(); - } - long endTime = System.currentTimeMillis(); - log.info("向量存储完成消耗时间:"+ (endTime-startTime)/1000+"秒"); + log.info("存储向量数据: kid={}, docId={}, 数据条数={}", + storeEmbeddingBo.getKid(), storeEmbeddingBo.getDocId(), storeEmbeddingBo.getChunkList().size()); + VectorStoreStrategy strategy = getCurrentStrategy(); + strategy.storeEmbeddings(storeEmbeddingBo); } - private static Float[] toObjectArray(float[] primitive) { - Float[] result = new Float[primitive.length]; - for (int i = 0; i < primitive.length; i++) { - result[i] = primitive[i]; // 自动装箱 - } - return result; - } @Override public List getQueryVector(QueryVectorBo queryVectorBo) { - createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName()); - BaseEmbedModelService model = embeddingModelFactory.createModel(queryVectorBo.getEmbeddingModelId()); - Embedding queryEmbedding = model.embed(queryVectorBo.getQuery()).content(); - float[] vector = queryEmbedding.vector(); - List vectorStrings = new ArrayList<>(); - for (float v : vector) { - vectorStrings.add(String.valueOf(v)); - } - String vectorStr = String.join(",", vectorStrings); - String className = configService.getConfigValue("weaviate", "classname") ; - // 构建 GraphQL 查询 - String graphQLQuery = String.format( - "{\n" + - " Get {\n" + - " %s(nearVector: {vector: [%s]} limit: %d) {\n" + - " text\n" + - " fid\n" + - " kid\n" + - " docId\n" + - " _additional {\n" + - " distance\n" + - " id\n" + - " }\n" + - " }\n" + - " }\n" + - "}", - className+ queryVectorBo.getKid(), - vectorStr, - queryVectorBo.getMaxResults() - ); - - Result result = client.graphQL().raw().withQuery(graphQLQuery).run(); - List resultList = new ArrayList<>(); - if (result != null && !result.hasErrors()) { - Object data = result.getResult().getData(); - JSONObject entries = new JSONObject(data); - Map entriesMap = entries.get("Get", Map.class); - cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid()); - if(objects.isEmpty()){ - return resultList; - } - for (Object object : objects) { - Map map = (Map) object; - String content = map.get("text"); - resultList.add( content); - } - return resultList; - } else { - log.error("GraphQL 查询失败: {}", result.getError()); - return resultList; - } + log.info("查询向量数据: kid={}, query={}, maxResults={}", + queryVectorBo.getKid(), queryVectorBo.getQuery(), queryVectorBo.getMaxResults()); + VectorStoreStrategy strategy = getCurrentStrategy(); + return strategy.getQueryVector(queryVectorBo); } @Override - @SneakyThrows public void removeById(String id, String modelName) { - String protocol = configService.getConfigValue("weaviate", "protocol"); - String host = configService.getConfigValue("weaviate", "host"); - String className = configService.getConfigValue("weaviate", "classname"); - String finalClassName = className + id; - WeaviateClient client = new WeaviateClient(new Config(protocol, host)); - Result result = client.schema().classDeleter().withClassName(finalClassName).run(); - if (result.hasErrors()) { - log.error("失败删除向量: " + result.getError()); - throw new ServiceException("失败删除向量数据!"); - } else { - log.info("成功删除向量数据: " + result.getResult()); - } + log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName); + VectorStoreStrategy strategy = getCurrentStrategy(); + strategy.removeById(id, modelName); } @Override public void removeByDocId(String docId, String kid) { - String className = configService.getConfigValue("weaviate", "classname") + kid; - // 构建 Where 条件 - WhereFilter whereFilter = WhereFilter.builder() - .path("docId") - .operator(Operator.Equal) - .valueText(docId) - .build(); - ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter(); - Result result = deleter.withClassName(className) - .withWhere(whereFilter) - .run(); - if (result != null && !result.hasErrors()) { - log.info("成功删除 docId={} 的所有向量数据", docId); - } else { - log.error("删除失败: {}", result.getError()); - } + log.info("根据docId删除向量数据: docId={}, kid={}", docId, kid); + VectorStoreStrategy strategy = getCurrentStrategy(); + strategy.removeByDocId(docId, kid); } @Override public void removeByFid(String fid, String kid) { - String className = configService.getConfigValue("weaviate", "classname") + kid; - // 构建 Where 条件 - WhereFilter whereFilter = WhereFilter.builder() - .path("fid") - .operator(Operator.Equal) - .valueText(fid) - .build(); - ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter(); - Result result = deleter.withClassName(className) - .withWhere(whereFilter) - .run(); - if (result != null && !result.hasErrors()) { - log.info("成功删除 fid={} 的所有向量数据", fid); - } else { - log.error("删除失败: {}", result.getError()); - } + log.info("根据fid删除向量数据: fid={}, kid={}", fid, kid); + VectorStoreStrategy strategy = getCurrentStrategy(); + strategy.removeByFid(fid, kid); } } diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/AbstractVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/AbstractVectorStoreStrategy.java new file mode 100644 index 00000000..7fdeb195 --- /dev/null +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/AbstractVectorStoreStrategy.java @@ -0,0 +1,62 @@ +package org.ruoyi.service.strategy; + +import org.ruoyi.common.core.exception.ServiceException; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.ollama.OllamaEmbeddingModel; +import dev.langchain4j.model.openai.OpenAiEmbeddingModel; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.config.VectorStoreProperties; + +/** + * 向量库策略抽象基类 + * 提供公共的方法实现,如embedding模型获取等 + * + * @author ageer + */ +@Slf4j +@RequiredArgsConstructor +public abstract class AbstractVectorStoreStrategy implements VectorStoreStrategy { + + protected final VectorStoreProperties vectorStoreProperties; + + /** + * 获取向量模型 + */ + @SneakyThrows + protected EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) { + EmbeddingModel embeddingModel; + if ("quentinz/bge-large-zh-v1.5".equals(modelName)) { + embeddingModel = OllamaEmbeddingModel.builder() + .baseUrl(baseUrl) + .modelName(modelName) + .build(); + } else if ("baai/bge-m3".equals(modelName)) { + embeddingModel = OpenAiEmbeddingModel.builder() + .apiKey(apiKey) + .baseUrl(baseUrl) + .modelName(modelName) + .build(); + } else { + throw new ServiceException("未找到对应向量化模型!"); + } + return embeddingModel; + } + + /** + * 将float数组转换为Float对象数组 + */ + protected static Float[] toObjectArray(float[] primitive) { + Float[] result = new Float[primitive.length]; + for (int i = 0; i < primitive.length; i++) { + result[i] = primitive[i]; // 自动装箱 + } + return result; + } + + /** + * 获取向量库类型标识 + */ + public abstract String getVectorStoreType(); +} \ No newline at end of file diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategy.java new file mode 100644 index 00000000..bd93e6fa --- /dev/null +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategy.java @@ -0,0 +1,18 @@ +package org.ruoyi.service.strategy; + +import org.ruoyi.service.VectorStoreService; + +/** + * 向量库策略接口 + * 继承VectorStoreService以避免重复定义相同的方法 + * + * @author ageer + */ +public interface VectorStoreStrategy extends VectorStoreService { + + /** + * 获取向量库类型标识 + * @return 向量库类型(如:weaviate, milvus) + */ + String getVectorStoreType(); +} \ No newline at end of file diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategyFactory.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategyFactory.java new file mode 100644 index 00000000..fbd5b27b --- /dev/null +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/VectorStoreStrategyFactory.java @@ -0,0 +1,74 @@ +package org.ruoyi.service.strategy; + +import jakarta.annotation.PostConstruct; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.config.VectorStoreProperties; +import org.ruoyi.service.strategy.impl.MilvusVectorStoreStrategy; +import org.ruoyi.service.strategy.impl.WeaviateVectorStoreStrategy; +import org.springframework.stereotype.Component; + +import java.util.HashMap; +import java.util.Map; + +/** + * 向量库策略工厂 + * 根据配置动态选择向量库实现 + * + * @author ageer + */ +@Slf4j +@Component +@RequiredArgsConstructor +public class VectorStoreStrategyFactory { + + private final VectorStoreProperties vectorStoreProperties; + private final WeaviateVectorStoreStrategy weaviateStrategy; + private final MilvusVectorStoreStrategy milvusStrategy; + + private Map strategies; + + @PostConstruct + public void init() { + strategies = new HashMap<>(); + strategies.put("weaviate", weaviateStrategy); + strategies.put("milvus", milvusStrategy); + log.info("向量库策略工厂初始化完成,支持的策略: {}", strategies.keySet()); + } + + /** + * 获取当前配置的向量库策略 + */ + public VectorStoreStrategy getStrategy() { + String vectorStoreType = vectorStoreProperties.getType(); + if (vectorStoreType == null || vectorStoreType.trim().isEmpty()) { + vectorStoreType = "weaviate"; // 默认使用weaviate + } + + VectorStoreStrategy strategy = strategies.get(vectorStoreType.toLowerCase()); + if (strategy == null) { + log.warn("未找到向量库策略: {}, 使用默认策略: weaviate", vectorStoreType); + strategy = strategies.get("weaviate"); + } + + log.debug("使用向量库策略: {}", vectorStoreType); + return strategy; + } + + /** + * 根据类型获取向量库策略 + */ + public VectorStoreStrategy getStrategy(String type) { + if (type == null || type.trim().isEmpty()) { + return getStrategy(); + } + + VectorStoreStrategy strategy = strategies.get(type.toLowerCase()); + if (strategy == null) { + log.warn("未找到向量库策略: {}, 使用默认策略", type); + return getStrategy(); + } + + return strategy; + } +} \ No newline at end of file diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java new file mode 100644 index 00000000..26a09f1f --- /dev/null +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/MilvusVectorStoreStrategy.java @@ -0,0 +1,337 @@ +package org.ruoyi.service.strategy.impl; + +import org.ruoyi.common.core.exception.ServiceException; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.model.embedding.EmbeddingModel; +import io.milvus.client.MilvusServiceClient; +import io.milvus.common.clientenum.ConsistencyLevelEnum; +import io.milvus.grpc.*; +import io.milvus.param.*; +import io.milvus.param.collection.*; +import io.milvus.param.dml.DeleteParam; +import io.milvus.param.dml.InsertParam; +import io.milvus.param.dml.SearchParam; +import io.milvus.param.index.CreateIndexParam; +import io.milvus.param.index.DescribeIndexParam; +import io.milvus.response.DescCollResponseWrapper; +import io.milvus.response.SearchResultsWrapper; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.config.VectorStoreProperties; +import org.ruoyi.domain.bo.QueryVectorBo; +import org.ruoyi.domain.bo.StoreEmbeddingBo; +import org.ruoyi.service.strategy.AbstractVectorStoreStrategy; +import org.springframework.stereotype.Component; + +import java.util.*; + +/** + * Milvus向量库策略实现 + * + * @author ageer + */ +@Slf4j +@Component +public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { + + private MilvusServiceClient milvusClient; + + public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties) { + super(vectorStoreProperties); + } + + @Override + public String getVectorStoreType() { + return "milvus"; + } + + @Override + public void createSchema(String vectorModelName, String kid, String modelName) { + String url = vectorStoreProperties.getMilvus().getUrl(); + String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid; + + // 创建Milvus客户端连接 + ConnectParam connectParam = ConnectParam.newBuilder() + .withUri(url) + .build(); + milvusClient = new MilvusServiceClient(connectParam); + + // 检查集合是否存在 + HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder() + .withCollectionName(collectionName) + .build(); + + R hasCollectionResponse = milvusClient.hasCollection(hasCollectionParam); + if (hasCollectionResponse.getStatus() != R.Status.Success.getCode()) { + log.error("检查集合是否存在失败: {}", hasCollectionResponse.getMessage()); + return; + } + + if (!hasCollectionResponse.getData()) { + // 创建字段 + List fields = new ArrayList<>(); + + // ID字段 (主键) + fields.add(FieldType.newBuilder() + .withName("id") + .withDataType(DataType.Int64) + .withPrimaryKey(true) + .withAutoID(true) + .build()); + + // 文本字段 + fields.add(FieldType.newBuilder() + .withName("text") + .withDataType(DataType.VarChar) + .withMaxLength(65535) + .build()); + + // fid字段 + fields.add(FieldType.newBuilder() + .withName("fid") + .withDataType(DataType.VarChar) + .withMaxLength(255) + .build()); + + // kid字段 + fields.add(FieldType.newBuilder() + .withName("kid") + .withDataType(DataType.VarChar) + .withMaxLength(255) + .build()); + + // docId字段 + fields.add(FieldType.newBuilder() + .withName("docId") + .withDataType(DataType.VarChar) + .withMaxLength(255) + .build()); + + // 向量字段 + fields.add(FieldType.newBuilder() + .withName("vector") + .withDataType(DataType.FloatVector) + .withDimension(1024) // 根据实际embedding维度调整 + .build()); + + // 创建集合 + CreateCollectionParam createCollectionParam = CreateCollectionParam.newBuilder() + .withCollectionName(collectionName) + .withDescription("Knowledge base collection for " + kid) + .withShardsNum(2) + .withFieldTypes(fields) + .build(); + + R createCollectionResponse = milvusClient.createCollection(createCollectionParam); + if (createCollectionResponse.getStatus() != R.Status.Success.getCode()) { + log.error("创建集合失败: {}", createCollectionResponse.getMessage()); + return; + } + + // 创建索引 + CreateIndexParam createIndexParam = CreateIndexParam.newBuilder() + .withCollectionName(collectionName) + .withFieldName("vector") + .withIndexType(IndexType.IVF_FLAT) + .withMetricType(MetricType.L2) + .withExtraParam("{\"nlist\":1024}") + .build(); + + R createIndexResponse = milvusClient.createIndex(createIndexParam); + if (createIndexResponse.getStatus() != R.Status.Success.getCode()) { + log.error("创建索引失败: {}", createIndexResponse.getMessage()); + } else { + log.info("Milvus集合和索引创建成功: {}", collectionName); + } + } else { + log.info("Milvus集合已存在: {}", collectionName); + } + } + + @Override + public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) { + createSchema(storeEmbeddingBo.getVectorModelName(), storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName()); + + EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), + storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl()); + + List chunkList = storeEmbeddingBo.getChunkList(); + List fidList = storeEmbeddingBo.getFids(); + String kid = storeEmbeddingBo.getKid(); + String docId = storeEmbeddingBo.getDocId(); + String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid; + + log.info("Milvus向量存储条数记录: " + chunkList.size()); + long startTime = System.currentTimeMillis(); + + // 准备批量插入数据 + List fields = new ArrayList<>(); + List textList = new ArrayList<>(); + List fidListData = new ArrayList<>(); + List kidList = new ArrayList<>(); + List docIdList = new ArrayList<>(); + List> vectorList = new ArrayList<>(); + + for (int i = 0; i < chunkList.size(); i++) { + String text = chunkList.get(i); + String fid = fidList.get(i); + Embedding embedding = embeddingModel.embed(text).content(); + + textList.add(text); + fidListData.add(fid); + kidList.add(kid); + docIdList.add(docId); + + List vector = new ArrayList<>(); + for (float f : embedding.vector()) { + vector.add(f); + } + vectorList.add(vector); + } + + // 构建字段数据 + fields.add(new InsertParam.Field("text", textList)); + fields.add(new InsertParam.Field("fid", fidListData)); + fields.add(new InsertParam.Field("kid", kidList)); + fields.add(new InsertParam.Field("docId", docIdList)); + fields.add(new InsertParam.Field("vector", vectorList)); + + // 执行插入 + InsertParam insertParam = InsertParam.newBuilder() + .withCollectionName(collectionName) + .withFields(fields) + .build(); + + R insertResponse = milvusClient.insert(insertParam); + if (insertResponse.getStatus() != R.Status.Success.getCode()) { + log.error("Milvus向量存储失败: {}", insertResponse.getMessage()); + throw new ServiceException("Milvus向量存储失败"); + } else { + log.info("Milvus向量存储成功,插入条数: {}", insertResponse.getData().getInsertCnt()); + } + + long endTime = System.currentTimeMillis(); + log.info("Milvus向量存储完成消耗时间:" + (endTime - startTime) / 1000 + "秒"); + } + + @Override + public List getQueryVector(QueryVectorBo queryVectorBo) { + createSchema(queryVectorBo.getVectorModelName(), queryVectorBo.getKid(), queryVectorBo.getVectorModelName()); + + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), + queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl()); + + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid(); + + List resultList = new ArrayList<>(); + + // 加载集合到内存 + LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder() + .withCollectionName(collectionName) + .build(); + milvusClient.loadCollection(loadCollectionParam); + + // 准备查询向量 + List> searchVectors = new ArrayList<>(); + List queryVector = new ArrayList<>(); + for (float f : queryEmbedding.vector()) { + queryVector.add(f); + } + searchVectors.add(queryVector); + + // 构建搜索参数 + SearchParam searchParam = SearchParam.newBuilder() + .withCollectionName(collectionName) + .withMetricType(MetricType.L2) + .withOutFields(Arrays.asList("text", "fid", "kid", "docId")) + .withTopK(queryVectorBo.getMaxResults()) + .withVectors(searchVectors) + .withVectorFieldName("vector") + .withParams("{\"nprobe\":10}") + .build(); + + R searchResponse = milvusClient.search(searchParam); + if (searchResponse.getStatus() != R.Status.Success.getCode()) { + log.error("Milvus查询失败: {}", searchResponse.getMessage()); + return resultList; + } + + SearchResultsWrapper wrapper = new SearchResultsWrapper(searchResponse.getData().getResults()); + + // 遍历搜索结果 + for (int i = 0; i < wrapper.getIDScore(0).size(); i++) { + SearchResultsWrapper.IDScore idScore = wrapper.getIDScore(0).get(i); + + // 获取text字段数据 + List textFieldData = wrapper.getFieldData("text", 0); + if (textFieldData != null && i < textFieldData.size()) { + Object textObj = textFieldData.get(i); + if (textObj != null) { + resultList.add(textObj.toString()); + log.debug("找到相似文本,ID: {}, 距离: {}, 内容: {}", + idScore.getLongID(), idScore.getScore(), textObj.toString()); + } + } + } + + return resultList; + } + + @Override + @SneakyThrows + public void removeById(String id, String modelName) { + String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + id; + + // 删除整个集合 + DropCollectionParam dropCollectionParam = DropCollectionParam.newBuilder() + .withCollectionName(collectionName) + .build(); + + R dropResponse = milvusClient.dropCollection(dropCollectionParam); + if (dropResponse.getStatus() != R.Status.Success.getCode()) { + log.error("Milvus集合删除失败: {}", dropResponse.getMessage()); + throw new ServiceException("Milvus集合删除失败"); + } else { + log.info("Milvus集合删除成功: {}", collectionName); + } + } + + @Override + public void removeByDocId(String docId, String kid) { + String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid; + + String expr = "docId == \"" + docId + "\""; + DeleteParam deleteParam = DeleteParam.newBuilder() + .withCollectionName(collectionName) + .withExpr(expr) + .build(); + + R deleteResponse = milvusClient.delete(deleteParam); + if (deleteResponse.getStatus() != R.Status.Success.getCode()) { + log.error("Milvus删除失败: {}", deleteResponse.getMessage()); + throw new ServiceException("Milvus删除失败"); + } else { + log.info("Milvus成功删除 docId={} 的所有向量数据,删除条数: {}", docId, deleteResponse.getData().getDeleteCnt()); + } + } + + @Override + public void removeByFid(String fid, String kid) { + String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid; + + String expr = "fid == \"" + fid + "\""; + DeleteParam deleteParam = DeleteParam.newBuilder() + .withCollectionName(collectionName) + .withExpr(expr) + .build(); + + R deleteResponse = milvusClient.delete(deleteParam); + if (deleteResponse.getStatus() != R.Status.Success.getCode()) { + log.error("Milvus删除失败: {}", deleteResponse.getMessage()); + throw new ServiceException("Milvus删除失败"); + } else { + log.info("Milvus成功删除 fid={} 的所有向量数据,删除条数: {}", fid, deleteResponse.getData().getDeleteCnt()); + } + } +} \ No newline at end of file diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/WeaviateVectorStoreStrategy.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/WeaviateVectorStoreStrategy.java new file mode 100644 index 00000000..6275d939 --- /dev/null +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/strategy/impl/WeaviateVectorStoreStrategy.java @@ -0,0 +1,233 @@ +package org.ruoyi.service.strategy.impl; + +import cn.hutool.json.JSONObject; +import org.ruoyi.common.core.exception.ServiceException; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.model.embedding.EmbeddingModel; +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter; +import io.weaviate.client.v1.batch.model.BatchDeleteResponse; +import io.weaviate.client.v1.filters.Operator; +import io.weaviate.client.v1.filters.WhereFilter; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.schema.model.Property; +import io.weaviate.client.v1.schema.model.Schema; +import io.weaviate.client.v1.schema.model.WeaviateClass; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.config.VectorStoreProperties; +import org.ruoyi.domain.bo.QueryVectorBo; +import org.ruoyi.domain.bo.StoreEmbeddingBo; +import org.ruoyi.service.strategy.AbstractVectorStoreStrategy; +import org.springframework.stereotype.Component; +import java.util.*; + +/** + * Weaviate向量库策略实现 + * + * @author ageer + */ +@Slf4j +@Component +public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { + + private WeaviateClient client; + + public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties) { + super(vectorStoreProperties); + } + + @Override + public String getVectorStoreType() { + return "weaviate"; + } + + @Override + public void createSchema(String vectorModelName, String kid, String modelName) { + String protocol = vectorStoreProperties.getWeaviate().getProtocol(); + String host = vectorStoreProperties.getWeaviate().getHost(); + String className = vectorStoreProperties.getWeaviate().getClassname() + kid; + // 创建 Weaviate 客户端 + client = new WeaviateClient(new Config(protocol, host)); + // 检查类是否存在,如果不存在就创建 schema + Result schemaResult = client.schema().getter().run(); + Schema schema = schemaResult.getResult(); + boolean classExists = false; + for (WeaviateClass weaviateClass : schema.getClasses()) { + if (weaviateClass.getClassName().equals(className)) { + classExists = true; + break; + } + } + if (!classExists) { + // 类不存在,创建 schema + WeaviateClass build = WeaviateClass.builder() + .className(className) + .vectorizer("none") + .properties( + List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(), + Property.builder().name("fid").dataType(Collections.singletonList("text")).build(), + Property.builder().name("kid").dataType(Collections.singletonList("text")).build(), + Property.builder().name("docId").dataType(Collections.singletonList("text")).build()) + ) + .build(); + Result createResult = client.schema().classCreator().withClass(build).run(); + if (createResult.hasErrors()) { + log.error("Schema 创建失败: {}", createResult.getError()); + } else { + log.info("Schema 创建成功: {}", className); + } + } + } + + @Override + public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) { + createSchema(storeEmbeddingBo.getVectorModelName(), storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), + storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl()); + List chunkList = storeEmbeddingBo.getChunkList(); + List fidList = storeEmbeddingBo.getFids(); + String kid = storeEmbeddingBo.getKid(); + String docId = storeEmbeddingBo.getDocId(); + log.info("向量存储条数记录: " + chunkList.size()); + long startTime = System.currentTimeMillis(); + for (int i = 0; i < chunkList.size(); i++) { + String text = chunkList.get(i); + String fid = fidList.get(i); + Embedding embedding = embeddingModel.embed(text).content(); + Map properties = Map.of( + "text", text, + "fid", fid, + "kid", kid, + "docId", docId + ); + Float[] vector = toObjectArray(embedding.vector()); + client.data().creator() + .withClassName("LocalKnowledge" + kid) + .withProperties(properties) + .withVector(vector) + .run(); + } + long endTime = System.currentTimeMillis(); + log.info("向量存储完成消耗时间:" + (endTime - startTime) / 1000 + "秒"); + } + + + + @Override + public List getQueryVector(QueryVectorBo queryVectorBo) { + createSchema(queryVectorBo.getVectorModelName(), queryVectorBo.getKid(), queryVectorBo.getVectorModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), + queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl()); + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + float[] vector = queryEmbedding.vector(); + List vectorStrings = new ArrayList<>(); + for (float v : vector) { + vectorStrings.add(String.valueOf(v)); + } + String vectorStr = String.join(",", vectorStrings); + String className = vectorStoreProperties.getWeaviate().getClassname(); + + // 构建 GraphQL 查询 + String graphQLQuery = String.format( + "{\n" + + " Get {\n" + + " %s(nearVector: {vector: [%s]} limit: %d) {\n" + + " text\n" + + " fid\n" + + " kid\n" + + " docId\n" + + " _additional {\n" + + " distance\n" + + " id\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + className + queryVectorBo.getKid(), + vectorStr, + queryVectorBo.getMaxResults() + ); + + Result result = client.graphQL().raw().withQuery(graphQLQuery).run(); + List resultList = new ArrayList<>(); + if (result != null && !result.hasErrors()) { + Object data = result.getResult().getData(); + JSONObject entries = new JSONObject(data); + Map entriesMap = entries.get("Get", Map.class); + cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid()); + if (objects.isEmpty()) { + return resultList; + } + for (Object object : objects) { + Map map = (Map) object; + String content = map.get("text"); + resultList.add(content); + } + return resultList; + } else { + log.error("GraphQL 查询失败: {}", result.getError()); + return resultList; + } + } + + @Override + @SneakyThrows + public void removeById(String id, String modelName) { + String protocol = vectorStoreProperties.getWeaviate().getProtocol(); + String host = vectorStoreProperties.getWeaviate().getHost(); + String className = vectorStoreProperties.getWeaviate().getClassname(); + String finalClassName = className + id; + WeaviateClient client = new WeaviateClient(new Config(protocol, host)); + Result result = client.schema().classDeleter().withClassName(finalClassName).run(); + if (result.hasErrors()) { + log.error("失败删除向量: " + result.getError()); + throw new ServiceException("失败删除向量数据!"); + } else { + log.info("成功删除向量数据: " + result.getResult()); + } + } + + @Override + public void removeByDocId(String docId, String kid) { + String className = vectorStoreProperties.getWeaviate().getClassname() + kid; + // 构建 Where 条件 + WhereFilter whereFilter = WhereFilter.builder() + .path("docId") + .operator(Operator.Equal) + .valueText(docId) + .build(); + ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter(); + Result result = deleter.withClassName(className) + .withWhere(whereFilter) + .run(); + if (result != null && !result.hasErrors()) { + log.info("成功删除 docId={} 的所有向量数据", docId); + } else { + log.error("删除失败: {}", result.getError()); + } + } + + @Override + public void removeByFid(String fid, String kid) { + String className = vectorStoreProperties.getWeaviate().getClassname() + kid; + // 构建 Where 条件 + WhereFilter whereFilter = WhereFilter.builder() + .path("fid") + .operator(Operator.Equal) + .valueText(fid) + .build(); + ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter(); + Result result = deleter.withClassName(className) + .withWhere(whereFilter) + .run(); + if (result != null && !result.hasErrors()) { + log.info("成功删除 fid={} 的所有向量数据", fid); + } else { + log.error("删除失败: {}", result.getError()); + } + } + +} \ No newline at end of file diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/controller/chat/ChatMessageController.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/controller/chat/ChatMessageController.java index 10ddd3c6..40a5971e 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/controller/chat/ChatMessageController.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/controller/chat/ChatMessageController.java @@ -45,6 +45,18 @@ public class ChatMessageController extends BaseController { return chatMessageService.queryPageList(bo, pageQuery); } + /** + * 根据会话ID查询聊天消息列表 + */ + @GetMapping("/listBySession/{sessionId}") + public TableDataInfo listBySession(@NotNull(message = "会话ID不能为空") + @PathVariable Long sessionId, + PageQuery pageQuery) { + ChatMessageBo bo = new ChatMessageBo(); + bo.setSessionId(sessionId); + return chatMessageService.queryPageList(bo, pageQuery); + } + /** * 导出聊天消息列表 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java index 0250177a..e9a38259 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java @@ -100,24 +100,7 @@ public class SseServiceImpl implements ISseService { // 设置用户id chatRequest.setUserId(LoginHelper.getUserId()); - - //待优化的地方 (这里请前端提交send的时候传递uuid进来或者sessionId) - //待优化的地方 (这里请前端提交send的时候传递uuid进来或者sessionId) - //待优化的地方 (这里请前端提交send的时候传递uuid进来或者sessionId) - { - // 设置会话id - if (chatRequest.getUuid() == null) { - //暂时随机生成会话id - chatRequest.setSessionId(System.currentTimeMillis()); - } else { - //这里或许需要修改一下,这里应该用uuid 或者 前端传递 sessionId - chatRequest.setSessionId(chatRequest.getUuid()); - } - } - - - - chatRequest.setUserId(chatCostService.getUserId()); + // 设置会话id if (chatRequest.getSessionId() == null) { ChatSessionBo chatSessionBo = new ChatSessionBo(); chatSessionBo.setUserId(chatCostService.getUserId()); diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java index 3b88de9b..9346323b 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java @@ -216,7 +216,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { } baseMapper.insert(knowledgeInfo); if (knowledgeInfo != null) { - vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()), + vectorStoreService.createSchema(knowledgeInfo.getVectorModelName(),String.valueOf(knowledgeInfo.getId()), bo.getVectorModelName()); } } else { @@ -258,6 +258,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { knowledgeAttach.setDocType(fileName.substring(fileName.lastIndexOf(".") + 1)); String content = ""; ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(knowledgeAttach.getDocType()); + // 文档分段入库 List fids = new ArrayList<>(); try { content = resourceLoader.getContent(file.getInputStream()); @@ -265,6 +266,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { List knowledgeFragmentList = new ArrayList<>(); if (CollUtil.isNotEmpty(chunkList)) { for (int i = 0; i < chunkList.size(); i++) { + // 生成知识片段ID String fid = RandomUtil.randomString(10); fids.add(fid); KnowledgeFragment knowledgeFragment = new KnowledgeFragment();