diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java index ac79a62e..d739a30f 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java @@ -8,6 +8,7 @@ import jakarta.validation.constraints.*; import cn.dev33.satoken.annotation.SaCheckPermission; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.springframework.web.bind.annotation.*; import org.springframework.validation.annotation.Validated; @@ -102,4 +103,12 @@ public class KnowledgeFragmentController extends BaseController { @PathVariable Long[] ids) { return toAjax(knowledgeFragmentService.deleteWithValidByIds(List.of(ids), true)); } + + /** + * 检索测试 + */ + @PostMapping("/retrieval") + public R> retrieval(@RequestBody KnowledgeFragmentBo bo) { + return R.ok(knowledgeFragmentService.retrieval(bo)); + } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java index 7472f57f..e1925028 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java @@ -49,5 +49,24 @@ public class KnowledgeFragmentBo extends BaseEntity { */ private String remark; + /** + * 知识库ID + */ + private Long knowledgeId; + + /** + * 检索内容 + */ + private String query; + + /** + * 返回条数 + */ + private Integer topK; + + /** + * 相似度阈值 + */ + private Double threshold; } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java new file mode 100644 index 00000000..daeaae59 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java @@ -0,0 +1,35 @@ +package org.ruoyi.domain.vo.knowledge; + +import lombok.Builder; +import lombok.Data; + +import java.io.Serial; +import java.io.Serializable; + +/** + * 知识检索测试结果视图对象 + * + * @author RobustH + */ +@Data +@Builder +public class KnowledgeRetrievalVo implements Serializable { + + @Serial + private static final long serialVersionUID = 1L; + + /** + * 片段内容 + */ + private String content; + + /** + * 相似度得分 + */ + private Double score; + + /** + * 来源文档名称 + */ + private String sourceName; +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java index ba1c5ac8..2332ca84 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java @@ -16,6 +16,15 @@ import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.rag.AugmentationRequest; +import dev.langchain4j.rag.AugmentationResult; +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.content.aggregator.ContentAggregator; +import dev.langchain4j.rag.content.aggregator.DefaultContentAggregator; +import dev.langchain4j.rag.content.aggregator.ReRankingContentAggregator; +import dev.langchain4j.model.scoring.ScoringModel; +import dev.langchain4j.rag.query.Metadata; import dev.langchain4j.service.tool.ToolProvider; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; @@ -46,6 +55,8 @@ import org.ruoyi.mcp.service.core.ToolProviderFactory; import org.ruoyi.service.chat.AbstractChatService; import org.ruoyi.service.chat.IChatMessageService; import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore; +import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever; +import org.ruoyi.service.knowledge.rerank.ScoringModelFactory; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; @@ -89,6 +100,8 @@ public class ChatServiceFacade implements IChatService { private final ToolProviderFactory toolProviderFactory; + private final ScoringModelFactory scoringModelFactory; + /** * 内存实例缓存,避免同一会话重复创建 * Key: sessionId, Value: MessageWindowChatMemory实例 @@ -119,7 +132,9 @@ public class ChatServiceFacade implements IChatService { // 2. 构建上下文消息列表 List contextMessages = buildContextMessages(chatRequest); - // 3. 处理特殊聊天模式(工作流、人机交互恢复、思考模式) + // 注意:buildContextMessages() 最后返回的列表中,最新的带有增强知识的 UserMessage 在最后。 + // 对于有些模型API(非langchain4j的代理),它们可能不识别增强后的复杂文本(取决于供应商适配度) + // 但是通过标准流,它被解析为 String。 SseEmitter specialResult = handleSpecialChatModes(chatRequest, contextMessages, chatModelVo, emitter); if (specialResult != null) { return specialResult; @@ -346,39 +361,63 @@ public class ChatServiceFacade implements IChatService { * @return 上下文消息列表 */ private List buildContextMessages(ChatRequest chatRequest) { - List messages = new ArrayList<>(); - // 构建用户消息 + List messages = new ArrayList<>(); + + // 初始化用户消息 UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent()); - messages.add(userMessage); - // 从向量库查询相关历史消息 + // 使用 LangChain4j 的 RetrievalAugmentor 进行检索增强 if (chatRequest.getKnowledgeId() != null) { - // 查询知识库信息 KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId())); - if (knowledgeInfoVo == null) { - log.warn("知识库信息不存在,kid: {}", chatRequest.getKnowledgeId()); - return messages; - } + if (knowledgeInfoVo != null) { + ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); + if (chatModel != null) { - // 查询向量模型配置信息 - ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); - if (chatModel == null) { - log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel()); - return messages; - } + // 1. 构建适配器(Retriever) + CustomVectorRetriever retriever = new CustomVectorRetriever( + vectorStoreService, knowledgeInfoVo, chatModel); - // 构建向量查询参数 - QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel); + // 2. 获取和构建重排模型聚合器(Aggregator) + // 假设已在 KnowledgeInfoVo 等加入 getRerankModelConfig/getRerankModel 等,这里演示通用逻辑 + // 若无重排需求,使用 DefaultContentAggregator 或无 ScoringModel 的聚合器 + ContentAggregator contentAggregator; + // TODO: 一旦实体类实现了重排模型的支持,此处可以从数据库读出: + // ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel()); + ChatModelVo scoringModelConfig = null; // 当前暂无对应配置字段 - // 获取向量查询结果 - List nearestList = vectorStoreService.getQueryVector(queryVectorBo); - for (String prompt : nearestList) { - // 知识库内容作为系统上下文添加 - messages.add( new AiMessage(prompt)); + ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig); + if (scoringModel != null) { + contentAggregator = ReRankingContentAggregator.builder() + .scoringModel(scoringModel) + // .maxResults(3) 这个数字将来从配置取 + .build(); + } else { + contentAggregator = new DefaultContentAggregator(); + } + + // 3. 构造流水线 + RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder() + .contentRetriever(retriever) + .contentAggregator(contentAggregator) + .build(); + + // 4. 执行 Augmentor 增强:将检索到的知识内容编织进 UserMessage 中 + Metadata ragMetadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>()); + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, ragMetadata); + + AugmentationResult augmentedResult = augmentor.augment(augmentationRequest); + + ChatMessage augmentedMessage = augmentedResult.chatMessage(); + if (augmentedMessage instanceof UserMessage) { + userMessage = (UserMessage) augmentedMessage; + } + log.info("RAG 增强完成: UserMessage 已重构并附加上下文背景。"); + + } } } - // 从数据库查询历史对话消息 + // 从数据库查询历史对话消息(历史消息应放在当前提问前) if (chatRequest.getSessionId() != null) { MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId()); if (memory != null) { @@ -390,6 +429,9 @@ public class ChatServiceFacade implements IChatService { } } + // 注入本次用户提问(经过 RAG 增强后的 UserMessage) + messages.add(userMessage); + return messages; } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java index e323e79e..b8b88b45 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java @@ -4,6 +4,7 @@ import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.common.mybatis.core.page.PageQuery; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import java.util.Collection; import java.util.List; @@ -65,4 +66,12 @@ public interface IKnowledgeFragmentService { * @return 是否删除成功 */ Boolean deleteWithValidByIds(Collection ids, Boolean isValid); + + /** + * 检索测试 + * + * @param bo 检索参数 + * @return 检索结果 + */ + List retrieval(KnowledgeFragmentBo bo); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java index 782b40af..9c2477df 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java @@ -13,8 +13,17 @@ import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; +import org.ruoyi.service.knowledge.IKnowledgeInfoService; +import org.ruoyi.common.chat.service.chat.IChatModelService; +import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; +import java.util.ArrayList; +import java.util.stream.Collectors; import java.util.List; import java.util.Map; @@ -32,6 +41,9 @@ import java.util.Collection; public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { private final KnowledgeFragmentMapper baseMapper; + private final IKnowledgeInfoService knowledgeInfoService; + private final IChatModelService chatModelService; + private final VectorStoreService vectorStoreService; /** * 查询知识片段 @@ -131,4 +143,45 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { } return baseMapper.deleteByIds(ids) > 0; } + + /** + * 检索测试核心实现 + */ + @Override + public List retrieval(KnowledgeFragmentBo bo) { + if (bo.getKnowledgeId() == null || StringUtils.isBlank(bo.getQuery())) { + return new ArrayList<>(); + } + + // 1. 获取知识库及模型配置 + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId()); + if (knowledgeInfoVo == null) { + return new ArrayList<>(); + } + + ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); + if (chatModel == null) { + log.warn("未找到对应的向量模型配置: {}", knowledgeInfoVo.getEmbeddingModel()); + return new ArrayList<>(); + } + + // 2. 构造向量检索参数 + QueryVectorBo queryVectorBo = new QueryVectorBo(); + queryVectorBo.setQuery(bo.getQuery()); + queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId())); + queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit()); + queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); + queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); + queryVectorBo.setApiKey(chatModel.getApiKey()); + queryVectorBo.setBaseUrl(chatModel.getApiHost()); + + // 3. 执行物理检索 + List allResults = vectorStoreService.search(queryVectorBo); + + // 4. 根据阈值过滤 (LangChain4j 结果 score 通常 0-1) + double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0; + return allResults.stream() + .filter(res -> res.getScore() >= threshold) + .collect(Collectors.toList()); + } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java new file mode 100644 index 00000000..5f28b9c2 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java @@ -0,0 +1,37 @@ +package org.ruoyi.service.knowledge.rerank; + +import dev.langchain4j.model.scoring.ScoringModel; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.springframework.stereotype.Component; + +/** + * 重排模型提供商工厂 + * 用于将来无缝拓展硅基流动、百炼等支持重排的模型厂商 + * + * @author RobustH + */ +@Slf4j +@Component +public class ScoringModelFactory { + + /** + * 根据后台传递的模型配置创建具体的重排模型 + * + * @param rerankModelConfig 重排模型的配置 (例如其 providerCode, apiUrl, apiKey 等) + * @return 标准的 LangChain4j ScoringModel + */ + public ScoringModel createScoringModel(ChatModelVo rerankModelConfig) { + if (rerankModelConfig == null) { + return null; + } + + String providerCode = rerankModelConfig.getProviderCode(); + log.info("初始化重排模型,供应商代码: {}", providerCode); + + // TODO: 在这里通过 switch 或反射具体实例化支持的各种 ScoringModel (例如 CohereScoringModel, DascScope 等) + // 目前返回 null 代表暂时没有加载特定的重排底座,这不会影响流程,Aggregator 会忽略它返回原样结果 + + return null; + } +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java new file mode 100644 index 00000000..6d876710 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java @@ -0,0 +1,54 @@ +package org.ruoyi.service.knowledge.retriever; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.rag.query.Query; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.domain.bo.vector.QueryVectorBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; +import org.ruoyi.service.vector.VectorStoreService; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * 自定义向量检索器:适配 LangChain4j ContentRetriever 接口 + * 桥接现有的 VectorStoreService 获取检索结果 + * + * @author RobustH + */ +@Slf4j +@RequiredArgsConstructor +public class CustomVectorRetriever implements ContentRetriever { + + private final VectorStoreService vectorStoreService; + private final KnowledgeInfoVo knowledgeInfoVo; + private final ChatModelVo chatModelVo; + + @Override + public List retrieve(Query query) { + log.info("执行自定义向量检索,关键字: {}", query.text()); + + // 构建内部查询参数 + QueryVectorBo queryVectorBo = new QueryVectorBo(); + queryVectorBo.setQuery(query.text()); + queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId())); + queryVectorBo.setApiKey(chatModelVo.getApiKey()); + queryVectorBo.setBaseUrl(chatModelVo.getApiHost()); + queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); + queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); + // 如果接入了重排,这里的 retrieveLimit 也就是 MaxResults 应当被放大,后续留给 Aggregator 截断 + queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit()); + + // 执行底层的多种向量库策略检索 + List nearestList = vectorStoreService.getQueryVector(queryVectorBo); + + // 将结果包装为标准的 Content 返回 + return nearestList.stream() + .map(text -> Content.from(TextSegment.from(text))) + .collect(Collectors.toList()); + } +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java index 3c37835f..66a6a6f2 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java @@ -3,6 +3,7 @@ package org.ruoyi.service.vector; import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import java.util.List; @@ -17,6 +18,11 @@ public interface VectorStoreService { List getQueryVector(QueryVectorBo queryVectorBo); + /** + * 带分数及元数据的检索(用于测试检索功能) + */ + List search(QueryVectorBo queryVectorBo); + void createSchema(String kid, String embeddingModelName); void removeById(String id, String modelName) throws ServiceException; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java index 906c8090..2fd2052d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java @@ -37,6 +37,24 @@ public abstract class AbstractVectorStoreStrategy implements VectorStoreService return result; } + /** + * 向量 L2 归一化 (单位化) + */ + protected static float[] normalize(float[] vector) { + if (vector == null) return null; + double sum = 0; + for (float v : vector) { + sum += v * v; + } + float norm = (float) Math.sqrt(sum); + if (norm > 1e-9) { + for (int i = 0; i < vector.length; i++) { + vector[i] /= norm; + } + } + return vector; + } + /** * 获取向量模型 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java index baf1c612..c06b9c73 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java @@ -19,7 +19,11 @@ import org.ruoyi.common.chat.service.chat.IChatModelService; import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.EmbeddingModelFactory; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; +import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.springframework.stereotype.Component; import java.util.ArrayList; @@ -32,10 +36,14 @@ import java.util.stream.IntStream; @Component public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { + private final KnowledgeAttachMapper knowledgeAttachMapper; + public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, IChatModelService chatModelService, - EmbeddingModelFactory embeddingModelFactory) { + EmbeddingModelFactory embeddingModelFactory, + KnowledgeAttachMapper knowledgeAttachMapper) { super(vectorStoreProperties, embeddingModelFactory, chatModelService); + this.knowledgeAttachMapper = knowledgeAttachMapper; } // 缓存不同集合与 autoFlush 配置的 Milvus 连接 @@ -51,7 +59,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { .collectionName(collectionName) .dimension(dimension) .indexType(IndexType.IVF_FLAT) - .metricType(MetricType.L2) + .metricType(MetricType.COSINE) .autoFlushOnInsert(autoFlushOnInsert) .idFieldName("id") .textFieldName("text") @@ -104,7 +112,10 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { TextSegment textSegment = TextSegment.from(text, metadata); Embedding embedding = embeddingModel.embed(text).content(); - embeddingStore.add(embedding, textSegment); + // 单位化处理 + float[] vector = embedding.vector(); + normalize(vector); + embeddingStore.add(Embedding.from(vector), textSegment); }); long endTime = System.currentTimeMillis(); log.info("Milvus向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000); @@ -136,6 +147,55 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { return resultList; } + @Override + public List search(QueryVectorBo queryVectorBo) { + int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); + + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + // 查询向量单位化处理 + float[] queryVector = queryEmbedding.vector(); + normalize(queryVector); + + String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid(); + + EmbeddingStore embeddingStore = getMilvusStore(collectionName, dimension, true); + + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(Embedding.from(queryVector)) + .maxResults(queryVectorBo.getMaxResults()) + .build(); + + List> matches = embeddingStore.search(request).matches(); + List resultList = new ArrayList<>(); + + for (EmbeddingMatch match : matches) { + TextSegment segment = match.embedded(); + if (segment == null) continue; + + String docId = segment.metadata().getString("docId"); + String sourceName = "未知来源"; + if (docId != null) { + KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper() + .eq(KnowledgeAttach::getDocId, docId) + .last("limit 1")); + if (attach != null) { + sourceName = attach.getName(); + } + } + + // 提取内容、评分及来源 + double score = match.score(); + + resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder() + .content(segment.text()) + .score(score) + .sourceName(sourceName) + .build()); + } + return resultList; + } + @Override @SneakyThrows public void removeById(String id, String modelName) { diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java index 973d8485..da6bca80 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java @@ -24,7 +24,11 @@ import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.EmbeddingModelFactory; +import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.springframework.stereotype.Component; import static io.qdrant.client.VectorInputFactory.vectorInput; @@ -47,10 +51,14 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { private static final String METADATA_KID_KEY = "kid"; private static final String METADATA_DOC_ID_KEY = "doc_id"; + private final KnowledgeAttachMapper knowledgeAttachMapper; + public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, IChatModelService chatModelService, - EmbeddingModelFactory embeddingModelFactory) { + EmbeddingModelFactory embeddingModelFactory, + KnowledgeAttachMapper knowledgeAttachMapper) { super(vectorStoreProperties, embeddingModelFactory, chatModelService); + this.knowledgeAttachMapper = knowledgeAttachMapper; } private EmbeddingStore getQdrantStore(String collectionName) { @@ -129,7 +137,10 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { metadata.put(METADATA_DOC_ID_KEY, docId); TextSegment textSegment = TextSegment.from(text, metadata); Embedding embedding = embeddingModel.embed(text).content(); - embeddingStore.add(embedding, textSegment); + // 单位化处理 + float[] vector = embedding.vector(); + normalize(vector); + embeddingStore.add(Embedding.from(vector), textSegment); }); long endTime = System.currentTimeMillis(); @@ -140,18 +151,22 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { public List getQueryVector(QueryVectorBo queryVectorBo) { EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + // 查询向量单位化处理 + float[] queryVector = queryEmbedding.vector(); + normalize(queryVector); + String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid(); - List vector = new ArrayList<>(); - for (float f : queryEmbedding.vector()) { - vector.add(f); + List vectorList = new ArrayList<>(); + for (float f : queryVector) { + vectorList.add(f); } try (QdrantClient client = buildQdrantClient()) { QueryPoints request = QueryPoints.newBuilder() .setCollectionName(collectionName) .setQuery(Query.newBuilder() - .setNearest(vectorInput(vector)) + .setNearest(vectorInput(vectorList)) .build()) .setLimit(queryVectorBo.getMaxResults()) .setWithPayload(enable(true)) @@ -172,6 +187,69 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { } } + @Override + public List search(QueryVectorBo queryVectorBo) { + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + // 查询向量单位化处理 + float[] queryVector = queryEmbedding.vector(); + normalize(queryVector); + + String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid(); + + List vectorList = new ArrayList<>(); + for (float f : queryVector) { + vectorList.add(f); + } + + try (QdrantClient client = buildQdrantClient()) { + QueryPoints request = QueryPoints.newBuilder() + .setCollectionName(collectionName) + .setQuery(Query.newBuilder() + .setNearest(vectorInput(vectorList)) + .build()) + .setLimit(queryVectorBo.getMaxResults()) + .setWithPayload(enable(true)) + .build(); + + List results = client.queryAsync(request).get(); + List resultList = new ArrayList<>(); + for (ScoredPoint point : results) { + String content = ""; + JsonWithInt.Value textValue = point.getPayloadMap().get(TEXT_SEGMENT_KEY); + if (textValue != null && textValue.hasStringValue()) { + content = textValue.getStringValue(); + } + + String docId = null; + JsonWithInt.Value docIdValue = point.getPayloadMap().get(METADATA_DOC_ID_KEY); + if (docIdValue != null && docIdValue.hasStringValue()) { + docId = docIdValue.getStringValue(); + } + + String sourceName = "未知来源"; + if (docId != null) { + KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper() + .eq(KnowledgeAttach::getDocId, docId) + .last("limit 1")); + if (attach != null) { + sourceName = attach.getName(); + } + } + + resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder() + .content(content) + .score((double) point.getScore()) + .sourceName(sourceName) + .build()); + } + return resultList; + } catch (Exception e) { + log.error("Qdrant检索失败: {}", collectionName, e); + throw new ServiceException("Qdrant向量检索失败"); + } + } + @Override public void removeById(String id, String modelName) { String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java index 603ed84c..73b1fa2d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java @@ -4,6 +4,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.VectorStoreStrategyFactory; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.context.annotation.Primary; @@ -54,6 +55,13 @@ public class VectorStoreServiceImpl implements VectorStoreService { return strategy.getQueryVector(queryVectorBo); } + @Override + public List search(QueryVectorBo queryVectorBo) { + log.info("执行测试搜索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery()); + VectorStoreService strategy = getCurrentStrategy(); + return strategy.search(queryVectorBo); + } + @Override public void removeById(String id, String modelName) { log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName); diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java index c62a8470..f90f0568 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java @@ -12,6 +12,7 @@ import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.EmbeddingModelFactory; import org.springframework.stereotype.Component; import io.weaviate.client.Config; @@ -24,6 +25,9 @@ import io.weaviate.client.v1.graphql.model.GraphQLResponse; import io.weaviate.client.v1.schema.model.Property; import io.weaviate.client.v1.schema.model.Schema; import io.weaviate.client.v1.schema.model.WeaviateClass; +import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import java.util.ArrayList; import java.util.Collections; @@ -40,11 +44,14 @@ import java.util.Map; public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { private WeaviateClient client; + private final KnowledgeAttachMapper knowledgeAttachMapper; public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, IChatModelService chatModelService, - EmbeddingModelFactory embeddingModelFactory) { + EmbeddingModelFactory embeddingModelFactory, + KnowledgeAttachMapper knowledgeAttachMapper) { super(vectorStoreProperties, embeddingModelFactory,chatModelService); + this.knowledgeAttachMapper = knowledgeAttachMapper; } @Override @@ -110,9 +117,12 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { "kid", kid, "docId", docId ); - Float[] vector = toObjectArray(embedding.vector()); + float[] vectorArray = embedding.vector(); + normalize(vectorArray); + Float[] vector = toObjectArray(vectorArray); + client.data().creator() - .withClassName("LocalKnowledge" + kid) + .withClassName(vectorStoreProperties.getWeaviate().getClassname() + kid) .withProperties(properties) .withVector(vector) .run(); @@ -128,6 +138,9 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); float[] vector = queryEmbedding.vector(); + // 查询向量单位化处理 + normalize(vector); + List vectorStrings = new ArrayList<>(); for (float v : vector) { vectorStrings.add(String.valueOf(v)); @@ -178,6 +191,77 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { } } + @Override + public List search(QueryVectorBo queryVectorBo) { + createSchema(queryVectorBo.getKid(), queryVectorBo.getEmbeddingModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + float[] vector = queryEmbedding.vector(); + // 查询向量单位化处理 + normalize(vector); + List vectorStrings = new ArrayList<>(); + for (float v : vector) { + vectorStrings.add(String.valueOf(v)); + } + String vectorStr = String.join(",", vectorStrings); + String className = vectorStoreProperties.getWeaviate().getClassname(); + + String graphQLQuery = String.format( + "{\n" + + " Get {\n" + + " %s(nearVector: {vector: [%s]} limit: %d) {\n" + + " text\n" + + " docId\n" + + " _additional {\n" + + " distance\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + className + queryVectorBo.getKid(), + vectorStr, + queryVectorBo.getMaxResults() + ); + + Result result = client.graphQL().raw().withQuery(graphQLQuery).run(); + List resultList = new ArrayList<>(); + + if (result != null && !result.hasErrors()) { + Object data = result.getResult().getData(); + JSONObject entries = new JSONObject(data); + Map entriesMap = entries.get("Get", Map.class); + cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid()); + + for (Object obj : objects) { + Map map = (Map) obj; + String content = (String) map.get("text"); + String docId = (String) map.get("docId"); + + Map additional = (Map) map.get("_additional"); + Double distance = Double.valueOf(String.valueOf(additional.get("distance"))); + // 转换距离为得分 (Weaviate 0 是最相近,1 是最远;余弦距离下 1-dist 即为相似度) + double score = 1.0 - distance; + + String sourceName = "未知来源"; + if (docId != null) { + KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper() + .eq(KnowledgeAttach::getDocId, docId) + .last("limit 1")); + if (attach != null) { + sourceName = attach.getName(); + } + } + + resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder() + .content(content) + .score(score) + .sourceName(sourceName) + .build()); + } + } + return resultList; + } + @Override @SneakyThrows public void removeById(String id, String modelName) {