diff --git a/docs/script/sql/update/updat-0423.sql b/docs/script/sql/update/updat-0423.sql new file mode 100644 index 00000000..4ed14430 --- /dev/null +++ b/docs/script/sql/update/updat-0423.sql @@ -0,0 +1,14 @@ +-- 为知识库信息表新增检索配置字段 (剔除了已存在的重排字段) +ALTER TABLE knowledge_info + ADD COLUMN similarity_threshold DOUBLE DEFAULT 0.5 COMMENT '相似度阈值' +AFTER retrieve_limit; + +ALTER TABLE knowledge_info ADD COLUMN enable_hybrid tinyint(1) DEFAULT 0 COMMENT '是否启用混合检索'; +ALTER TABLE knowledge_info ADD COLUMN hybrid_alpha double DEFAULT 0.5 COMMENT '混合检索权重比例 (0.0=纯向量, 1.0=纯关键词)'; + +-- 为知识片段表增加全文索引及关联ID +ALTER TABLE knowledge_fragment ADD COLUMN knowledge_id bigint COMMENT '知识库ID'; +ALTER TABLE knowledge_fragment ADD FULLTEXT INDEX ft_content (content) WITH PARSER ngram; + +-- 为知识库附件表增加解析状态字段 +ALTER TABLE `knowledge_attach` ADD COLUMN `status` TINYINT DEFAULT 0 COMMENT '解析状态: 0待解析, 1解析中, 2已解析, 3解析失败'; diff --git a/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/ThreadPoolConfig.java b/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/ThreadPoolConfig.java index cc218c21..313d2782 100644 --- a/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/ThreadPoolConfig.java +++ b/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/ThreadPoolConfig.java @@ -10,6 +10,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.core.task.VirtualThreadTaskExecutor; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import java.util.concurrent.*; /** @@ -22,6 +23,12 @@ import java.util.concurrent.*; @EnableConfigurationProperties(ThreadPoolProperties.class) public class ThreadPoolConfig { + private final ThreadPoolProperties properties; + + public ThreadPoolConfig(ThreadPoolProperties properties) { + this.properties = properties; + } + /** * 核心线程数 = cpu 核心数 + 1 */ @@ -54,6 +61,22 @@ public class ThreadPoolConfig { return scheduledThreadPoolExecutor; } + /** + * 知识库解析专用异步线程池 + */ + @Bean(name = "knowledgeParseExecutor") + public ThreadPoolTaskExecutor knowledgeParseExecutor() { + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + executor.setCorePoolSize(core); + executor.setMaxPoolSize(core * 2); + executor.setQueueCapacity(properties.getQueueCapacity()); + executor.setKeepAliveSeconds(properties.getKeepAliveSeconds()); + executor.setThreadNamePrefix("knowledge-parse-pool-"); + executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy()); + executor.initialize(); + return executor; + } + /** * 销毁事件 * 停止线程池 diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeAttachController.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeAttachController.java index 6367b618..e2e1b4e8 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeAttachController.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeAttachController.java @@ -110,6 +110,17 @@ public class KnowledgeAttachController extends BaseController { @PostMapping(value = "/upload") public R upload(KnowledgeInfoUploadBo bo){ knowledgeAttachService.upload(bo); - return R.ok("上传知识库附件成功!"); + return R.ok("上传成功!"); + } + + /** + * 手动解析附件内容 + * + * @param id 附件ID + */ + @PostMapping("/parse/{id}") + public R parse(@PathVariable Long id) { + knowledgeAttachService.parse(id); + return R.ok(); } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java index ac79a62e..d739a30f 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java @@ -8,6 +8,7 @@ import jakarta.validation.constraints.*; import cn.dev33.satoken.annotation.SaCheckPermission; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.springframework.web.bind.annotation.*; import org.springframework.validation.annotation.Validated; @@ -102,4 +103,12 @@ public class KnowledgeFragmentController extends BaseController { @PathVariable Long[] ids) { return toAjax(knowledgeFragmentService.deleteWithValidByIds(List.of(ids), true)); } + + /** + * 检索测试 + */ + @PostMapping("/retrieval") + public R> retrieval(@RequestBody KnowledgeFragmentBo bo) { + return R.ok(knowledgeFragmentService.retrieval(bo)); + } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java index 7472f57f..895ab69f 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java @@ -49,5 +49,44 @@ public class KnowledgeFragmentBo extends BaseEntity { */ private String remark; + /** + * 知识库ID + */ + private Long knowledgeId; + + /** + * 检索内容 + */ + private String query; + + /** + * 返回条数 + */ + private Integer topK; + + /** + * 相似度阈值 + */ + private Double threshold; + + /** + * 是否启用重排 + */ + private Boolean enableRerank; + + /** + * 重排模型名称 + */ + private String rerankModel; + + /** + * 是否启用混合检索 + */ + private Boolean enableHybrid; + + /** + * 混合检索权重 (0.0-1.0) + */ + private Double hybridAlpha; } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java index 4d5cf619..b0045a1d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java @@ -62,6 +62,11 @@ public class KnowledgeInfoBo extends BaseEntity { */ private Long retrieveLimit; + /** + * 相似度阈值 + */ + private Double similarityThreshold; + /** * 文本块大小 */ @@ -98,12 +103,19 @@ public class KnowledgeInfoBo extends BaseEntity { private Double rerankScoreThreshold; + /** + * 是否启用混合检索(0 否 1是) + */ + private Integer enableHybrid; + + /** + * 混合检索权重 (0.0-1.0) + */ + private Double hybridAlpha; + /** * 备注 */ private String remark; - - - } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoUploadBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoUploadBo.java index a2e1abd0..04c39d05 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoUploadBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoUploadBo.java @@ -16,6 +16,11 @@ public class KnowledgeInfoUploadBo { private MultipartFile file; + /** + * 是否自动解析 (true: 立即解析, false: 仅上传) + */ + private Boolean autoParse; + /** * 生效时间, 为空则立即生效 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java index bb5634a3..6f0b9352 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java @@ -77,4 +77,22 @@ public class QueryVectorBo { */ private Double rerankScoreThreshold; + // ========== 混合检索与阈值相关参数 ========== + + /** + * 相似度阈值 (0.0-1.0) + * 应用于向量搜索阶段 + */ + private Double similarityThreshold; + + /** + * 是否启用混合检索 + */ + private Boolean enableHybrid = false; + + /** + * 混合检索权重 (0.0-1.0) + */ + private Double hybridAlpha; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeAttach.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeAttach.java index 486c2466..2d836f88 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeAttach.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeAttach.java @@ -57,5 +57,10 @@ public class KnowledgeAttach extends BaseEntity { */ private String remark; + /** + * 解析状态: 0待解析, 1解析中, 2已解析, 3解析失败 + */ + private Integer status; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeFragment.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeFragment.java index 04716e15..d184a10a 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeFragment.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeFragment.java @@ -47,5 +47,10 @@ public class KnowledgeFragment extends BaseEntity { */ private String remark; + /** + * 知识库ID + */ + private Long knowledgeId; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java index 948d04c0..95bd9286 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java @@ -63,6 +63,11 @@ public class KnowledgeInfo extends BaseEntity { */ private Long retrieveLimit; + /** + * 相似度阈值 + */ + private Double similarityThreshold; + /** * 文本块大小 */ @@ -98,6 +103,16 @@ public class KnowledgeInfo extends BaseEntity { */ private Double rerankScoreThreshold; + /** + * 是否启用混合检索(0 否 1是) + */ + private Integer enableHybrid; + + /** + * 混合检索权重 (0.0-1.0) + */ + private Double hybridAlpha; + /** * 备注 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/DocFragmentCountVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/DocFragmentCountVo.java new file mode 100644 index 00000000..f162d93b --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/DocFragmentCountVo.java @@ -0,0 +1,20 @@ +package org.ruoyi.domain.vo.knowledge; + +import lombok.Data; + +/** + * 文档分块数统计 VO(用于 GROUP BY 查询结果接收) + */ +@Data +public class DocFragmentCountVo { + + /** + * 文档ID(关联 knowledge_attach.doc_id) + */ + private String docId; + + /** + * 该文档下的分块数量 + */ + private Integer fragmentCount; +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java index 8d3f2f08..09fbd62e 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java @@ -8,6 +8,7 @@ import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; import java.io.Serial; import java.io.Serializable; +import java.util.Date; @@ -68,5 +69,22 @@ public class KnowledgeAttachVo implements Serializable { @ExcelProperty(value = "备注") private String remark; + /** + * 上传时间(来自 BaseEntity.createTime) + */ + @ExcelProperty(value = "上传时间") + private Date createTime; + + /** + * 解析状态: 0待解析, 1解析中, 2已解析, 3解析失败 + */ + @ExcelProperty(value = "解析状态") + private Integer status; + + /** + * 分块数(统计字段,非数据库列) + */ + private Integer fragmentCount; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeFragmentVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeFragmentVo.java index b8be695e..45b4a9ab 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeFragmentVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeFragmentVo.java @@ -39,7 +39,7 @@ public class KnowledgeFragmentVo implements Serializable { * 片段索引下标 */ @ExcelProperty(value = "片段索引下标") - private Long idx; + private Integer idx; /** * 文档内容 @@ -53,5 +53,10 @@ public class KnowledgeFragmentVo implements Serializable { @ExcelProperty(value = "备注") private String remark; + /** + * 知识库ID + */ + private Long knowledgeId; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java index 41d48480..d742fad0 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java @@ -76,6 +76,12 @@ public class KnowledgeInfoVo implements Serializable { @ExcelProperty(value = "知识库中检索的条数") private Integer retrieveLimit; + /** + * 相似度阈值 + */ + @ExcelProperty(value = "相似度阈值") + private Double similarityThreshold; + /** * 文本块大小 */ @@ -118,6 +124,24 @@ public class KnowledgeInfoVo implements Serializable { @ExcelProperty(value = "重排序分数阈值") private Double rerankScoreThreshold; + /** + * 是否启用混合检索(0 否 1是) + */ + @ExcelProperty(value = "是否启用混合检索") + private Integer enableHybrid; + + /** + * 混合检索权重 (0.0-1.0) + */ + @ExcelProperty(value = "混合检索权重") + private Double hybridAlpha; + + /** + * 文档数量 + */ + @ExcelProperty(value = "文档数量") + private Integer documentCount; + /** * 备注 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java new file mode 100644 index 00000000..420015d8 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java @@ -0,0 +1,69 @@ +package org.ruoyi.domain.vo.knowledge; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.io.Serial; +import java.io.Serializable; + +/** + * 知识检索测试结果视图对象 + * + * @author RobustH + */ +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class KnowledgeRetrievalVo implements Serializable { + + @Serial + private static final long serialVersionUID = 1L; + + /** + * 片段ID + */ + private String id; + + /** + * 文档ID + */ + private String docId; + + /** + * 知识库ID + */ + private Long knowledgeId; + + /** + * 分片索引 + */ + private Integer idx; + + /** + * 片段内容 + */ + private String content; + + /** + * 相似度得分 + */ + private Double score; + + /** + * 原始检索排名 (重排前) + */ + private Integer originalIndex; + + /** + * 原始检索得分 (重排前) + */ + private Double rawScore; + + /** + * 来源文档名称 + */ + private String sourceName; +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/enums/KnowledgeAttachStatus.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/enums/KnowledgeAttachStatus.java new file mode 100644 index 00000000..ea13da11 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/enums/KnowledgeAttachStatus.java @@ -0,0 +1,38 @@ +package org.ruoyi.enums; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * 知识库附件解析状态枚举 + * + * @author RobustH + */ +@Getter +@AllArgsConstructor +public enum KnowledgeAttachStatus { + + /** + * 待解析 + */ + WAITING(0, "待解析"), + + /** + * 解析中 + */ + PARSING(1, "解析中"), + + /** + * 已解析 + */ + COMPLETED(2, "已解析"), + + /** + * 解析失败 + */ + FAILED(3, "解析失败"); + + private final Integer code; + private final String info; + +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java index 71ab3b28..05b4d780 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java @@ -1,5 +1,8 @@ package org.ruoyi.mapper.knowledge; +import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.annotations.Select; import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; import org.ruoyi.domain.vo.knowledge.KnowledgeAttachVo; import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus; @@ -10,6 +13,12 @@ import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus; * @author ageerle * @date 2025-12-17 */ +@Mapper public interface KnowledgeAttachMapper extends BaseMapperPlus { + /** + * 统计指定知识库下的文档数量 + */ + @Select("SELECT COUNT(*) FROM knowledge_attach WHERE knowledge_id = #{knowledgeId}") + int countByKnowledgeId(@Param("knowledgeId") Long knowledgeId); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java index b16995fa..304bb7d5 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java @@ -1,15 +1,45 @@ package org.ruoyi.mapper.knowledge; +import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.annotations.Select; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; +import org.ruoyi.domain.vo.knowledge.DocFragmentCountVo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus; +import java.util.List; + /** * 知识片段Mapper接口 * * @author ageerle * @date 2025-12-17 */ +@Mapper public interface KnowledgeFragmentMapper extends BaseMapperPlus { + /** + * 批量统计各文档的分块数(强类型接收,避免 Map key 大小写问题) + * + * @param docIds 文档 ID 列表 + * @return 每个 docId 对应的分块数列表 + */ + @Select("") + List selectFragmentCountByDocIds(@Param("docIds") List docIds); + @Select("") + List searchByKeyword(@Param("knowledgeId") Long knowledgeId, @Param("query") String query, @Param("limit") Integer limit); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java index 3bd0876e..16e0750d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java @@ -20,6 +20,11 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.service.tool.ToolProvider; import dev.langchain4j.skills.shell.ShellSkills; +import dev.langchain4j.rag.AugmentationRequest; +import dev.langchain4j.rag.AugmentationResult; +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.query.Metadata; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; @@ -55,6 +60,7 @@ import org.ruoyi.service.chat.IChatMessageService; import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.service.retrieval.KnowledgeRetrievalService; +import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -412,16 +418,49 @@ public class ChatServiceFacade implements IChatService { /** * 构建上下文消息列表 - * 消息顺序:历史消息 → 当前用户消息(确保 AI 正确理解对话上下文) * * @param chatRequest 聊天请求 * @return 上下文消息列表 */ private List buildContextMessages(ChatRequest chatRequest) { - List messages = new ArrayList<>(); + List messages = new ArrayList<>(); - // 从数据库查询历史对话消息(放在前面) + // 1. 初始化当前用户消息 + UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent()); + + // 2. 知识库检索增强 (RAG) + if (chatRequest.getKnowledgeId() != null) { + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId())); + if (knowledgeInfoVo != null) { + ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); + if (chatModel != null) { + log.info("执行高级 RAG 流程: kid={}", chatRequest.getKnowledgeId()); + + // 构建自定义检索器 + CustomVectorRetriever retriever = new CustomVectorRetriever( + knowledgeRetrievalService, knowledgeInfoVo, chatModel); + + // 构建增强流水线 + RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder() + .contentRetriever(retriever) + .build(); + + // 执行增强:编织上下文到 UserMessage + Metadata metadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>()); + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); + AugmentationResult result = augmentor.augment(augmentationRequest); + + ChatMessage augmented = result.chatMessage(); + if (augmented instanceof UserMessage) { + userMessage = (UserMessage) augmented; + log.debug("RAG 增强完成,UserMessage 已注入背景知识"); + } + } + } + } + + // 3. 从数据库查询历史对话消息(放在前面) if (chatRequest.getSessionId() != null) { MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId()); if (memory != null) { @@ -433,38 +472,7 @@ public class ChatServiceFacade implements IChatService { } } - // 从向量库查询相关历史消息(知识库内容作为上下文) - if (chatRequest.getKnowledgeId() != null) { - // 查询知识库信息 - KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId())); - if (knowledgeInfoVo == null) { - log.warn("知识库信息不存在,kid: {}", chatRequest.getKnowledgeId()); - // 继续添加当前用户消息 - messages.add(UserMessage.userMessage(chatRequest.getContent())); - return messages; - } - - // 查询向量模型配置信息 - ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); - if (chatModel == null) { - log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel()); - messages.add(UserMessage.userMessage(chatRequest.getContent())); - return messages; - } - - // 构建向量查询参数 - QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel); - - // 使用知识库检索服务(支持重排序) - List nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo); - for (String prompt : nearestList) { - // 知识库内容作为系统上下文添加 - messages.add(new AiMessage(prompt)); - } - } - - // 构建当前用户消息(放在最后) - UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent()); + // 4. 添加经过增强的用户消息(放在最后) messages.add(userMessage); return messages; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeAttachService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeAttachService.java index 013c9736..75d0a528 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeAttachService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeAttachService.java @@ -72,4 +72,11 @@ public interface IKnowledgeAttachService { * 上传附件 */ void upload(KnowledgeInfoUploadBo bo); + + /** + * 解析附件知识片段 + * + * @param id 附件ID + */ + void parse(Long id); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java index e323e79e..b8b88b45 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java @@ -4,6 +4,7 @@ import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.common.mybatis.core.page.PageQuery; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import java.util.Collection; import java.util.List; @@ -65,4 +66,12 @@ public interface IKnowledgeFragmentService { * @return 是否删除成功 */ Boolean deleteWithValidByIds(Collection ids, Boolean isValid); + + /** + * 检索测试 + * + * @param bo 检索参数 + * @return 检索结果 + */ + List retrieval(KnowledgeFragmentBo bo); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java index ce02785e..28b3d90f 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java @@ -2,24 +2,27 @@ package org.ruoyi.service.knowledge.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.RandomUtil; -import org.ruoyi.common.chat.service.chat.IChatModelService; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.core.toolkit.Wrappers; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.common.chat.service.chat.IChatModelService; +import org.ruoyi.enums.KnowledgeAttachStatus; import org.ruoyi.common.core.domain.dto.OssDTO; import org.ruoyi.common.core.service.OssService; import org.ruoyi.common.core.utils.MapstructUtils; +import org.ruoyi.common.core.utils.SpringUtils; import org.ruoyi.common.core.utils.StringUtils; -import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.common.mybatis.core.page.PageQuery; -import com.baomidou.mybatisplus.extension.plugins.pagination.Page; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; -import com.baomidou.mybatisplus.core.toolkit.Wrappers; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.domain.bo.knowledge.KnowledgeAttachBo; import org.ruoyi.domain.bo.knowledge.KnowledgeInfoUploadBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; +import org.ruoyi.domain.vo.knowledge.DocFragmentCountVo; import org.ruoyi.domain.vo.knowledge.KnowledgeAttachVo; import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; import org.ruoyi.factory.ResourceLoaderFactory; @@ -29,11 +32,15 @@ import org.ruoyi.service.knowledge.IKnowledgeAttachService; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.service.knowledge.ResourceLoader; import org.ruoyi.service.vector.VectorStoreService; +import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; -import java.io.IOException; +import java.io.InputStream; + +import java.net.URL; import java.util.*; +import java.util.stream.Collectors; /** * 知识库附件Service业务层处理 @@ -47,57 +54,51 @@ import java.util.*; public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { private final KnowledgeAttachMapper baseMapper; - private final IKnowledgeInfoService knowledgeInfoService; - private final KnowledgeFragmentMapper knowledgeFragmentMapper; - private final IChatModelService chatModelService; - private final ResourceLoaderFactory resourceLoaderFactory; - private final VectorStoreService vectorStoreService; - private final OssService ossService; - /** - * 查询知识库附件 - * - * @param id 主键 - * @return 知识库附件 - */ + @Override - public KnowledgeAttachVo queryById(Long id){ + public KnowledgeAttachVo queryById(Long id) { return baseMapper.selectVoById(id); } - /** - * 分页查询知识库附件列表 - * - * @param bo 查询条件 - * @param pageQuery 分页参数 - * @return 知识库附件分页列表 - */ @Override public TableDataInfo queryPageList(KnowledgeAttachBo bo, PageQuery pageQuery) { LambdaQueryWrapper lqw = buildQueryWrapper(bo); Page result = baseMapper.selectVoPage(pageQuery.build(), lqw); + fillFragmentCount(result.getRecords()); return TableDataInfo.build(result); } - /** - * 查询符合条件的知识库附件列表 - * - * @param bo 查询条件 - * @return 知识库附件列表 - */ @Override public List queryList(KnowledgeAttachBo bo) { LambdaQueryWrapper lqw = buildQueryWrapper(bo); - return baseMapper.selectVoList(lqw); + List list = baseMapper.selectVoList(lqw); + fillFragmentCount(list); + return list; + } + + private void fillFragmentCount(List records) { + if (records == null || records.isEmpty()) return; + List docIds = records.stream() + .map(KnowledgeAttachVo::getDocId) + .filter(StringUtils::isNotBlank) + .distinct() + .collect(Collectors.toList()); + if (docIds.isEmpty()) return; + List countList = knowledgeFragmentMapper.selectFragmentCountByDocIds(docIds); + Map countMap = countList.stream() + .collect(Collectors.toMap(DocFragmentCountVo::getDocId, DocFragmentCountVo::getFragmentCount, (k1, k2) -> k1)); + for (KnowledgeAttachVo vo : records) { + vo.setFragmentCount(countMap.getOrDefault(vo.getDocId(), 0)); + } } private LambdaQueryWrapper buildQueryWrapper(KnowledgeAttachBo bo) { - Map params = bo.getParams(); LambdaQueryWrapper lqw = Wrappers.lambdaQuery(); lqw.orderByAsc(KnowledgeAttach::getId); lqw.eq(bo.getKnowledgeId() != null, KnowledgeAttach::getKnowledgeId, bo.getKnowledgeId()); @@ -107,16 +108,9 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { return lqw; } - /** - * 新增知识库附件 - * - * @param bo 知识库附件 - * @return 是否新增成功 - */ @Override public Boolean insertByBo(KnowledgeAttachBo bo) { KnowledgeAttach add = MapstructUtils.convert(bo, KnowledgeAttach.class); - validEntityBeforeSave(add); boolean flag = baseMapper.insert(add) > 0; if (flag) { bo.setId(add.getId()); @@ -124,98 +118,109 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { return flag; } - /** - * 修改知识库附件 - * - * @param bo 知识库附件 - * @return 是否修改成功 - */ @Override public Boolean updateByBo(KnowledgeAttachBo bo) { KnowledgeAttach update = MapstructUtils.convert(bo, KnowledgeAttach.class); - validEntityBeforeSave(update); return baseMapper.updateById(update) > 0; } - /** - * 保存前的数据校验 - */ - private void validEntityBeforeSave(KnowledgeAttach entity){ - //TODO 做一些数据校验,如唯一约束 - } - - /** - * 校验并批量删除知识库附件信息 - * - * @param ids 待删除的主键集合 - * @param isValid 是否进行有效性校验 - * @return 是否删除成功 - */ @Override public Boolean deleteWithValidByIds(Collection ids, Boolean isValid) { - if(isValid){ - //TODO 做一些业务上的校验,判断是否需要校验 - } return baseMapper.deleteByIds(ids) > 0; } @Override public void upload(KnowledgeInfoUploadBo bo) { MultipartFile file = bo.getFile(); - // 保存文件信息 OssDTO ossDTO = ossService.uploadFile(file); - Long knowledgeId = bo.getKnowledgeId(); - List chunkList = new ArrayList<>(); + KnowledgeAttach knowledgeAttach = new KnowledgeAttach(); knowledgeAttach.setKnowledgeId(bo.getKnowledgeId()); - String docId = RandomUtil.randomString(10); knowledgeAttach.setOssId(ossDTO.getOssId()); - knowledgeAttach.setDocId(docId); + knowledgeAttach.setDocId(RandomUtil.randomString(10)); knowledgeAttach.setName(ossDTO.getOriginalName()); knowledgeAttach.setType(ossDTO.getFileSuffix()); - String content = ""; - ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(knowledgeAttach.getType()); - // 文档分段入库 - List fids = new ArrayList<>(); + knowledgeAttach.setStatus(KnowledgeAttachStatus.WAITING.getCode()); // 待解析 + + baseMapper.insert(knowledgeAttach); + + if (Boolean.TRUE.equals(bo.getAutoParse())) { + // 通过 SpringUtils 获取代理对象,确保 @Async 生效 + SpringUtils.getBean(IKnowledgeAttachService.class).parse(knowledgeAttach.getId()); + } + } + + @Async("knowledgeParseExecutor") + @Override + public void parse(Long id) { + KnowledgeAttach attach = baseMapper.selectById(id); + if (attach == null || (!KnowledgeAttachStatus.WAITING.getCode().equals(attach.getStatus()) && !KnowledgeAttachStatus.FAILED.getCode().equals(attach.getStatus()))) { + return; + } + try { - content = resourceLoader.getContent(file.getInputStream()); - chunkList = resourceLoader.getChunkList(content, String.valueOf(knowledgeId)); + attach.setStatus(KnowledgeAttachStatus.PARSING.getCode()); // 解析中 + baseMapper.updateById(attach); + + log.info("开始解析知识库文档... id: {}, docId: {}", id, attach.getDocId()); + + Long knowledgeId = attach.getKnowledgeId(); + String docId = attach.getDocId(); + + // 获取文件信息并下载 + List ossDTOs = ossService.selectByIds(String.valueOf(attach.getOssId())); + if (ossDTOs == null || ossDTOs.isEmpty()) { + throw new RuntimeException("未找到对应的 OSS 文件信息"); + } + OssDTO ossDTO = ossDTOs.get(0); + String content; + ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(attach.getType()); + try (InputStream inputStream = new URL(ossDTO.getUrl()).openStream()) { + content = resourceLoader.getContent(inputStream); + } + List chunkList = resourceLoader.getChunkList(content, String.valueOf(knowledgeId)); + + List fids = new ArrayList<>(); List knowledgeFragmentList = new ArrayList<>(); if (CollUtil.isNotEmpty(chunkList)) { for (int i = 0; i < chunkList.size(); i++) { - // 生成知识片段ID String fid = RandomUtil.randomString(10); fids.add(fid); KnowledgeFragment knowledgeFragment = new KnowledgeFragment(); + knowledgeFragment.setKnowledgeId(knowledgeId); knowledgeFragment.setDocId(docId); knowledgeFragment.setIdx(i); knowledgeFragment.setContent(chunkList.get(i)); knowledgeFragment.setCreateTime(new Date()); knowledgeFragmentList.add(knowledgeFragment); } + knowledgeFragmentMapper.delete(Wrappers.lambdaQuery().eq(KnowledgeFragment::getDocId, docId)); + knowledgeFragmentMapper.insertBatch(knowledgeFragmentList); + log.info("文档切片并入库完成,共计 {} 个片段。id: {}", chunkList.size(), id); } - knowledgeFragmentMapper.insertBatch(knowledgeFragmentList); - } catch (IOException e) { - log.error("保存知识库信息失败!{}", e.getMessage()); + + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(knowledgeId); + ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); + + StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo(); + storeEmbeddingBo.setKid(String.valueOf(knowledgeId)); + storeEmbeddingBo.setDocId(docId); + storeEmbeddingBo.setFids(fids); + storeEmbeddingBo.setChunkList(chunkList); + storeEmbeddingBo.setVectorStoreName(knowledgeInfoVo.getVectorModel()); + storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); + storeEmbeddingBo.setApiKey(chatModelVo.getApiKey()); + storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost()); + vectorStoreService.storeEmbeddings(storeEmbeddingBo); + + attach.setStatus(KnowledgeAttachStatus.COMPLETED.getCode()); // 已完成 + baseMapper.updateById(attach); + log.info("知识库文档解析、向量化并入库成功!id: {}", id); + } catch (Exception e) { + log.error("解析文档失败!id: {}, error: {}", id, e.getMessage(), e); + attach.setStatus(KnowledgeAttachStatus.FAILED.getCode()); // 失败 + attach.setRemark(StringUtils.substring(e.getMessage(), 0, 255)); // 保存错误原因,截取防止溢出 + baseMapper.updateById(attach); } - baseMapper.insert(knowledgeAttach); - - // 查询知识库信息 - KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(knowledgeId); - - // 查询向量模信息 - ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); - - StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo(); - storeEmbeddingBo.setKid(String.valueOf(knowledgeId)); - storeEmbeddingBo.setDocId(docId); - storeEmbeddingBo.setFids(fids); - storeEmbeddingBo.setChunkList(chunkList); - storeEmbeddingBo.setVectorStoreName(knowledgeInfoVo.getVectorModel()); - storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); - storeEmbeddingBo.setApiKey(chatModelVo.getApiKey()); - storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost()); - vectorStoreService.storeEmbeddings(storeEmbeddingBo); } - } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java index 782b40af..da17f9b9 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java @@ -1,24 +1,29 @@ package org.ruoyi.service.knowledge.impl; -import org.ruoyi.common.core.utils.MapstructUtils; -import org.ruoyi.common.core.utils.StringUtils; -import org.ruoyi.common.mybatis.core.page.TableDataInfo; -import org.ruoyi.common.mybatis.core.page.PageQuery; -import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.toolkit.Wrappers; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.common.chat.service.chat.IChatModelService; +import org.ruoyi.common.core.utils.MapstructUtils; +import org.ruoyi.common.core.utils.StringUtils; +import org.ruoyi.common.mybatis.core.page.PageQuery; +import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; +import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; +import org.ruoyi.service.knowledge.IKnowledgeInfoService; +import org.ruoyi.service.retrieval.KnowledgeRetrievalService; import org.springframework.stereotype.Service; -import java.util.List; -import java.util.Map; -import java.util.Collection; +import java.util.*; /** * 知识片段Service业务层处理 @@ -32,6 +37,9 @@ import java.util.Collection; public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { private final KnowledgeFragmentMapper baseMapper; + private final IKnowledgeInfoService knowledgeInfoService; + private final IChatModelService chatModelService; + private final KnowledgeRetrievalService knowledgeRetrievalService; /** * 查询知识片段 @@ -71,7 +79,6 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { } private LambdaQueryWrapper buildQueryWrapper(KnowledgeFragmentBo bo) { - Map params = bo.getParams(); LambdaQueryWrapper lqw = Wrappers.lambdaQuery(); lqw.orderByAsc(KnowledgeFragment::getId); lqw.eq(bo.getDocId() != null, KnowledgeFragment::getDocId, bo.getDocId()); @@ -131,4 +138,50 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { } return baseMapper.deleteByIds(ids) > 0; } + + /** + * 检索测试核心实现 - 委托给统一的 KnowledgeRetrievalService + */ + @Override + public List retrieval(KnowledgeFragmentBo bo) { + if (bo.getKnowledgeId() == null || StringUtils.isBlank(bo.getQuery())) { + return new ArrayList<>(); + } + + // 1. 获取知识库及模型配置(为了获取 API Key/Host 等模型参数) + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId()); + if (knowledgeInfoVo == null) { + return new ArrayList<>(); + } + + ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); + if (chatModel == null) { + log.warn("未找到对应的向量模型配置: {}", knowledgeInfoVo.getEmbeddingModel()); + return new ArrayList<>(); + } + + // 2. 构造通用的参数对象 + QueryVectorBo queryVectorBo = new QueryVectorBo(); + queryVectorBo.setQuery(bo.getQuery()); + queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId())); + queryVectorBo.setApiKey(chatModel.getApiKey()); + queryVectorBo.setBaseUrl(chatModel.getApiHost()); + queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); + queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); + + // 使用前端传入的实时测试参数,若无则使用知识库默认参数 + queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit()); + queryVectorBo.setSimilarityThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getSimilarityThreshold()); + + queryVectorBo.setEnableHybrid(bo.getEnableHybrid() != null ? bo.getEnableHybrid() : Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1)); + queryVectorBo.setHybridAlpha(bo.getHybridAlpha() != null ? bo.getHybridAlpha() : knowledgeInfoVo.getHybridAlpha()); + + queryVectorBo.setEnableRerank(bo.getEnableRerank() != null ? bo.getEnableRerank() : Objects.equals(knowledgeInfoVo.getEnableRerank(), 1)); + queryVectorBo.setRerankModelName(StringUtils.isNotBlank(bo.getRerankModel()) ? bo.getRerankModel() : knowledgeInfoVo.getRerankModel()); + queryVectorBo.setRerankTopN(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRerankTopN()); + queryVectorBo.setRerankScoreThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getRerankScoreThreshold()); + + // 3. 执行统一检索 + return knowledgeRetrievalService.retrieve(queryVectorBo); + } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeInfoServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeInfoServiceImpl.java index 6e7a8d7b..266fd5c8 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeInfoServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeInfoServiceImpl.java @@ -12,6 +12,7 @@ import lombok.extern.slf4j.Slf4j; import org.ruoyi.domain.bo.knowledge.KnowledgeInfoBo; import org.ruoyi.domain.entity.knowledge.KnowledgeInfo; import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; import org.ruoyi.mapper.knowledge.KnowledgeInfoMapper; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.springframework.stereotype.Service; @@ -33,6 +34,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { private final KnowledgeInfoMapper baseMapper; + private final KnowledgeAttachMapper knowledgeAttachMapper; + /** * 查询知识库 * @@ -55,6 +58,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { public TableDataInfo queryPageList(KnowledgeInfoBo bo, PageQuery pageQuery) { LambdaQueryWrapper lqw = buildQueryWrapper(bo); Page result = baseMapper.selectVoPage(pageQuery.build(), lqw); + // 批量填充文档数 + fillDocumentCount(result.getRecords()); return TableDataInfo.build(result); } @@ -87,6 +92,17 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { return lqw; } + /** + * 批量填充知识库列表每一条记录的文档数(documentCount) + */ + private void fillDocumentCount(List records) { + if (records == null || records.isEmpty()) return; + for (KnowledgeInfoVo vo : records) { + int count = knowledgeAttachMapper.countByKnowledgeId(vo.getId()); + vo.setDocumentCount(count); + } + } + /** * 新增知识库 * diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java new file mode 100644 index 00000000..f79206bc --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java @@ -0,0 +1,65 @@ +package org.ruoyi.service.knowledge.retriever; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.rag.query.Query; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.domain.bo.vector.QueryVectorBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; +import org.ruoyi.service.retrieval.KnowledgeRetrievalService; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * 自定义检索器:适配 LangChain4j ContentRetriever 接口 + * 桥接统一的 KnowledgeRetrievalService,支持配置化的混合检索、阈值过滤等功能 + * + * @author RobustH + */ +@Slf4j +@RequiredArgsConstructor +public class CustomVectorRetriever implements ContentRetriever { + + private final KnowledgeRetrievalService knowledgeRetrievalService; + private final KnowledgeInfoVo knowledgeInfoVo; + private final ChatModelVo chatModelVo; + + @Override + public List retrieve(Query query) { + log.info("执行自定义检索,关键字: {}", query.text()); + + // 构建增强后的查询参数 + QueryVectorBo queryVectorBo = new QueryVectorBo(); + queryVectorBo.setQuery(query.text()); + queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId())); + queryVectorBo.setApiKey(chatModelVo.getApiKey()); + queryVectorBo.setBaseUrl(chatModelVo.getApiHost()); + queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); + queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); + + // 应用知识库配置参数 + queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit()); + queryVectorBo.setSimilarityThreshold(knowledgeInfoVo.getSimilarityThreshold()); + queryVectorBo.setEnableHybrid(Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1)); + queryVectorBo.setHybridAlpha(knowledgeInfoVo.getHybridAlpha()); + + // 设置重排序参数 (如果 retriever 阶段也想做初步重排,可以在此设置) + queryVectorBo.setEnableRerank(Objects.equals(knowledgeInfoVo.getEnableRerank(), 1)); + queryVectorBo.setRerankModelName(knowledgeInfoVo.getRerankModel()); + queryVectorBo.setRerankTopN(knowledgeInfoVo.getRerankTopN()); + queryVectorBo.setRerankScoreThreshold(knowledgeInfoVo.getRerankScoreThreshold()); + + // 通过统一服务执行检索 + List nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo); + + // 将结果包装为标准的 Content 返回 + return nearestList.stream() + .map(text -> Content.from(TextSegment.from(text))) + .collect(Collectors.toList()); + } +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/rerank/impl/SiliconFlowRerankModelService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/rerank/impl/SiliconFlowRerankModelService.java new file mode 100644 index 00000000..49c5aecf --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/rerank/impl/SiliconFlowRerankModelService.java @@ -0,0 +1,174 @@ +package org.ruoyi.service.rerank.impl; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import okhttp3.*; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.domain.bo.rerank.RerankRequest; +import org.ruoyi.domain.bo.rerank.RerankResult; +import org.ruoyi.service.rerank.RerankModelService; +import org.springframework.stereotype.Component; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +/** + * 硅基流动重排序模型实现 + * 适配硅基流动的 /v1/rerank 接口 + * + * @author RobustH + * @date 2026-04-21 + */ +@Slf4j +@Component("siliconflowRerank") +public class SiliconFlowRerankModelService implements RerankModelService { + + private static final String DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1/rerank"; + + private final OkHttpClient okHttpClient; + private final ObjectMapper objectMapper = new ObjectMapper() + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + private ChatModelVo chatModelVo; + + public SiliconFlowRerankModelService() { + this.okHttpClient = new OkHttpClient.Builder() + .connectTimeout(30, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .writeTimeout(30, TimeUnit.SECONDS) + .build(); + } + + @Override + public void configure(ChatModelVo config) { + this.chatModelVo = config; + } + + @Override + public RerankResult rerank(RerankRequest rerankRequest) { + long startTime = System.currentTimeMillis(); + + try { + String url = buildUrl(); + String requestJson = buildRequestJson(rerankRequest); + + RequestBody body = RequestBody.create(requestJson, MediaType.get("application/json")); + Request httpRequest = new Request.Builder() + .url(url) + .addHeader("Authorization", "Bearer " + chatModelVo.getApiKey()) + .addHeader("Content-Type", "application/json") + .post(body) + .build(); + + log.info("硅基流动重排序请求: model={}, url={}", chatModelVo.getModelName(), url); + + try (Response response = okHttpClient.newCall(httpRequest).execute()) { + if (!response.isSuccessful()) { + String err = response.body() != null ? response.body().string() : "无错误信息"; + throw new IllegalArgumentException("硅基流动 Rerank API 调用失败: " + response.code() + " - " + err); + } + + ResponseBody responseBody = response.body(); + if (responseBody == null) { + throw new IllegalArgumentException("响应体为空"); + } + + SiliconFlowRerankResponse rerankResponse = objectMapper.readValue( + responseBody.string(), SiliconFlowRerankResponse.class); + + return buildRerankResult(rerankResponse, rerankRequest.getDocuments(), + System.currentTimeMillis() - startTime); + } + + } catch (Exception e) { + log.error("硅基流动重排序失败: {}", e.getMessage(), e); + throw new RuntimeException("重排序服务调用失败: " + e.getMessage(), e); + } + } + + /** + * 构建请求 URL,鲁棒处理 API Host 末尾路径 + */ + private String buildUrl() { + String apiHost = chatModelVo.getApiHost(); + if (apiHost == null || apiHost.isBlank()) { + return DEFAULT_BASE_URL; + } + if (apiHost.endsWith("/rerank")) { + return apiHost; + } + if (apiHost.endsWith("/v1")) { + return apiHost + "/rerank"; + } + return apiHost.endsWith("/") ? apiHost + "rerank" : apiHost + "/rerank"; + } + + /** + * 构建请求体 JSON + */ + private String buildRequestJson(RerankRequest rerankRequest) throws IOException { + SiliconFlowRerankRequest request = new SiliconFlowRerankRequest(); + request.setModel(chatModelVo.getModelName()); + request.setQuery(rerankRequest.getQuery()); + request.setDocuments(rerankRequest.getDocuments()); + request.setTop_n(rerankRequest.getTopN() != null ? rerankRequest.getTopN() : rerankRequest.getDocuments().size()); + request.setReturn_documents(rerankRequest.getReturnDocuments() != null ? rerankRequest.getReturnDocuments() : false); + return objectMapper.writeValueAsString(request); + } + + /** + * 构建标准 RerankResult + */ + private RerankResult buildRerankResult(SiliconFlowRerankResponse response, + List originalDocuments, long durationMs) { + Double[] scores = new Double[originalDocuments.size()]; + for (int i = 0; i < scores.length; i++) { + scores[i] = 0.0; + } + + List docs = new ArrayList<>(); + if (response != null && response.getResults() != null) { + response.getResults().forEach(item -> { + if (item.getIndex() != null && item.getIndex() < originalDocuments.size()) { + scores[item.getIndex()] = item.getRelevance_score(); + docs.add(RerankResult.RerankDocument.builder() + .index(item.getIndex()) + .relevanceScore(item.getRelevance_score()) + .document(originalDocuments.get(item.getIndex())) + .build()); + } + }); + } + + return RerankResult.builder() + .documents(docs) + .totalDocuments(originalDocuments.size()) + .durationMs(durationMs) + .build(); + } + + // ==================== 内部 DTO ==================== + + @Data + static class SiliconFlowRerankRequest { + private String model; + private String query; + private List documents; + private Integer top_n; + private Boolean return_documents; + } + + @Data + static class SiliconFlowRerankResponse { + private List results; + } + + @Data + static class SiliconFlowRerankResultItem { + private Integer index; + private Double relevance_score; + } +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java index 9c42dd1a..3e0a6cab 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java @@ -1,12 +1,13 @@ package org.ruoyi.service.retrieval; import org.ruoyi.domain.bo.vector.QueryVectorBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import java.util.List; /** * 知识库检索服务接口 - * 整合粗召回(向量检索)和重排序流程 + * 整合粗召回(向量检索/关键词检索)和重排序流程 * * @author yang * @date 2026-04-19 @@ -21,4 +22,13 @@ public interface KnowledgeRetrievalService { * @return 文本内容列表 */ List retrieveTexts(QueryVectorBo queryVectorBo); + + /** + * 执行知识库检索,返回详细结果对象(包含分数、文档ID等) + * 支持混合检索和重排序 + * + * @param queryVectorBo 查询参数 + * @return 检索结果列表 + */ + List retrieve(QueryVectorBo queryVectorBo); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java index 42f6cf68..b6841ef1 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java @@ -2,21 +2,26 @@ package org.ruoyi.service.retrieval.impl; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.utils.StringUtils; import org.ruoyi.domain.bo.rerank.RerankRequest; import org.ruoyi.domain.bo.rerank.RerankResult; import org.ruoyi.domain.bo.vector.QueryVectorBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.RerankModelFactory; +import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; import org.ruoyi.service.rerank.RerankModelService; import org.ruoyi.service.retrieval.KnowledgeRetrievalService; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; -import java.util.ArrayList; -import java.util.List; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; /** * 知识库检索服务实现 - * 整合粗召回(向量检索)和重排序流程 + * 整合粗召回(向量检索/关键词检索)、RRF融合和重排序流程 * * @author yang * @date 2026-04-19 @@ -28,6 +33,7 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService private final VectorStoreService vectorStoreService; private final RerankModelFactory rerankModelFactory; + private final KnowledgeFragmentMapper fragmentMapper; /** * 粗召回默认扩大倍数 @@ -37,99 +43,214 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService @Override public List retrieveTexts(QueryVectorBo queryVectorBo) { + List results = retrieve(queryVectorBo); + return results.stream() + .map(KnowledgeRetrievalVo::getContent) + .collect(Collectors.toList()); + } + + @Override + public List retrieve(QueryVectorBo queryVectorBo) { log.info("开始知识库检索, kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery()); - // 1. 粗召回阶段 - 向量检索 - List coarseResults = coarseRetrieval(queryVectorBo); + // 1. 粗召回阶段 (向量检索 + 关键词搜索) + List coarseResults = performCoarseRetrieval(queryVectorBo); log.debug("粗召回返回 {} 条结果", coarseResults.size()); if (coarseResults.isEmpty()) { return coarseResults; } - // 2. 重排序阶段(可选) - if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && - queryVectorBo.getRerankModelName() != null) { - return rerank(queryVectorBo, coarseResults); + // 2. 初始化原始索引 + for (int i = 0; i < coarseResults.size(); i++) { + coarseResults.get(i).setOriginalIndex(i); } - return coarseResults; + // 3. 重排序阶段 (可选) + List finalResults = coarseResults; + if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && + StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) { + finalResults = performRerank(queryVectorBo, coarseResults); + } + + // 4. 应用分值阈值过滤 (重排分值或 RRF 分值) + double threshold = queryVectorBo.getRerankScoreThreshold() != null ? + queryVectorBo.getRerankScoreThreshold() : 0.0; + + return finalResults.stream() + .filter(res -> res.getScore() >= threshold) + .collect(Collectors.toList()); } /** - * 粗召回阶段 - 向量检索 + * 粗召回阶段:根据配置执行向量搜索或混合搜索 */ - private List coarseRetrieval(QueryVectorBo queryVectorBo) { - // 如果启用重排序,扩大粗召回数量 - int originalMaxResults = queryVectorBo.getMaxResults(); - int expandedResults = originalMaxResults; - if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && - queryVectorBo.getRerankModelName() != null) { - expandedResults = originalMaxResults * RERANK_EXPANSION_FACTOR; - log.debug("启用重排序,粗召回数量从 {} 扩大到 {}", originalMaxResults, expandedResults); + private List performCoarseRetrieval(QueryVectorBo queryVectorBo) { + // 如果启用重排序,适当扩大召回数量 + int originalMaxResults = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10; + int targetMaxResults = originalMaxResults; + if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && + StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) { + targetMaxResults = originalMaxResults * RERANK_EXPANSION_FACTOR; } - // 临时修改查询数量 - queryVectorBo.setMaxResults(expandedResults); + // 如果未启用混合检索,直接走向量搜索 + if (!Boolean.TRUE.equals(queryVectorBo.getEnableHybrid())) { + QueryVectorBo vectorQuery = copyOf(queryVectorBo, targetMaxResults); + List results = vectorStoreService.search(vectorQuery); + + // 应用基础相似度阈值过滤(如果有) + if (queryVectorBo.getSimilarityThreshold() != null) { + results = results.stream() + .filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold()) + .collect(Collectors.toList()); + } + return results; + } + + // 混合检索逻辑 + log.info("执行混合检索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery()); try { - return vectorStoreService.getQueryVector(queryVectorBo); - } finally { - // 恢复原始值 - queryVectorBo.setMaxResults(originalMaxResults); + // A. 并行执行向量搜索 + int finalTargetMaxResults = targetMaxResults; + CompletableFuture> vectorFuture = CompletableFuture.supplyAsync(() -> { + QueryVectorBo vectorQuery = copyOf(queryVectorBo, finalTargetMaxResults); + List results = vectorStoreService.search(vectorQuery); + // 向量层初步过滤 + if (queryVectorBo.getSimilarityThreshold() != null) { + return results.stream() + .filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold()) + .collect(Collectors.toList()); + } + return results; + }); + + // B. 并行执行关键词搜索 (MySQL Fulltext) + CompletableFuture> keywordFuture = CompletableFuture.supplyAsync(() -> { + try { + Long kid = Long.valueOf(queryVectorBo.getKid()); + List fragments = fragmentMapper.searchByKeyword(kid, queryVectorBo.getQuery(), finalTargetMaxResults); + return fragments.stream().map(f -> { + KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo(); + vo.setId(f.getId().toString()); + vo.setContent(f.getContent()); + vo.setDocId(f.getDocId()); + vo.setIdx(f.getIdx()); + vo.setKnowledgeId(f.getKnowledgeId()); + vo.setScore(10.0); // RRF 初始占位分 + return vo; + }).collect(Collectors.toList()); + } catch (Exception e) { + log.error("关键词检索失败: {}", e.getMessage()); + return new ArrayList<>(); + } + }); + + List vectorResults = vectorFuture.get(); + List keywordResults = keywordFuture.get(); + + // C. RRF 融合 + double alpha = queryVectorBo.getHybridAlpha() != null ? queryVectorBo.getHybridAlpha() : 0.5; + return calculateRRF(vectorResults, keywordResults, alpha); + + } catch (Exception e) { + log.error("混合检索执行失败,回退到纯向量检索: {}", e.getMessage(), e); + return vectorStoreService.search(copyOf(queryVectorBo, targetMaxResults)); } } /** * 重排序阶段 */ - private List rerank(QueryVectorBo queryVectorBo, List coarseResults) { - long startTime = System.currentTimeMillis(); - + private List performRerank(QueryVectorBo queryVectorBo, List coarseResults) { try { - // 1. 通过工厂获取重排序模型 RerankModelService rerankModel = rerankModelFactory.createModel(queryVectorBo.getRerankModelName()); + + List contents = coarseResults.stream() + .map(KnowledgeRetrievalVo::getContent) + .collect(Collectors.toList()); - // 2. 构建重排序请求 - int topN = queryVectorBo.getRerankTopN() != null ? - queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults(); + // topN 默认为 maxResults + int topN = queryVectorBo.getRerankTopN() != null ? queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults(); RerankRequest rerankRequest = RerankRequest.builder() .query(queryVectorBo.getQuery()) - .documents(coarseResults) + .documents(contents) .topN(topN) - .returnDocuments(true) .build(); - log.info("执行重排序, model={}, documents={}, topN={}", - queryVectorBo.getRerankModelName(), coarseResults.size(), topN); - - // 3. 执行重排序 RerankResult rerankResult = rerankModel.rerank(rerankRequest); - // 4. 转换重排序结果 - List finalResults = new ArrayList<>(); + // 写回分数并记录原始分 for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) { - // 应用分数阈值过滤 - if (queryVectorBo.getRerankScoreThreshold() != null && - doc.getRelevanceScore() < queryVectorBo.getRerankScoreThreshold()) { - continue; - } - - if (doc.getDocument() != null) { - finalResults.add(doc.getDocument()); + if (doc.getIndex() != null && doc.getIndex() < coarseResults.size()) { + KnowledgeRetrievalVo vo = coarseResults.get(doc.getIndex()); + vo.setRawScore(vo.getScore()); + vo.setScore(doc.getRelevanceScore()); } } - long duration = System.currentTimeMillis() - startTime; - log.info("重排序完成, 返回 {} 条结果, 耗时 {}ms", finalResults.size(), duration); - - return finalResults; + // 按新分排序 + coarseResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); + + // 截断到 topN + return coarseResults.subList(0, Math.min(topN, coarseResults.size())); } catch (Exception e) { - log.error("重排序失败: {}", e.getMessage(), e); - // 重排序失败时返回原始粗召回结果(截取到期望数量) - int limit = Math.min(queryVectorBo.getMaxResults(), coarseResults.size()); - return new ArrayList<>(coarseResults.subList(0, limit)); + log.error("重排序流程失败: {}", e.getMessage()); + int limit = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10; + return coarseResults.subList(0, Math.min(limit, coarseResults.size())); } } + + /** + * RRF (Reciprocal Rank Fusion) 融合计算 + */ + private List calculateRRF(List vectorList, List keywordList, double alpha) { + Map allMap = new LinkedHashMap<>(); + Map vectorScores = new HashMap<>(); + Map keywordScores = new HashMap<>(); + + int k = 60; // RRF 常数 + + for (int i = 0; i < vectorList.size(); i++) { + KnowledgeRetrievalVo vo = vectorList.get(i); + allMap.put(vo.getId(), vo); + vectorScores.put(vo.getId(), 1.0 / (k + i + 1)); + } + + for (int i = 0; i < keywordList.size(); i++) { + KnowledgeRetrievalVo vo = keywordList.get(i); + if (!allMap.containsKey(vo.getId())) { + allMap.put(vo.getId(), vo); + } + keywordScores.put(vo.getId(), 1.0 / (k + i + 1)); + } + + List fusedResults = new ArrayList<>(); + for (Map.Entry entry : allMap.entrySet()) { + String id = entry.getKey(); + double finalScore = (1 - alpha) * vectorScores.getOrDefault(id, 0.0) + + alpha * keywordScores.getOrDefault(id, 0.0); + + KnowledgeRetrievalVo vo = entry.getValue(); + vo.setScore(finalScore * 60.0); // 归一化缩放 + fusedResults.add(vo); + } + + fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); + return fusedResults; + } + + private QueryVectorBo copyOf(QueryVectorBo original, int maxResults) { + QueryVectorBo copy = new QueryVectorBo(); + copy.setQuery(original.getQuery()); + copy.setKid(original.getKid()); + copy.setMaxResults(maxResults); + copy.setVectorModelName(original.getVectorModelName()); + copy.setEmbeddingModelName(original.getEmbeddingModelName()); + copy.setApiKey(original.getApiKey()); + copy.setBaseUrl(original.getBaseUrl()); + return copy; + } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java index 3c37835f..66a6a6f2 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java @@ -3,6 +3,7 @@ package org.ruoyi.service.vector; import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import java.util.List; @@ -17,6 +18,11 @@ public interface VectorStoreService { List getQueryVector(QueryVectorBo queryVectorBo); + /** + * 带分数及元数据的检索(用于测试检索功能) + */ + List search(QueryVectorBo queryVectorBo); + void createSchema(String kid, String embeddingModelName); void removeById(String id, String modelName) throws ServiceException; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java index 906c8090..2fd2052d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java @@ -37,6 +37,24 @@ public abstract class AbstractVectorStoreStrategy implements VectorStoreService return result; } + /** + * 向量 L2 归一化 (单位化) + */ + protected static float[] normalize(float[] vector) { + if (vector == null) return null; + double sum = 0; + for (float v : vector) { + sum += v * v; + } + float norm = (float) Math.sqrt(sum); + if (norm > 1e-9) { + for (int i = 0; i < vector.length; i++) { + vector[i] /= norm; + } + } + return vector; + } + /** * 获取向量模型 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java index baf1c612..c06b9c73 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java @@ -19,7 +19,11 @@ import org.ruoyi.common.chat.service.chat.IChatModelService; import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.EmbeddingModelFactory; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; +import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.springframework.stereotype.Component; import java.util.ArrayList; @@ -32,10 +36,14 @@ import java.util.stream.IntStream; @Component public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { + private final KnowledgeAttachMapper knowledgeAttachMapper; + public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, IChatModelService chatModelService, - EmbeddingModelFactory embeddingModelFactory) { + EmbeddingModelFactory embeddingModelFactory, + KnowledgeAttachMapper knowledgeAttachMapper) { super(vectorStoreProperties, embeddingModelFactory, chatModelService); + this.knowledgeAttachMapper = knowledgeAttachMapper; } // 缓存不同集合与 autoFlush 配置的 Milvus 连接 @@ -51,7 +59,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { .collectionName(collectionName) .dimension(dimension) .indexType(IndexType.IVF_FLAT) - .metricType(MetricType.L2) + .metricType(MetricType.COSINE) .autoFlushOnInsert(autoFlushOnInsert) .idFieldName("id") .textFieldName("text") @@ -104,7 +112,10 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { TextSegment textSegment = TextSegment.from(text, metadata); Embedding embedding = embeddingModel.embed(text).content(); - embeddingStore.add(embedding, textSegment); + // 单位化处理 + float[] vector = embedding.vector(); + normalize(vector); + embeddingStore.add(Embedding.from(vector), textSegment); }); long endTime = System.currentTimeMillis(); log.info("Milvus向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000); @@ -136,6 +147,55 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { return resultList; } + @Override + public List search(QueryVectorBo queryVectorBo) { + int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); + + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + // 查询向量单位化处理 + float[] queryVector = queryEmbedding.vector(); + normalize(queryVector); + + String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid(); + + EmbeddingStore embeddingStore = getMilvusStore(collectionName, dimension, true); + + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(Embedding.from(queryVector)) + .maxResults(queryVectorBo.getMaxResults()) + .build(); + + List> matches = embeddingStore.search(request).matches(); + List resultList = new ArrayList<>(); + + for (EmbeddingMatch match : matches) { + TextSegment segment = match.embedded(); + if (segment == null) continue; + + String docId = segment.metadata().getString("docId"); + String sourceName = "未知来源"; + if (docId != null) { + KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper() + .eq(KnowledgeAttach::getDocId, docId) + .last("limit 1")); + if (attach != null) { + sourceName = attach.getName(); + } + } + + // 提取内容、评分及来源 + double score = match.score(); + + resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder() + .content(segment.text()) + .score(score) + .sourceName(sourceName) + .build()); + } + return resultList; + } + @Override @SneakyThrows public void removeById(String id, String modelName) { diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java index 973d8485..da6bca80 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java @@ -24,7 +24,11 @@ import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.EmbeddingModelFactory; +import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.springframework.stereotype.Component; import static io.qdrant.client.VectorInputFactory.vectorInput; @@ -47,10 +51,14 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { private static final String METADATA_KID_KEY = "kid"; private static final String METADATA_DOC_ID_KEY = "doc_id"; + private final KnowledgeAttachMapper knowledgeAttachMapper; + public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, IChatModelService chatModelService, - EmbeddingModelFactory embeddingModelFactory) { + EmbeddingModelFactory embeddingModelFactory, + KnowledgeAttachMapper knowledgeAttachMapper) { super(vectorStoreProperties, embeddingModelFactory, chatModelService); + this.knowledgeAttachMapper = knowledgeAttachMapper; } private EmbeddingStore getQdrantStore(String collectionName) { @@ -129,7 +137,10 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { metadata.put(METADATA_DOC_ID_KEY, docId); TextSegment textSegment = TextSegment.from(text, metadata); Embedding embedding = embeddingModel.embed(text).content(); - embeddingStore.add(embedding, textSegment); + // 单位化处理 + float[] vector = embedding.vector(); + normalize(vector); + embeddingStore.add(Embedding.from(vector), textSegment); }); long endTime = System.currentTimeMillis(); @@ -140,18 +151,22 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { public List getQueryVector(QueryVectorBo queryVectorBo) { EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + // 查询向量单位化处理 + float[] queryVector = queryEmbedding.vector(); + normalize(queryVector); + String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid(); - List vector = new ArrayList<>(); - for (float f : queryEmbedding.vector()) { - vector.add(f); + List vectorList = new ArrayList<>(); + for (float f : queryVector) { + vectorList.add(f); } try (QdrantClient client = buildQdrantClient()) { QueryPoints request = QueryPoints.newBuilder() .setCollectionName(collectionName) .setQuery(Query.newBuilder() - .setNearest(vectorInput(vector)) + .setNearest(vectorInput(vectorList)) .build()) .setLimit(queryVectorBo.getMaxResults()) .setWithPayload(enable(true)) @@ -172,6 +187,69 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { } } + @Override + public List search(QueryVectorBo queryVectorBo) { + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + // 查询向量单位化处理 + float[] queryVector = queryEmbedding.vector(); + normalize(queryVector); + + String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid(); + + List vectorList = new ArrayList<>(); + for (float f : queryVector) { + vectorList.add(f); + } + + try (QdrantClient client = buildQdrantClient()) { + QueryPoints request = QueryPoints.newBuilder() + .setCollectionName(collectionName) + .setQuery(Query.newBuilder() + .setNearest(vectorInput(vectorList)) + .build()) + .setLimit(queryVectorBo.getMaxResults()) + .setWithPayload(enable(true)) + .build(); + + List results = client.queryAsync(request).get(); + List resultList = new ArrayList<>(); + for (ScoredPoint point : results) { + String content = ""; + JsonWithInt.Value textValue = point.getPayloadMap().get(TEXT_SEGMENT_KEY); + if (textValue != null && textValue.hasStringValue()) { + content = textValue.getStringValue(); + } + + String docId = null; + JsonWithInt.Value docIdValue = point.getPayloadMap().get(METADATA_DOC_ID_KEY); + if (docIdValue != null && docIdValue.hasStringValue()) { + docId = docIdValue.getStringValue(); + } + + String sourceName = "未知来源"; + if (docId != null) { + KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper() + .eq(KnowledgeAttach::getDocId, docId) + .last("limit 1")); + if (attach != null) { + sourceName = attach.getName(); + } + } + + resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder() + .content(content) + .score((double) point.getScore()) + .sourceName(sourceName) + .build()); + } + return resultList; + } catch (Exception e) { + log.error("Qdrant检索失败: {}", collectionName, e); + throw new ServiceException("Qdrant向量检索失败"); + } + } + @Override public void removeById(String id, String modelName) { String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java index 603ed84c..73b1fa2d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java @@ -4,6 +4,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.VectorStoreStrategyFactory; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.context.annotation.Primary; @@ -54,6 +55,13 @@ public class VectorStoreServiceImpl implements VectorStoreService { return strategy.getQueryVector(queryVectorBo); } + @Override + public List search(QueryVectorBo queryVectorBo) { + log.info("执行测试搜索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery()); + VectorStoreService strategy = getCurrentStrategy(); + return strategy.search(queryVectorBo); + } + @Override public void removeById(String id, String modelName) { log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName); diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java index c62a8470..f90f0568 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java @@ -12,6 +12,7 @@ import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.EmbeddingModelFactory; import org.springframework.stereotype.Component; import io.weaviate.client.Config; @@ -24,6 +25,9 @@ import io.weaviate.client.v1.graphql.model.GraphQLResponse; import io.weaviate.client.v1.schema.model.Property; import io.weaviate.client.v1.schema.model.Schema; import io.weaviate.client.v1.schema.model.WeaviateClass; +import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import java.util.ArrayList; import java.util.Collections; @@ -40,11 +44,14 @@ import java.util.Map; public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { private WeaviateClient client; + private final KnowledgeAttachMapper knowledgeAttachMapper; public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, IChatModelService chatModelService, - EmbeddingModelFactory embeddingModelFactory) { + EmbeddingModelFactory embeddingModelFactory, + KnowledgeAttachMapper knowledgeAttachMapper) { super(vectorStoreProperties, embeddingModelFactory,chatModelService); + this.knowledgeAttachMapper = knowledgeAttachMapper; } @Override @@ -110,9 +117,12 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { "kid", kid, "docId", docId ); - Float[] vector = toObjectArray(embedding.vector()); + float[] vectorArray = embedding.vector(); + normalize(vectorArray); + Float[] vector = toObjectArray(vectorArray); + client.data().creator() - .withClassName("LocalKnowledge" + kid) + .withClassName(vectorStoreProperties.getWeaviate().getClassname() + kid) .withProperties(properties) .withVector(vector) .run(); @@ -128,6 +138,9 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); float[] vector = queryEmbedding.vector(); + // 查询向量单位化处理 + normalize(vector); + List vectorStrings = new ArrayList<>(); for (float v : vector) { vectorStrings.add(String.valueOf(v)); @@ -178,6 +191,77 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { } } + @Override + public List search(QueryVectorBo queryVectorBo) { + createSchema(queryVectorBo.getKid(), queryVectorBo.getEmbeddingModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + float[] vector = queryEmbedding.vector(); + // 查询向量单位化处理 + normalize(vector); + List vectorStrings = new ArrayList<>(); + for (float v : vector) { + vectorStrings.add(String.valueOf(v)); + } + String vectorStr = String.join(",", vectorStrings); + String className = vectorStoreProperties.getWeaviate().getClassname(); + + String graphQLQuery = String.format( + "{\n" + + " Get {\n" + + " %s(nearVector: {vector: [%s]} limit: %d) {\n" + + " text\n" + + " docId\n" + + " _additional {\n" + + " distance\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + className + queryVectorBo.getKid(), + vectorStr, + queryVectorBo.getMaxResults() + ); + + Result result = client.graphQL().raw().withQuery(graphQLQuery).run(); + List resultList = new ArrayList<>(); + + if (result != null && !result.hasErrors()) { + Object data = result.getResult().getData(); + JSONObject entries = new JSONObject(data); + Map entriesMap = entries.get("Get", Map.class); + cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid()); + + for (Object obj : objects) { + Map map = (Map) obj; + String content = (String) map.get("text"); + String docId = (String) map.get("docId"); + + Map additional = (Map) map.get("_additional"); + Double distance = Double.valueOf(String.valueOf(additional.get("distance"))); + // 转换距离为得分 (Weaviate 0 是最相近,1 是最远;余弦距离下 1-dist 即为相似度) + double score = 1.0 - distance; + + String sourceName = "未知来源"; + if (docId != null) { + KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper() + .eq(KnowledgeAttach::getDocId, docId) + .last("limit 1")); + if (attach != null) { + sourceName = attach.getName(); + } + } + + resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder() + .content(content) + .score(score) + .sourceName(sourceName) + .build()); + } + } + return resultList; + } + @Override @SneakyThrows public void removeById(String id, String modelName) {