Merge remote-tracking branch 'origin/main'

# Conflicts:
#	ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/knowledge/KnowledgeInfoServiceImpl.java
This commit is contained in:
zhouweiyi
2025-05-14 16:13:07 +08:00
11 changed files with 349 additions and 393 deletions

View File

@@ -34,6 +34,23 @@
<a href="https://github.com/ageerle/ruoyi-ai/issues">提出新特性</a>
</p>
## 快速启动
1. **克隆项目**
```bash
git clone https://github.com/alanpeng/ruoyi-ai-docker-deploy
cd ruoyi-ai-docker-deploy
```
2. **启动全套应用**
```bash
docker-compose up -d
```
3. **访问应用界面**
- 用户界面:`http://your-server-ip:8081`
- 管理员界面:`http://your-server-ip:8082`
## 目录
- [系统体验](#系统体验)

View File

@@ -1,5 +1,6 @@
package org.ruoyi.domain;
import com.alibaba.excel.annotation.ExcelProperty;
import com.baomidou.mybatisplus.annotation.*;
import lombok.Data;
import lombok.EqualsAndHashCode;
@@ -78,14 +79,19 @@ public class KnowledgeInfo extends BaseEntity {
private Long textBlockSize;
/**
* 向量库
* 向量库模型名称
*/
private String vector;
private String vectorModelName;
/**
* 向量模型
* 向量模型名称
*/
private String vectorModel;
private String embeddingModelName;
/**
* 系统提示词
*/
private String systemPrompt;
/**
* 备注

View File

@@ -83,16 +83,22 @@ public class KnowledgeInfoBo extends BaseEntity {
private Long textBlockSize;
/**
* 向量库
* 向量库模型名称
*/
@NotBlank(message = "向量库不能为空", groups = { AddGroup.class, EditGroup.class })
private String vector;
private String vectorModelName;
/**
* 向量模型
* 向量模型名称
*/
@NotBlank(message = "向量模型不能为空", groups = { AddGroup.class, EditGroup.class })
private String vectorModel;
private String embeddingModelName;
/**
* 系统提示词
*/
private String systemPrompt;
/**
* 备注

View File

@@ -26,9 +26,14 @@ public class QueryVectorBo {
private Integer maxResults;
/**
* 模型名称
* 向量库模型名称
*/
private String modelName;
private String vectorModelName;
/**
* 向量化模型名称
*/
private String embeddingModelName;
/**
* 请求key

View File

@@ -32,9 +32,14 @@ public class StoreEmbeddingBo {
private List<String> fids;
/**
* 模型名称
* 向量库模型名称
*/
private String modelName;
private String vectorModelName;
/**
* 向量化模型名称
*/
private String embeddingModelName;
/**
* 请求key

View File

@@ -98,16 +98,20 @@ public class KnowledgeInfoVo implements Serializable {
private Integer textBlockSize;
/**
* 向量库
* 向量库模型名称
*/
@ExcelProperty(value = "向量库")
private String vector;
private String vectorModelName;
/**
* 向量模型
* 向量模型名称
*/
@ExcelProperty(value = "向量模型")
private String vectorModel;
private String embeddingModelName;
/**
* 系统提示词
*/
private String systemPrompt;
/**
* 备注

View File

@@ -13,14 +13,14 @@ public interface VectorStoreService {
void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo);
void removeByDocId(String kid,String docId);
void removeByKid(String kid);
List<String> getQueryVector(QueryVectorBo queryVectorBo);
void createSchema(String kid,String modelName);
void removeByKidAndFid(String kid, String fid);
void removeByKid(String kid,String modelName);
void removeByDocId(String kid,String docId,String modelName);
void removeByKidAndFid(String kid, String fid,String modelName);
}

View File

@@ -1,5 +1,7 @@
package org.ruoyi.service.impl;
import cn.hutool.core.util.RandomUtil;
import com.google.protobuf.ServiceException;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
@@ -16,6 +18,7 @@ import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.domain.bo.QueryVectorBo;
@@ -40,11 +43,10 @@ public class VectorStoreServiceImpl implements VectorStoreService {
private final ConfigService configService;
Map<String,EmbeddingStore<TextSegment>> storeMap = new HashMap<>();
private EmbeddingStore<TextSegment> embeddingStore;
@Override
public void createSchema(String kid,String modelName) {
EmbeddingStore<TextSegment> embeddingStore;
switch (modelName) {
case "weaviate" -> {
String protocol = configService.getConfigValue("weaviate", "protocol");
@@ -84,88 +86,83 @@ public class VectorStoreServiceImpl implements VectorStoreService {
embeddingStore = new InMemoryEmbeddingStore<>();
}
}
storeMap.put(kid,embeddingStore);
}
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
EmbeddingStore<TextSegment> store = storeMap.get(storeEmbeddingBo.getKid());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getModelName(),
createSchema(storeEmbeddingBo.getKid(),storeEmbeddingBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
for (int i = 0; i < storeEmbeddingBo.getChunkList().size(); i++) {
List<String> chunkList = storeEmbeddingBo.getChunkList();
for (int i = 0; i < chunkList.size(); i++) {
Map<String, Object> dataSchema = new HashMap<>();
dataSchema.put("kid", storeEmbeddingBo.getKid());
dataSchema.put("docId", storeEmbeddingBo.getKid());
dataSchema.put("fid", storeEmbeddingBo.getFids().get(i));
Response<Embedding> response = embeddingModel.embed(storeEmbeddingBo.getChunkList().get(i));
Embedding embedding = response.content();
TextSegment segment = TextSegment.from(storeEmbeddingBo.getChunkList().get(i));
Embedding embedding = embeddingModel.embed(chunkList.get(i)).content();
TextSegment segment = TextSegment.from(chunkList.get(i));
segment.metadata().putAll(dataSchema);
store.add(embedding,segment);
embeddingStore.add(embedding,segment);
}
}
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
EmbeddingStore<TextSegment> store = storeMap.get(queryVectorBo.getKid());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getModelName(),
createSchema(queryVectorBo.getKid(),queryVectorBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
Filter simpleFilter = new IsEqualTo("kid", queryVectorBo.getKid());
// Filter simpleFilter = new IsEqualTo("kid", queryVectorBo.getKid());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
.maxResults(queryVectorBo.getMaxResults())
// 添加过滤条件
.filter(simpleFilter)
// .filter(simpleFilter)
.build();
List<EmbeddingMatch<TextSegment>> matches = store.search(embeddingSearchRequest).matches();
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
List<String> results = new ArrayList<>();
matches.forEach(embeddingMatch -> results.add(embeddingMatch.embedded().text()));
return results;
}
@Override
public void removeByKid(String kid) {
EmbeddingStore<TextSegment> store = storeMap.get(kid);
public void removeByKid(String kid,String modelName) {
createSchema(kid,modelName);
// 根据条件删除向量数据
Filter simpleFilter = new IsEqualTo("kid", kid);
store.removeAll(simpleFilter);
embeddingStore.removeAll(simpleFilter);
}
@Override
public void removeByDocId(String kid, String docId) {
EmbeddingStore<TextSegment> store = storeMap.get(kid);
public void removeByDocId(String kid, String docId,String modelName) {
createSchema(kid,modelName);
// 根据条件删除向量数据
Filter simpleFilterByDocId = new IsEqualTo("docId", docId);
store.removeAll(simpleFilterByDocId);
embeddingStore.removeAll(simpleFilterByDocId);
}
@Override
public void removeByKidAndFid(String kid, String fid) {
EmbeddingStore<TextSegment> store = storeMap.get(kid);
public void removeByKidAndFid(String kid, String fid,String modelName) {
createSchema(kid,modelName);
// 根据条件删除向量数据
Filter simpleFilterByKid = new IsEqualTo("kid", kid);
Filter simpleFilterFid = new IsEqualTo("fid", fid);
Filter simpleFilterByAnd = Filter.and(simpleFilterFid, simpleFilterByKid);
store.removeAll(simpleFilterByAnd);
embeddingStore.removeAll(simpleFilterByAnd);
}
/**
* 获取向量模型
*/
public EmbeddingModel getEmbeddingModel(String modelName,String apiKey,String baseUrl) {
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder().build();
@SneakyThrows
public EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
EmbeddingModel embeddingModel;
if(TEXT_EMBEDDING_3_SMALL.toString().equals(modelName)) {
embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.modelName(TEXT_EMBEDDING_3_SMALL)
.modelName(modelName)
.build();
// TODO 添加枚举
}else if("quentinz/bge-large-zh-v1.5".equals(modelName)) {
@@ -173,6 +170,14 @@ public class VectorStoreServiceImpl implements VectorStoreService {
.baseUrl(baseUrl)
.modelName(modelName)
.build();
}else if("baai/bge-m3".equals(modelName)) {
embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.modelName(modelName)
.build();
}else {
throw new ServiceException("未找到对应向量化模型!");
}
return embeddingModel;
}

View File

@@ -2,6 +2,7 @@ package org.ruoyi.chat.service.chat.impl;
import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.core.collection.CollectionUtil;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.google.protobuf.ServiceException;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
@@ -29,6 +30,8 @@ import org.ruoyi.common.redis.utils.RedisUtils;
import org.ruoyi.domain.bo.ChatSessionBo;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.domain.vo.KnowledgeInfoVo;
import org.ruoyi.service.IKnowledgeInfoService;
import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.IChatModelService;
import org.ruoyi.service.IChatSessionService;
@@ -67,6 +70,8 @@ public class SseServiceImpl implements ISseService {
private final IChatSessionService chatSessionService;
private final IKnowledgeInfoService knowledgeInfoService;
private ChatModelVo chatModelVo;
@@ -148,50 +153,61 @@ public class SseServiceImpl implements ISseService {
}
}
/**
* 构建消息列表
*/
private void buildChatMessageList(ChatRequest chatRequest){
chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
String sysPrompt;
chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
// 获取对话消息列表
List<Message> messages = chatRequest.getMessages();
String sysPrompt = chatModelVo.getSystemPrompt();
// 查询向量库相关信息加入到上下文
if(StringUtils.isNotEmpty(chatRequest.getKid())){
List<Message> knMessages = new ArrayList<>();
String content = messages.get(messages.size() - 1).getContent().toString();
// 通过kid查询知识库信息
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKid()));
// 查询向量模型配置信息
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
if(StringUtils.isEmpty(sysPrompt)){
// TODO 系统默认提示词,后续会增加提示词管理
sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手名字叫熊猫助手。你擅长中英文对话能够理解并处理各种问题提供安全、有帮助、准确的回答。" +
"当前时间:"+ DateUtils.getDate()+
"#注意:回复之前注意结合上下文和工具返回内容进行回复。";
QueryVectorBo queryVectorBo = new QueryVectorBo();
queryVectorBo.setQuery(content);
queryVectorBo.setKid(chatRequest.getKid());
queryVectorBo.setApiKey(chatModel.getApiKey());
queryVectorBo.setBaseUrl(chatModel.getApiHost());
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModelName());
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
for (String prompt : nearestList) {
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
knMessages.add(userMessage);
}
messages.addAll(knMessages);
// 设置知识库系统提示词
sysPrompt = knowledgeInfoVo.getSystemPrompt();
if(StringUtils.isEmpty(sysPrompt)){
sysPrompt ="###角色设定\n" +
"你是一个智能知识助手,专注于利用上下文中的信息来提供准确和相关的回答。\n" +
"###指令\n" +
"当用户的问题与上下文知识匹配时,利用上下文信息进行回答。如果问题与上下文不匹配,运用自身的推理能力生成合适的回答。\n" +
"###限制\n" +
"确保回答清晰简洁,避免提供不必要的细节。始终保持语气友好" +
"当前时间:"+ DateUtils.getDate();
}
}else {
sysPrompt = chatModelVo.getSystemPrompt();
if(StringUtils.isEmpty(sysPrompt)){
sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手名字叫熊猫助手。你擅长中英文对话能够理解并处理各种问题提供安全、有帮助、准确的回答。" +
"当前时间:"+ DateUtils.getDate()+
"#注意:回复之前注意结合上下文和工具返回内容进行回复。";
}
}
// 设置系统默认提示词
Message sysMessage = Message.builder().content(sysPrompt).role(Message.Role.SYSTEM).build();
messages.add(0,sysMessage);
chatRequest.setSysPrompt(sysPrompt);
// 查询向量库相关信息加入到上下文
if(StringUtils.isNotEmpty(chatRequest.getKid())){
List<Message> knMessages = new ArrayList<>();
String content = messages.get(messages.size() - 1).getContent().toString();
QueryVectorBo queryVectorBo = new QueryVectorBo();
queryVectorBo.setQuery(content);
queryVectorBo.setKid(chatRequest.getKid());
queryVectorBo.setApiKey(chatModelVo.getApiKey());
queryVectorBo.setBaseUrl(chatModelVo.getApiHost());
queryVectorBo.setModelName(chatModelVo.getModelName());
// TODO 查询向量返回条数,这里应该查询知识库配置
queryVectorBo.setMaxResults(3);
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
for (String prompt : nearestList) {
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
knMessages.add(userMessage);
}
// TODO 提示词,这里应该查询知识库配置
Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意回答问题时须严格根据我给你的系统上下文内容原文进行回答请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
knMessages.add(userMessage);
messages.addAll(knMessages);
}
// 用户对话内容
String chatString = null;
// 获取用户对话信息

View File

@@ -1,14 +1,11 @@
package org.ruoyi.chat.service.knowledge;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.RandomUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import java.util.stream.Collectors;
import lombok.RequiredArgsConstructor;
import org.ruoyi.chain.loader.ResourceLoader;
import org.ruoyi.chain.loader.ResourceLoaderFactory;
@@ -16,8 +13,6 @@ import org.ruoyi.common.core.domain.model.LoginUser;
import org.ruoyi.common.core.utils.MapstructUtils;
import org.ruoyi.common.core.utils.StringUtils;
import org.ruoyi.common.satoken.utils.LoginHelper;
import org.ruoyi.constant.DealStatus;
import org.ruoyi.constant.FileType;
import org.ruoyi.core.page.PageQuery;
import org.ruoyi.core.page.TableDataInfo;
import org.ruoyi.domain.ChatModel;
@@ -35,15 +30,11 @@ import org.ruoyi.mapper.KnowledgeInfoMapper;
import org.ruoyi.service.IChatModelService;
import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.IKnowledgeInfoService;
import org.ruoyi.system.domain.vo.SysOssVo;
import org.ruoyi.system.service.ISysOssService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.scheduling.annotation.Async;
import java.io.IOException;
import java.util.*;
@@ -58,321 +49,216 @@ import java.util.*;
@Service
public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
private static final Logger log = LoggerFactory.getLogger(KnowledgeInfoServiceImpl.class);
private final KnowledgeInfoMapper baseMapper;
private static final Logger log = LoggerFactory.getLogger(KnowledgeInfoServiceImpl.class);
private final KnowledgeInfoMapper baseMapper;
private final VectorStoreService vectorStoreService;
private final VectorStoreService vectorStoreService;
private final ResourceLoaderFactory resourceLoaderFactory;
private final ResourceLoaderFactory resourceLoaderFactory;
private final KnowledgeFragmentMapper fragmentMapper;
private final KnowledgeFragmentMapper fragmentMapper;
private final KnowledgeAttachMapper attachMapper;
private final KnowledgeAttachMapper attachMapper;
private final IChatModelService chatModelService;
private final IChatModelService chatModelService;
private final ISysOssService ossService;
/**
* 查询知识库
*/
@Override
public KnowledgeInfoVo queryById(Long id) {
return baseMapper.selectVoById(id);
}
/**
* 查询知识库列表
*/
@Override
public TableDataInfo<KnowledgeInfoVo> queryPageList(KnowledgeInfoBo bo, PageQuery pageQuery) {
LambdaQueryWrapper<KnowledgeInfo> lqw = buildQueryWrapper(bo);
Page<KnowledgeInfoVo> result = baseMapper.selectVoPage(pageQuery.build(), lqw);
return TableDataInfo.build(result);
}
/**
* 查询知识库列表
*/
@Override
public List<KnowledgeInfoVo> queryList(KnowledgeInfoBo bo) {
LambdaQueryWrapper<KnowledgeInfo> lqw = buildQueryWrapper(bo);
return baseMapper.selectVoList(lqw);
}
private LambdaQueryWrapper<KnowledgeInfo> buildQueryWrapper(KnowledgeInfoBo bo) {
Map<String, Object> params = bo.getParams();
LambdaQueryWrapper<KnowledgeInfo> lqw = Wrappers.lambdaQuery();
lqw.eq(StringUtils.isNotBlank(bo.getKid()), KnowledgeInfo::getKid, bo.getKid());
lqw.eq(bo.getUid() != null, KnowledgeInfo::getUid, bo.getUid());
lqw.like(StringUtils.isNotBlank(bo.getKname()), KnowledgeInfo::getKname, bo.getKname());
lqw.eq(bo.getShare() != null, KnowledgeInfo::getShare, bo.getShare());
lqw.eq(StringUtils.isNotBlank(bo.getDescription()), KnowledgeInfo::getDescription,
bo.getDescription());
lqw.eq(StringUtils.isNotBlank(bo.getKnowledgeSeparator()), KnowledgeInfo::getKnowledgeSeparator,
bo.getKnowledgeSeparator());
lqw.eq(StringUtils.isNotBlank(bo.getQuestionSeparator()), KnowledgeInfo::getQuestionSeparator,
bo.getQuestionSeparator());
lqw.eq(bo.getOverlapChar() != null, KnowledgeInfo::getOverlapChar, bo.getOverlapChar());
lqw.eq(bo.getRetrieveLimit() != null, KnowledgeInfo::getRetrieveLimit, bo.getRetrieveLimit());
lqw.eq(bo.getTextBlockSize() != null, KnowledgeInfo::getTextBlockSize, bo.getTextBlockSize());
lqw.eq(StringUtils.isNotBlank(bo.getVector()), KnowledgeInfo::getVector, bo.getVector());
lqw.eq(StringUtils.isNotBlank(bo.getVectorModel()), KnowledgeInfo::getVectorModel,
bo.getVectorModel());
return lqw;
}
/**
* 新增知识库
*/
@Override
public Boolean insertByBo(KnowledgeInfoBo bo) {
KnowledgeInfo add = MapstructUtils.convert(bo, KnowledgeInfo.class);
validEntityBeforeSave(add);
boolean flag = baseMapper.insert(add) > 0;
if (flag) {
bo.setId(add.getId());
}
return flag;
}
/**
* 修改知识库
*/
@Override
public Boolean updateByBo(KnowledgeInfoBo bo) {
KnowledgeInfo update = MapstructUtils.convert(bo, KnowledgeInfo.class);
validEntityBeforeSave(update);
return baseMapper.updateById(update) > 0;
}
/**
* 保存前的数据校验
*/
private void validEntityBeforeSave(KnowledgeInfo entity) {
//TODO 做一些数据校验,如唯一约束
}
/**
* 批量删除知识库
*/
@Override
public Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid) {
if (isValid) {
//TODO 做一些业务上的校验,判断是否需要校验
}
return baseMapper.deleteBatchIds(ids) > 0;
}
@Override
@Transactional(rollbackFor = Exception.class)
public void saveOne(KnowledgeInfoBo bo) {
KnowledgeInfo knowledgeInfo = MapstructUtils.convert(bo, KnowledgeInfo.class);
if (StringUtils.isBlank(bo.getKid())) {
String kid = RandomUtil.randomString(10);
if (knowledgeInfo != null) {
knowledgeInfo.setKid(kid);
knowledgeInfo.setUid(LoginHelper.getLoginUser().getUserId());
}
baseMapper.insert(knowledgeInfo);
if (knowledgeInfo != null) {
vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()), bo.getVector());
}
} else {
baseMapper.updateById(knowledgeInfo);
}
}
@Override
@Transactional(rollbackFor = Exception.class)
public void removeKnowledge(String id) {
Map<String, Object> map = new HashMap<>();
map.put("kid", id);
List<KnowledgeInfoVo> knowledgeInfoList = baseMapper.selectVoByMap(map);
check(knowledgeInfoList);
// 删除向量库信息
knowledgeInfoList.forEach(knowledgeInfoVo -> {
vectorStoreService.removeByKid(String.valueOf(knowledgeInfoVo.getId()));
});
// 删除附件和知识片段
fragmentMapper.deleteByMap(map);
attachMapper.deleteByMap(map);
// 删除知识库
baseMapper.deleteByMap(map);
}
@Override
public void upload(KnowledgeInfoUploadBo bo) {
storeContent(bo.getFile(), bo.getKid());
}
public void storeContent(MultipartFile file, String kid) {
if (file == null || file.isEmpty()) {
throw new IllegalArgumentException("File cannot be null or empty");
/**
* 查询知识库
*/
@Override
public KnowledgeInfoVo queryById(Long id){
return baseMapper.selectVoById(id);
}
SysOssVo uploadDto = null;
/**
* 查询知识库列表
*/
@Override
public TableDataInfo<KnowledgeInfoVo> queryPageList(KnowledgeInfoBo bo, PageQuery pageQuery) {
LambdaQueryWrapper<KnowledgeInfo> lqw = buildQueryWrapper(bo);
Page<KnowledgeInfoVo> result = baseMapper.selectVoPage(pageQuery.build(), lqw);
return TableDataInfo.build(result);
}
String fileName = file.getOriginalFilename();
List<String> chunkList = new ArrayList<>();
KnowledgeAttach knowledgeAttach = new KnowledgeAttach();
knowledgeAttach.setKid(kid);
String docId = RandomUtil.randomString(10);
knowledgeAttach.setDocId(docId);
knowledgeAttach.setDocName(fileName);
knowledgeAttach.setDocType(fileName.substring(fileName.lastIndexOf(".") + 1));
String content = "";
ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(
knowledgeAttach.getDocType());
List<String> fids = new ArrayList<>();
try {
content = resourceLoader.getContent(file.getInputStream());
chunkList = resourceLoader.getChunkList(content, kid);
List<KnowledgeFragment> knowledgeFragmentList = new ArrayList<>();
if (CollUtil.isNotEmpty(chunkList)) {
// Upload file to OSS
uploadDto = ossService.upload(file);
/**
* 查询知识库列表
*/
@Override
public List<KnowledgeInfoVo> queryList(KnowledgeInfoBo bo) {
LambdaQueryWrapper<KnowledgeInfo> lqw = buildQueryWrapper(bo);
return baseMapper.selectVoList(lqw);
}
for (int i = 0; i < chunkList.size(); i++) {
String fid = RandomUtil.randomString(10);
fids.add(fid);
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
knowledgeFragment.setKid(kid);
knowledgeFragment.setDocId(docId);
knowledgeFragment.setFid(fid);
knowledgeFragment.setIdx(i);
knowledgeFragment.setContent(chunkList.get(i));
knowledgeFragment.setCreateTime(new Date());
knowledgeFragmentList.add(knowledgeFragment);
private LambdaQueryWrapper<KnowledgeInfo> buildQueryWrapper(KnowledgeInfoBo bo) {
Map<String, Object> params = bo.getParams();
LambdaQueryWrapper<KnowledgeInfo> lqw = Wrappers.lambdaQuery();
lqw.eq(StringUtils.isNotBlank(bo.getKid()), KnowledgeInfo::getKid, bo.getKid());
lqw.eq(bo.getUid() != null, KnowledgeInfo::getUid, bo.getUid());
lqw.like(StringUtils.isNotBlank(bo.getKname()), KnowledgeInfo::getKname, bo.getKname());
lqw.eq(bo.getShare() != null, KnowledgeInfo::getShare, bo.getShare());
lqw.eq(StringUtils.isNotBlank(bo.getDescription()), KnowledgeInfo::getDescription, bo.getDescription());
lqw.eq(StringUtils.isNotBlank(bo.getKnowledgeSeparator()), KnowledgeInfo::getKnowledgeSeparator, bo.getKnowledgeSeparator());
lqw.eq(StringUtils.isNotBlank(bo.getQuestionSeparator()), KnowledgeInfo::getQuestionSeparator, bo.getQuestionSeparator());
lqw.eq(bo.getOverlapChar() != null, KnowledgeInfo::getOverlapChar, bo.getOverlapChar());
lqw.eq(bo.getRetrieveLimit() != null, KnowledgeInfo::getRetrieveLimit, bo.getRetrieveLimit());
lqw.eq(bo.getTextBlockSize() != null, KnowledgeInfo::getTextBlockSize, bo.getTextBlockSize());
return lqw;
}
/**
* 新增知识库
*/
@Override
public Boolean insertByBo(KnowledgeInfoBo bo) {
KnowledgeInfo add = MapstructUtils.convert(bo, KnowledgeInfo.class);
validEntityBeforeSave(add);
boolean flag = baseMapper.insert(add) > 0;
if (flag) {
bo.setId(add.getId());
}
}
fragmentMapper.insertBatch(knowledgeFragmentList);
} catch (IOException e) {
log.error("保存知识库信息失败!{}", e.getMessage());
}
knowledgeAttach.setContent(content);
knowledgeAttach.setCreateTime(new Date());
if (ObjectUtil.isNotEmpty(uploadDto) && ObjectUtil.isNotEmpty(uploadDto.getOssId())) {
knowledgeAttach.setOssId(uploadDto.getOssId());
//只有pdf文件 才需要拆解图片和分析图片内容
if (FileType.PDF.equals(knowledgeAttach.getDocType())) {
knowledgeAttach.setPicStatus(DealStatus.STATUS_10);
knowledgeAttach.setPicAnysStatus(DealStatus.STATUS_10);
} else {
knowledgeAttach.setPicStatus(DealStatus.STATUS_30);
knowledgeAttach.setPicAnysStatus(DealStatus.STATUS_30);
}
//所有文件上传后,都需要同步到向量数据库
knowledgeAttach.setVectorStatus(DealStatus.STATUS_10);
return flag;
}
attachMapper.insert(knowledgeAttach);
}
/**
* 检查用户是否有删除知识库权限
*
* @param knowledgeInfoList 知识库列表
*/
public void check(List<KnowledgeInfoVo> knowledgeInfoList) {
LoginUser loginUser = LoginHelper.getLoginUser();
for (KnowledgeInfoVo knowledgeInfoVo : knowledgeInfoList) {
if (!knowledgeInfoVo.getUid().equals(loginUser.getUserId())) {
throw new SecurityException("权限不足");
}
/**
* 修改知识库
*/
@Override
public Boolean updateByBo(KnowledgeInfoBo bo) {
KnowledgeInfo update = MapstructUtils.convert(bo, KnowledgeInfo.class);
validEntityBeforeSave(update);
return baseMapper.updateById(update) > 0;
}
}
/**
* 定时 处理 附件上传后上传向量数据库和PDF文件图片拆解和分析内容
*/
@Scheduled(fixedDelay = 3000) // 每3秒执行一次
public void dealKnowledgeAttach() throws Exception {
//处理 需要上传向量数据库的记录
List<KnowledgeAttach> knowledgeAttaches = attachMapper.selectList(
new LambdaQueryWrapper<KnowledgeAttach>()
.eq(KnowledgeAttach::getPicStatus, DealStatus.STATUS_30)
.eq(KnowledgeAttach::getPicAnysStatus, DealStatus.STATUS_30)
.eq(KnowledgeAttach::getVectorStatus, DealStatus.STATUS_10)
);
if (ObjectUtil.isNotEmpty(knowledgeAttaches)) {
for (KnowledgeAttach attachItem : knowledgeAttaches) {
this.dealVectorStatus(attachItem);
}
/**
* 保存前的数据校验
*/
private void validEntityBeforeSave(KnowledgeInfo entity){
//TODO 做一些数据校验,如唯一约束
}
}
@Async
public void dealVectorStatus(KnowledgeAttach attachItem) throws Exception {
try {
//锁定数据 更改VectorStatus 到进行中
if (attachMapper.update(new LambdaUpdateWrapper<KnowledgeAttach>()
.set(KnowledgeAttach::getVectorStatus, DealStatus.STATUS_20)
.eq(KnowledgeAttach::getPicStatus, DealStatus.STATUS_30)
.eq(KnowledgeAttach::getPicAnysStatus, DealStatus.STATUS_30)
.eq(KnowledgeAttach::getVectorStatus, DealStatus.STATUS_10)
.eq(KnowledgeAttach::getId, attachItem.getId())
) == 0) {
return;
}
// 通过kid查询知识库信息
KnowledgeInfoVo knowledgeInfoVo = baseMapper.selectVoOne(Wrappers.<KnowledgeInfo>lambdaQuery()
.eq(KnowledgeInfo::getKid, attachItem.getKid()));
// 通过向量模型查询模型信息
ChatModelVo chatModelVo = chatModelService.selectModelByName(
knowledgeInfoVo.getVectorModel());
List<KnowledgeFragment> knowledgeFragments = fragmentMapper.selectList(
new LambdaQueryWrapper<KnowledgeFragment>()
.eq(KnowledgeFragment::getKid, attachItem.getKid())
.eq(KnowledgeFragment::getDocId, attachItem.getDocId())
);
if (ObjectUtil.isEmpty(knowledgeFragments)) {
throw new Exception("文件段落为空");
}
List<String> fids = knowledgeFragments.stream()
.map(KnowledgeFragment::getFid)
.collect(Collectors.toList());
if (ObjectUtil.isEmpty(fids)) {
throw new Exception("fids 为空");
}
List<String> chunkList = knowledgeFragments.stream()
.map(KnowledgeFragment::getContent)
.collect(Collectors.toList());
if (ObjectUtil.isEmpty(chunkList)) {
throw new Exception("chunkList 为空");
}
StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
storeEmbeddingBo.setKid(attachItem.getKid());
storeEmbeddingBo.setDocId(attachItem.getDocId());
storeEmbeddingBo.setFids(fids);
storeEmbeddingBo.setChunkList(chunkList);
storeEmbeddingBo.setModelName(knowledgeInfoVo.getVectorModel());
storeEmbeddingBo.setApiKey(chatModelVo.getApiKey());
storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost());
vectorStoreService.storeEmbeddings(storeEmbeddingBo);
//设置处理完成
attachMapper.update(new LambdaUpdateWrapper<KnowledgeAttach>()
.set(KnowledgeAttach::getVectorStatus, DealStatus.STATUS_30)
.eq(KnowledgeAttach::getPicStatus, DealStatus.STATUS_30)
.eq(KnowledgeAttach::getPicAnysStatus, DealStatus.STATUS_30)
.eq(KnowledgeAttach::getVectorStatus, DealStatus.STATUS_20)
.eq(KnowledgeAttach::getId, attachItem.getId()));
} catch (Exception e) {
//设置处理失败
attachMapper.update(new LambdaUpdateWrapper<KnowledgeAttach>()
.set(KnowledgeAttach::getVectorStatus, DealStatus.STATUS_10)
.eq(KnowledgeAttach::getPicStatus, DealStatus.STATUS_30)
.eq(KnowledgeAttach::getPicAnysStatus, DealStatus.STATUS_30)
.eq(KnowledgeAttach::getVectorStatus, DealStatus.STATUS_20)
.eq(KnowledgeAttach::getId, attachItem.getId()));
throw new RuntimeException(e);
/**
* 批量删除知识库
*/
@Override
public Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid) {
if(isValid){
//TODO 做一些业务上的校验,判断是否需要校验
}
return baseMapper.deleteBatchIds(ids) > 0;
}
@Override
@Transactional(rollbackFor = Exception.class)
public void saveOne(KnowledgeInfoBo bo) {
KnowledgeInfo knowledgeInfo = MapstructUtils.convert(bo, KnowledgeInfo.class);
if (StringUtils.isBlank(bo.getKid())){
String kid = RandomUtil.randomString(10);
if (knowledgeInfo != null) {
knowledgeInfo.setKid(kid);
knowledgeInfo.setUid(LoginHelper.getLoginUser().getUserId());
}
baseMapper.insert(knowledgeInfo);
if (knowledgeInfo != null) {
vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()),bo.getVectorModelName());
}
}else {
baseMapper.updateById(knowledgeInfo);
}
}
@Override
@Transactional(rollbackFor = Exception.class)
public void removeKnowledge(String id) {
Map<String,Object> map = new HashMap<>();
map.put("kid",id);
List<KnowledgeInfoVo> knowledgeInfoList = baseMapper.selectVoByMap(map);
check(knowledgeInfoList);
// 删除向量库信息
knowledgeInfoList.forEach(knowledgeInfoVo -> {
vectorStoreService.removeByKid(String.valueOf(knowledgeInfoVo.getId()),knowledgeInfoVo.getVectorModelName());
});
// 删除附件和知识片段
fragmentMapper.deleteByMap(map);
attachMapper.deleteByMap(map);
// 删除知识库
baseMapper.deleteByMap(map);
}
@Override
public void upload(KnowledgeInfoUploadBo bo) {
storeContent(bo.getFile(), bo.getKid());
}
public void storeContent(MultipartFile file, String kid) {
String fileName = file.getOriginalFilename();
List<String> chunkList = new ArrayList<>();
KnowledgeAttach knowledgeAttach = new KnowledgeAttach();
knowledgeAttach.setKid(kid);
String docId = RandomUtil.randomString(10);
knowledgeAttach.setDocId(docId);
knowledgeAttach.setDocName(fileName);
knowledgeAttach.setDocType(fileName.substring(fileName.lastIndexOf(".")+1));
String content = "";
ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(knowledgeAttach.getDocType());
List<String> fids = new ArrayList<>();
try {
content = resourceLoader.getContent(file.getInputStream());
chunkList = resourceLoader.getChunkList(content, kid);
List<KnowledgeFragment> knowledgeFragmentList = new ArrayList<>();
if (CollUtil.isNotEmpty(chunkList)) {
for (int i = 0; i < chunkList.size(); i++) {
String fid = RandomUtil.randomString(10);
fids.add(fid);
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
knowledgeFragment.setKid(kid);
knowledgeFragment.setDocId(docId);
knowledgeFragment.setFid(fid);
knowledgeFragment.setIdx(i);
knowledgeFragment.setContent(chunkList.get(i));
knowledgeFragment.setCreateTime(new Date());
knowledgeFragmentList.add(knowledgeFragment);
}
}
fragmentMapper.insertBatch(knowledgeFragmentList);
} catch (IOException e) {
log.error("保存知识库信息失败!{}", e.getMessage());
}
knowledgeAttach.setContent(content);
knowledgeAttach.setCreateTime(new Date());
attachMapper.insert(knowledgeAttach);
// 通过kid查询知识库信息
KnowledgeInfoVo knowledgeInfoVo = baseMapper.selectVoOne(Wrappers.<KnowledgeInfo>lambdaQuery()
.eq(KnowledgeInfo::getId, kid));
// 通过向量模型查询模型信息
ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModelName());
StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
storeEmbeddingBo.setKid(kid);
storeEmbeddingBo.setDocId(docId);
storeEmbeddingBo.setFids(fids);
storeEmbeddingBo.setChunkList(chunkList);
storeEmbeddingBo.setVectorModelName(knowledgeInfoVo.getVectorModelName());
storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
storeEmbeddingBo.setApiKey(chatModelVo.getApiKey());
storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost());
vectorStoreService.storeEmbeddings(storeEmbeddingBo);
}
/**
* 检查用户是否有删除知识库权限
*
* @param knowledgeInfoList 知识库列表
*/
public void check(List<KnowledgeInfoVo> knowledgeInfoList){
LoginUser loginUser = LoginHelper.getLoginUser();
for (KnowledgeInfoVo knowledgeInfoVo : knowledgeInfoList) {
if(!knowledgeInfoVo.getUid().equals(loginUser.getUserId())){
throw new SecurityException("权限不足");
}
}
}
}
}

View File

@@ -0,0 +1,6 @@
LTER TABLE `knowledge_info`
ADD COLUMN `system_prompt` varchar(255) NULL COMMENT '系统提示词' AFTER `vector_model`;
ALTER TABLE `knowledge_info`
CHANGE COLUMN `vector` `vector_model_name` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT NULL COMMENT '向量库' AFTER `text_block_size`,
CHANGE COLUMN `vector_model` `embedding_model_name` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT NULL COMMENT '向量模型' AFTER `vector_model_name`;