diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml b/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml
index cb35d342..8d7d3963 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/pom.xml
@@ -48,17 +48,17 @@
-
- io.milvus
- milvus-sdk-java
- 2.3.2
-
+
+
+
+
+
-
- io.weaviate
- client
- 4.0.0
-
+
+
+
+
+
@@ -86,7 +86,12 @@
dev.langchain4j
- langchain4j-open-ai-spring-boot-starter
+ langchain4j-open-ai
+
+
+
+ dev.langchain4j
+ langchain4j-ollama
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/QueryVectorBo.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/QueryVectorBo.java
new file mode 100644
index 00000000..33e82049
--- /dev/null
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/QueryVectorBo.java
@@ -0,0 +1,43 @@
+package org.ruoyi.domain.bo;
+
+
+import lombok.Data;
+
+/**
+ * 查询向量所需参数
+ * @author ageer
+ */
+@Data
+public class QueryVectorBo {
+
+ /**
+ * 查询内容
+ */
+ private String query;
+
+ /**
+ * 知识库kid
+ */
+ private String kid;
+
+ /**
+ * 查询向量返回条数
+ */
+ private Integer maxResults;
+
+ /**
+ * 模型名称
+ */
+ private String modelName;
+
+ /**
+ * 请求key
+ */
+ private String apiKey;
+
+ /**
+ * 请求地址
+ */
+ private String baseUrl;
+
+}
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/StoreEmbeddingBo.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/StoreEmbeddingBo.java
new file mode 100644
index 00000000..95104037
--- /dev/null
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/bo/StoreEmbeddingBo.java
@@ -0,0 +1,49 @@
+package org.ruoyi.domain.bo;
+
+import lombok.Data;
+
+import java.util.List;
+
+/**
+ * 保存向量所需参数
+ * @author ageer
+ */
+@Data
+public class StoreEmbeddingBo {
+
+ /**
+ * 切分文本块列表
+ */
+ private List chunkList;
+
+ /**
+ * 知识库kid
+ */
+ private String kid;
+
+ /**
+ * 文档id
+ */
+ private String docId;
+
+ /**
+ * 知识块id列表
+ */
+ private List fids;
+
+ /**
+ * 模型名称
+ */
+ private String modelName;
+
+ /**
+ * 请求key
+ */
+ private String apiKey;
+
+ /**
+ * 请求地址
+ */
+ private String baseUrl;
+
+}
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java
index 6edaa5d3..277d0b11 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/VectorStoreService.java
@@ -1,20 +1,23 @@
package org.ruoyi.service;
+import org.ruoyi.domain.bo.QueryVectorBo;
+import org.ruoyi.domain.bo.StoreEmbeddingBo;
+
import java.util.List;
/**
- * @author ageer
* 向量库管理
+ * @author ageer
*/
public interface VectorStoreService {
- void storeEmbeddings(List chunkList, String kid,String docId,List fids);
+ void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo);
void removeByDocId(String kid,String docId);
void removeByKid(String kid);
- List getQueryVector(String query, String kid);
+ List getQueryVector(QueryVectorBo queryVectorBo);
void createSchema(String kid);
diff --git a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/WeaviateVectorStoreImpl.java b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/WeaviateVectorStoreImpl.java
index 9d5b9299..680a1bb6 100644
--- a/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/WeaviateVectorStoreImpl.java
+++ b/ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/WeaviateVectorStoreImpl.java
@@ -3,6 +3,7 @@ package org.ruoyi.service.impl;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.EmbeddingMatch;
@@ -11,9 +12,12 @@ import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
+import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.service.ConfigService;
+import org.ruoyi.domain.bo.QueryVectorBo;
+import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.service.VectorStoreService;
import org.springframework.stereotype.Service;
@@ -23,9 +27,11 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
+
+
/**
+ * Weaviate向量库管理
* @author ageer
- * Weaviate 向量库管理
*/
@Service
@Slf4j
@@ -37,38 +43,7 @@ public class WeaviateVectorStoreImpl implements VectorStoreService {
private final ConfigService configService;
@Override
- public List getQueryVector(String query, String kid) {
- EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
- .apiKey("sk-xxx")
- .baseUrl("https://api.pandarobot.chat/v1/")
- .modelName(TEXT_EMBEDDING_3_SMALL)
- .build();
-
- // Filter simpleFilter = new IsEqualTo("kid", kid);
-
- // createSchema(kid);
-
- Embedding queryEmbedding = embeddingModel.embed("聊天补全模型").content();
- EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
- .queryEmbedding(queryEmbedding)
- .maxResults(2)
- // 添加过滤条件
- // .filter(simpleFilter)
- .build();
- List> matches = embeddingStore.search(embeddingSearchRequest).matches();
-
-
-
- List results = new ArrayList<>();
-
- matches.forEach(embeddingMatch -> {
- results.add(embeddingMatch.embedded().text());
- });
-
- return results;
- }
-
- @Override
+ @PostConstruct
public void createSchema(String kid) {
String protocol = configService.getConfigValue("weaviate", "protocol");
String host = configService.getConfigValue("weaviate", "host");
@@ -84,24 +59,42 @@ public class WeaviateVectorStoreImpl implements VectorStoreService {
}
@Override
- public void storeEmbeddings(List chunkList,String kid,String docId,List fids) {
- EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
- .apiKey("sk-xxxx")
- .baseUrl("https://api.pandarobot.chat/v1/")
- .modelName(TEXT_EMBEDDING_3_SMALL)
- .build();
-
- chunkList.forEach(chunk -> {
+ public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
+ EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getModelName(),
+ storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
+ for (int i = 0; i < storeEmbeddingBo.getChunkList().size(); i++) {
Map dataSchema = new HashMap<>();
- dataSchema.put("kid", kid);
- dataSchema.put("docId", docId);
- dataSchema.put("fid", fids.get(0));
- Response response = embeddingModel.embed(chunk);
+ dataSchema.put("kid", storeEmbeddingBo.getKid());
+ dataSchema.put("docId", storeEmbeddingBo.getKid());
+ dataSchema.put("fid", storeEmbeddingBo.getFids().get(i));
+ Response response = embeddingModel.embed(storeEmbeddingBo.getChunkList().get(i));
Embedding embedding = response.content();
- TextSegment segment = TextSegment.from(chunk);
+ TextSegment segment = TextSegment.from(storeEmbeddingBo.getChunkList().get(i));
segment.metadata().putAll(dataSchema);
embeddingStore.add(embedding,segment);
+ }
+ }
+
+ @Override
+ public List getQueryVector(QueryVectorBo queryVectorBo) {
+ EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getModelName(),
+ queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
+ Filter simpleFilter = new IsEqualTo("kid", queryVectorBo.getKid());
+ Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
+ EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
+ .queryEmbedding(queryEmbedding)
+ .maxResults(queryVectorBo.getMaxResults())
+ // 添加过滤条件
+ .filter(simpleFilter)
+ .build();
+ List> matches = embeddingStore.search(embeddingSearchRequest).matches();
+
+ List results = new ArrayList<>();
+
+ matches.forEach(embeddingMatch -> {
+ results.add(embeddingMatch.embedded().text());
});
+ return results;
}
@@ -128,4 +121,25 @@ public class WeaviateVectorStoreImpl implements VectorStoreService {
embeddingStore.removeAll(simpleFilterByAnd);
}
+ /**
+ * 获取向量模型
+ */
+ public EmbeddingModel getEmbeddingModel(String modelName,String apiKey,String baseUrl) {
+ EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder().build();
+ if(TEXT_EMBEDDING_3_SMALL.toString().equals(modelName)) {
+ embeddingModel = OpenAiEmbeddingModel.builder()
+ .apiKey(apiKey)
+ .baseUrl(baseUrl)
+ .modelName(TEXT_EMBEDDING_3_SMALL)
+ .build();
+ // TODO 添加枚举
+ }else if("quentinz/bge-large-zh-v1.5".equals(modelName)) {
+ embeddingModel = OllamaEmbeddingModel.builder()
+ .baseUrl(baseUrl)
+ .modelName(modelName)
+ .build();
+ }
+ return embeddingModel;
+ }
+
}
diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java
index 5a86f953..c12ed44e 100644
--- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java
+++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java
@@ -25,6 +25,7 @@ import org.ruoyi.common.core.utils.file.FileUtils;
import org.ruoyi.common.core.utils.file.MimeTypeUtils;
import org.ruoyi.common.redis.utils.RedisUtils;
import org.ruoyi.domain.bo.ChatSessionBo;
+import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.IChatModelService;
@@ -166,7 +167,10 @@ public class SseServiceImpl implements ISseService {
// 获取对话消息列表
List messages = chatRequest.getMessages();
String sysPrompt = chatModelVo.getSystemPrompt();
+
+
if(StringUtils.isEmpty(sysPrompt)){
+ // TODO 系统默认提示词,后续会增加提示词管理
sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手,名字叫熊猫助手。你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。" +
"当前时间:"+ DateUtils.getDate()+
"#注意:回复之前注意结合上下文和工具返回内容进行回复。";
@@ -180,11 +184,20 @@ public class SseServiceImpl implements ISseService {
if(StringUtils.isNotEmpty(chatRequest.getKid())){
List knMessages = new ArrayList<>();
String content = messages.get(messages.size() - 1).getContent().toString();
- List nearestList = vectorStoreService.getQueryVector(content, chatRequest.getKid());
+ QueryVectorBo queryVectorBo = new QueryVectorBo();
+ queryVectorBo.setQuery(content);
+ queryVectorBo.setKid(chatRequest.getKid());
+ queryVectorBo.setApiKey(chatModelVo.getApiKey());
+ queryVectorBo.setBaseUrl(chatModelVo.getApiHost());
+ queryVectorBo.setModelName(chatModelVo.getModelName());
+ // TODO 查询向量返回条数,这里应该查询知识库配置
+ queryVectorBo.setMaxResults(3);
+ List nearestList = vectorStoreService.getQueryVector(queryVectorBo);
for (String prompt : nearestList) {
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
knMessages.add(userMessage);
}
+ // TODO 提示词,这里应该查询知识库配置
Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
knMessages.add(userMessage);
messages.addAll(knMessages);
diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
index 259e8a30..6cf62517 100644
--- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
+++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
@@ -3,6 +3,7 @@ package org.ruoyi.chat.service.knowledge;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.RandomUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import lombok.RequiredArgsConstructor;
@@ -14,15 +15,19 @@ import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.satoken.utils.LoginHelper;
import org.ruoyi.core.page.PageQuery;
import org.ruoyi.core.page.TableDataInfo;
+import org.ruoyi.domain.ChatModel;
import org.ruoyi.domain.KnowledgeAttach;
import org.ruoyi.domain.KnowledgeFragment;
import org.ruoyi.domain.KnowledgeInfo;
import org.ruoyi.domain.bo.KnowledgeInfoBo;
import org.ruoyi.domain.bo.KnowledgeInfoUploadBo;
+import org.ruoyi.domain.bo.StoreEmbeddingBo;
+import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.domain.vo.KnowledgeInfoVo;
import org.ruoyi.mapper.KnowledgeAttachMapper;
import org.ruoyi.mapper.KnowledgeFragmentMapper;
import org.ruoyi.mapper.KnowledgeInfoMapper;
+import org.ruoyi.service.IChatModelService;
import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.IKnowledgeInfoService;
import org.slf4j.Logger;
@@ -55,6 +60,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
private final KnowledgeAttachMapper attachMapper;
+ private final IChatModelService chatModelService;
+
/**
* 查询知识库
*/
@@ -219,10 +226,31 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
knowledgeAttach.setContent(content);
knowledgeAttach.setCreateTime(new Date());
attachMapper.insert(knowledgeAttach);
- vectorStoreService.storeEmbeddings(chunkList,kid,docId,fids);
+
+ // 通过kid查询知识库信息
+ KnowledgeInfoVo knowledgeInfoVo = baseMapper.selectVoOne(Wrappers.lambdaQuery()
+ .eq(KnowledgeInfo::getKid, kid));
+
+ // 通过向量模型查询模型信息
+ ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getVectorModel());
+
+ StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
+ storeEmbeddingBo.setKid(kid);
+ storeEmbeddingBo.setDocId(docId);
+ storeEmbeddingBo.setFids(fids);
+ storeEmbeddingBo.setChunkList(chunkList);
+ storeEmbeddingBo.setModelName(knowledgeInfoVo.getVectorModel());
+ storeEmbeddingBo.setApiKey(chatModelVo.getApiKey());
+ storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost());
+ vectorStoreService.storeEmbeddings(storeEmbeddingBo);
}
+ /**
+ * 检查用户是否有删除知识库权限
+ *
+ * @param knowledgeInfoList 知识库列表
+ */
public void check(List knowledgeInfoList){
LoginUser loginUser = LoginHelper.getLoginUser();
for (KnowledgeInfoVo knowledgeInfoVo : knowledgeInfoList) {