mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-23 08:43:40 +00:00
feat: 新增检索测试相关接口
- 实现向量 L2 归一化,统一 Milvus/Qdrant/Weaviate 检索评分为 [0, 1] 空间
This commit is contained in:
@@ -8,6 +8,7 @@ import jakarta.validation.constraints.*;
|
||||
import cn.dev33.satoken.annotation.SaCheckPermission;
|
||||
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
@@ -102,4 +103,12 @@ public class KnowledgeFragmentController extends BaseController {
|
||||
@PathVariable Long[] ids) {
|
||||
return toAjax(knowledgeFragmentService.deleteWithValidByIds(List.of(ids), true));
|
||||
}
|
||||
|
||||
/**
|
||||
* 检索测试
|
||||
*/
|
||||
@PostMapping("/retrieval")
|
||||
public R<List<KnowledgeRetrievalVo>> retrieval(@RequestBody KnowledgeFragmentBo bo) {
|
||||
return R.ok(knowledgeFragmentService.retrieval(bo));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,5 +49,24 @@ public class KnowledgeFragmentBo extends BaseEntity {
|
||||
*/
|
||||
private String remark;
|
||||
|
||||
/**
|
||||
* 知识库ID
|
||||
*/
|
||||
private Long knowledgeId;
|
||||
|
||||
/**
|
||||
* 检索内容
|
||||
*/
|
||||
private String query;
|
||||
|
||||
/**
|
||||
* 返回条数
|
||||
*/
|
||||
private Integer topK;
|
||||
|
||||
/**
|
||||
* 相似度阈值
|
||||
*/
|
||||
private Double threshold;
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
package org.ruoyi.domain.vo.knowledge;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* 知识检索测试结果视图对象
|
||||
*
|
||||
* @author RobustH
|
||||
*/
|
||||
@Data
|
||||
@Builder
|
||||
public class KnowledgeRetrievalVo implements Serializable {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
/**
|
||||
* 片段内容
|
||||
*/
|
||||
private String content;
|
||||
|
||||
/**
|
||||
* 相似度得分
|
||||
*/
|
||||
private Double score;
|
||||
|
||||
/**
|
||||
* 来源文档名称
|
||||
*/
|
||||
private String sourceName;
|
||||
}
|
||||
@@ -16,6 +16,15 @@ import dev.langchain4j.model.chat.StreamingChatModel;
|
||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.rag.AugmentationRequest;
|
||||
import dev.langchain4j.rag.AugmentationResult;
|
||||
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
|
||||
import dev.langchain4j.rag.RetrievalAugmentor;
|
||||
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
|
||||
import dev.langchain4j.rag.content.aggregator.DefaultContentAggregator;
|
||||
import dev.langchain4j.rag.content.aggregator.ReRankingContentAggregator;
|
||||
import dev.langchain4j.model.scoring.ScoringModel;
|
||||
import dev.langchain4j.rag.query.Metadata;
|
||||
import dev.langchain4j.service.tool.ToolProvider;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.SneakyThrows;
|
||||
@@ -46,6 +55,8 @@ import org.ruoyi.mcp.service.core.ToolProviderFactory;
|
||||
import org.ruoyi.service.chat.AbstractChatService;
|
||||
import org.ruoyi.service.chat.IChatMessageService;
|
||||
import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore;
|
||||
import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever;
|
||||
import org.ruoyi.service.knowledge.rerank.ScoringModelFactory;
|
||||
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.vector.VectorStoreService;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -89,6 +100,8 @@ public class ChatServiceFacade implements IChatService {
|
||||
|
||||
private final ToolProviderFactory toolProviderFactory;
|
||||
|
||||
private final ScoringModelFactory scoringModelFactory;
|
||||
|
||||
/**
|
||||
* 内存实例缓存,避免同一会话重复创建
|
||||
* Key: sessionId, Value: MessageWindowChatMemory实例
|
||||
@@ -119,7 +132,9 @@ public class ChatServiceFacade implements IChatService {
|
||||
// 2. 构建上下文消息列表
|
||||
List<ChatMessage> contextMessages = buildContextMessages(chatRequest);
|
||||
|
||||
// 3. 处理特殊聊天模式(工作流、人机交互恢复、思考模式)
|
||||
// 注意:buildContextMessages() 最后返回的列表中,最新的带有增强知识的 UserMessage 在最后。
|
||||
// 对于有些模型API(非langchain4j的代理),它们可能不识别增强后的复杂文本(取决于供应商适配度)
|
||||
// 但是通过标准流,它被解析为 String。
|
||||
SseEmitter specialResult = handleSpecialChatModes(chatRequest, contextMessages, chatModelVo, emitter);
|
||||
if (specialResult != null) {
|
||||
return specialResult;
|
||||
@@ -346,39 +361,63 @@ public class ChatServiceFacade implements IChatService {
|
||||
* @return 上下文消息列表
|
||||
*/
|
||||
private List<ChatMessage> buildContextMessages(ChatRequest chatRequest) {
|
||||
List<ChatMessage> messages = new ArrayList<>();
|
||||
// 构建用户消息
|
||||
List<ChatMessage> messages = new ArrayList<>();
|
||||
|
||||
// 初始化用户消息
|
||||
UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent());
|
||||
messages.add(userMessage);
|
||||
|
||||
// 从向量库查询相关历史消息
|
||||
// 使用 LangChain4j 的 RetrievalAugmentor 进行检索增强
|
||||
if (chatRequest.getKnowledgeId() != null) {
|
||||
// 查询知识库信息
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId()));
|
||||
if (knowledgeInfoVo == null) {
|
||||
log.warn("知识库信息不存在,kid: {}", chatRequest.getKnowledgeId());
|
||||
return messages;
|
||||
}
|
||||
if (knowledgeInfoVo != null) {
|
||||
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
||||
if (chatModel != null) {
|
||||
|
||||
// 查询向量模型配置信息
|
||||
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
||||
if (chatModel == null) {
|
||||
log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel());
|
||||
return messages;
|
||||
}
|
||||
// 1. 构建适配器(Retriever)
|
||||
CustomVectorRetriever retriever = new CustomVectorRetriever(
|
||||
vectorStoreService, knowledgeInfoVo, chatModel);
|
||||
|
||||
// 构建向量查询参数
|
||||
QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel);
|
||||
// 2. 获取和构建重排模型聚合器(Aggregator)
|
||||
// 假设已在 KnowledgeInfoVo 等加入 getRerankModelConfig/getRerankModel 等,这里演示通用逻辑
|
||||
// 若无重排需求,使用 DefaultContentAggregator 或无 ScoringModel 的聚合器
|
||||
ContentAggregator contentAggregator;
|
||||
// TODO: 一旦实体类实现了重排模型的支持,此处可以从数据库读出:
|
||||
// ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel());
|
||||
ChatModelVo scoringModelConfig = null; // 当前暂无对应配置字段
|
||||
|
||||
// 获取向量查询结果
|
||||
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
|
||||
for (String prompt : nearestList) {
|
||||
// 知识库内容作为系统上下文添加
|
||||
messages.add( new AiMessage(prompt));
|
||||
ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig);
|
||||
if (scoringModel != null) {
|
||||
contentAggregator = ReRankingContentAggregator.builder()
|
||||
.scoringModel(scoringModel)
|
||||
// .maxResults(3) 这个数字将来从配置取
|
||||
.build();
|
||||
} else {
|
||||
contentAggregator = new DefaultContentAggregator();
|
||||
}
|
||||
|
||||
// 3. 构造流水线
|
||||
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
|
||||
.contentRetriever(retriever)
|
||||
.contentAggregator(contentAggregator)
|
||||
.build();
|
||||
|
||||
// 4. 执行 Augmentor 增强:将检索到的知识内容编织进 UserMessage 中
|
||||
Metadata ragMetadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>());
|
||||
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, ragMetadata);
|
||||
|
||||
AugmentationResult augmentedResult = augmentor.augment(augmentationRequest);
|
||||
|
||||
ChatMessage augmentedMessage = augmentedResult.chatMessage();
|
||||
if (augmentedMessage instanceof UserMessage) {
|
||||
userMessage = (UserMessage) augmentedMessage;
|
||||
}
|
||||
log.info("RAG 增强完成: UserMessage 已重构并附加上下文背景。");
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 从数据库查询历史对话消息
|
||||
// 从数据库查询历史对话消息(历史消息应放在当前提问前)
|
||||
if (chatRequest.getSessionId() != null) {
|
||||
MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId());
|
||||
if (memory != null) {
|
||||
@@ -390,6 +429,9 @@ public class ChatServiceFacade implements IChatService {
|
||||
}
|
||||
}
|
||||
|
||||
// 注入本次用户提问(经过 RAG 增强后的 UserMessage)
|
||||
messages.add(userMessage);
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
@@ -65,4 +66,12 @@ public interface IKnowledgeFragmentService {
|
||||
* @return 是否删除成功
|
||||
*/
|
||||
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
|
||||
|
||||
/**
|
||||
* 检索测试
|
||||
*
|
||||
* @param bo 检索参数
|
||||
* @return 检索结果
|
||||
*/
|
||||
List<KnowledgeRetrievalVo> retrieval(KnowledgeFragmentBo bo);
|
||||
}
|
||||
|
||||
@@ -13,8 +13,17 @@ import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
||||
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
||||
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
||||
import org.ruoyi.service.vector.VectorStoreService;
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.ArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -32,6 +41,9 @@ import java.util.Collection;
|
||||
public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
||||
|
||||
private final KnowledgeFragmentMapper baseMapper;
|
||||
private final IKnowledgeInfoService knowledgeInfoService;
|
||||
private final IChatModelService chatModelService;
|
||||
private final VectorStoreService vectorStoreService;
|
||||
|
||||
/**
|
||||
* 查询知识片段
|
||||
@@ -131,4 +143,45 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
||||
}
|
||||
return baseMapper.deleteByIds(ids) > 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* 检索测试核心实现
|
||||
*/
|
||||
@Override
|
||||
public List<KnowledgeRetrievalVo> retrieval(KnowledgeFragmentBo bo) {
|
||||
if (bo.getKnowledgeId() == null || StringUtils.isBlank(bo.getQuery())) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
// 1. 获取知识库及模型配置
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId());
|
||||
if (knowledgeInfoVo == null) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
||||
if (chatModel == null) {
|
||||
log.warn("未找到对应的向量模型配置: {}", knowledgeInfoVo.getEmbeddingModel());
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
// 2. 构造向量检索参数
|
||||
QueryVectorBo queryVectorBo = new QueryVectorBo();
|
||||
queryVectorBo.setQuery(bo.getQuery());
|
||||
queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId()));
|
||||
queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit());
|
||||
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
||||
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
|
||||
queryVectorBo.setApiKey(chatModel.getApiKey());
|
||||
queryVectorBo.setBaseUrl(chatModel.getApiHost());
|
||||
|
||||
// 3. 执行物理检索
|
||||
List<KnowledgeRetrievalVo> allResults = vectorStoreService.search(queryVectorBo);
|
||||
|
||||
// 4. 根据阈值过滤 (LangChain4j 结果 score 通常 0-1)
|
||||
double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0;
|
||||
return allResults.stream()
|
||||
.filter(res -> res.getScore() >= threshold)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
package org.ruoyi.service.knowledge.rerank;
|
||||
|
||||
import dev.langchain4j.model.scoring.ScoringModel;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* 重排模型提供商工厂
|
||||
* 用于将来无缝拓展硅基流动、百炼等支持重排的模型厂商
|
||||
*
|
||||
* @author RobustH
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class ScoringModelFactory {
|
||||
|
||||
/**
|
||||
* 根据后台传递的模型配置创建具体的重排模型
|
||||
*
|
||||
* @param rerankModelConfig 重排模型的配置 (例如其 providerCode, apiUrl, apiKey 等)
|
||||
* @return 标准的 LangChain4j ScoringModel
|
||||
*/
|
||||
public ScoringModel createScoringModel(ChatModelVo rerankModelConfig) {
|
||||
if (rerankModelConfig == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
String providerCode = rerankModelConfig.getProviderCode();
|
||||
log.info("初始化重排模型,供应商代码: {}", providerCode);
|
||||
|
||||
// TODO: 在这里通过 switch 或反射具体实例化支持的各种 ScoringModel (例如 CohereScoringModel, DascScope 等)
|
||||
// 目前返回 null 代表暂时没有加载特定的重排底座,这不会影响流程,Aggregator 会忽略它返回原样结果
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package org.ruoyi.service.knowledge.retriever;
|
||||
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.rag.content.Content;
|
||||
import dev.langchain4j.rag.content.retriever.ContentRetriever;
|
||||
import dev.langchain4j.rag.query.Query;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||
import org.ruoyi.service.vector.VectorStoreService;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 自定义向量检索器:适配 LangChain4j ContentRetriever 接口
|
||||
* 桥接现有的 VectorStoreService 获取检索结果
|
||||
*
|
||||
* @author RobustH
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class CustomVectorRetriever implements ContentRetriever {
|
||||
|
||||
private final VectorStoreService vectorStoreService;
|
||||
private final KnowledgeInfoVo knowledgeInfoVo;
|
||||
private final ChatModelVo chatModelVo;
|
||||
|
||||
@Override
|
||||
public List<Content> retrieve(Query query) {
|
||||
log.info("执行自定义向量检索,关键字: {}", query.text());
|
||||
|
||||
// 构建内部查询参数
|
||||
QueryVectorBo queryVectorBo = new QueryVectorBo();
|
||||
queryVectorBo.setQuery(query.text());
|
||||
queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId()));
|
||||
queryVectorBo.setApiKey(chatModelVo.getApiKey());
|
||||
queryVectorBo.setBaseUrl(chatModelVo.getApiHost());
|
||||
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
|
||||
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
||||
// 如果接入了重排,这里的 retrieveLimit 也就是 MaxResults 应当被放大,后续留给 Aggregator 截断
|
||||
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
|
||||
|
||||
// 执行底层的多种向量库策略检索
|
||||
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
|
||||
|
||||
// 将结果包装为标准的 Content 返回
|
||||
return nearestList.stream()
|
||||
.map(text -> Content.from(TextSegment.from(text)))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package org.ruoyi.service.vector;
|
||||
import org.ruoyi.common.core.exception.ServiceException;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -17,6 +18,11 @@ public interface VectorStoreService {
|
||||
|
||||
List<String> getQueryVector(QueryVectorBo queryVectorBo);
|
||||
|
||||
/**
|
||||
* 带分数及元数据的检索(用于测试检索功能)
|
||||
*/
|
||||
List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo);
|
||||
|
||||
void createSchema(String kid, String embeddingModelName);
|
||||
|
||||
void removeById(String id, String modelName) throws ServiceException;
|
||||
|
||||
@@ -37,6 +37,24 @@ public abstract class AbstractVectorStoreStrategy implements VectorStoreService
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 向量 L2 归一化 (单位化)
|
||||
*/
|
||||
protected static float[] normalize(float[] vector) {
|
||||
if (vector == null) return null;
|
||||
double sum = 0;
|
||||
for (float v : vector) {
|
||||
sum += v * v;
|
||||
}
|
||||
float norm = (float) Math.sqrt(sum);
|
||||
if (norm > 1e-9) {
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
vector[i] /= norm;
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取向量模型
|
||||
*/
|
||||
|
||||
@@ -19,7 +19,11 @@ import org.ruoyi.common.chat.service.chat.IChatModelService;
|
||||
import org.ruoyi.config.VectorStoreProperties;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
|
||||
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -32,10 +36,14 @@ import java.util.stream.IntStream;
|
||||
@Component
|
||||
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
private final KnowledgeAttachMapper knowledgeAttachMapper;
|
||||
|
||||
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||
IChatModelService chatModelService,
|
||||
EmbeddingModelFactory embeddingModelFactory) {
|
||||
EmbeddingModelFactory embeddingModelFactory,
|
||||
KnowledgeAttachMapper knowledgeAttachMapper) {
|
||||
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
||||
this.knowledgeAttachMapper = knowledgeAttachMapper;
|
||||
}
|
||||
|
||||
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
|
||||
@@ -51,7 +59,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
.collectionName(collectionName)
|
||||
.dimension(dimension)
|
||||
.indexType(IndexType.IVF_FLAT)
|
||||
.metricType(MetricType.L2)
|
||||
.metricType(MetricType.COSINE)
|
||||
.autoFlushOnInsert(autoFlushOnInsert)
|
||||
.idFieldName("id")
|
||||
.textFieldName("text")
|
||||
@@ -104,7 +112,10 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
TextSegment textSegment = TextSegment.from(text, metadata);
|
||||
Embedding embedding = embeddingModel.embed(text).content();
|
||||
embeddingStore.add(embedding, textSegment);
|
||||
// 单位化处理
|
||||
float[] vector = embedding.vector();
|
||||
normalize(vector);
|
||||
embeddingStore.add(Embedding.from(vector), textSegment);
|
||||
});
|
||||
long endTime = System.currentTimeMillis();
|
||||
log.info("Milvus向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000);
|
||||
@@ -136,6 +147,55 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
return resultList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||
int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
// 查询向量单位化处理
|
||||
float[] queryVector = queryEmbedding.vector();
|
||||
normalize(queryVector);
|
||||
|
||||
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
|
||||
|
||||
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, dimension, true);
|
||||
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(Embedding.from(queryVector))
|
||||
.maxResults(queryVectorBo.getMaxResults())
|
||||
.build();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(request).matches();
|
||||
List<KnowledgeRetrievalVo> resultList = new ArrayList<>();
|
||||
|
||||
for (EmbeddingMatch<TextSegment> match : matches) {
|
||||
TextSegment segment = match.embedded();
|
||||
if (segment == null) continue;
|
||||
|
||||
String docId = segment.metadata().getString("docId");
|
||||
String sourceName = "未知来源";
|
||||
if (docId != null) {
|
||||
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
|
||||
.eq(KnowledgeAttach::getDocId, docId)
|
||||
.last("limit 1"));
|
||||
if (attach != null) {
|
||||
sourceName = attach.getName();
|
||||
}
|
||||
}
|
||||
|
||||
// 提取内容、评分及来源
|
||||
double score = match.score();
|
||||
|
||||
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
|
||||
.content(segment.text())
|
||||
.score(score)
|
||||
.sourceName(sourceName)
|
||||
.build());
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public void removeById(String id, String modelName) {
|
||||
|
||||
@@ -24,7 +24,11 @@ import org.ruoyi.common.core.exception.ServiceException;
|
||||
import org.ruoyi.config.VectorStoreProperties;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static io.qdrant.client.VectorInputFactory.vectorInput;
|
||||
@@ -47,10 +51,14 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
private static final String METADATA_KID_KEY = "kid";
|
||||
private static final String METADATA_DOC_ID_KEY = "doc_id";
|
||||
|
||||
private final KnowledgeAttachMapper knowledgeAttachMapper;
|
||||
|
||||
public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||
IChatModelService chatModelService,
|
||||
EmbeddingModelFactory embeddingModelFactory) {
|
||||
EmbeddingModelFactory embeddingModelFactory,
|
||||
KnowledgeAttachMapper knowledgeAttachMapper) {
|
||||
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
||||
this.knowledgeAttachMapper = knowledgeAttachMapper;
|
||||
}
|
||||
|
||||
private EmbeddingStore<TextSegment> getQdrantStore(String collectionName) {
|
||||
@@ -129,7 +137,10 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
metadata.put(METADATA_DOC_ID_KEY, docId);
|
||||
TextSegment textSegment = TextSegment.from(text, metadata);
|
||||
Embedding embedding = embeddingModel.embed(text).content();
|
||||
embeddingStore.add(embedding, textSegment);
|
||||
// 单位化处理
|
||||
float[] vector = embedding.vector();
|
||||
normalize(vector);
|
||||
embeddingStore.add(Embedding.from(vector), textSegment);
|
||||
});
|
||||
|
||||
long endTime = System.currentTimeMillis();
|
||||
@@ -140,18 +151,22 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
// 查询向量单位化处理
|
||||
float[] queryVector = queryEmbedding.vector();
|
||||
normalize(queryVector);
|
||||
|
||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
|
||||
|
||||
List<Float> vector = new ArrayList<>();
|
||||
for (float f : queryEmbedding.vector()) {
|
||||
vector.add(f);
|
||||
List<Float> vectorList = new ArrayList<>();
|
||||
for (float f : queryVector) {
|
||||
vectorList.add(f);
|
||||
}
|
||||
|
||||
try (QdrantClient client = buildQdrantClient()) {
|
||||
QueryPoints request = QueryPoints.newBuilder()
|
||||
.setCollectionName(collectionName)
|
||||
.setQuery(Query.newBuilder()
|
||||
.setNearest(vectorInput(vector))
|
||||
.setNearest(vectorInput(vectorList))
|
||||
.build())
|
||||
.setLimit(queryVectorBo.getMaxResults())
|
||||
.setWithPayload(enable(true))
|
||||
@@ -172,6 +187,69 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
// 查询向量单位化处理
|
||||
float[] queryVector = queryEmbedding.vector();
|
||||
normalize(queryVector);
|
||||
|
||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
|
||||
|
||||
List<Float> vectorList = new ArrayList<>();
|
||||
for (float f : queryVector) {
|
||||
vectorList.add(f);
|
||||
}
|
||||
|
||||
try (QdrantClient client = buildQdrantClient()) {
|
||||
QueryPoints request = QueryPoints.newBuilder()
|
||||
.setCollectionName(collectionName)
|
||||
.setQuery(Query.newBuilder()
|
||||
.setNearest(vectorInput(vectorList))
|
||||
.build())
|
||||
.setLimit(queryVectorBo.getMaxResults())
|
||||
.setWithPayload(enable(true))
|
||||
.build();
|
||||
|
||||
List<ScoredPoint> results = client.queryAsync(request).get();
|
||||
List<org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo> resultList = new ArrayList<>();
|
||||
for (ScoredPoint point : results) {
|
||||
String content = "";
|
||||
JsonWithInt.Value textValue = point.getPayloadMap().get(TEXT_SEGMENT_KEY);
|
||||
if (textValue != null && textValue.hasStringValue()) {
|
||||
content = textValue.getStringValue();
|
||||
}
|
||||
|
||||
String docId = null;
|
||||
JsonWithInt.Value docIdValue = point.getPayloadMap().get(METADATA_DOC_ID_KEY);
|
||||
if (docIdValue != null && docIdValue.hasStringValue()) {
|
||||
docId = docIdValue.getStringValue();
|
||||
}
|
||||
|
||||
String sourceName = "未知来源";
|
||||
if (docId != null) {
|
||||
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
|
||||
.eq(KnowledgeAttach::getDocId, docId)
|
||||
.last("limit 1"));
|
||||
if (attach != null) {
|
||||
sourceName = attach.getName();
|
||||
}
|
||||
}
|
||||
|
||||
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
|
||||
.content(content)
|
||||
.score((double) point.getScore())
|
||||
.sourceName(sourceName)
|
||||
.build());
|
||||
}
|
||||
return resultList;
|
||||
} catch (Exception e) {
|
||||
log.error("Qdrant检索失败: {}", collectionName, e);
|
||||
throw new ServiceException("Qdrant向量检索失败");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeById(String id, String modelName) {
|
||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id;
|
||||
|
||||
@@ -4,6 +4,7 @@ import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
import org.ruoyi.factory.VectorStoreStrategyFactory;
|
||||
import org.ruoyi.service.vector.VectorStoreService;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
@@ -54,6 +55,13 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
return strategy.getQueryVector(queryVectorBo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||
log.info("执行测试搜索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
|
||||
VectorStoreService strategy = getCurrentStrategy();
|
||||
return strategy.search(queryVectorBo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeById(String id, String modelName) {
|
||||
log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName);
|
||||
|
||||
@@ -12,6 +12,7 @@ import org.ruoyi.common.core.exception.ServiceException;
|
||||
import org.ruoyi.config.VectorStoreProperties;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import io.weaviate.client.Config;
|
||||
@@ -24,6 +25,9 @@ import io.weaviate.client.v1.graphql.model.GraphQLResponse;
|
||||
import io.weaviate.client.v1.schema.model.Property;
|
||||
import io.weaviate.client.v1.schema.model.Schema;
|
||||
import io.weaviate.client.v1.schema.model.WeaviateClass;
|
||||
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
@@ -40,11 +44,14 @@ import java.util.Map;
|
||||
public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
private WeaviateClient client;
|
||||
private final KnowledgeAttachMapper knowledgeAttachMapper;
|
||||
|
||||
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||
IChatModelService chatModelService,
|
||||
EmbeddingModelFactory embeddingModelFactory) {
|
||||
EmbeddingModelFactory embeddingModelFactory,
|
||||
KnowledgeAttachMapper knowledgeAttachMapper) {
|
||||
super(vectorStoreProperties, embeddingModelFactory,chatModelService);
|
||||
this.knowledgeAttachMapper = knowledgeAttachMapper;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -110,9 +117,12 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
"kid", kid,
|
||||
"docId", docId
|
||||
);
|
||||
Float[] vector = toObjectArray(embedding.vector());
|
||||
float[] vectorArray = embedding.vector();
|
||||
normalize(vectorArray);
|
||||
Float[] vector = toObjectArray(vectorArray);
|
||||
|
||||
client.data().creator()
|
||||
.withClassName("LocalKnowledge" + kid)
|
||||
.withClassName(vectorStoreProperties.getWeaviate().getClassname() + kid)
|
||||
.withProperties(properties)
|
||||
.withVector(vector)
|
||||
.run();
|
||||
@@ -128,6 +138,9 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
float[] vector = queryEmbedding.vector();
|
||||
// 查询向量单位化处理
|
||||
normalize(vector);
|
||||
|
||||
List<String> vectorStrings = new ArrayList<>();
|
||||
for (float v : vector) {
|
||||
vectorStrings.add(String.valueOf(v));
|
||||
@@ -178,6 +191,77 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||
createSchema(queryVectorBo.getKid(), queryVectorBo.getEmbeddingModelName());
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
float[] vector = queryEmbedding.vector();
|
||||
// 查询向量单位化处理
|
||||
normalize(vector);
|
||||
List<String> vectorStrings = new ArrayList<>();
|
||||
for (float v : vector) {
|
||||
vectorStrings.add(String.valueOf(v));
|
||||
}
|
||||
String vectorStr = String.join(",", vectorStrings);
|
||||
String className = vectorStoreProperties.getWeaviate().getClassname();
|
||||
|
||||
String graphQLQuery = String.format(
|
||||
"{\n" +
|
||||
" Get {\n" +
|
||||
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
|
||||
" text\n" +
|
||||
" docId\n" +
|
||||
" _additional {\n" +
|
||||
" distance\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
"}",
|
||||
className + queryVectorBo.getKid(),
|
||||
vectorStr,
|
||||
queryVectorBo.getMaxResults()
|
||||
);
|
||||
|
||||
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
|
||||
List<org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo> resultList = new ArrayList<>();
|
||||
|
||||
if (result != null && !result.hasErrors()) {
|
||||
Object data = result.getResult().getData();
|
||||
JSONObject entries = new JSONObject(data);
|
||||
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
|
||||
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
|
||||
|
||||
for (Object obj : objects) {
|
||||
Map<String, Object> map = (Map<String, Object>) obj;
|
||||
String content = (String) map.get("text");
|
||||
String docId = (String) map.get("docId");
|
||||
|
||||
Map<String, Object> additional = (Map<String, Object>) map.get("_additional");
|
||||
Double distance = Double.valueOf(String.valueOf(additional.get("distance")));
|
||||
// 转换距离为得分 (Weaviate 0 是最相近,1 是最远;余弦距离下 1-dist 即为相似度)
|
||||
double score = 1.0 - distance;
|
||||
|
||||
String sourceName = "未知来源";
|
||||
if (docId != null) {
|
||||
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
|
||||
.eq(KnowledgeAttach::getDocId, docId)
|
||||
.last("limit 1"));
|
||||
if (attach != null) {
|
||||
sourceName = attach.getName();
|
||||
}
|
||||
}
|
||||
|
||||
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
|
||||
.content(content)
|
||||
.score(score)
|
||||
.sourceName(sourceName)
|
||||
.build());
|
||||
}
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public void removeById(String id, String modelName) {
|
||||
|
||||
Reference in New Issue
Block a user