diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml b/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml
index 27e20736..cb35d342 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml
@@ -16,8 +16,21 @@
17
17
UTF-8
+ 1.0.0-beta4
+
+
+
+ dev.langchain4j
+ langchain4j-bom
+ ${langchain4j.version}
+ pom
+ import
+
+
+
+
@@ -47,6 +60,35 @@
4.0.0
+
+
+ dev.langchain4j
+ langchain4j
+
+
+
+
+ dev.langchain4j
+ langchain4j-weaviate
+
+
+
+ dev.langchain4j
+ langchain4j-embeddings-all-minilm-l6-v2
+
+
+
+
+ org.testcontainers
+ weaviate
+ 1.19.6
+
+
+
+ dev.langchain4j
+ langchain4j-open-ai-spring-boot-starter
+
+
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/EmbeddingService.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/EmbeddingService.java
deleted file mode 100644
index 98841189..00000000
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/EmbeddingService.java
+++ /dev/null
@@ -1,20 +0,0 @@
-package org.ruoyi.service;
-
-import java.util.List;
-
-public interface EmbeddingService {
-
- void storeEmbeddings(List chunkList, String kid, String docId,List fidList);
-
- void removeByDocId(String kid,String docId);
-
- void removeByKid(String kid);
-
- List getQueryVector(String query, String kid);
-
- void createSchema(String kid);
-
- void removeByKidAndFid(String kid, String fid);
-
- void saveFragment(String kid, String docId, String fid, String content);
-}
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java
index d3294bb3..dbc1a9a3 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java
@@ -2,22 +2,18 @@ package org.ruoyi.service;
import java.util.List;
-/**
- * 向量存储
- */
public interface VectorStoreService {
- void storeEmbeddings(List chunkList, List> vectorList, String kid, String docId, List fidList);
+ void storeEmbeddings(List chunkList, String kid);
- void removeByDocId(String kid, String docId);
+ void removeByDocId(String kid,String docId);
void removeByKid(String kid);
- List nearest(List queryVector, String kid);
+ List getQueryVector(String query, String kid);
- List nearest(String query, String kid);
-
- void newSchema(String kid);
+ void createSchema(String kid);
void removeByKidAndFid(String kid, String fid);
+
}
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/EmbeddingServiceImpl.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/EmbeddingServiceImpl.java
deleted file mode 100644
index 00657399..00000000
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/EmbeddingServiceImpl.java
+++ /dev/null
@@ -1,64 +0,0 @@
-package org.ruoyi.service.impl;
-
-import lombok.AllArgsConstructor;
-import org.ruoyi.service.EmbeddingService;
-import org.ruoyi.service.VectorStoreService;
-import org.ruoyi.service.VectorizationService;
-import org.springframework.stereotype.Service;
-
-import java.util.ArrayList;
-import java.util.List;
-
-@Service
-@AllArgsConstructor
-public class EmbeddingServiceImpl implements EmbeddingService {
-
- private final VectorStoreService vectorStore;
- private final VectorizationService vectorization;
-
- /**
- * 保存向量数据库
- * @param chunkList 文档按行切分的片段
- * @param kid 知识库ID
- * @param docId 文档ID
- */
- @Override
- public void storeEmbeddings(List chunkList, String kid, String docId,List fidList) {
- List> vectorList = vectorization.batchVectorization(chunkList, kid);
- vectorStore.storeEmbeddings(chunkList,vectorList,kid,docId,fidList);
- }
-
- @Override
- public void removeByDocId(String kid,String docId) {
- vectorStore.removeByDocId(kid,docId);
- }
-
- @Override
- public void removeByKid(String kid) {
- vectorStore.removeByKid(kid);
- }
-
- @Override
- public List getQueryVector(String query, String kid) {
- return vectorization.singleVectorization(query,kid);
- }
-
- @Override
- public void createSchema(String kid) {
- vectorStore.newSchema(kid);
- }
-
- @Override
- public void removeByKidAndFid(String kid, String fid) {
- vectorStore.removeByKidAndFid(kid,fid);
- }
-
- @Override
- public void saveFragment(String kid, String docId, String fid, String content) {
- List chunkList = new ArrayList<>();
- List fidList = new ArrayList<>();
- chunkList.add(content);
- fidList.add(fid);
- storeEmbeddings(chunkList,kid,docId,fidList);
- }
-}
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/WeaviateVectorStoreImpl.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/WeaviateVectorStoreImpl.java
index 994bc727..ca3d6e76 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/WeaviateVectorStoreImpl.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/WeaviateVectorStoreImpl.java
@@ -1,37 +1,25 @@
package org.ruoyi.service.impl;
-import cn.hutool.core.lang.UUID;
-import cn.hutool.json.JSONObject;
-import com.google.gson.internal.LinkedTreeMap;
-import io.weaviate.client.Config;
-import io.weaviate.client.WeaviateClient;
-import io.weaviate.client.base.Result;
-import io.weaviate.client.v1.data.model.WeaviateObject;
-import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
-import io.weaviate.client.v1.filters.Operator;
-import io.weaviate.client.v1.filters.WhereFilter;
-import io.weaviate.client.v1.graphql.model.GraphQLResponse;
-import io.weaviate.client.v1.graphql.query.argument.NearTextArgument;
-import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
-import io.weaviate.client.v1.graphql.query.fields.Field;
-import io.weaviate.client.v1.misc.model.Meta;
-import io.weaviate.client.v1.misc.model.ReplicationConfig;
-import io.weaviate.client.v1.misc.model.ShardingConfig;
-import io.weaviate.client.v1.misc.model.VectorIndexConfig;
-import io.weaviate.client.v1.schema.model.DataType;
-import io.weaviate.client.v1.schema.model.Property;
-import io.weaviate.client.v1.schema.model.Schema;
-import io.weaviate.client.v1.schema.model.WeaviateClass;
+import cn.hutool.core.util.RandomUtil;
+import dev.langchain4j.data.embedding.Embedding;
+import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
+import dev.langchain4j.store.embedding.EmbeddingMatch;
+import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
+import dev.langchain4j.store.embedding.EmbeddingStore;
+import dev.langchain4j.store.embedding.filter.Filter;
+import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
+import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
-import org.apache.commons.lang3.StringUtils;
import org.ruoyi.common.core.service.ConfigService;
-import org.ruoyi.domain.vo.KnowledgeInfoVo;
-import org.ruoyi.service.IKnowledgeInfoService;
import org.ruoyi.service.VectorStoreService;
+import org.ruoyi.service.IKnowledgeInfoService;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
+import org.testcontainers.weaviate.WeaviateContainer;
import java.util.ArrayList;
import java.util.HashMap;
@@ -54,6 +42,8 @@ public class WeaviateVectorStoreImpl implements VectorStoreService {
@Resource
private ConfigService configService;
+ private EmbeddingStore embeddingStore;
+
@PostConstruct
public void loadConfig() {
this.protocol = configService.getConfigValue("weaviate", "protocol");
@@ -61,342 +51,94 @@ public class WeaviateVectorStoreImpl implements VectorStoreService {
this.className = configService.getConfigValue("weaviate", "classname");
}
- public WeaviateClient getClient() {
- Config config = new Config(protocol, host);
- WeaviateClient client = new WeaviateClient(config);
- return client;
- }
- public Result getMeta() {
- WeaviateClient client = getClient();
- Result meta = client.misc().metaGetter().run();
- if (meta.getError() == null) {
- System.out.printf("meta.hostname: %s\n", meta.getResult().getHostname());
- System.out.printf("meta.version: %s\n", meta.getResult().getVersion());
- System.out.printf("meta.modules: %s\n", meta.getResult().getModules());
- } else {
- System.out.printf("Error: %s\n", meta.getError().getMessages());
- }
- return meta;
- }
-
- public Result getSchemas() {
- WeaviateClient client = getClient();
- Result result = client.schema().getter().run();
- if (result.hasErrors()) {
- System.out.println(result.getError());
- } else {
- System.out.println(result.getResult());
- }
- return result;
- }
-
-
- public Result createSchema(String kid) {
- WeaviateClient client = getClient();
-
- VectorIndexConfig vectorIndexConfig = VectorIndexConfig.builder()
- .distance("cosine")
- .cleanupIntervalSeconds(300)
- .efConstruction(128)
- .maxConnections(64)
- .vectorCacheMaxObjects(500000L)
- .ef(-1)
- .skip(false)
- .dynamicEfFactor(8)
- .dynamicEfMax(500)
- .dynamicEfMin(100)
- .flatSearchCutoff(40000)
+ @Override
+ public List getQueryVector(String query, String kid) {
+ EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
+ .apiKey(System.getenv("OPENAI_API_KEY"))
+ .baseUrl(System.getenv("OPENAI_BASE_URL"))
+ .modelName("text-embedding-3-small")
.build();
- ShardingConfig shardingConfig = ShardingConfig.builder()
- .desiredCount(3)
- .desiredVirtualCount(128)
- .function("murmur3")
- .key("_id")
- .strategy("hash")
- .virtualPerPhysical(128)
+ Filter simpleFilter = new IsEqualTo("kid", kid);
+
+ Embedding queryEmbedding = embeddingModel.embed("What is your favourite sport?").content();
+ EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
+ .queryEmbedding(queryEmbedding)
+ .maxResults(3)
+ // 添加过滤条件
+ .filter(simpleFilter)
.build();
+ List> matches = embeddingStore.search(embeddingSearchRequest).matches();
- ReplicationConfig replicationConfig = ReplicationConfig.builder()
- .factor(1)
- .build();
+ List results = new ArrayList<>();
- JSONObject classModuleConfigValue = new JSONObject();
- classModuleConfigValue.put("vectorizeClassName", false);
- JSONObject classModuleConfig = new JSONObject();
- classModuleConfig.put("text2vec-transformers", classModuleConfigValue);
+ matches.forEach(embeddingMatch -> {
+ results.add(embeddingMatch.embedded().text());
+ });
- JSONObject propertyModuleConfigValueSkipTrue = new JSONObject();
- propertyModuleConfigValueSkipTrue.put("vectorizePropertyName", false);
- propertyModuleConfigValueSkipTrue.put("skip", true);
- JSONObject propertyModuleConfigSkipTrue = new JSONObject();
- propertyModuleConfigSkipTrue.put("text2vec-transformers", propertyModuleConfigValueSkipTrue);
-
- JSONObject propertyModuleConfigValueSkipFalse = new JSONObject();
- propertyModuleConfigValueSkipFalse.put("vectorizePropertyName", false);
- propertyModuleConfigValueSkipFalse.put("skip", false);
- JSONObject propertyModuleConfigSkipFalse = new JSONObject();
- propertyModuleConfigSkipFalse.put("text2vec-transformers", propertyModuleConfigValueSkipFalse);
-
- WeaviateClass clazz = WeaviateClass.builder()
- .className(className + kid)
- .description("local knowledge")
- .vectorIndexType("hnsw")
- .vectorizer("text2vec-transformers")
- .shardingConfig(shardingConfig)
- .vectorIndexConfig(vectorIndexConfig)
- .replicationConfig(replicationConfig)
- .moduleConfig(classModuleConfig)
- .properties(new ArrayList() {
- {
- add(Property.builder()
- .dataType(new ArrayList() {
- {
- add(DataType.TEXT);
- }
- })
- .name("content")
- .description("The content of the local knowledge,for search")
- .moduleConfig(propertyModuleConfigSkipFalse)
- .build());
- add(Property.builder()
- .dataType(new ArrayList() {
- {
- add(DataType.TEXT);
- }
- })
- .name("kid")
- .description("The knowledge id of the local knowledge,for search")
- .moduleConfig(propertyModuleConfigSkipTrue)
- .build());
- add(Property.builder()
- .dataType(new ArrayList() {
- {
- add(DataType.TEXT);
- }
- })
- .name("docId")
- .description("The doc id of the local knowledge,for search")
- .moduleConfig(propertyModuleConfigSkipTrue)
- .build());
- add(Property.builder()
- .dataType(new ArrayList() {
- {
- add(DataType.TEXT);
- }
- })
- .name("fid")
- .description("The fragment id of the local knowledge,for search")
- .moduleConfig(propertyModuleConfigSkipTrue)
- .build());
- add(Property.builder()
- .dataType(new ArrayList() {
- {
- add(DataType.TEXT);
- }
- })
- .name("uuid")
- .description("The uuid id of the local knowledge fragment(same with id properties),for search")
- .moduleConfig(propertyModuleConfigSkipTrue)
- .build());
- } })
- .build();
-
- Result result = client.schema().classCreator().withClass(clazz).run();
- if (result.hasErrors()) {
- System.out.println(result.getError());
- }
- System.out.println(result.getResult());
- return result;
+ return results;
}
@Override
- public void newSchema(String kid) {
- createSchema(kid);
- }
-
- @Override
- public void removeByKidAndFid(String kid, String fid) {
- List resultList = new ArrayList<>();
- WeaviateClient client = getClient();
- Field fieldId = Field.builder().name("uuid").build();
- WhereFilter where = WhereFilter.builder()
- .path(new String[]{"fid"})
- .operator(Operator.Equal)
- .valueString(fid)
+ public void createSchema(String kid) {
+ WeaviateContainer weaviate = new WeaviateContainer(protocol);
+ weaviate.start();
+ this.embeddingStore = WeaviateEmbeddingStore.builder()
+ .scheme("http")
+ .host(host)
+ .objectClass(className+kid)
+ .scheme(protocol)
+ .avoidDups(true)
+ .consistencyLevel("ALL")
.build();
- Result result = client.graphQL().get()
- .withClassName(className + kid)
- .withFields(fieldId)
- .withWhere(where)
- .run();
- LinkedTreeMap t = (LinkedTreeMap) result.getResult().getData();
- LinkedTreeMap> l = (LinkedTreeMap>) t.get("Get");
- ArrayList m = l.get(className + kid);
- for (LinkedTreeMap linkedTreeMap : m) {
- String uuid = linkedTreeMap.get("uuid").toString();
- resultList.add(uuid);
- }
- for (String uuid : resultList) {
- Result deleteResult = client.data().deleter()
- .withID(uuid)
- .withClassName(className + kid)
- .withConsistencyLevel(ConsistencyLevel.ALL) // default QUORUM
- .run();
- }
}
@Override
- public void storeEmbeddings(List chunkList, List> vectorList, String kid, String docId, List fidList) {
- WeaviateClient client = getClient();
-
- for (int i = 0; i < Math.min(chunkList.size(), vectorList.size()); i++) {
- List vector = vectorList.get(i);
- Float[] vf = vector.stream().map(Double::floatValue).toArray(Float[]::new);
-
+ public void storeEmbeddings(List chunkList,String kid) {
+ EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
+ .apiKey(System.getenv("OPENAI_API_KEY"))
+ .baseUrl(System.getenv("OPENAI_BASE_URL"))
+ .modelName("text-embedding-3-small")
+ .build();
+ // 生成文档id
+ String docId = RandomUtil.randomString(10);
+ chunkList.forEach(chunk -> {
+ // 生成知识块id
+ String fid = RandomUtil.randomString(10);
Map dataSchema = new HashMap<>();
- dataSchema.put("content", chunkList.get(i));
dataSchema.put("kid", kid);
dataSchema.put("docId", docId);
- dataSchema.put("fid", fidList.get(i));
- String uuid = UUID.randomUUID().toString();
- dataSchema.put("uuid", uuid);
+ dataSchema.put("fid", fid);
+ TextSegment segment = TextSegment.from(chunk);
+ segment.metadata().putAll(dataSchema);
+ Embedding content = embeddingModel.embed(segment).content();
+ embeddingStore.add(content);
+ });
+ }
- Result result = client.data().creator()
- .withClassName(className + kid)
- .withID(uuid)
- .withVector(vf)
- .withProperties(dataSchema)
- .run();
- }
+ @Override
+ public void removeByKid(String kid) {
+ // 根据条件删除向量数据
+ Filter simpleFilter = new IsEqualTo("kid", kid);
+ embeddingStore.removeAll(simpleFilter);
}
@Override
public void removeByDocId(String kid, String docId) {
- List resultList = new ArrayList<>();
- WeaviateClient client = getClient();
- Field fieldId = Field.builder().name("uuid").build();
- WhereFilter where = WhereFilter.builder()
- .path(new String[]{"docId"})
- .operator(Operator.Equal)
- .valueString(docId)
- .build();
- Result result = client.graphQL().get()
- .withClassName(className + kid)
- .withFields(fieldId)
- .withWhere(where)
- .run();
- LinkedTreeMap t = (LinkedTreeMap) result.getResult().getData();
- LinkedTreeMap> l = (LinkedTreeMap>) t.get("Get");
- ArrayList m = l.get(className + kid);
- for (LinkedTreeMap linkedTreeMap : m) {
- String uuid = linkedTreeMap.get("uuid").toString();
- resultList.add(uuid);
- }
- for (String uuid : resultList) {
- Result deleteResult = client.data().deleter()
- .withID(uuid)
- .withClassName(className + kid)
- .withConsistencyLevel(ConsistencyLevel.ALL) // default QUORUM
- .run();
- }
+ // 根据条件删除向量数据
+ Filter simpleFilterByDocId = new IsEqualTo("docId", docId);
+ embeddingStore.removeAll(simpleFilterByDocId);
}
@Override
- public void removeByKid(String kid) {
- WeaviateClient client = getClient();
- Result result = client.schema().classDeleter().withClassName(className + kid).run();
- if (result.hasErrors()) {
- System.out.println("删除schema失败" + result.getError());
- } else {
- System.out.println("删除schema成功" + result.getResult());
- }
- log.info("drop schema by kid, result = {}", result);
+ public void removeByKidAndFid(String kid, String fid) {
+ // 根据条件删除向量数据
+ Filter simpleFilterByKid = new IsEqualTo("kid", kid);
+ Filter simpleFilterFid = new IsEqualTo("fid", fid);
+ Filter simpleFilterByAnd = Filter.and(simpleFilterFid, simpleFilterByKid);
+ embeddingStore.removeAll(simpleFilterByAnd);
}
- @Override
- public List nearest(List queryVector, String kid) {
- if (StringUtils.isBlank(kid)) {
- return new ArrayList();
- }
- List resultList = new ArrayList<>();
- Float[] vf = new Float[queryVector.size()];
- for (int j = 0; j < queryVector.size(); j++) {
- Double value = queryVector.get(j);
- vf[j] = value.floatValue();
- }
- WeaviateClient client = getClient();
- Field contentField = Field.builder().name("content").build();
- Field _additional = Field.builder()
- .name("_additional")
- .fields(new Field[]{
- Field.builder().name("distance").build()
- }).build();
- NearVectorArgument nearVector = NearVectorArgument.builder()
- .vector(vf)
- .distance(1.6f) // certainty = 1f - distance /2f
- .build();
- KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
- Result result = client.graphQL().get()
- .withClassName(className + kid)
- .withFields(contentField, _additional)
- .withNearVector(nearVector)
- .withLimit(knowledgeInfoVo.getRetrieveLimit())
- .run();
- LinkedTreeMap t = (LinkedTreeMap) result.getResult().getData();
- LinkedTreeMap> l = (LinkedTreeMap>) t.get("Get");
- ArrayList m = l.get(className + kid);
- for (LinkedTreeMap linkedTreeMap : m) {
- String content = linkedTreeMap.get("content").toString();
- resultList.add(content);
- }
- return resultList;
- }
-
- @Override
- public List nearest(String query, String kid) {
- if (StringUtils.isBlank(kid)) {
- return new ArrayList();
- }
- List resultList = new ArrayList<>();
- WeaviateClient client = getClient();
- Field contentField = Field.builder().name("content").build();
- Field _additional = Field.builder()
- .name("_additional")
- .fields(new Field[]{
- Field.builder().name("distance").build()
- }).build();
- NearTextArgument nearText = client.graphQL().arguments().nearTextArgBuilder()
- .concepts(new String[]{query})
- .distance(1.6f) // certainty = 1f - distance /2f
- .build();
- KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
- Result result = client.graphQL().get()
- .withClassName(className + kid)
- .withFields(contentField, _additional)
- .withNearText(nearText)
- .withLimit(knowledgeInfoVo.getRetrieveLimit())
- .run();
- LinkedTreeMap t = (LinkedTreeMap) result.getResult().getData();
- LinkedTreeMap> l = (LinkedTreeMap>) t.get("Get");
- ArrayList m = l.get(className + kid);
- for (LinkedTreeMap linkedTreeMap : m) {
- String content = linkedTreeMap.get("content").toString();
- resultList.add(content);
- }
- return resultList;
- }
-
- public Result deleteSchema(String kid) {
- WeaviateClient client = getClient();
- Result result = client.schema().classDeleter().withClassName(className + kid).run();
- if (result.hasErrors()) {
- System.out.println(result.getError());
- } else {
- System.out.println(result.getResult());
- }
- return result;
- }
}
diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java
index ccf9f3a2..8dabcc27 100644
--- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java
+++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java
@@ -24,13 +24,11 @@ import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.core.utils.file.FileUtils;
import org.ruoyi.common.core.utils.file.MimeTypeUtils;
import org.ruoyi.common.redis.utils.RedisUtils;
-import org.ruoyi.domain.ChatSession;
import org.ruoyi.domain.bo.ChatSessionBo;
import org.ruoyi.domain.vo.ChatModelVo;
-import org.ruoyi.service.EmbeddingService;
+import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.IChatModelService;
import org.ruoyi.service.IChatSessionService;
-import org.ruoyi.service.VectorStoreService;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.Resource;
import org.springframework.http.MediaType;
@@ -56,7 +54,7 @@ public class SseServiceImpl implements ISseService {
private final OpenAiStreamClient openAiStreamClient;
- private final EmbeddingService embeddingService;
+ private final VectorStoreService vectorStoreService;
private final VectorStoreService vectorStore;
@@ -184,9 +182,7 @@ public class SseServiceImpl implements ISseService {
if(StringUtils.isNotEmpty(chatRequest.getKid())){
List knMessages = new ArrayList<>();
String content = messages.get(messages.size() - 1).getContent().toString();
- List nearestList;
- List queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
- nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
+ List nearestList = vectorStoreService.getQueryVector(content, chatRequest.getKid());
for (String prompt : nearestList) {
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
knMessages.add(userMessage);
diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
index 7a489a6b..33d9c11b 100644
--- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
+++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
@@ -23,7 +23,7 @@ import org.ruoyi.domain.vo.KnowledgeInfoVo;
import org.ruoyi.mapper.KnowledgeAttachMapper;
import org.ruoyi.mapper.KnowledgeFragmentMapper;
import org.ruoyi.mapper.KnowledgeInfoMapper;
-import org.ruoyi.service.EmbeddingService;
+import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.IKnowledgeInfoService;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -44,7 +44,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
private final KnowledgeInfoMapper baseMapper;
- private final EmbeddingService embeddingService;
+ private final VectorStoreService vectorStoreService;
private final ResourceLoaderFactory resourceLoaderFactory;
@@ -150,7 +150,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
knowledgeInfo.setUid(LoginHelper.getLoginUser().getUserId());
}
baseMapper.insert(knowledgeInfo);
- embeddingService.createSchema(String.valueOf(knowledgeInfo.getId()));
+ vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()));
}else {
baseMapper.updateById(knowledgeInfo);
}
@@ -165,7 +165,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
check(knowledgeInfoList);
// 删除向量库信息
knowledgeInfoList.forEach(knowledgeInfoVo -> {
- embeddingService.removeByKid(String.valueOf(knowledgeInfoVo.getId()));
+ vectorStoreService.removeByKid(String.valueOf(knowledgeInfoVo.getId()));
});
// 删除附件和知识片段
fragmentMapper.deleteByMap(map);
@@ -197,7 +197,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
List knowledgeFragmentList = new ArrayList<>();
if (CollUtil.isNotEmpty(chunkList)) {
for (int i = 0; i < chunkList.size(); i++) {
- String fid = RandomUtil.randomString(16);
+ String fid = RandomUtil.randomString(10);
fids.add(fid);
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
knowledgeFragment.setKid(kid);
@@ -216,7 +216,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
knowledgeAttach.setContent(content);
knowledgeAttach.setCreateTime(new Date());
attachMapper.insert(knowledgeAttach);
- embeddingService.storeEmbeddings(chunkList,kid,docId,fids);
+ vectorStoreService.storeEmbeddings(chunkList,kid);
}