diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java index 8d232728..29bfecc3 100644 --- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java @@ -19,5 +19,7 @@ public interface VectorStoreService { void removeById(String id,String modelName); + void removeByDocId(String docId, String kid); + void removeByFid(String fid, String kid); } diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java index db17a580..85534b2c 100644 --- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java +++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java @@ -1,18 +1,25 @@ package org.ruoyi.service.impl; +import cn.hutool.json.JSONObject; import com.google.protobuf.ServiceException; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.ollama.OllamaEmbeddingModel; import dev.langchain4j.model.openai.OpenAiEmbeddingModel; -import dev.langchain4j.store.embedding.EmbeddingMatch; -import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore; import io.weaviate.client.Config; import io.weaviate.client.WeaviateClient; import io.weaviate.client.base.Result; +import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter; +import io.weaviate.client.v1.batch.model.BatchDeleteResponse; +import io.weaviate.client.v1.filters.Operator; +import io.weaviate.client.v1.filters.WhereFilter; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.schema.model.Property; +import io.weaviate.client.v1.schema.model.Schema; +import io.weaviate.client.v1.schema.model.WeaviateClass; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; @@ -21,9 +28,7 @@ import org.ruoyi.domain.bo.QueryVectorBo; import org.ruoyi.domain.bo.StoreEmbeddingBo; import org.ruoyi.service.VectorStoreService; import org.springframework.stereotype.Service; - -import java.util.ArrayList; -import java.util.List; +import java.util.*; /** * 向量库管理 @@ -38,17 +43,49 @@ public class VectorStoreServiceImpl implements VectorStoreService { private final ConfigService configService; private EmbeddingStore embeddingStore; + private WeaviateClient client; @Override public void createSchema(String kid, String modelName) { String protocol = configService.getConfigValue("weaviate", "protocol"); String host = configService.getConfigValue("weaviate", "host"); - String className = configService.getConfigValue("weaviate", "classname"); + String className = configService.getConfigValue("weaviate", "classname")+kid; + // 创建 Weaviate 客户端 + client= new WeaviateClient(new Config(protocol, host)); + // 检查类是否存在,如果不存在就创建 schema + Result schemaResult = client.schema().getter().run(); + Schema schema = schemaResult.getResult(); + boolean classExists = false; + for (WeaviateClass weaviateClass : schema.getClasses()) { + if (weaviateClass.getClassName().equals(className)) { + classExists = true; + break; + } + } + if (!classExists) { + // 类不存在,创建 schema + WeaviateClass build = WeaviateClass.builder() + .className(className) + .vectorizer("none") + .properties( + List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(), + Property.builder().name("fid").dataType(Collections.singletonList("text")).build(), + Property.builder().name("kid").dataType(Collections.singletonList("text")).build(), + Property.builder().name("docId").dataType(Collections.singletonList("text")).build()) + ) + .build(); + Result createResult = client.schema().classCreator().withClass(build).run(); + if (createResult.hasErrors()) { + log.error("Schema 创建失败: {}", createResult.getError()); + } else { + log.info("Schema 创建成功: {}", className); + } + } embeddingStore = WeaviateEmbeddingStore.builder() .scheme(protocol) .host(host) - .objectClass(className+kid) + .objectClass(className) .scheme(protocol) .avoidDups(true) .consistencyLevel("ALL") @@ -61,33 +98,98 @@ public class VectorStoreServiceImpl implements VectorStoreService { EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(), storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl()); List chunkList = storeEmbeddingBo.getChunkList(); - for (String s : chunkList) { - Embedding embedding = embeddingModel.embed(s).content(); - TextSegment segment = TextSegment.from(s); - embeddingStore.add(embedding, segment); + List fidList = storeEmbeddingBo.getFids(); + String kid = storeEmbeddingBo.getKid(); + String docId = storeEmbeddingBo.getDocId(); + log.info("向量存储条数记录: " + chunkList.size()); + long startTime = System.currentTimeMillis(); + for (int i = 0; i < chunkList.size(); i++) { + String text = chunkList.get(i); + String fid = fidList.get(i); + Embedding embedding = embeddingModel.embed(text).content(); + Map properties = Map.of( + "text", text, + "fid",fid, + "kid", kid, + "docId", docId + ); + Float[] vector = toObjectArray(embedding.vector()); + client.data().creator() + .withClassName("LocalKnowledge" + kid) // 注意替换成实际类名 + .withProperties(properties) + .withVector(vector) + .run(); } + long endTime = System.currentTimeMillis(); + log.info("向量存储完成消耗时间:"+ (endTime-startTime)/1000+"秒"); } + private static Float[] toObjectArray(float[] primitive) { + Float[] result = new Float[primitive.length]; + for (int i = 0; i < primitive.length; i++) { + result[i] = primitive[i]; // 自动装箱 + } + return result; + } @Override public List getQueryVector(QueryVectorBo queryVectorBo) { createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName()); EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(), queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl()); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); - EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder() - .queryEmbedding(queryEmbedding) - .maxResults(queryVectorBo.getMaxResults()) - .build(); - List> matches = embeddingStore.search(embeddingSearchRequest).matches(); - List results = new ArrayList<>(); - matches.forEach(embeddingMatch -> results.add(embeddingMatch.embedded().text())); - return results; - } + float[] vector = queryEmbedding.vector(); + List vectorStrings = new ArrayList<>(); + for (float v : vector) { + vectorStrings.add(String.valueOf(v)); + } + String vectorStr = String.join(",", vectorStrings); + String className = configService.getConfigValue("weaviate", "classname") ; + // 构建 GraphQL 查询 + String graphQLQuery = String.format( + "{\n" + + " Get {\n" + + " %s(nearVector: {vector: [%s], certainty: %f} limit: %d) {\n" + + " text\n" + + " fid\n" + + " kid\n" + + " docId\n" + + " _additional {\n" + + " distance\n" + + " id\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + className+ queryVectorBo.getKid(), + vectorStr, + queryVectorBo.getMaxResults() + ); + Result result = client.graphQL().raw().withQuery(graphQLQuery).run(); + List resultList = new ArrayList<>(); + if (result != null && !result.hasErrors()) { + Object data = result.getResult().getData(); + JSONObject entries = new JSONObject(data); + Map entriesMap = entries.get("Get", Map.class); + cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid()); + if(objects.isEmpty()){ + return resultList; + } + for (Object object : objects) { + Map map = (Map) object; + String content = map.get("text"); + resultList.add( content); + } + return resultList; + } else { + log.error("GraphQL 查询失败: {}", result.getError()); + return resultList; + } + } @Override @SneakyThrows - public void removeById(String id, String modelName) { + public void removeById(String id, String modelName) { String protocol = configService.getConfigValue("weaviate", "protocol"); String host = configService.getConfigValue("weaviate", "host"); String className = configService.getConfigValue("weaviate", "classname"); @@ -102,6 +204,46 @@ public class VectorStoreServiceImpl implements VectorStoreService { } } + @Override + public void removeByDocId(String docId, String kid) { + String className = configService.getConfigValue("weaviate", "classname") + kid; + // 构建 Where 条件 + WhereFilter whereFilter = WhereFilter.builder() + .path("docId") + .operator(Operator.Equal) + .valueText(docId) + .build(); + ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter(); + Result result = deleter.withClassName(className) + .withWhere(whereFilter) + .run(); + if (result != null && !result.hasErrors()) { + log.info("成功删除 docId={} 的所有向量数据", docId); + } else { + log.error("删除失败: {}", result.getError()); + } + } + + @Override + public void removeByFid(String fid, String kid) { + String className = configService.getConfigValue("weaviate", "classname") + kid; + // 构建 Where 条件 + WhereFilter whereFilter = WhereFilter.builder() + .path("fid") + .operator(Operator.Equal) + .valueText(fid) + .build(); + ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter(); + Result result = deleter.withClassName(className) + .withWhere(whereFilter) + .run(); + if (result != null && !result.hasErrors()) { + log.info("成功删除 fid={} 的所有向量数据", fid); + } else { + log.error("删除失败: {}", result.getError()); + } + } + /** * 获取向量模型 */