mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-23 16:53:38 +00:00
feat: 新增检索测试相关接口
- 实现向量 L2 归一化,统一 Milvus/Qdrant/Weaviate 检索评分为 [0, 1] 空间
This commit is contained in:
@@ -8,6 +8,7 @@ import jakarta.validation.constraints.*;
|
|||||||
import cn.dev33.satoken.annotation.SaCheckPermission;
|
import cn.dev33.satoken.annotation.SaCheckPermission;
|
||||||
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
import org.springframework.validation.annotation.Validated;
|
import org.springframework.validation.annotation.Validated;
|
||||||
@@ -102,4 +103,12 @@ public class KnowledgeFragmentController extends BaseController {
|
|||||||
@PathVariable Long[] ids) {
|
@PathVariable Long[] ids) {
|
||||||
return toAjax(knowledgeFragmentService.deleteWithValidByIds(List.of(ids), true));
|
return toAjax(knowledgeFragmentService.deleteWithValidByIds(List.of(ids), true));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索测试
|
||||||
|
*/
|
||||||
|
@PostMapping("/retrieval")
|
||||||
|
public R<List<KnowledgeRetrievalVo>> retrieval(@RequestBody KnowledgeFragmentBo bo) {
|
||||||
|
return R.ok(knowledgeFragmentService.retrieval(bo));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,5 +49,24 @@ public class KnowledgeFragmentBo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库ID
|
||||||
|
*/
|
||||||
|
private Long knowledgeId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索内容
|
||||||
|
*/
|
||||||
|
private String query;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 返回条数
|
||||||
|
*/
|
||||||
|
private Integer topK;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 相似度阈值
|
||||||
|
*/
|
||||||
|
private Double threshold;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
package org.ruoyi.domain.vo.knowledge;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serial;
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识检索测试结果视图对象
|
||||||
|
*
|
||||||
|
* @author RobustH
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
public class KnowledgeRetrievalVo implements Serializable {
|
||||||
|
|
||||||
|
@Serial
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 片段内容
|
||||||
|
*/
|
||||||
|
private String content;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 相似度得分
|
||||||
|
*/
|
||||||
|
private Double score;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 来源文档名称
|
||||||
|
*/
|
||||||
|
private String sourceName;
|
||||||
|
}
|
||||||
@@ -16,6 +16,15 @@ import dev.langchain4j.model.chat.StreamingChatModel;
|
|||||||
import dev.langchain4j.model.chat.response.ChatResponse;
|
import dev.langchain4j.model.chat.response.ChatResponse;
|
||||||
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||||
|
import dev.langchain4j.rag.AugmentationRequest;
|
||||||
|
import dev.langchain4j.rag.AugmentationResult;
|
||||||
|
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
|
||||||
|
import dev.langchain4j.rag.RetrievalAugmentor;
|
||||||
|
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
|
||||||
|
import dev.langchain4j.rag.content.aggregator.DefaultContentAggregator;
|
||||||
|
import dev.langchain4j.rag.content.aggregator.ReRankingContentAggregator;
|
||||||
|
import dev.langchain4j.model.scoring.ScoringModel;
|
||||||
|
import dev.langchain4j.rag.query.Metadata;
|
||||||
import dev.langchain4j.service.tool.ToolProvider;
|
import dev.langchain4j.service.tool.ToolProvider;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.SneakyThrows;
|
import lombok.SneakyThrows;
|
||||||
@@ -46,6 +55,8 @@ import org.ruoyi.mcp.service.core.ToolProviderFactory;
|
|||||||
import org.ruoyi.service.chat.AbstractChatService;
|
import org.ruoyi.service.chat.AbstractChatService;
|
||||||
import org.ruoyi.service.chat.IChatMessageService;
|
import org.ruoyi.service.chat.IChatMessageService;
|
||||||
import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore;
|
import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore;
|
||||||
|
import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever;
|
||||||
|
import org.ruoyi.service.knowledge.rerank.ScoringModelFactory;
|
||||||
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||||
import org.ruoyi.service.vector.VectorStoreService;
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -89,6 +100,8 @@ public class ChatServiceFacade implements IChatService {
|
|||||||
|
|
||||||
private final ToolProviderFactory toolProviderFactory;
|
private final ToolProviderFactory toolProviderFactory;
|
||||||
|
|
||||||
|
private final ScoringModelFactory scoringModelFactory;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 内存实例缓存,避免同一会话重复创建
|
* 内存实例缓存,避免同一会话重复创建
|
||||||
* Key: sessionId, Value: MessageWindowChatMemory实例
|
* Key: sessionId, Value: MessageWindowChatMemory实例
|
||||||
@@ -119,7 +132,9 @@ public class ChatServiceFacade implements IChatService {
|
|||||||
// 2. 构建上下文消息列表
|
// 2. 构建上下文消息列表
|
||||||
List<ChatMessage> contextMessages = buildContextMessages(chatRequest);
|
List<ChatMessage> contextMessages = buildContextMessages(chatRequest);
|
||||||
|
|
||||||
// 3. 处理特殊聊天模式(工作流、人机交互恢复、思考模式)
|
// 注意:buildContextMessages() 最后返回的列表中,最新的带有增强知识的 UserMessage 在最后。
|
||||||
|
// 对于有些模型API(非langchain4j的代理),它们可能不识别增强后的复杂文本(取决于供应商适配度)
|
||||||
|
// 但是通过标准流,它被解析为 String。
|
||||||
SseEmitter specialResult = handleSpecialChatModes(chatRequest, contextMessages, chatModelVo, emitter);
|
SseEmitter specialResult = handleSpecialChatModes(chatRequest, contextMessages, chatModelVo, emitter);
|
||||||
if (specialResult != null) {
|
if (specialResult != null) {
|
||||||
return specialResult;
|
return specialResult;
|
||||||
@@ -347,38 +362,62 @@ public class ChatServiceFacade implements IChatService {
|
|||||||
*/
|
*/
|
||||||
private List<ChatMessage> buildContextMessages(ChatRequest chatRequest) {
|
private List<ChatMessage> buildContextMessages(ChatRequest chatRequest) {
|
||||||
List<ChatMessage> messages = new ArrayList<>();
|
List<ChatMessage> messages = new ArrayList<>();
|
||||||
// 构建用户消息
|
|
||||||
|
// 初始化用户消息
|
||||||
UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent());
|
UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent());
|
||||||
messages.add(userMessage);
|
|
||||||
|
|
||||||
// 从向量库查询相关历史消息
|
// 使用 LangChain4j 的 RetrievalAugmentor 进行检索增强
|
||||||
if (chatRequest.getKnowledgeId() != null) {
|
if (chatRequest.getKnowledgeId() != null) {
|
||||||
// 查询知识库信息
|
|
||||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId()));
|
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId()));
|
||||||
if (knowledgeInfoVo == null) {
|
if (knowledgeInfoVo != null) {
|
||||||
log.warn("知识库信息不存在,kid: {}", chatRequest.getKnowledgeId());
|
|
||||||
return messages;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查询向量模型配置信息
|
|
||||||
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
||||||
if (chatModel == null) {
|
if (chatModel != null) {
|
||||||
log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel());
|
|
||||||
return messages;
|
// 1. 构建适配器(Retriever)
|
||||||
|
CustomVectorRetriever retriever = new CustomVectorRetriever(
|
||||||
|
vectorStoreService, knowledgeInfoVo, chatModel);
|
||||||
|
|
||||||
|
// 2. 获取和构建重排模型聚合器(Aggregator)
|
||||||
|
// 假设已在 KnowledgeInfoVo 等加入 getRerankModelConfig/getRerankModel 等,这里演示通用逻辑
|
||||||
|
// 若无重排需求,使用 DefaultContentAggregator 或无 ScoringModel 的聚合器
|
||||||
|
ContentAggregator contentAggregator;
|
||||||
|
// TODO: 一旦实体类实现了重排模型的支持,此处可以从数据库读出:
|
||||||
|
// ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel());
|
||||||
|
ChatModelVo scoringModelConfig = null; // 当前暂无对应配置字段
|
||||||
|
|
||||||
|
ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig);
|
||||||
|
if (scoringModel != null) {
|
||||||
|
contentAggregator = ReRankingContentAggregator.builder()
|
||||||
|
.scoringModel(scoringModel)
|
||||||
|
// .maxResults(3) 这个数字将来从配置取
|
||||||
|
.build();
|
||||||
|
} else {
|
||||||
|
contentAggregator = new DefaultContentAggregator();
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建向量查询参数
|
// 3. 构造流水线
|
||||||
QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel);
|
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
|
||||||
|
.contentRetriever(retriever)
|
||||||
|
.contentAggregator(contentAggregator)
|
||||||
|
.build();
|
||||||
|
|
||||||
// 获取向量查询结果
|
// 4. 执行 Augmentor 增强:将检索到的知识内容编织进 UserMessage 中
|
||||||
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
|
Metadata ragMetadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>());
|
||||||
for (String prompt : nearestList) {
|
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, ragMetadata);
|
||||||
// 知识库内容作为系统上下文添加
|
|
||||||
messages.add( new AiMessage(prompt));
|
AugmentationResult augmentedResult = augmentor.augment(augmentationRequest);
|
||||||
|
|
||||||
|
ChatMessage augmentedMessage = augmentedResult.chatMessage();
|
||||||
|
if (augmentedMessage instanceof UserMessage) {
|
||||||
|
userMessage = (UserMessage) augmentedMessage;
|
||||||
|
}
|
||||||
|
log.info("RAG 增强完成: UserMessage 已重构并附加上下文背景。");
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从数据库查询历史对话消息
|
// 从数据库查询历史对话消息(历史消息应放在当前提问前)
|
||||||
if (chatRequest.getSessionId() != null) {
|
if (chatRequest.getSessionId() != null) {
|
||||||
MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId());
|
MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId());
|
||||||
if (memory != null) {
|
if (memory != null) {
|
||||||
@@ -390,6 +429,9 @@ public class ChatServiceFacade implements IChatService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 注入本次用户提问(经过 RAG 增强后的 UserMessage)
|
||||||
|
messages.add(userMessage);
|
||||||
|
|
||||||
return messages;
|
return messages;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
|||||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||||
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -65,4 +66,12 @@ public interface IKnowledgeFragmentService {
|
|||||||
* @return 是否删除成功
|
* @return 是否删除成功
|
||||||
*/
|
*/
|
||||||
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
|
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索测试
|
||||||
|
*
|
||||||
|
* @param bo 检索参数
|
||||||
|
* @return 检索结果
|
||||||
|
*/
|
||||||
|
List<KnowledgeRetrievalVo> retrieval(KnowledgeFragmentBo bo);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,8 +13,17 @@ import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
|||||||
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||||
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
|
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||||
|
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||||
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
||||||
|
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||||
|
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
||||||
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -32,6 +41,9 @@ import java.util.Collection;
|
|||||||
public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
||||||
|
|
||||||
private final KnowledgeFragmentMapper baseMapper;
|
private final KnowledgeFragmentMapper baseMapper;
|
||||||
|
private final IKnowledgeInfoService knowledgeInfoService;
|
||||||
|
private final IChatModelService chatModelService;
|
||||||
|
private final VectorStoreService vectorStoreService;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 查询知识片段
|
* 查询知识片段
|
||||||
@@ -131,4 +143,45 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
|||||||
}
|
}
|
||||||
return baseMapper.deleteByIds(ids) > 0;
|
return baseMapper.deleteByIds(ids) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索测试核心实现
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> retrieval(KnowledgeFragmentBo bo) {
|
||||||
|
if (bo.getKnowledgeId() == null || StringUtils.isBlank(bo.getQuery())) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 获取知识库及模型配置
|
||||||
|
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId());
|
||||||
|
if (knowledgeInfoVo == null) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
if (chatModel == null) {
|
||||||
|
log.warn("未找到对应的向量模型配置: {}", knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 构造向量检索参数
|
||||||
|
QueryVectorBo queryVectorBo = new QueryVectorBo();
|
||||||
|
queryVectorBo.setQuery(bo.getQuery());
|
||||||
|
queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId()));
|
||||||
|
queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit());
|
||||||
|
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
|
||||||
|
queryVectorBo.setApiKey(chatModel.getApiKey());
|
||||||
|
queryVectorBo.setBaseUrl(chatModel.getApiHost());
|
||||||
|
|
||||||
|
// 3. 执行物理检索
|
||||||
|
List<KnowledgeRetrievalVo> allResults = vectorStoreService.search(queryVectorBo);
|
||||||
|
|
||||||
|
// 4. 根据阈值过滤 (LangChain4j 结果 score 通常 0-1)
|
||||||
|
double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0;
|
||||||
|
return allResults.stream()
|
||||||
|
.filter(res -> res.getScore() >= threshold)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
package org.ruoyi.service.knowledge.rerank;
|
||||||
|
|
||||||
|
import dev.langchain4j.model.scoring.ScoringModel;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重排模型提供商工厂
|
||||||
|
* 用于将来无缝拓展硅基流动、百炼等支持重排的模型厂商
|
||||||
|
*
|
||||||
|
* @author RobustH
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
public class ScoringModelFactory {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据后台传递的模型配置创建具体的重排模型
|
||||||
|
*
|
||||||
|
* @param rerankModelConfig 重排模型的配置 (例如其 providerCode, apiUrl, apiKey 等)
|
||||||
|
* @return 标准的 LangChain4j ScoringModel
|
||||||
|
*/
|
||||||
|
public ScoringModel createScoringModel(ChatModelVo rerankModelConfig) {
|
||||||
|
if (rerankModelConfig == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
String providerCode = rerankModelConfig.getProviderCode();
|
||||||
|
log.info("初始化重排模型,供应商代码: {}", providerCode);
|
||||||
|
|
||||||
|
// TODO: 在这里通过 switch 或反射具体实例化支持的各种 ScoringModel (例如 CohereScoringModel, DascScope 等)
|
||||||
|
// 目前返回 null 代表暂时没有加载特定的重排底座,这不会影响流程,Aggregator 会忽略它返回原样结果
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package org.ruoyi.service.knowledge.retriever;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.rag.content.Content;
|
||||||
|
import dev.langchain4j.rag.content.retriever.ContentRetriever;
|
||||||
|
import dev.langchain4j.rag.query.Query;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||||
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||||
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 自定义向量检索器:适配 LangChain4j ContentRetriever 接口
|
||||||
|
* 桥接现有的 VectorStoreService 获取检索结果
|
||||||
|
*
|
||||||
|
* @author RobustH
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class CustomVectorRetriever implements ContentRetriever {
|
||||||
|
|
||||||
|
private final VectorStoreService vectorStoreService;
|
||||||
|
private final KnowledgeInfoVo knowledgeInfoVo;
|
||||||
|
private final ChatModelVo chatModelVo;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Content> retrieve(Query query) {
|
||||||
|
log.info("执行自定义向量检索,关键字: {}", query.text());
|
||||||
|
|
||||||
|
// 构建内部查询参数
|
||||||
|
QueryVectorBo queryVectorBo = new QueryVectorBo();
|
||||||
|
queryVectorBo.setQuery(query.text());
|
||||||
|
queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId()));
|
||||||
|
queryVectorBo.setApiKey(chatModelVo.getApiKey());
|
||||||
|
queryVectorBo.setBaseUrl(chatModelVo.getApiHost());
|
||||||
|
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
|
||||||
|
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
// 如果接入了重排,这里的 retrieveLimit 也就是 MaxResults 应当被放大,后续留给 Aggregator 截断
|
||||||
|
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
|
||||||
|
|
||||||
|
// 执行底层的多种向量库策略检索
|
||||||
|
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
|
||||||
|
|
||||||
|
// 将结果包装为标准的 Content 返回
|
||||||
|
return nearestList.stream()
|
||||||
|
.map(text -> Content.from(TextSegment.from(text)))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package org.ruoyi.service.vector;
|
|||||||
import org.ruoyi.common.core.exception.ServiceException;
|
import org.ruoyi.common.core.exception.ServiceException;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -17,6 +18,11 @@ public interface VectorStoreService {
|
|||||||
|
|
||||||
List<String> getQueryVector(QueryVectorBo queryVectorBo);
|
List<String> getQueryVector(QueryVectorBo queryVectorBo);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 带分数及元数据的检索(用于测试检索功能)
|
||||||
|
*/
|
||||||
|
List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo);
|
||||||
|
|
||||||
void createSchema(String kid, String embeddingModelName);
|
void createSchema(String kid, String embeddingModelName);
|
||||||
|
|
||||||
void removeById(String id, String modelName) throws ServiceException;
|
void removeById(String id, String modelName) throws ServiceException;
|
||||||
|
|||||||
@@ -37,6 +37,24 @@ public abstract class AbstractVectorStoreStrategy implements VectorStoreService
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 向量 L2 归一化 (单位化)
|
||||||
|
*/
|
||||||
|
protected static float[] normalize(float[] vector) {
|
||||||
|
if (vector == null) return null;
|
||||||
|
double sum = 0;
|
||||||
|
for (float v : vector) {
|
||||||
|
sum += v * v;
|
||||||
|
}
|
||||||
|
float norm = (float) Math.sqrt(sum);
|
||||||
|
if (norm > 1e-9) {
|
||||||
|
for (int i = 0; i < vector.length; i++) {
|
||||||
|
vector[i] /= norm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vector;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取向量模型
|
* 获取向量模型
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -19,7 +19,11 @@ import org.ruoyi.common.chat.service.chat.IChatModelService;
|
|||||||
import org.ruoyi.config.VectorStoreProperties;
|
import org.ruoyi.config.VectorStoreProperties;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.factory.EmbeddingModelFactory;
|
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||||
|
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
|
||||||
|
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -32,10 +36,14 @@ import java.util.stream.IntStream;
|
|||||||
@Component
|
@Component
|
||||||
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||||
|
|
||||||
|
private final KnowledgeAttachMapper knowledgeAttachMapper;
|
||||||
|
|
||||||
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||||
IChatModelService chatModelService,
|
IChatModelService chatModelService,
|
||||||
EmbeddingModelFactory embeddingModelFactory) {
|
EmbeddingModelFactory embeddingModelFactory,
|
||||||
|
KnowledgeAttachMapper knowledgeAttachMapper) {
|
||||||
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
||||||
|
this.knowledgeAttachMapper = knowledgeAttachMapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
|
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
|
||||||
@@ -51,7 +59,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
.collectionName(collectionName)
|
.collectionName(collectionName)
|
||||||
.dimension(dimension)
|
.dimension(dimension)
|
||||||
.indexType(IndexType.IVF_FLAT)
|
.indexType(IndexType.IVF_FLAT)
|
||||||
.metricType(MetricType.L2)
|
.metricType(MetricType.COSINE)
|
||||||
.autoFlushOnInsert(autoFlushOnInsert)
|
.autoFlushOnInsert(autoFlushOnInsert)
|
||||||
.idFieldName("id")
|
.idFieldName("id")
|
||||||
.textFieldName("text")
|
.textFieldName("text")
|
||||||
@@ -104,7 +112,10 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
|
|
||||||
TextSegment textSegment = TextSegment.from(text, metadata);
|
TextSegment textSegment = TextSegment.from(text, metadata);
|
||||||
Embedding embedding = embeddingModel.embed(text).content();
|
Embedding embedding = embeddingModel.embed(text).content();
|
||||||
embeddingStore.add(embedding, textSegment);
|
// 单位化处理
|
||||||
|
float[] vector = embedding.vector();
|
||||||
|
normalize(vector);
|
||||||
|
embeddingStore.add(Embedding.from(vector), textSegment);
|
||||||
});
|
});
|
||||||
long endTime = System.currentTimeMillis();
|
long endTime = System.currentTimeMillis();
|
||||||
log.info("Milvus向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000);
|
log.info("Milvus向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000);
|
||||||
@@ -136,6 +147,55 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
return resultList;
|
return resultList;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||||
|
int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName());
|
||||||
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
|
|
||||||
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
|
// 查询向量单位化处理
|
||||||
|
float[] queryVector = queryEmbedding.vector();
|
||||||
|
normalize(queryVector);
|
||||||
|
|
||||||
|
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
|
||||||
|
|
||||||
|
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, dimension, true);
|
||||||
|
|
||||||
|
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||||
|
.queryEmbedding(Embedding.from(queryVector))
|
||||||
|
.maxResults(queryVectorBo.getMaxResults())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(request).matches();
|
||||||
|
List<KnowledgeRetrievalVo> resultList = new ArrayList<>();
|
||||||
|
|
||||||
|
for (EmbeddingMatch<TextSegment> match : matches) {
|
||||||
|
TextSegment segment = match.embedded();
|
||||||
|
if (segment == null) continue;
|
||||||
|
|
||||||
|
String docId = segment.metadata().getString("docId");
|
||||||
|
String sourceName = "未知来源";
|
||||||
|
if (docId != null) {
|
||||||
|
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
|
||||||
|
.eq(KnowledgeAttach::getDocId, docId)
|
||||||
|
.last("limit 1"));
|
||||||
|
if (attach != null) {
|
||||||
|
sourceName = attach.getName();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取内容、评分及来源
|
||||||
|
double score = match.score();
|
||||||
|
|
||||||
|
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
|
||||||
|
.content(segment.text())
|
||||||
|
.score(score)
|
||||||
|
.sourceName(sourceName)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
return resultList;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
public void removeById(String id, String modelName) {
|
public void removeById(String id, String modelName) {
|
||||||
|
|||||||
@@ -24,7 +24,11 @@ import org.ruoyi.common.core.exception.ServiceException;
|
|||||||
import org.ruoyi.config.VectorStoreProperties;
|
import org.ruoyi.config.VectorStoreProperties;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.factory.EmbeddingModelFactory;
|
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||||
|
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||||
|
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import static io.qdrant.client.VectorInputFactory.vectorInput;
|
import static io.qdrant.client.VectorInputFactory.vectorInput;
|
||||||
@@ -47,10 +51,14 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
private static final String METADATA_KID_KEY = "kid";
|
private static final String METADATA_KID_KEY = "kid";
|
||||||
private static final String METADATA_DOC_ID_KEY = "doc_id";
|
private static final String METADATA_DOC_ID_KEY = "doc_id";
|
||||||
|
|
||||||
|
private final KnowledgeAttachMapper knowledgeAttachMapper;
|
||||||
|
|
||||||
public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||||
IChatModelService chatModelService,
|
IChatModelService chatModelService,
|
||||||
EmbeddingModelFactory embeddingModelFactory) {
|
EmbeddingModelFactory embeddingModelFactory,
|
||||||
|
KnowledgeAttachMapper knowledgeAttachMapper) {
|
||||||
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
||||||
|
this.knowledgeAttachMapper = knowledgeAttachMapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
private EmbeddingStore<TextSegment> getQdrantStore(String collectionName) {
|
private EmbeddingStore<TextSegment> getQdrantStore(String collectionName) {
|
||||||
@@ -129,7 +137,10 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
metadata.put(METADATA_DOC_ID_KEY, docId);
|
metadata.put(METADATA_DOC_ID_KEY, docId);
|
||||||
TextSegment textSegment = TextSegment.from(text, metadata);
|
TextSegment textSegment = TextSegment.from(text, metadata);
|
||||||
Embedding embedding = embeddingModel.embed(text).content();
|
Embedding embedding = embeddingModel.embed(text).content();
|
||||||
embeddingStore.add(embedding, textSegment);
|
// 单位化处理
|
||||||
|
float[] vector = embedding.vector();
|
||||||
|
normalize(vector);
|
||||||
|
embeddingStore.add(Embedding.from(vector), textSegment);
|
||||||
});
|
});
|
||||||
|
|
||||||
long endTime = System.currentTimeMillis();
|
long endTime = System.currentTimeMillis();
|
||||||
@@ -140,18 +151,22 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
|
// 查询向量单位化处理
|
||||||
|
float[] queryVector = queryEmbedding.vector();
|
||||||
|
normalize(queryVector);
|
||||||
|
|
||||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
|
||||||
|
|
||||||
List<Float> vector = new ArrayList<>();
|
List<Float> vectorList = new ArrayList<>();
|
||||||
for (float f : queryEmbedding.vector()) {
|
for (float f : queryVector) {
|
||||||
vector.add(f);
|
vectorList.add(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
try (QdrantClient client = buildQdrantClient()) {
|
try (QdrantClient client = buildQdrantClient()) {
|
||||||
QueryPoints request = QueryPoints.newBuilder()
|
QueryPoints request = QueryPoints.newBuilder()
|
||||||
.setCollectionName(collectionName)
|
.setCollectionName(collectionName)
|
||||||
.setQuery(Query.newBuilder()
|
.setQuery(Query.newBuilder()
|
||||||
.setNearest(vectorInput(vector))
|
.setNearest(vectorInput(vectorList))
|
||||||
.build())
|
.build())
|
||||||
.setLimit(queryVectorBo.getMaxResults())
|
.setLimit(queryVectorBo.getMaxResults())
|
||||||
.setWithPayload(enable(true))
|
.setWithPayload(enable(true))
|
||||||
@@ -172,6 +187,69 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||||
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
|
// 查询向量单位化处理
|
||||||
|
float[] queryVector = queryEmbedding.vector();
|
||||||
|
normalize(queryVector);
|
||||||
|
|
||||||
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
|
||||||
|
|
||||||
|
List<Float> vectorList = new ArrayList<>();
|
||||||
|
for (float f : queryVector) {
|
||||||
|
vectorList.add(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
try (QdrantClient client = buildQdrantClient()) {
|
||||||
|
QueryPoints request = QueryPoints.newBuilder()
|
||||||
|
.setCollectionName(collectionName)
|
||||||
|
.setQuery(Query.newBuilder()
|
||||||
|
.setNearest(vectorInput(vectorList))
|
||||||
|
.build())
|
||||||
|
.setLimit(queryVectorBo.getMaxResults())
|
||||||
|
.setWithPayload(enable(true))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<ScoredPoint> results = client.queryAsync(request).get();
|
||||||
|
List<org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo> resultList = new ArrayList<>();
|
||||||
|
for (ScoredPoint point : results) {
|
||||||
|
String content = "";
|
||||||
|
JsonWithInt.Value textValue = point.getPayloadMap().get(TEXT_SEGMENT_KEY);
|
||||||
|
if (textValue != null && textValue.hasStringValue()) {
|
||||||
|
content = textValue.getStringValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
String docId = null;
|
||||||
|
JsonWithInt.Value docIdValue = point.getPayloadMap().get(METADATA_DOC_ID_KEY);
|
||||||
|
if (docIdValue != null && docIdValue.hasStringValue()) {
|
||||||
|
docId = docIdValue.getStringValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
String sourceName = "未知来源";
|
||||||
|
if (docId != null) {
|
||||||
|
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
|
||||||
|
.eq(KnowledgeAttach::getDocId, docId)
|
||||||
|
.last("limit 1"));
|
||||||
|
if (attach != null) {
|
||||||
|
sourceName = attach.getName();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
|
||||||
|
.content(content)
|
||||||
|
.score((double) point.getScore())
|
||||||
|
.sourceName(sourceName)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
return resultList;
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Qdrant检索失败: {}", collectionName, e);
|
||||||
|
throw new ServiceException("Qdrant向量检索失败");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void removeById(String id, String modelName) {
|
public void removeById(String id, String modelName) {
|
||||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id;
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id;
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import lombok.RequiredArgsConstructor;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.factory.VectorStoreStrategyFactory;
|
import org.ruoyi.factory.VectorStoreStrategyFactory;
|
||||||
import org.ruoyi.service.vector.VectorStoreService;
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
import org.springframework.context.annotation.Primary;
|
import org.springframework.context.annotation.Primary;
|
||||||
@@ -54,6 +55,13 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|||||||
return strategy.getQueryVector(queryVectorBo);
|
return strategy.getQueryVector(queryVectorBo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||||
|
log.info("执行测试搜索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
|
||||||
|
VectorStoreService strategy = getCurrentStrategy();
|
||||||
|
return strategy.search(queryVectorBo);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void removeById(String id, String modelName) {
|
public void removeById(String id, String modelName) {
|
||||||
log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName);
|
log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName);
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import org.ruoyi.common.core.exception.ServiceException;
|
|||||||
import org.ruoyi.config.VectorStoreProperties;
|
import org.ruoyi.config.VectorStoreProperties;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.factory.EmbeddingModelFactory;
|
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import io.weaviate.client.Config;
|
import io.weaviate.client.Config;
|
||||||
@@ -24,6 +25,9 @@ import io.weaviate.client.v1.graphql.model.GraphQLResponse;
|
|||||||
import io.weaviate.client.v1.schema.model.Property;
|
import io.weaviate.client.v1.schema.model.Property;
|
||||||
import io.weaviate.client.v1.schema.model.Schema;
|
import io.weaviate.client.v1.schema.model.Schema;
|
||||||
import io.weaviate.client.v1.schema.model.WeaviateClass;
|
import io.weaviate.client.v1.schema.model.WeaviateClass;
|
||||||
|
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||||
|
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@@ -40,11 +44,14 @@ import java.util.Map;
|
|||||||
public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||||
|
|
||||||
private WeaviateClient client;
|
private WeaviateClient client;
|
||||||
|
private final KnowledgeAttachMapper knowledgeAttachMapper;
|
||||||
|
|
||||||
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||||
IChatModelService chatModelService,
|
IChatModelService chatModelService,
|
||||||
EmbeddingModelFactory embeddingModelFactory) {
|
EmbeddingModelFactory embeddingModelFactory,
|
||||||
|
KnowledgeAttachMapper knowledgeAttachMapper) {
|
||||||
super(vectorStoreProperties, embeddingModelFactory,chatModelService);
|
super(vectorStoreProperties, embeddingModelFactory,chatModelService);
|
||||||
|
this.knowledgeAttachMapper = knowledgeAttachMapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -110,9 +117,12 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
"kid", kid,
|
"kid", kid,
|
||||||
"docId", docId
|
"docId", docId
|
||||||
);
|
);
|
||||||
Float[] vector = toObjectArray(embedding.vector());
|
float[] vectorArray = embedding.vector();
|
||||||
|
normalize(vectorArray);
|
||||||
|
Float[] vector = toObjectArray(vectorArray);
|
||||||
|
|
||||||
client.data().creator()
|
client.data().creator()
|
||||||
.withClassName("LocalKnowledge" + kid)
|
.withClassName(vectorStoreProperties.getWeaviate().getClassname() + kid)
|
||||||
.withProperties(properties)
|
.withProperties(properties)
|
||||||
.withVector(vector)
|
.withVector(vector)
|
||||||
.run();
|
.run();
|
||||||
@@ -128,6 +138,9 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
float[] vector = queryEmbedding.vector();
|
float[] vector = queryEmbedding.vector();
|
||||||
|
// 查询向量单位化处理
|
||||||
|
normalize(vector);
|
||||||
|
|
||||||
List<String> vectorStrings = new ArrayList<>();
|
List<String> vectorStrings = new ArrayList<>();
|
||||||
for (float v : vector) {
|
for (float v : vector) {
|
||||||
vectorStrings.add(String.valueOf(v));
|
vectorStrings.add(String.valueOf(v));
|
||||||
@@ -178,6 +191,77 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||||
|
createSchema(queryVectorBo.getKid(), queryVectorBo.getEmbeddingModelName());
|
||||||
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
|
float[] vector = queryEmbedding.vector();
|
||||||
|
// 查询向量单位化处理
|
||||||
|
normalize(vector);
|
||||||
|
List<String> vectorStrings = new ArrayList<>();
|
||||||
|
for (float v : vector) {
|
||||||
|
vectorStrings.add(String.valueOf(v));
|
||||||
|
}
|
||||||
|
String vectorStr = String.join(",", vectorStrings);
|
||||||
|
String className = vectorStoreProperties.getWeaviate().getClassname();
|
||||||
|
|
||||||
|
String graphQLQuery = String.format(
|
||||||
|
"{\n" +
|
||||||
|
" Get {\n" +
|
||||||
|
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
|
||||||
|
" text\n" +
|
||||||
|
" docId\n" +
|
||||||
|
" _additional {\n" +
|
||||||
|
" distance\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
"}",
|
||||||
|
className + queryVectorBo.getKid(),
|
||||||
|
vectorStr,
|
||||||
|
queryVectorBo.getMaxResults()
|
||||||
|
);
|
||||||
|
|
||||||
|
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
|
||||||
|
List<org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo> resultList = new ArrayList<>();
|
||||||
|
|
||||||
|
if (result != null && !result.hasErrors()) {
|
||||||
|
Object data = result.getResult().getData();
|
||||||
|
JSONObject entries = new JSONObject(data);
|
||||||
|
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
|
||||||
|
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
|
||||||
|
|
||||||
|
for (Object obj : objects) {
|
||||||
|
Map<String, Object> map = (Map<String, Object>) obj;
|
||||||
|
String content = (String) map.get("text");
|
||||||
|
String docId = (String) map.get("docId");
|
||||||
|
|
||||||
|
Map<String, Object> additional = (Map<String, Object>) map.get("_additional");
|
||||||
|
Double distance = Double.valueOf(String.valueOf(additional.get("distance")));
|
||||||
|
// 转换距离为得分 (Weaviate 0 是最相近,1 是最远;余弦距离下 1-dist 即为相似度)
|
||||||
|
double score = 1.0 - distance;
|
||||||
|
|
||||||
|
String sourceName = "未知来源";
|
||||||
|
if (docId != null) {
|
||||||
|
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
|
||||||
|
.eq(KnowledgeAttach::getDocId, docId)
|
||||||
|
.last("limit 1"));
|
||||||
|
if (attach != null) {
|
||||||
|
sourceName = attach.getName();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
|
||||||
|
.content(content)
|
||||||
|
.score(score)
|
||||||
|
.sourceName(sourceName)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resultList;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
public void removeById(String id, String modelName) {
|
public void removeById(String id, String modelName) {
|
||||||
|
|||||||
Reference in New Issue
Block a user