From 28ad29d6edb516420b0e9118c43b5130e3905e58 Mon Sep 17 00:00:00 2001 From: RobustH <1511209518@qq.com> Date: Sun, 12 Apr 2026 18:37:52 +0800 Subject: [PATCH 1/8] =?UTF-8?q?feat(knowledge):=20=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E5=8F=8A=E9=99=84=E4=BB=B6=E7=BB=9F?= =?UTF-8?q?=E8=AE=A1=E5=8A=9F=E8=83=BD=E5=B9=B6=E4=BF=AE=E5=A4=8D=E5=88=86?= =?UTF-8?q?=E5=9D=97=E6=95=B0=E7=BB=9F=E8=AE=A1=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../vo/knowledge/DocFragmentCountVo.java | 20 ++++++++++++ .../vo/knowledge/KnowledgeAttachVo.java | 12 +++++++ .../domain/vo/knowledge/KnowledgeInfoVo.java | 5 +++ .../knowledge/KnowledgeAttachMapper.java | 7 ++++ .../knowledge/KnowledgeFragmentMapper.java | 19 +++++++++++ .../impl/KnowledgeAttachServiceImpl.java | 32 ++++++++++++++++++- .../impl/KnowledgeInfoServiceImpl.java | 16 ++++++++++ 7 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/DocFragmentCountVo.java diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/DocFragmentCountVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/DocFragmentCountVo.java new file mode 100644 index 00000000..f162d93b --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/DocFragmentCountVo.java @@ -0,0 +1,20 @@ +package org.ruoyi.domain.vo.knowledge; + +import lombok.Data; + +/** + * 文档分块数统计 VO(用于 GROUP BY 查询结果接收) + */ +@Data +public class DocFragmentCountVo { + + /** + * 文档ID(关联 knowledge_attach.doc_id) + */ + private String docId; + + /** + * 该文档下的分块数量 + */ + private Integer fragmentCount; +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java index 8d3f2f08..adf27c9c 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java @@ -8,6 +8,7 @@ import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; import java.io.Serial; import java.io.Serializable; +import java.util.Date; @@ -68,5 +69,16 @@ public class KnowledgeAttachVo implements Serializable { @ExcelProperty(value = "备注") private String remark; + /** + * 上传时间(来自 BaseEntity.createTime) + */ + @ExcelProperty(value = "上传时间") + private Date createTime; + + /** + * 分块数(统计字段,非数据库列) + */ + private Integer fragmentCount; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java index bf0580dd..53e136dd 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java @@ -100,5 +100,10 @@ public class KnowledgeInfoVo implements Serializable { @ExcelProperty(value = "备注") private String remark; + /** + * 文档数(统计字段,非数据库列) + */ + private Integer documentCount; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java index 71ab3b28..b413916b 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java @@ -1,5 +1,7 @@ package org.ruoyi.mapper.knowledge; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.annotations.Select; import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; import org.ruoyi.domain.vo.knowledge.KnowledgeAttachVo; import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus; @@ -12,4 +14,9 @@ import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus; */ public interface KnowledgeAttachMapper extends BaseMapperPlus { + /** + * 统计指定知识库下的文档数量 + */ + @Select("SELECT COUNT(*) FROM knowledge_attach WHERE knowledge_id = #{knowledgeId}") + int countByKnowledgeId(@Param("knowledgeId") Long knowledgeId); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java index b16995fa..2d4f5d32 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java @@ -1,9 +1,14 @@ package org.ruoyi.mapper.knowledge; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.annotations.Select; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; +import org.ruoyi.domain.vo.knowledge.DocFragmentCountVo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus; +import java.util.List; + /** * 知识片段Mapper接口 * @@ -12,4 +17,18 @@ import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus; */ public interface KnowledgeFragmentMapper extends BaseMapperPlus { + /** + * 批量统计各文档的分块数(强类型接收,避免 Map key 大小写问题) + * + * @param docIds 文档 ID 列表 + * @return 每个 docId 对应的分块数列表 + */ + @Select("") + List selectFragmentCountByDocIds(@Param("docIds") List docIds); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java index ce02785e..a030d599 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java @@ -20,6 +20,7 @@ import org.ruoyi.domain.bo.knowledge.KnowledgeInfoUploadBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; +import org.ruoyi.domain.vo.knowledge.DocFragmentCountVo; import org.ruoyi.domain.vo.knowledge.KnowledgeAttachVo; import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; import org.ruoyi.factory.ResourceLoaderFactory; @@ -81,6 +82,9 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { public TableDataInfo queryPageList(KnowledgeAttachBo bo, PageQuery pageQuery) { LambdaQueryWrapper lqw = buildQueryWrapper(bo); Page result = baseMapper.selectVoPage(pageQuery.build(), lqw); + // 批量填充分块数 + List records = result.getRecords(); + fillFragmentCount(records); return TableDataInfo.build(result); } @@ -93,7 +97,33 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { @Override public List queryList(KnowledgeAttachBo bo) { LambdaQueryWrapper lqw = buildQueryWrapper(bo); - return baseMapper.selectVoList(lqw); + List list = baseMapper.selectVoList(lqw); + fillFragmentCount(list); + return list; + } + + /** + * 批量填充每个附件记录的分块数(fragmentCount) + */ + private void fillFragmentCount(List records) { + if (records == null || records.isEmpty()) return; + List docIds = records.stream() + .map(KnowledgeAttachVo::getDocId) + .filter(docId -> docId != null && !docId.isEmpty()) + .distinct() + .collect(java.util.stream.Collectors.toList()); + if (docIds.isEmpty()) return; + List countList = + knowledgeFragmentMapper.selectFragmentCountByDocIds(docIds); + Map countMap = new java.util.HashMap<>(); + for (DocFragmentCountVo item : countList) { + if (item.getDocId() != null) { + countMap.put(item.getDocId(), item.getFragmentCount()); + } + } + for (KnowledgeAttachVo vo : records) { + vo.setFragmentCount(countMap.getOrDefault(vo.getDocId(), 0)); + } } private LambdaQueryWrapper buildQueryWrapper(KnowledgeAttachBo bo) { diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeInfoServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeInfoServiceImpl.java index 6e7a8d7b..266fd5c8 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeInfoServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeInfoServiceImpl.java @@ -12,6 +12,7 @@ import lombok.extern.slf4j.Slf4j; import org.ruoyi.domain.bo.knowledge.KnowledgeInfoBo; import org.ruoyi.domain.entity.knowledge.KnowledgeInfo; import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; import org.ruoyi.mapper.knowledge.KnowledgeInfoMapper; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.springframework.stereotype.Service; @@ -33,6 +34,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { private final KnowledgeInfoMapper baseMapper; + private final KnowledgeAttachMapper knowledgeAttachMapper; + /** * 查询知识库 * @@ -55,6 +58,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { public TableDataInfo queryPageList(KnowledgeInfoBo bo, PageQuery pageQuery) { LambdaQueryWrapper lqw = buildQueryWrapper(bo); Page result = baseMapper.selectVoPage(pageQuery.build(), lqw); + // 批量填充文档数 + fillDocumentCount(result.getRecords()); return TableDataInfo.build(result); } @@ -87,6 +92,17 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { return lqw; } + /** + * 批量填充知识库列表每一条记录的文档数(documentCount) + */ + private void fillDocumentCount(List records) { + if (records == null || records.isEmpty()) return; + for (KnowledgeInfoVo vo : records) { + int count = knowledgeAttachMapper.countByKnowledgeId(vo.getId()); + vo.setDocumentCount(count); + } + } + /** * 新增知识库 * From 0fa25032a3935fa3b87a1c3b8771847077b9fc4f Mon Sep 17 00:00:00 2001 From: RobustH <1511209518@qq.com> Date: Mon, 13 Apr 2026 00:15:01 +0800 Subject: [PATCH 2/8] =?UTF-8?q?feat(knowledge):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E6=96=87=E4=BB=B6=E7=8A=B6=E6=80=81?= =?UTF-8?q?=E6=9E=9A=E4=B8=BE=E4=B8=BA"=E6=9C=AA=E8=A7=A3=E6=9E=90?= =?UTF-8?q?=EF=BC=8C=E8=A7=A3=E6=9E=90=E4=B8=AD=EF=BC=8C=E8=A7=A3=E6=9E=90?= =?UTF-8?q?=E6=88=90=E5=8A=9F=EF=BC=8C=E8=A7=A3=E6=9E=90=E5=A4=B1=E8=B4=A5?= =?UTF-8?q?"=EF=BC=8C=E6=94=AF=E6=8C=81=E5=BC=82=E6=AD=A5=E7=BA=BF?= =?UTF-8?q?=E7=A8=8B=E6=B1=A0=E8=A7=A3=E6=9E=90=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../common/core/config/ThreadPoolConfig.java | 23 ++ .../knowledge/KnowledgeAttachController.java | 13 +- .../bo/knowledge/KnowledgeInfoUploadBo.java | 5 + .../entity/knowledge/KnowledgeAttach.java | 5 + .../vo/knowledge/KnowledgeAttachVo.java | 6 + .../ruoyi/enums/KnowledgeAttachStatus.java | 38 ++++ .../knowledge/KnowledgeAttachMapper.java | 2 + .../knowledge/KnowledgeFragmentMapper.java | 2 + .../knowledge/IKnowledgeAttachService.java | 7 + .../impl/KnowledgeAttachServiceImpl.java | 206 ++++++++---------- 10 files changed, 190 insertions(+), 117 deletions(-) create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/enums/KnowledgeAttachStatus.java diff --git a/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/ThreadPoolConfig.java b/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/ThreadPoolConfig.java index cc218c21..313d2782 100644 --- a/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/ThreadPoolConfig.java +++ b/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/config/ThreadPoolConfig.java @@ -10,6 +10,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean; import org.springframework.core.task.VirtualThreadTaskExecutor; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import java.util.concurrent.*; /** @@ -22,6 +23,12 @@ import java.util.concurrent.*; @EnableConfigurationProperties(ThreadPoolProperties.class) public class ThreadPoolConfig { + private final ThreadPoolProperties properties; + + public ThreadPoolConfig(ThreadPoolProperties properties) { + this.properties = properties; + } + /** * 核心线程数 = cpu 核心数 + 1 */ @@ -54,6 +61,22 @@ public class ThreadPoolConfig { return scheduledThreadPoolExecutor; } + /** + * 知识库解析专用异步线程池 + */ + @Bean(name = "knowledgeParseExecutor") + public ThreadPoolTaskExecutor knowledgeParseExecutor() { + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + executor.setCorePoolSize(core); + executor.setMaxPoolSize(core * 2); + executor.setQueueCapacity(properties.getQueueCapacity()); + executor.setKeepAliveSeconds(properties.getKeepAliveSeconds()); + executor.setThreadNamePrefix("knowledge-parse-pool-"); + executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy()); + executor.initialize(); + return executor; + } + /** * 销毁事件 * 停止线程池 diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeAttachController.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeAttachController.java index 6367b618..e2e1b4e8 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeAttachController.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeAttachController.java @@ -110,6 +110,17 @@ public class KnowledgeAttachController extends BaseController { @PostMapping(value = "/upload") public R upload(KnowledgeInfoUploadBo bo){ knowledgeAttachService.upload(bo); - return R.ok("上传知识库附件成功!"); + return R.ok("上传成功!"); + } + + /** + * 手动解析附件内容 + * + * @param id 附件ID + */ + @PostMapping("/parse/{id}") + public R parse(@PathVariable Long id) { + knowledgeAttachService.parse(id); + return R.ok(); } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoUploadBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoUploadBo.java index a2e1abd0..04c39d05 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoUploadBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoUploadBo.java @@ -16,6 +16,11 @@ public class KnowledgeInfoUploadBo { private MultipartFile file; + /** + * 是否自动解析 (true: 立即解析, false: 仅上传) + */ + private Boolean autoParse; + /** * 生效时间, 为空则立即生效 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeAttach.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeAttach.java index 486c2466..2d836f88 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeAttach.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeAttach.java @@ -57,5 +57,10 @@ public class KnowledgeAttach extends BaseEntity { */ private String remark; + /** + * 解析状态: 0待解析, 1解析中, 2已解析, 3解析失败 + */ + private Integer status; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java index adf27c9c..09fbd62e 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeAttachVo.java @@ -75,6 +75,12 @@ public class KnowledgeAttachVo implements Serializable { @ExcelProperty(value = "上传时间") private Date createTime; + /** + * 解析状态: 0待解析, 1解析中, 2已解析, 3解析失败 + */ + @ExcelProperty(value = "解析状态") + private Integer status; + /** * 分块数(统计字段,非数据库列) */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/enums/KnowledgeAttachStatus.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/enums/KnowledgeAttachStatus.java new file mode 100644 index 00000000..ea13da11 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/enums/KnowledgeAttachStatus.java @@ -0,0 +1,38 @@ +package org.ruoyi.enums; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * 知识库附件解析状态枚举 + * + * @author RobustH + */ +@Getter +@AllArgsConstructor +public enum KnowledgeAttachStatus { + + /** + * 待解析 + */ + WAITING(0, "待解析"), + + /** + * 解析中 + */ + PARSING(1, "解析中"), + + /** + * 已解析 + */ + COMPLETED(2, "已解析"), + + /** + * 解析失败 + */ + FAILED(3, "解析失败"); + + private final Integer code; + private final String info; + +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java index b413916b..05b4d780 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeAttachMapper.java @@ -1,5 +1,6 @@ package org.ruoyi.mapper.knowledge; +import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Select; import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; @@ -12,6 +13,7 @@ import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus; * @author ageerle * @date 2025-12-17 */ +@Mapper public interface KnowledgeAttachMapper extends BaseMapperPlus { /** diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java index 2d4f5d32..b99ad6af 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java @@ -1,5 +1,6 @@ package org.ruoyi.mapper.knowledge; +import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Select; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; @@ -15,6 +16,7 @@ import java.util.List; * @author ageerle * @date 2025-12-17 */ +@Mapper public interface KnowledgeFragmentMapper extends BaseMapperPlus { /** diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeAttachService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeAttachService.java index 013c9736..75d0a528 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeAttachService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeAttachService.java @@ -72,4 +72,11 @@ public interface IKnowledgeAttachService { * 上传附件 */ void upload(KnowledgeInfoUploadBo bo); + + /** + * 解析附件知识片段 + * + * @param id 附件ID + */ + void parse(Long id); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java index a030d599..e8260f25 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java @@ -2,19 +2,21 @@ package org.ruoyi.service.knowledge.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.RandomUtil; -import org.ruoyi.common.chat.service.chat.IChatModelService; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.core.toolkit.Wrappers; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.common.chat.service.chat.IChatModelService; +import org.ruoyi.enums.KnowledgeAttachStatus; import org.ruoyi.common.core.domain.dto.OssDTO; import org.ruoyi.common.core.service.OssService; import org.ruoyi.common.core.utils.MapstructUtils; +import org.ruoyi.common.core.utils.SpringUtils; import org.ruoyi.common.core.utils.StringUtils; -import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.common.mybatis.core.page.PageQuery; -import com.baomidou.mybatisplus.extension.plugins.pagination.Page; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; -import com.baomidou.mybatisplus.core.toolkit.Wrappers; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.domain.bo.knowledge.KnowledgeAttachBo; import org.ruoyi.domain.bo.knowledge.KnowledgeInfoUploadBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; @@ -30,11 +32,15 @@ import org.ruoyi.service.knowledge.IKnowledgeAttachService; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.service.knowledge.ResourceLoader; import org.ruoyi.service.vector.VectorStoreService; +import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; -import java.io.IOException; +import java.io.InputStream; + +import java.net.URL; import java.util.*; +import java.util.stream.Collectors; /** * 知识库附件Service业务层处理 @@ -48,52 +54,26 @@ import java.util.*; public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { private final KnowledgeAttachMapper baseMapper; - private final IKnowledgeInfoService knowledgeInfoService; - private final KnowledgeFragmentMapper knowledgeFragmentMapper; - private final IChatModelService chatModelService; - private final ResourceLoaderFactory resourceLoaderFactory; - private final VectorStoreService vectorStoreService; - private final OssService ossService; - /** - * 查询知识库附件 - * - * @param id 主键 - * @return 知识库附件 - */ + @Override - public KnowledgeAttachVo queryById(Long id){ + public KnowledgeAttachVo queryById(Long id) { return baseMapper.selectVoById(id); } - /** - * 分页查询知识库附件列表 - * - * @param bo 查询条件 - * @param pageQuery 分页参数 - * @return 知识库附件分页列表 - */ @Override public TableDataInfo queryPageList(KnowledgeAttachBo bo, PageQuery pageQuery) { LambdaQueryWrapper lqw = buildQueryWrapper(bo); Page result = baseMapper.selectVoPage(pageQuery.build(), lqw); - // 批量填充分块数 - List records = result.getRecords(); - fillFragmentCount(records); + fillFragmentCount(result.getRecords()); return TableDataInfo.build(result); } - /** - * 查询符合条件的知识库附件列表 - * - * @param bo 查询条件 - * @return 知识库附件列表 - */ @Override public List queryList(KnowledgeAttachBo bo) { LambdaQueryWrapper lqw = buildQueryWrapper(bo); @@ -102,32 +82,23 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { return list; } - /** - * 批量填充每个附件记录的分块数(fragmentCount) - */ private void fillFragmentCount(List records) { if (records == null || records.isEmpty()) return; List docIds = records.stream() .map(KnowledgeAttachVo::getDocId) - .filter(docId -> docId != null && !docId.isEmpty()) + .filter(StringUtils::isNotBlank) .distinct() - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); if (docIds.isEmpty()) return; - List countList = - knowledgeFragmentMapper.selectFragmentCountByDocIds(docIds); - Map countMap = new java.util.HashMap<>(); - for (DocFragmentCountVo item : countList) { - if (item.getDocId() != null) { - countMap.put(item.getDocId(), item.getFragmentCount()); - } - } + List countList = knowledgeFragmentMapper.selectFragmentCountByDocIds(docIds); + Map countMap = countList.stream() + .collect(Collectors.toMap(DocFragmentCountVo::getDocId, DocFragmentCountVo::getFragmentCount, (k1, k2) -> k1)); for (KnowledgeAttachVo vo : records) { vo.setFragmentCount(countMap.getOrDefault(vo.getDocId(), 0)); } } private LambdaQueryWrapper buildQueryWrapper(KnowledgeAttachBo bo) { - Map params = bo.getParams(); LambdaQueryWrapper lqw = Wrappers.lambdaQuery(); lqw.orderByAsc(KnowledgeAttach::getId); lqw.eq(bo.getKnowledgeId() != null, KnowledgeAttach::getKnowledgeId, bo.getKnowledgeId()); @@ -137,16 +108,9 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { return lqw; } - /** - * 新增知识库附件 - * - * @param bo 知识库附件 - * @return 是否新增成功 - */ @Override public Boolean insertByBo(KnowledgeAttachBo bo) { KnowledgeAttach add = MapstructUtils.convert(bo, KnowledgeAttach.class); - validEntityBeforeSave(add); boolean flag = baseMapper.insert(add) > 0; if (flag) { bo.setId(add.getId()); @@ -154,66 +118,72 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { return flag; } - /** - * 修改知识库附件 - * - * @param bo 知识库附件 - * @return 是否修改成功 - */ @Override public Boolean updateByBo(KnowledgeAttachBo bo) { KnowledgeAttach update = MapstructUtils.convert(bo, KnowledgeAttach.class); - validEntityBeforeSave(update); return baseMapper.updateById(update) > 0; } - /** - * 保存前的数据校验 - */ - private void validEntityBeforeSave(KnowledgeAttach entity){ - //TODO 做一些数据校验,如唯一约束 - } - - /** - * 校验并批量删除知识库附件信息 - * - * @param ids 待删除的主键集合 - * @param isValid 是否进行有效性校验 - * @return 是否删除成功 - */ @Override public Boolean deleteWithValidByIds(Collection ids, Boolean isValid) { - if(isValid){ - //TODO 做一些业务上的校验,判断是否需要校验 - } return baseMapper.deleteByIds(ids) > 0; } @Override public void upload(KnowledgeInfoUploadBo bo) { MultipartFile file = bo.getFile(); - // 保存文件信息 OssDTO ossDTO = ossService.uploadFile(file); - Long knowledgeId = bo.getKnowledgeId(); - List chunkList = new ArrayList<>(); + KnowledgeAttach knowledgeAttach = new KnowledgeAttach(); knowledgeAttach.setKnowledgeId(bo.getKnowledgeId()); - String docId = RandomUtil.randomString(10); knowledgeAttach.setOssId(ossDTO.getOssId()); - knowledgeAttach.setDocId(docId); + knowledgeAttach.setDocId(RandomUtil.randomString(10)); knowledgeAttach.setName(ossDTO.getOriginalName()); knowledgeAttach.setType(ossDTO.getFileSuffix()); - String content = ""; - ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(knowledgeAttach.getType()); - // 文档分段入库 - List fids = new ArrayList<>(); + knowledgeAttach.setStatus(KnowledgeAttachStatus.WAITING.getCode()); // 待解析 + + baseMapper.insert(knowledgeAttach); + + if (Boolean.TRUE.equals(bo.getAutoParse())) { + // 通过 SpringUtils 获取代理对象,确保 @Async 生效 + SpringUtils.getBean(IKnowledgeAttachService.class).parse(knowledgeAttach.getId()); + } + } + + @Async("knowledgeParseExecutor") + @Override + public void parse(Long id) { + KnowledgeAttach attach = baseMapper.selectById(id); + if (attach == null || (!KnowledgeAttachStatus.WAITING.getCode().equals(attach.getStatus()) && !KnowledgeAttachStatus.FAILED.getCode().equals(attach.getStatus()))) { + return; + } + try { - content = resourceLoader.getContent(file.getInputStream()); - chunkList = resourceLoader.getChunkList(content, String.valueOf(knowledgeId)); + attach.setStatus(KnowledgeAttachStatus.PARSING.getCode()); // 解析中 + baseMapper.updateById(attach); + + log.info("开始解析知识库文档... id: {}, docId: {}", id, attach.getDocId()); + + Long knowledgeId = attach.getKnowledgeId(); + String docId = attach.getDocId(); + + // 获取文件信息并下载 + List ossDTOs = ossService.selectByIds(String.valueOf(attach.getOssId())); + if (ossDTOs == null || ossDTOs.isEmpty()) { + throw new RuntimeException("未找到对应的 OSS 文件信息"); + } + OssDTO ossDTO = ossDTOs.get(0); + String content; + ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(attach.getType()); + try (InputStream inputStream = new URL(ossDTO.getUrl()).openStream()) { + content = resourceLoader.getContent(inputStream); + } + List chunkList = resourceLoader.getChunkList(content, String.valueOf(knowledgeId)); + + List fids = new ArrayList<>(); List knowledgeFragmentList = new ArrayList<>(); if (CollUtil.isNotEmpty(chunkList)) { for (int i = 0; i < chunkList.size(); i++) { - // 生成知识片段ID String fid = RandomUtil.randomString(10); fids.add(fid); KnowledgeFragment knowledgeFragment = new KnowledgeFragment(); @@ -223,29 +193,33 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { knowledgeFragment.setCreateTime(new Date()); knowledgeFragmentList.add(knowledgeFragment); } + knowledgeFragmentMapper.delete(Wrappers.lambdaQuery().eq(KnowledgeFragment::getDocId, docId)); + knowledgeFragmentMapper.insertBatch(knowledgeFragmentList); + log.info("文档切片并入库完成,共计 {} 个片段。id: {}", chunkList.size(), id); } - knowledgeFragmentMapper.insertBatch(knowledgeFragmentList); - } catch (IOException e) { - log.error("保存知识库信息失败!{}", e.getMessage()); + + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(knowledgeId); + ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); + + StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo(); + storeEmbeddingBo.setKid(String.valueOf(knowledgeId)); + storeEmbeddingBo.setDocId(docId); + storeEmbeddingBo.setFids(fids); + storeEmbeddingBo.setChunkList(chunkList); + storeEmbeddingBo.setVectorStoreName(knowledgeInfoVo.getVectorModel()); + storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); + storeEmbeddingBo.setApiKey(chatModelVo.getApiKey()); + storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost()); + vectorStoreService.storeEmbeddings(storeEmbeddingBo); + + attach.setStatus(KnowledgeAttachStatus.COMPLETED.getCode()); // 已完成 + baseMapper.updateById(attach); + log.info("知识库文档解析、向量化并入库成功!id: {}", id); + } catch (Exception e) { + log.error("解析文档失败!id: {}, error: {}", id, e.getMessage(), e); + attach.setStatus(KnowledgeAttachStatus.FAILED.getCode()); // 失败 + attach.setRemark(StringUtils.substring(e.getMessage(), 0, 255)); // 保存错误原因,截取防止溢出 + baseMapper.updateById(attach); } - baseMapper.insert(knowledgeAttach); - - // 查询知识库信息 - KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(knowledgeId); - - // 查询向量模信息 - ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); - - StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo(); - storeEmbeddingBo.setKid(String.valueOf(knowledgeId)); - storeEmbeddingBo.setDocId(docId); - storeEmbeddingBo.setFids(fids); - storeEmbeddingBo.setChunkList(chunkList); - storeEmbeddingBo.setVectorStoreName(knowledgeInfoVo.getVectorModel()); - storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); - storeEmbeddingBo.setApiKey(chatModelVo.getApiKey()); - storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost()); - vectorStoreService.storeEmbeddings(storeEmbeddingBo); } - } From 06a63c377e655a1d5abe299415944d953453d5d9 Mon Sep 17 00:00:00 2001 From: RobustH <1511209518@qq.com> Date: Mon, 13 Apr 2026 23:33:56 +0800 Subject: [PATCH 3/8] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E6=A3=80?= =?UTF-8?q?=E7=B4=A2=E6=B5=8B=E8=AF=95=E7=9B=B8=E5=85=B3=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=20-=20=E5=AE=9E=E7=8E=B0=E5=90=91=E9=87=8F=20L2=20=E5=BD=92?= =?UTF-8?q?=E4=B8=80=E5=8C=96=EF=BC=8C=E7=BB=9F=E4=B8=80=20Milvus/Qdrant/W?= =?UTF-8?q?eaviate=20=E6=A3=80=E7=B4=A2=E8=AF=84=E5=88=86=E4=B8=BA=20[0,?= =?UTF-8?q?=201]=20=E7=A9=BA=E9=97=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../KnowledgeFragmentController.java | 9 ++ .../bo/knowledge/KnowledgeFragmentBo.java | 19 ++++ .../vo/knowledge/KnowledgeRetrievalVo.java | 35 ++++++++ .../service/chat/impl/ChatServiceFacade.java | 90 ++++++++++++++----- .../knowledge/IKnowledgeFragmentService.java | 9 ++ .../impl/KnowledgeFragmentServiceImpl.java | 53 +++++++++++ .../knowledge/rerank/ScoringModelFactory.java | 37 ++++++++ .../retriever/CustomVectorRetriever.java | 54 +++++++++++ .../service/vector/VectorStoreService.java | 6 ++ .../impl/AbstractVectorStoreStrategy.java | 18 ++++ .../impl/MilvusVectorStoreStrategy.java | 66 +++++++++++++- .../impl/QdrantVectorStoreStrategy.java | 90 +++++++++++++++++-- .../vector/impl/VectorStoreServiceImpl.java | 8 ++ .../impl/WeaviateVectorStoreStrategy.java | 90 ++++++++++++++++++- 14 files changed, 548 insertions(+), 36 deletions(-) create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java index ac79a62e..d739a30f 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/controller/knowledge/KnowledgeFragmentController.java @@ -8,6 +8,7 @@ import jakarta.validation.constraints.*; import cn.dev33.satoken.annotation.SaCheckPermission; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.springframework.web.bind.annotation.*; import org.springframework.validation.annotation.Validated; @@ -102,4 +103,12 @@ public class KnowledgeFragmentController extends BaseController { @PathVariable Long[] ids) { return toAjax(knowledgeFragmentService.deleteWithValidByIds(List.of(ids), true)); } + + /** + * 检索测试 + */ + @PostMapping("/retrieval") + public R> retrieval(@RequestBody KnowledgeFragmentBo bo) { + return R.ok(knowledgeFragmentService.retrieval(bo)); + } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java index 7472f57f..e1925028 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java @@ -49,5 +49,24 @@ public class KnowledgeFragmentBo extends BaseEntity { */ private String remark; + /** + * 知识库ID + */ + private Long knowledgeId; + + /** + * 检索内容 + */ + private String query; + + /** + * 返回条数 + */ + private Integer topK; + + /** + * 相似度阈值 + */ + private Double threshold; } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java new file mode 100644 index 00000000..daeaae59 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java @@ -0,0 +1,35 @@ +package org.ruoyi.domain.vo.knowledge; + +import lombok.Builder; +import lombok.Data; + +import java.io.Serial; +import java.io.Serializable; + +/** + * 知识检索测试结果视图对象 + * + * @author RobustH + */ +@Data +@Builder +public class KnowledgeRetrievalVo implements Serializable { + + @Serial + private static final long serialVersionUID = 1L; + + /** + * 片段内容 + */ + private String content; + + /** + * 相似度得分 + */ + private Double score; + + /** + * 来源文档名称 + */ + private String sourceName; +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java index ba1c5ac8..2332ca84 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java @@ -16,6 +16,15 @@ import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.rag.AugmentationRequest; +import dev.langchain4j.rag.AugmentationResult; +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.content.aggregator.ContentAggregator; +import dev.langchain4j.rag.content.aggregator.DefaultContentAggregator; +import dev.langchain4j.rag.content.aggregator.ReRankingContentAggregator; +import dev.langchain4j.model.scoring.ScoringModel; +import dev.langchain4j.rag.query.Metadata; import dev.langchain4j.service.tool.ToolProvider; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; @@ -46,6 +55,8 @@ import org.ruoyi.mcp.service.core.ToolProviderFactory; import org.ruoyi.service.chat.AbstractChatService; import org.ruoyi.service.chat.IChatMessageService; import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore; +import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever; +import org.ruoyi.service.knowledge.rerank.ScoringModelFactory; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; @@ -89,6 +100,8 @@ public class ChatServiceFacade implements IChatService { private final ToolProviderFactory toolProviderFactory; + private final ScoringModelFactory scoringModelFactory; + /** * 内存实例缓存,避免同一会话重复创建 * Key: sessionId, Value: MessageWindowChatMemory实例 @@ -119,7 +132,9 @@ public class ChatServiceFacade implements IChatService { // 2. 构建上下文消息列表 List contextMessages = buildContextMessages(chatRequest); - // 3. 处理特殊聊天模式(工作流、人机交互恢复、思考模式) + // 注意:buildContextMessages() 最后返回的列表中,最新的带有增强知识的 UserMessage 在最后。 + // 对于有些模型API(非langchain4j的代理),它们可能不识别增强后的复杂文本(取决于供应商适配度) + // 但是通过标准流,它被解析为 String。 SseEmitter specialResult = handleSpecialChatModes(chatRequest, contextMessages, chatModelVo, emitter); if (specialResult != null) { return specialResult; @@ -346,39 +361,63 @@ public class ChatServiceFacade implements IChatService { * @return 上下文消息列表 */ private List buildContextMessages(ChatRequest chatRequest) { - List messages = new ArrayList<>(); - // 构建用户消息 + List messages = new ArrayList<>(); + + // 初始化用户消息 UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent()); - messages.add(userMessage); - // 从向量库查询相关历史消息 + // 使用 LangChain4j 的 RetrievalAugmentor 进行检索增强 if (chatRequest.getKnowledgeId() != null) { - // 查询知识库信息 KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId())); - if (knowledgeInfoVo == null) { - log.warn("知识库信息不存在,kid: {}", chatRequest.getKnowledgeId()); - return messages; - } + if (knowledgeInfoVo != null) { + ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); + if (chatModel != null) { - // 查询向量模型配置信息 - ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); - if (chatModel == null) { - log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel()); - return messages; - } + // 1. 构建适配器(Retriever) + CustomVectorRetriever retriever = new CustomVectorRetriever( + vectorStoreService, knowledgeInfoVo, chatModel); - // 构建向量查询参数 - QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel); + // 2. 获取和构建重排模型聚合器(Aggregator) + // 假设已在 KnowledgeInfoVo 等加入 getRerankModelConfig/getRerankModel 等,这里演示通用逻辑 + // 若无重排需求,使用 DefaultContentAggregator 或无 ScoringModel 的聚合器 + ContentAggregator contentAggregator; + // TODO: 一旦实体类实现了重排模型的支持,此处可以从数据库读出: + // ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel()); + ChatModelVo scoringModelConfig = null; // 当前暂无对应配置字段 - // 获取向量查询结果 - List nearestList = vectorStoreService.getQueryVector(queryVectorBo); - for (String prompt : nearestList) { - // 知识库内容作为系统上下文添加 - messages.add( new AiMessage(prompt)); + ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig); + if (scoringModel != null) { + contentAggregator = ReRankingContentAggregator.builder() + .scoringModel(scoringModel) + // .maxResults(3) 这个数字将来从配置取 + .build(); + } else { + contentAggregator = new DefaultContentAggregator(); + } + + // 3. 构造流水线 + RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder() + .contentRetriever(retriever) + .contentAggregator(contentAggregator) + .build(); + + // 4. 执行 Augmentor 增强:将检索到的知识内容编织进 UserMessage 中 + Metadata ragMetadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>()); + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, ragMetadata); + + AugmentationResult augmentedResult = augmentor.augment(augmentationRequest); + + ChatMessage augmentedMessage = augmentedResult.chatMessage(); + if (augmentedMessage instanceof UserMessage) { + userMessage = (UserMessage) augmentedMessage; + } + log.info("RAG 增强完成: UserMessage 已重构并附加上下文背景。"); + + } } } - // 从数据库查询历史对话消息 + // 从数据库查询历史对话消息(历史消息应放在当前提问前) if (chatRequest.getSessionId() != null) { MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId()); if (memory != null) { @@ -390,6 +429,9 @@ public class ChatServiceFacade implements IChatService { } } + // 注入本次用户提问(经过 RAG 增强后的 UserMessage) + messages.add(userMessage); + return messages; } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java index e323e79e..b8b88b45 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/IKnowledgeFragmentService.java @@ -4,6 +4,7 @@ import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.common.mybatis.core.page.PageQuery; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import java.util.Collection; import java.util.List; @@ -65,4 +66,12 @@ public interface IKnowledgeFragmentService { * @return 是否删除成功 */ Boolean deleteWithValidByIds(Collection ids, Boolean isValid); + + /** + * 检索测试 + * + * @param bo 检索参数 + * @return 检索结果 + */ + List retrieval(KnowledgeFragmentBo bo); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java index 782b40af..9c2477df 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java @@ -13,8 +13,17 @@ import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; +import org.ruoyi.service.knowledge.IKnowledgeInfoService; +import org.ruoyi.common.chat.service.chat.IChatModelService; +import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; +import java.util.ArrayList; +import java.util.stream.Collectors; import java.util.List; import java.util.Map; @@ -32,6 +41,9 @@ import java.util.Collection; public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { private final KnowledgeFragmentMapper baseMapper; + private final IKnowledgeInfoService knowledgeInfoService; + private final IChatModelService chatModelService; + private final VectorStoreService vectorStoreService; /** * 查询知识片段 @@ -131,4 +143,45 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { } return baseMapper.deleteByIds(ids) > 0; } + + /** + * 检索测试核心实现 + */ + @Override + public List retrieval(KnowledgeFragmentBo bo) { + if (bo.getKnowledgeId() == null || StringUtils.isBlank(bo.getQuery())) { + return new ArrayList<>(); + } + + // 1. 获取知识库及模型配置 + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId()); + if (knowledgeInfoVo == null) { + return new ArrayList<>(); + } + + ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); + if (chatModel == null) { + log.warn("未找到对应的向量模型配置: {}", knowledgeInfoVo.getEmbeddingModel()); + return new ArrayList<>(); + } + + // 2. 构造向量检索参数 + QueryVectorBo queryVectorBo = new QueryVectorBo(); + queryVectorBo.setQuery(bo.getQuery()); + queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId())); + queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit()); + queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); + queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); + queryVectorBo.setApiKey(chatModel.getApiKey()); + queryVectorBo.setBaseUrl(chatModel.getApiHost()); + + // 3. 执行物理检索 + List allResults = vectorStoreService.search(queryVectorBo); + + // 4. 根据阈值过滤 (LangChain4j 结果 score 通常 0-1) + double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0; + return allResults.stream() + .filter(res -> res.getScore() >= threshold) + .collect(Collectors.toList()); + } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java new file mode 100644 index 00000000..5f28b9c2 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java @@ -0,0 +1,37 @@ +package org.ruoyi.service.knowledge.rerank; + +import dev.langchain4j.model.scoring.ScoringModel; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.springframework.stereotype.Component; + +/** + * 重排模型提供商工厂 + * 用于将来无缝拓展硅基流动、百炼等支持重排的模型厂商 + * + * @author RobustH + */ +@Slf4j +@Component +public class ScoringModelFactory { + + /** + * 根据后台传递的模型配置创建具体的重排模型 + * + * @param rerankModelConfig 重排模型的配置 (例如其 providerCode, apiUrl, apiKey 等) + * @return 标准的 LangChain4j ScoringModel + */ + public ScoringModel createScoringModel(ChatModelVo rerankModelConfig) { + if (rerankModelConfig == null) { + return null; + } + + String providerCode = rerankModelConfig.getProviderCode(); + log.info("初始化重排模型,供应商代码: {}", providerCode); + + // TODO: 在这里通过 switch 或反射具体实例化支持的各种 ScoringModel (例如 CohereScoringModel, DascScope 等) + // 目前返回 null 代表暂时没有加载特定的重排底座,这不会影响流程,Aggregator 会忽略它返回原样结果 + + return null; + } +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java new file mode 100644 index 00000000..6d876710 --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java @@ -0,0 +1,54 @@ +package org.ruoyi.service.knowledge.retriever; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.rag.content.retriever.ContentRetriever; +import dev.langchain4j.rag.query.Query; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.domain.bo.vector.QueryVectorBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; +import org.ruoyi.service.vector.VectorStoreService; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * 自定义向量检索器:适配 LangChain4j ContentRetriever 接口 + * 桥接现有的 VectorStoreService 获取检索结果 + * + * @author RobustH + */ +@Slf4j +@RequiredArgsConstructor +public class CustomVectorRetriever implements ContentRetriever { + + private final VectorStoreService vectorStoreService; + private final KnowledgeInfoVo knowledgeInfoVo; + private final ChatModelVo chatModelVo; + + @Override + public List retrieve(Query query) { + log.info("执行自定义向量检索,关键字: {}", query.text()); + + // 构建内部查询参数 + QueryVectorBo queryVectorBo = new QueryVectorBo(); + queryVectorBo.setQuery(query.text()); + queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId())); + queryVectorBo.setApiKey(chatModelVo.getApiKey()); + queryVectorBo.setBaseUrl(chatModelVo.getApiHost()); + queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); + queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); + // 如果接入了重排,这里的 retrieveLimit 也就是 MaxResults 应当被放大,后续留给 Aggregator 截断 + queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit()); + + // 执行底层的多种向量库策略检索 + List nearestList = vectorStoreService.getQueryVector(queryVectorBo); + + // 将结果包装为标准的 Content 返回 + return nearestList.stream() + .map(text -> Content.from(TextSegment.from(text))) + .collect(Collectors.toList()); + } +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java index 3c37835f..66a6a6f2 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/VectorStoreService.java @@ -3,6 +3,7 @@ package org.ruoyi.service.vector; import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import java.util.List; @@ -17,6 +18,11 @@ public interface VectorStoreService { List getQueryVector(QueryVectorBo queryVectorBo); + /** + * 带分数及元数据的检索(用于测试检索功能) + */ + List search(QueryVectorBo queryVectorBo); + void createSchema(String kid, String embeddingModelName); void removeById(String id, String modelName) throws ServiceException; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java index 906c8090..2fd2052d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/AbstractVectorStoreStrategy.java @@ -37,6 +37,24 @@ public abstract class AbstractVectorStoreStrategy implements VectorStoreService return result; } + /** + * 向量 L2 归一化 (单位化) + */ + protected static float[] normalize(float[] vector) { + if (vector == null) return null; + double sum = 0; + for (float v : vector) { + sum += v * v; + } + float norm = (float) Math.sqrt(sum); + if (norm > 1e-9) { + for (int i = 0; i < vector.length; i++) { + vector[i] /= norm; + } + } + return vector; + } + /** * 获取向量模型 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java index baf1c612..c06b9c73 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/MilvusVectorStoreStrategy.java @@ -19,7 +19,11 @@ import org.ruoyi.common.chat.service.chat.IChatModelService; import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.EmbeddingModelFactory; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; +import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.springframework.stereotype.Component; import java.util.ArrayList; @@ -32,10 +36,14 @@ import java.util.stream.IntStream; @Component public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { + private final KnowledgeAttachMapper knowledgeAttachMapper; + public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, IChatModelService chatModelService, - EmbeddingModelFactory embeddingModelFactory) { + EmbeddingModelFactory embeddingModelFactory, + KnowledgeAttachMapper knowledgeAttachMapper) { super(vectorStoreProperties, embeddingModelFactory, chatModelService); + this.knowledgeAttachMapper = knowledgeAttachMapper; } // 缓存不同集合与 autoFlush 配置的 Milvus 连接 @@ -51,7 +59,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { .collectionName(collectionName) .dimension(dimension) .indexType(IndexType.IVF_FLAT) - .metricType(MetricType.L2) + .metricType(MetricType.COSINE) .autoFlushOnInsert(autoFlushOnInsert) .idFieldName("id") .textFieldName("text") @@ -104,7 +112,10 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { TextSegment textSegment = TextSegment.from(text, metadata); Embedding embedding = embeddingModel.embed(text).content(); - embeddingStore.add(embedding, textSegment); + // 单位化处理 + float[] vector = embedding.vector(); + normalize(vector); + embeddingStore.add(Embedding.from(vector), textSegment); }); long endTime = System.currentTimeMillis(); log.info("Milvus向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000); @@ -136,6 +147,55 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { return resultList; } + @Override + public List search(QueryVectorBo queryVectorBo) { + int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); + + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + // 查询向量单位化处理 + float[] queryVector = queryEmbedding.vector(); + normalize(queryVector); + + String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid(); + + EmbeddingStore embeddingStore = getMilvusStore(collectionName, dimension, true); + + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(Embedding.from(queryVector)) + .maxResults(queryVectorBo.getMaxResults()) + .build(); + + List> matches = embeddingStore.search(request).matches(); + List resultList = new ArrayList<>(); + + for (EmbeddingMatch match : matches) { + TextSegment segment = match.embedded(); + if (segment == null) continue; + + String docId = segment.metadata().getString("docId"); + String sourceName = "未知来源"; + if (docId != null) { + KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper() + .eq(KnowledgeAttach::getDocId, docId) + .last("limit 1")); + if (attach != null) { + sourceName = attach.getName(); + } + } + + // 提取内容、评分及来源 + double score = match.score(); + + resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder() + .content(segment.text()) + .score(score) + .sourceName(sourceName) + .build()); + } + return resultList; + } + @Override @SneakyThrows public void removeById(String id, String modelName) { diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java index 973d8485..da6bca80 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/QdrantVectorStoreStrategy.java @@ -24,7 +24,11 @@ import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.EmbeddingModelFactory; +import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.springframework.stereotype.Component; import static io.qdrant.client.VectorInputFactory.vectorInput; @@ -47,10 +51,14 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { private static final String METADATA_KID_KEY = "kid"; private static final String METADATA_DOC_ID_KEY = "doc_id"; + private final KnowledgeAttachMapper knowledgeAttachMapper; + public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, IChatModelService chatModelService, - EmbeddingModelFactory embeddingModelFactory) { + EmbeddingModelFactory embeddingModelFactory, + KnowledgeAttachMapper knowledgeAttachMapper) { super(vectorStoreProperties, embeddingModelFactory, chatModelService); + this.knowledgeAttachMapper = knowledgeAttachMapper; } private EmbeddingStore getQdrantStore(String collectionName) { @@ -129,7 +137,10 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { metadata.put(METADATA_DOC_ID_KEY, docId); TextSegment textSegment = TextSegment.from(text, metadata); Embedding embedding = embeddingModel.embed(text).content(); - embeddingStore.add(embedding, textSegment); + // 单位化处理 + float[] vector = embedding.vector(); + normalize(vector); + embeddingStore.add(Embedding.from(vector), textSegment); }); long endTime = System.currentTimeMillis(); @@ -140,18 +151,22 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { public List getQueryVector(QueryVectorBo queryVectorBo) { EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + // 查询向量单位化处理 + float[] queryVector = queryEmbedding.vector(); + normalize(queryVector); + String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid(); - List vector = new ArrayList<>(); - for (float f : queryEmbedding.vector()) { - vector.add(f); + List vectorList = new ArrayList<>(); + for (float f : queryVector) { + vectorList.add(f); } try (QdrantClient client = buildQdrantClient()) { QueryPoints request = QueryPoints.newBuilder() .setCollectionName(collectionName) .setQuery(Query.newBuilder() - .setNearest(vectorInput(vector)) + .setNearest(vectorInput(vectorList)) .build()) .setLimit(queryVectorBo.getMaxResults()) .setWithPayload(enable(true)) @@ -172,6 +187,69 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy { } } + @Override + public List search(QueryVectorBo queryVectorBo) { + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + // 查询向量单位化处理 + float[] queryVector = queryEmbedding.vector(); + normalize(queryVector); + + String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid(); + + List vectorList = new ArrayList<>(); + for (float f : queryVector) { + vectorList.add(f); + } + + try (QdrantClient client = buildQdrantClient()) { + QueryPoints request = QueryPoints.newBuilder() + .setCollectionName(collectionName) + .setQuery(Query.newBuilder() + .setNearest(vectorInput(vectorList)) + .build()) + .setLimit(queryVectorBo.getMaxResults()) + .setWithPayload(enable(true)) + .build(); + + List results = client.queryAsync(request).get(); + List resultList = new ArrayList<>(); + for (ScoredPoint point : results) { + String content = ""; + JsonWithInt.Value textValue = point.getPayloadMap().get(TEXT_SEGMENT_KEY); + if (textValue != null && textValue.hasStringValue()) { + content = textValue.getStringValue(); + } + + String docId = null; + JsonWithInt.Value docIdValue = point.getPayloadMap().get(METADATA_DOC_ID_KEY); + if (docIdValue != null && docIdValue.hasStringValue()) { + docId = docIdValue.getStringValue(); + } + + String sourceName = "未知来源"; + if (docId != null) { + KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper() + .eq(KnowledgeAttach::getDocId, docId) + .last("limit 1")); + if (attach != null) { + sourceName = attach.getName(); + } + } + + resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder() + .content(content) + .score((double) point.getScore()) + .sourceName(sourceName) + .build()); + } + return resultList; + } catch (Exception e) { + log.error("Qdrant检索失败: {}", collectionName, e); + throw new ServiceException("Qdrant向量检索失败"); + } + } + @Override public void removeById(String id, String modelName) { String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java index 603ed84c..73b1fa2d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/VectorStoreServiceImpl.java @@ -4,6 +4,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.VectorStoreStrategyFactory; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.context.annotation.Primary; @@ -54,6 +55,13 @@ public class VectorStoreServiceImpl implements VectorStoreService { return strategy.getQueryVector(queryVectorBo); } + @Override + public List search(QueryVectorBo queryVectorBo) { + log.info("执行测试搜索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery()); + VectorStoreService strategy = getCurrentStrategy(); + return strategy.search(queryVectorBo); + } + @Override public void removeById(String id, String modelName) { log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName); diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java index c62a8470..f90f0568 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/vector/impl/WeaviateVectorStoreStrategy.java @@ -12,6 +12,7 @@ import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.EmbeddingModelFactory; import org.springframework.stereotype.Component; import io.weaviate.client.Config; @@ -24,6 +25,9 @@ import io.weaviate.client.v1.graphql.model.GraphQLResponse; import io.weaviate.client.v1.schema.model.Property; import io.weaviate.client.v1.schema.model.Schema; import io.weaviate.client.v1.schema.model.WeaviateClass; +import org.ruoyi.domain.entity.knowledge.KnowledgeAttach; +import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import java.util.ArrayList; import java.util.Collections; @@ -40,11 +44,14 @@ import java.util.Map; public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { private WeaviateClient client; + private final KnowledgeAttachMapper knowledgeAttachMapper; public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, IChatModelService chatModelService, - EmbeddingModelFactory embeddingModelFactory) { + EmbeddingModelFactory embeddingModelFactory, + KnowledgeAttachMapper knowledgeAttachMapper) { super(vectorStoreProperties, embeddingModelFactory,chatModelService); + this.knowledgeAttachMapper = knowledgeAttachMapper; } @Override @@ -110,9 +117,12 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { "kid", kid, "docId", docId ); - Float[] vector = toObjectArray(embedding.vector()); + float[] vectorArray = embedding.vector(); + normalize(vectorArray); + Float[] vector = toObjectArray(vectorArray); + client.data().creator() - .withClassName("LocalKnowledge" + kid) + .withClassName(vectorStoreProperties.getWeaviate().getClassname() + kid) .withProperties(properties) .withVector(vector) .run(); @@ -128,6 +138,9 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); float[] vector = queryEmbedding.vector(); + // 查询向量单位化处理 + normalize(vector); + List vectorStrings = new ArrayList<>(); for (float v : vector) { vectorStrings.add(String.valueOf(v)); @@ -178,6 +191,77 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { } } + @Override + public List search(QueryVectorBo queryVectorBo) { + createSchema(queryVectorBo.getKid(), queryVectorBo.getEmbeddingModelName()); + EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); + Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); + float[] vector = queryEmbedding.vector(); + // 查询向量单位化处理 + normalize(vector); + List vectorStrings = new ArrayList<>(); + for (float v : vector) { + vectorStrings.add(String.valueOf(v)); + } + String vectorStr = String.join(",", vectorStrings); + String className = vectorStoreProperties.getWeaviate().getClassname(); + + String graphQLQuery = String.format( + "{\n" + + " Get {\n" + + " %s(nearVector: {vector: [%s]} limit: %d) {\n" + + " text\n" + + " docId\n" + + " _additional {\n" + + " distance\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + className + queryVectorBo.getKid(), + vectorStr, + queryVectorBo.getMaxResults() + ); + + Result result = client.graphQL().raw().withQuery(graphQLQuery).run(); + List resultList = new ArrayList<>(); + + if (result != null && !result.hasErrors()) { + Object data = result.getResult().getData(); + JSONObject entries = new JSONObject(data); + Map entriesMap = entries.get("Get", Map.class); + cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid()); + + for (Object obj : objects) { + Map map = (Map) obj; + String content = (String) map.get("text"); + String docId = (String) map.get("docId"); + + Map additional = (Map) map.get("_additional"); + Double distance = Double.valueOf(String.valueOf(additional.get("distance"))); + // 转换距离为得分 (Weaviate 0 是最相近,1 是最远;余弦距离下 1-dist 即为相似度) + double score = 1.0 - distance; + + String sourceName = "未知来源"; + if (docId != null) { + KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper() + .eq(KnowledgeAttach::getDocId, docId) + .last("limit 1")); + if (attach != null) { + sourceName = attach.getName(); + } + } + + resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder() + .content(content) + .score(score) + .sourceName(sourceName) + .build()); + } + } + return resultList; + } + @Override @SneakyThrows public void removeById(String id, String modelName) { From 1208c46cca20c1a795dab0c250a32caecb9c9cd8 Mon Sep 17 00:00:00 2001 From: RobustH <1511209518@qq.com> Date: Tue, 14 Apr 2026 01:40:28 +0800 Subject: [PATCH 4/8] =?UTF-8?q?feat(rag):=20=E9=9B=86=E6=88=90=E7=A1=85?= =?UTF-8?q?=E5=9F=BA=E6=B5=81=E5=8A=A8=E3=80=81=E9=98=BF=E9=87=8C=E7=99=BE?= =?UTF-8?q?=E7=82=BC=E9=87=8D=E6=8E=92=E6=A8=A1=E5=9E=8B=E5=B9=B6=E5=85=A8?= =?UTF-8?q?=E6=96=B9=E4=BD=8D=E5=A2=9E=E5=BC=BA=E6=A3=80=E7=B4=A2=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E4=BD=93=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../bo/knowledge/KnowledgeFragmentBo.java | 10 ++ .../domain/bo/knowledge/KnowledgeInfoBo.java | 10 ++ .../entity/knowledge/KnowledgeInfo.java | 10 ++ .../domain/vo/knowledge/KnowledgeInfoVo.java | 13 ++ .../vo/knowledge/KnowledgeRetrievalVo.java | 10 ++ .../service/chat/impl/ChatServiceFacade.java | 28 ++-- .../impl/KnowledgeFragmentServiceImpl.java | 48 +++++- .../rerank/DashScopeScoringModel.java | 98 +++++++++++ .../knowledge/rerank/ScoringModelFactory.java | 23 ++- .../rerank/SiliconFlowScoringModel.java | 155 ++++++++++++++++++ 10 files changed, 389 insertions(+), 16 deletions(-) create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java index e1925028..1508462f 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java @@ -69,4 +69,14 @@ public class KnowledgeFragmentBo extends BaseEntity { */ private Double threshold; + /** + * 是否启用重排 + */ + private Boolean enableRerank; + + /** + * 重排模型名称 + */ + private String rerankModel; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java index 113a2847..8629018a 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java @@ -77,6 +77,16 @@ public class KnowledgeInfoBo extends BaseEntity { */ private String embeddingModel; + /** + * 重排模型 + */ + private String rerankModel; + + /** + * 是否启用重排(0 否 1 是) + */ + private Integer enableRerank; + /** * 备注 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java index a51cf7da..a5211e69 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java @@ -78,6 +78,16 @@ public class KnowledgeInfo extends BaseEntity { */ private String embeddingModel; + /** + * 重排模型 + */ + private String rerankModel; + + /** + * 是否启用重排(0 否 1 是) + */ + private Integer enableRerank; + /** * 备注 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java index 53e136dd..e65444e7 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java @@ -94,6 +94,19 @@ public class KnowledgeInfoVo implements Serializable { @ExcelProperty(value = "向量模型") private String embeddingModel; + /** + * 重排模型 + */ + @ExcelProperty(value = "重排模型") + private String rerankModel; + + /** + * 是否启用重排(0 否 1 是) + */ + @ExcelProperty(value = "是否启用重排", converter = ExcelDictConvert.class) + @ExcelDictFormat(readConverterExp = "0=否,1=是") + private Integer enableRerank; + /** * 备注 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java index daeaae59..95c8e4cf 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java @@ -28,6 +28,16 @@ public class KnowledgeRetrievalVo implements Serializable { */ private Double score; + /** + * 原始检索排名 (重排前) + */ + private Integer originalIndex; + + /** + * 原始检索得分 (重排前) + */ + private Double rawScore; + /** * 来源文档名称 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java index 2332ca84..9ca93d10 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java @@ -377,20 +377,24 @@ public class ChatServiceFacade implements IChatService { CustomVectorRetriever retriever = new CustomVectorRetriever( vectorStoreService, knowledgeInfoVo, chatModel); - // 2. 获取和构建重排模型聚合器(Aggregator) - // 假设已在 KnowledgeInfoVo 等加入 getRerankModelConfig/getRerankModel 等,这里演示通用逻辑 - // 若无重排需求,使用 DefaultContentAggregator 或无 ScoringModel 的聚合器 + // 2. 构建重排聚合器 (Aggregator) ContentAggregator contentAggregator; - // TODO: 一旦实体类实现了重排模型的支持,此处可以从数据库读出: - // ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel()); - ChatModelVo scoringModelConfig = null; // 当前暂无对应配置字段 + if (knowledgeInfoVo.getEnableRerank() != null && knowledgeInfoVo.getEnableRerank() == 1 + && knowledgeInfoVo.getRerankModel() != null) { - ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig); - if (scoringModel != null) { - contentAggregator = ReRankingContentAggregator.builder() - .scoringModel(scoringModel) - // .maxResults(3) 这个数字将来从配置取 - .build(); + ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel()); + ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig); + + if (scoringModel != null) { + contentAggregator = ReRankingContentAggregator.builder() + .scoringModel(scoringModel) + // 默认重排后只留前 5 条,避免上下文过长 + .maxResults(5) + .build(); + log.info("启用重排模型: {}", knowledgeInfoVo.getRerankModel()); + } else { + contentAggregator = new DefaultContentAggregator(); + } } else { contentAggregator = new DefaultContentAggregator(); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java index 9c2477df..68ccf0bf 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java @@ -1,5 +1,8 @@ package org.ruoyi.service.knowledge.impl; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.scoring.ScoringModel; import org.ruoyi.common.core.utils.MapstructUtils; import org.ruoyi.common.core.utils.StringUtils; import org.ruoyi.common.mybatis.core.page.TableDataInfo; @@ -20,6 +23,7 @@ import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.common.chat.service.chat.IChatModelService; +import org.ruoyi.service.knowledge.rerank.ScoringModelFactory; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; import java.util.ArrayList; @@ -44,6 +48,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { private final IKnowledgeInfoService knowledgeInfoService; private final IChatModelService chatModelService; private final VectorStoreService vectorStoreService; + private final ScoringModelFactory scoringModelFactory; /** * 查询知识片段 @@ -178,7 +183,48 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { // 3. 执行物理检索 List allResults = vectorStoreService.search(queryVectorBo); - // 4. 根据阈值过滤 (LangChain4j 结果 score 通常 0-1) + // 初始化原始排名 + for (int i = 0; i < allResults.size(); i++) { + allResults.get(i).setOriginalIndex(i); + } + + // 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型) + if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) { + log.info("开始重排配置检索测试,传入模型名称: [{}]", bo.getRerankModel()); + ChatModelVo rerankModelConfig = chatModelService.selectModelByName(bo.getRerankModel()); + + if (rerankModelConfig == null) { + log.warn("未能找到重排模型配置: [{}]", bo.getRerankModel()); + } else { + ScoringModel scoringModel = scoringModelFactory.createScoringModel(rerankModelConfig); + if (scoringModel != null) { + log.info("执行重排精排,模型: {}, 供应商: {}", rerankModelConfig.getModelName(), rerankModelConfig.getProviderCode()); + + // 将 KnowledgeRetrievalVo 转换为 TextSegment 列表进行重排 + List segments = allResults.stream() + .map(res -> TextSegment.from(res.getContent())) + .collect(Collectors.toList()); + + Response> scoresResponse = scoringModel.scoreAll(segments, bo.getQuery()); + List scores = scoresResponse.content(); + + // 更新分数并重新排序 + for (int i = 0; i < allResults.size(); i++) { + KnowledgeRetrievalVo resultVo = allResults.get(i); + // 保存原始分数供前端展示对比 + resultVo.setRawScore(resultVo.getScore()); + if (i < scores.size()) { + resultVo.setScore(scores.get(i)); + } + } + + // 按重排后的分数从高到低排序 + allResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); + } + } + } + + // 5. 根据阈值过滤 double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0; return allResults.stream() .filter(res -> res.getScore() >= threshold) diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java new file mode 100644 index 00000000..1086df8e --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java @@ -0,0 +1,98 @@ +package org.ruoyi.service.knowledge.rerank; + +import com.alibaba.dashscope.exception.ApiException; +import com.alibaba.dashscope.exception.InputRequiredException; +import com.alibaba.dashscope.exception.NoApiKeyException; +import com.alibaba.dashscope.rerank.TextReRank; +import com.alibaba.dashscope.rerank.TextReRankParam; +import com.alibaba.dashscope.rerank.TextReRankResult; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.scoring.ScoringModel; +import lombok.Builder; +import lombok.extern.slf4j.Slf4j; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.Utils.isNullOrEmpty; + +/** + * DashScope 重排模型实现 (GTE-Rerank) + * 包装了阿里云 DashScope 的 TextReRank API,使其符合 LangChain4j 的 ScoringModel 标准。 + */ +@Slf4j +public class DashScopeScoringModel implements ScoringModel { + + private final String apiKey; + private final String modelName; + private final TextReRank rerank; + + @Builder + public DashScopeScoringModel(String apiKey, String modelName) { + if (isNullOrEmpty(apiKey)) { + throw new IllegalArgumentException("DashScope API Key 不能为空"); + } + this.apiKey = apiKey; + this.modelName = isNullOrEmpty(modelName) ? "gte-rerank" : modelName; + this.rerank = new TextReRank(); + } + + @Override + public Response> scoreAll(List segments, String query) { + if (isNullOrEmpty(segments)) { + return Response.from(new ArrayList<>()); + } + + // 提取文本列表供阿里 SDK 使用 + List texts = segments.stream() + .map(TextSegment::text) + .collect(Collectors.toList()); + + try { + TextReRankParam param = TextReRankParam.builder() + .apiKey(apiKey) + .model(modelName) + .query(query) + .documents(texts) + .topN(texts.size()) + .returnDocuments(false) + .build(); + + TextReRankResult result = rerank.call(param); + + // 初始化分数组,默认值为 0.0 + Double[] scores = new Double[texts.size()]; + for (int i = 0; i < texts.size(); i++) { + scores[i] = 0.0; + } + + // 根据返回结果填充对应的分数值(返回结果中包含原文索引) + result.getOutput().getResults().forEach(item -> { + if (item.getIndex() != null && item.getIndex() < texts.size()) { + scores[item.getIndex()] = item.getRelevanceScore(); + } + }); + + List scoreList = new ArrayList<>(); + for (Double s : scores) { + scoreList.add(s); + } + + return Response.from(scoreList); + + } catch (ApiException | NoApiKeyException | InputRequiredException e) { + log.error("DashScope 重排处理出错: {}", e.getMessage(), e); + throw new RuntimeException("调用 DashScope 重排服务失败", e); + } + } + + @Override + public Response score(TextSegment segment, String query) { + List segments = new ArrayList<>(); + segments.add(segment); + Response> response = scoreAll(segments, query); + return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason()); + } +} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java index 5f28b9c2..a3844d66 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java @@ -27,10 +27,27 @@ public class ScoringModelFactory { } String providerCode = rerankModelConfig.getProviderCode(); - log.info("初始化重排模型,供应商代码: {}", providerCode); + log.info("初始化重排模型,供应商代码: {}, 模型名称: {}", providerCode, rerankModelConfig.getModelName()); - // TODO: 在这里通过 switch 或反射具体实例化支持的各种 ScoringModel (例如 CohereScoringModel, DascScope 等) - // 目前返回 null 代表暂时没有加载特定的重排底座,这不会影响流程,Aggregator 会忽略它返回原样结果 + try { + if ("alibailian".equalsIgnoreCase(providerCode)) { + return DashScopeScoringModel.builder() + .apiKey(rerankModelConfig.getApiKey()) + .modelName(rerankModelConfig.getModelName()) + .build(); + } + + if ("siliconflow".equalsIgnoreCase(providerCode)) { + return SiliconFlowScoringModel.builder() + .apiKey(rerankModelConfig.getApiKey()) + .modelName(rerankModelConfig.getModelName()) + // 如果后台配置了不同的 API Host,可以在此传递,否则使用默认值 + .baseUrl(rerankModelConfig.getApiHost()) + .build(); + } + } catch (Exception e) { + log.error("创建重排模型失败: {}", e.getMessage(), e); + } return null; } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java new file mode 100644 index 00000000..ceae578e --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java @@ -0,0 +1,155 @@ +package org.ruoyi.service.knowledge.rerank; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.scoring.ScoringModel; +import lombok.Builder; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import okhttp3.*; +import org.ruoyi.common.json.utils.JsonUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static dev.langchain4j.internal.Utils.isNullOrEmpty; + +/** + * SiliconFlow 重排模型实现 + * 适配硅基流动的 /v1/rerank 接口 + */ +@Slf4j +public class SiliconFlowScoringModel implements ScoringModel { + + private final String apiKey; + private final String modelName; + private final String baseUrl; + private final OkHttpClient client; + + @Builder + public SiliconFlowScoringModel(String apiKey, String modelName, String baseUrl) { + if (isNullOrEmpty(apiKey)) { + throw new IllegalArgumentException("SiliconFlow API Key 不能为空"); + } + this.apiKey = apiKey; + this.modelName = isNullOrEmpty(modelName) ? "BAAI/bge-reranker-v2-m3" : modelName; + + // 鲁棒性处理:自动补全 /rerank 路径 + String finalUrl = baseUrl; + if (isNullOrEmpty(finalUrl)) { + finalUrl = "https://api.siliconflow.cn/v1/rerank"; + } else { + // 如果用户只填了基础路径 https://api.siliconflow.cn/v1,自动补全成 https://api.siliconflow.cn/v1/rerank + if (finalUrl.endsWith("/v1")) { + finalUrl = finalUrl + "/rerank"; + } else if (!finalUrl.endsWith("/rerank")) { + // 如果没有以 /rerank 结尾也不以斜杠结尾,尝试拼接 + finalUrl = finalUrl.endsWith("/") ? finalUrl + "rerank" : finalUrl + "/rerank"; + } + } + this.baseUrl = finalUrl; + log.info("初始化 SiliconFlow 重排模型: URL=[{}], Model=[{}]", this.baseUrl, this.modelName); + + this.client = new OkHttpClient.Builder() + .connectTimeout(60, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .build(); + } + + @Override + public Response> scoreAll(List segments, String query) { + if (isNullOrEmpty(segments)) { + return Response.from(new ArrayList<>()); + } + + List texts = segments.stream() + .map(TextSegment::text) + .collect(Collectors.toList()); + + RerankRequest requestBody = new RerankRequest(); + requestBody.setModel(modelName); + requestBody.setQuery(query); + requestBody.setDocuments(texts); + requestBody.setTop_n(texts.size()); + requestBody.setReturn_documents(false); + + String json = JsonUtils.toJsonString(requestBody); + RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8")); + + Request request = new Request.Builder() + .url(baseUrl) + .header("Authorization", "Bearer " + apiKey) + .post(body) + .build(); + + try (okhttp3.Response response = client.newCall(request).execute()) { + if (!response.isSuccessful()) { + String errorBody = response.body() != null ? response.body().string() : "unknown error"; + log.error("SiliconFlow Rerank API 调用失败: code={}, body={}", response.code(), errorBody); + throw new RuntimeException("SiliconFlow Rerank API 调用失败: " + response.code()); + } + + String responseBody = response.body().string(); + RerankResponse rerankResponse = JsonUtils.parseObject(responseBody, RerankResponse.class); + + if (rerankResponse == null || rerankResponse.getResults() == null) { + return Response.from(new ArrayList<>()); + } + + // 初始化分数组,默认值为 0.0 + Double[] scores = new Double[texts.size()]; + for (int i = 0; i < texts.size(); i++) { + scores[i] = 0.0; + } + + // 填充得分 + rerankResponse.getResults().forEach(item -> { + if (item.getIndex() != null && item.getIndex() < texts.size()) { + scores[item.getIndex()] = item.getRelevance_score(); + } + }); + + List scoreList = new ArrayList<>(); + for (Double s : scores) { + scoreList.add(s); + } + + return Response.from(scoreList); + + } catch (IOException e) { + log.error("SiliconFlow Rerank 网络请求异常", e); + throw new RuntimeException("SiliconFlow Rerank 网络请求异常", e); + } + } + + @Override + public Response score(TextSegment segment, String query) { + List segments = new ArrayList<>(); + segments.add(segment); + Response> response = scoreAll(segments, query); + return Response.from(response.content().get(0)); + } + + @Data + public static class RerankRequest { + private String model; + private String query; + private List documents; + private Integer top_n; + private Boolean return_documents; + } + + @Data + public static class RerankResponse { + private List results; + } + + @Data + public static class RerankResultItem { + private Integer index; + private Double relevance_score; + } +} From ccbf5c9520b77da26bd821e02059ab9f4a89ad9b Mon Sep 17 00:00:00 2001 From: RobustH <1511209518@qq.com> Date: Tue, 14 Apr 2026 23:18:29 +0800 Subject: [PATCH 5/8] =?UTF-8?q?feat(rag):=20=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=E6=B5=8B=E8=AF=95=E6=96=B0=E5=A2=9E=E6=B7=B7?= =?UTF-8?q?=E5=90=88=E6=A3=80=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../bo/knowledge/KnowledgeFragmentBo.java | 10 ++ .../domain/bo/knowledge/KnowledgeInfoBo.java | 10 ++ .../entity/knowledge/KnowledgeFragment.java | 5 + .../entity/knowledge/KnowledgeInfo.java | 10 ++ .../vo/knowledge/KnowledgeFragmentVo.java | 7 +- .../domain/vo/knowledge/KnowledgeInfoVo.java | 13 +++ .../vo/knowledge/KnowledgeRetrievalVo.java | 24 +++++ .../knowledge/KnowledgeFragmentMapper.java | 9 ++ .../impl/KnowledgeAttachServiceImpl.java | 1 + .../impl/KnowledgeFragmentServiceImpl.java | 92 ++++++++++++++++++- 10 files changed, 178 insertions(+), 3 deletions(-) diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java index 1508462f..895ab69f 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeFragmentBo.java @@ -79,4 +79,14 @@ public class KnowledgeFragmentBo extends BaseEntity { */ private String rerankModel; + /** + * 是否启用混合检索 + */ + private Boolean enableHybrid; + + /** + * 混合检索权重 (0.0-1.0) + */ + private Double hybridAlpha; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java index 8629018a..5f7a143e 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java @@ -92,5 +92,15 @@ public class KnowledgeInfoBo extends BaseEntity { */ private String remark; + /** + * 是否启用混合检索(0 否 1 是) + */ + private Integer enableHybrid; + + /** + * 混合检索权重比例 (0.0-1.0) + */ + private Double hybridAlpha; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeFragment.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeFragment.java index 04716e15..d184a10a 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeFragment.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeFragment.java @@ -47,5 +47,10 @@ public class KnowledgeFragment extends BaseEntity { */ private String remark; + /** + * 知识库ID + */ + private Long knowledgeId; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java index a5211e69..e2e6da4b 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java @@ -93,5 +93,15 @@ public class KnowledgeInfo extends BaseEntity { */ private String remark; + /** + * 是否启用混合检索(0 否 1 是) + */ + private Integer enableHybrid; + + /** + * 混合检索权重比例 (0.0-1.0) + */ + private Double hybridAlpha; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeFragmentVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeFragmentVo.java index b8be695e..45b4a9ab 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeFragmentVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeFragmentVo.java @@ -39,7 +39,7 @@ public class KnowledgeFragmentVo implements Serializable { * 片段索引下标 */ @ExcelProperty(value = "片段索引下标") - private Long idx; + private Integer idx; /** * 文档内容 @@ -53,5 +53,10 @@ public class KnowledgeFragmentVo implements Serializable { @ExcelProperty(value = "备注") private String remark; + /** + * 知识库ID + */ + private Long knowledgeId; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java index e65444e7..6f9a148d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java @@ -113,6 +113,19 @@ public class KnowledgeInfoVo implements Serializable { @ExcelProperty(value = "备注") private String remark; + /** + * 是否启用混合检索(0 否 1 是) + */ + @ExcelProperty(value = "是否启用混合检索", converter = ExcelDictConvert.class) + @ExcelDictFormat(readConverterExp = "0=否,1=是") + private Integer enableHybrid; + + /** + * 混合检索权重比例 (0.0-1.0) + */ + @ExcelProperty(value = "混合检索权重比例") + private Double hybridAlpha; + /** * 文档数(统计字段,非数据库列) */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java index 95c8e4cf..420015d8 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeRetrievalVo.java @@ -1,7 +1,9 @@ package org.ruoyi.domain.vo.knowledge; +import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; import java.io.Serial; import java.io.Serializable; @@ -13,11 +15,33 @@ import java.io.Serializable; */ @Data @Builder +@NoArgsConstructor +@AllArgsConstructor public class KnowledgeRetrievalVo implements Serializable { @Serial private static final long serialVersionUID = 1L; + /** + * 片段ID + */ + private String id; + + /** + * 文档ID + */ + private String docId; + + /** + * 知识库ID + */ + private Long knowledgeId; + + /** + * 分片索引 + */ + private Integer idx; + /** * 片段内容 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java index b99ad6af..304bb7d5 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/mapper/knowledge/KnowledgeFragmentMapper.java @@ -33,4 +33,13 @@ public interface KnowledgeFragmentMapper extends BaseMapperPlus") List selectFragmentCountByDocIds(@Param("docIds") List docIds); + @Select("") + List searchByKeyword(@Param("knowledgeId") Long knowledgeId, @Param("query") String query, @Param("limit") Integer limit); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java index e8260f25..28b3d90f 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeAttachServiceImpl.java @@ -187,6 +187,7 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { String fid = RandomUtil.randomString(10); fids.add(fid); KnowledgeFragment knowledgeFragment = new KnowledgeFragment(); + knowledgeFragment.setKnowledgeId(knowledgeId); knowledgeFragment.setDocId(docId); knowledgeFragment.setIdx(i); knowledgeFragment.setContent(chunkList.get(i)); diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java index 68ccf0bf..bbe56794 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java @@ -28,6 +28,8 @@ import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; import java.util.ArrayList; import java.util.stream.Collectors; +import java.util.*; +import java.util.concurrent.CompletableFuture; import java.util.List; import java.util.Map; @@ -180,8 +182,47 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { queryVectorBo.setApiKey(chatModel.getApiKey()); queryVectorBo.setBaseUrl(chatModel.getApiHost()); - // 3. 执行物理检索 - List allResults = vectorStoreService.search(queryVectorBo); + // 3. 执行搜索 (向量搜索 + 关键词搜索) + List allResults; + + boolean hybridEnabled = Boolean.TRUE.equals(bo.getEnableHybrid()) || + Integer.valueOf(1).equals(knowledgeInfoVo.getEnableHybrid()); + + if (hybridEnabled) { + log.info("执行混合检索: kid={}, query={}", bo.getKnowledgeId(), bo.getQuery()); + try { + // 并行执行向量搜索 + CompletableFuture> vectorFuture = CompletableFuture.supplyAsync(() -> + vectorStoreService.search(queryVectorBo)); + + // 执行关键词搜索 (MySQL) + int limit = bo.getTopK() != null ? bo.getTopK() : 50; + List keywordFragments = baseMapper.searchByKeyword(bo.getKnowledgeId(), bo.getQuery(), limit); + List keywordResults = keywordFragments.stream().map(f -> { + KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo(); + vo.setId(f.getId().toString()); + vo.setContent(f.getContent()); + vo.setDocId(f.getDocId()); + vo.setIdx(f.getIdx()); + vo.setKnowledgeId(f.getKnowledgeId()); + vo.setScore(10.0); // 初始分,后续由 RRF 重新打分 + return vo; + }).collect(Collectors.toList()); + + List vectorResults = vectorFuture.get(); + log.info("抽取混合结果成功: Vector命中={}条, Keyword命中={}条", vectorResults.size(), keywordResults.size()); + + double alpha = bo.getHybridAlpha() != null ? bo.getHybridAlpha() : + (knowledgeInfoVo.getHybridAlpha() != null ? knowledgeInfoVo.getHybridAlpha() : 0.5); + + allResults = calculateRRF(vectorResults, keywordResults, alpha); + } catch (Exception e) { + log.error("混合检索执行或合并失败,已自动降级回退到纯向量检索", e); + allResults = vectorStoreService.search(queryVectorBo); + } + } else { + allResults = vectorStoreService.search(queryVectorBo); + } // 初始化原始排名 for (int i = 0; i < allResults.size(); i++) { @@ -230,4 +271,51 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { .filter(res -> res.getScore() >= threshold) .collect(Collectors.toList()); } + + /** + * RRF (Reciprocal Rank Fusion) 融合算法 + * 公式: Score = (1-alpha) * (1 / (k + rank_vector)) + alpha * (1 / (k + rank_keyword)) + */ + private List calculateRRF(List vectorList, List keywordList, double alpha) { + Map allMap = new HashMap<>(); + Map vectorScores = new HashMap<>(); + Map keywordScores = new HashMap<>(); + + int k = 60; // 常用 RRF 常数 + + for (int i = 0; i < vectorList.size(); i++) { + KnowledgeRetrievalVo vo = vectorList.get(i); + allMap.put(vo.getId(), vo); + vectorScores.put(vo.getId(), 1.0 / (k + i + 1)); + } + + for (int i = 0; i < keywordList.size(); i++) { + KnowledgeRetrievalVo vo = keywordList.get(i); + if (!allMap.containsKey(vo.getId())) { + allMap.put(vo.getId(), vo); + } + keywordScores.put(vo.getId(), 1.0 / (k + i + 1)); + } + + // 重新计算得分 + List fusedResults = new ArrayList<>(); + for (Map.Entry entry : allMap.entrySet()) { + String id = entry.getKey(); + double vScore = vectorScores.getOrDefault(id, 0.0); + double kScore = keywordScores.getOrDefault(id, 0.0); + + // 混合分值 + double finalScore = (1 - alpha) * vScore + alpha * kScore; + + // 分值归一化/缩放:将 RRF 分值放大到 0-1 范围 + // 理论单路最大得分为 1/61 ≈ 0.016,乘以 60 使其处于相似度常用区间 + KnowledgeRetrievalVo vo = entry.getValue(); + vo.setScore(finalScore * 60.0); + fusedResults.add(vo); + } + + // 按融合分数从高到低排序 + fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); + return fusedResults; + } } From 1b50c7f9f1f83b906905eb5134ef3fe22308d665 Mon Sep 17 00:00:00 2001 From: RobustH <1511209518@qq.com> Date: Tue, 21 Apr 2026 22:41:00 +0800 Subject: [PATCH 6/8] =?UTF-8?q?fix(rag):=20=E4=BF=AE=E5=A4=8D=E5=90=88?= =?UTF-8?q?=E5=B9=B6=E9=87=8D=E5=A4=8D=EF=BC=8C=E9=87=8D=E6=8E=92=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E6=96=B0=E5=A2=9E=E7=A1=85=E5=9F=BA=E6=B5=81=E5=8A=A8?= =?UTF-8?q?=E4=BE=9B=E5=BA=94=E5=95=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../domain/bo/knowledge/KnowledgeInfoBo.java | 13 +- .../entity/knowledge/KnowledgeInfo.java | 10 + .../domain/vo/knowledge/KnowledgeInfoVo.java | 18 ++ .../impl/KnowledgeFragmentServiceImpl.java | 86 +++++---- .../rerank/DashScopeScoringModel.java | 98 ---------- .../knowledge/rerank/ScoringModelFactory.java | 54 ------ .../rerank/SiliconFlowScoringModel.java | 155 ---------------- .../impl/SiliconFlowRerankModelService.java | 174 ++++++++++++++++++ 8 files changed, 254 insertions(+), 354 deletions(-) delete mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java delete mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java delete mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java create mode 100644 ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/rerank/impl/SiliconFlowRerankModelService.java diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java index 4d5cf619..2ecc548b 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java @@ -98,12 +98,19 @@ public class KnowledgeInfoBo extends BaseEntity { private Double rerankScoreThreshold; + /** + * 是否启用混合检索(0 否 1是) + */ + private Integer enableHybrid; + + /** + * 混合检索权重 (0.0-1.0) + */ + private Double hybridAlpha; + /** * 备注 */ private String remark; - - - } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java index 948d04c0..5ed96c15 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java @@ -98,6 +98,16 @@ public class KnowledgeInfo extends BaseEntity { */ private Double rerankScoreThreshold; + /** + * 是否启用混合检索(0 否 1是) + */ + private Integer enableHybrid; + + /** + * 混合检索权重 (0.0-1.0) + */ + private Double hybridAlpha; + /** * 备注 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java index 41d48480..1ac81565 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java @@ -118,6 +118,24 @@ public class KnowledgeInfoVo implements Serializable { @ExcelProperty(value = "重排序分数阈值") private Double rerankScoreThreshold; + /** + * 是否启用混合检索(0 否 1是) + */ + @ExcelProperty(value = "是否启用混合检索") + private Integer enableHybrid; + + /** + * 混合检索权重 (0.0-1.0) + */ + @ExcelProperty(value = "混合检索权重") + private Double hybridAlpha; + + /** + * 文档数量 + */ + @ExcelProperty(value = "文档数量") + private Integer documentCount; + /** * 备注 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java index bbe56794..8ca9e647 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java @@ -1,39 +1,36 @@ package org.ruoyi.service.knowledge.impl; -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.scoring.ScoringModel; -import org.ruoyi.common.core.utils.MapstructUtils; -import org.ruoyi.common.core.utils.StringUtils; -import org.ruoyi.common.mybatis.core.page.TableDataInfo; -import org.ruoyi.common.mybatis.core.page.PageQuery; -import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.toolkit.Wrappers; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.common.chat.service.chat.IChatModelService; +import org.ruoyi.common.core.utils.MapstructUtils; +import org.ruoyi.common.core.utils.StringUtils; +import org.ruoyi.common.mybatis.core.page.PageQuery; +import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; +import org.ruoyi.domain.bo.rerank.RerankRequest; +import org.ruoyi.domain.bo.rerank.RerankResult; +import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; -import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; -import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; -import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; -import org.ruoyi.domain.bo.vector.QueryVectorBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; +import org.ruoyi.factory.RerankModelFactory; +import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.ruoyi.service.knowledge.IKnowledgeInfoService; -import org.ruoyi.common.chat.service.chat.IChatModelService; -import org.ruoyi.service.knowledge.rerank.ScoringModelFactory; +import org.ruoyi.service.rerank.RerankModelService; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; -import java.util.ArrayList; -import java.util.stream.Collectors; + import java.util.*; import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; -import java.util.List; -import java.util.Map; -import java.util.Collection; /** * 知识片段Service业务层处理 @@ -50,7 +47,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { private final IKnowledgeInfoService knowledgeInfoService; private final IChatModelService chatModelService; private final VectorStoreService vectorStoreService; - private final ScoringModelFactory scoringModelFactory; + private final RerankModelFactory rerankModelFactory; /** * 查询知识片段 @@ -231,37 +228,38 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { // 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型) if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) { - log.info("开始重排配置检索测试,传入模型名称: [{}]", bo.getRerankModel()); - ChatModelVo rerankModelConfig = chatModelService.selectModelByName(bo.getRerankModel()); - - if (rerankModelConfig == null) { - log.warn("未能找到重排模型配置: [{}]", bo.getRerankModel()); - } else { - ScoringModel scoringModel = scoringModelFactory.createScoringModel(rerankModelConfig); - if (scoringModel != null) { - log.info("执行重排精排,模型: {}, 供应商: {}", rerankModelConfig.getModelName(), rerankModelConfig.getProviderCode()); + log.info("开始重排精排,模型: [{}]", bo.getRerankModel()); + try { + RerankModelService rerankModel = rerankModelFactory.createModel(bo.getRerankModel()); - // 将 KnowledgeRetrievalVo 转换为 TextSegment 列表进行重排 - List segments = allResults.stream() - .map(res -> TextSegment.from(res.getContent())) + List contents = allResults.stream() + .map(KnowledgeRetrievalVo::getContent) .collect(Collectors.toList()); - Response> scoresResponse = scoringModel.scoreAll(segments, bo.getQuery()); - List scores = scoresResponse.content(); + RerankRequest rerankRequest = RerankRequest.builder() + .query(bo.getQuery()) + .documents(contents) + .topN(contents.size()) + .returnDocuments(false) + .build(); - // 更新分数并重新排序 - for (int i = 0; i < allResults.size(); i++) { - KnowledgeRetrievalVo resultVo = allResults.get(i); - // 保存原始分数供前端展示对比 + RerankResult rerankResult = rerankModel.rerank(rerankRequest); + + // 将重排分数写回,并记录原始分数供前端对比 + for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) { + if (doc.getIndex() != null && doc.getIndex() < allResults.size()) { + KnowledgeRetrievalVo resultVo = allResults.get(doc.getIndex()); resultVo.setRawScore(resultVo.getScore()); - if (i < scores.size()) { - resultVo.setScore(scores.get(i)); - } + resultVo.setScore(doc.getRelevanceScore()); } - - // 按重排后的分数从高到低排序 - allResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); } + + // 按重排后的分数从高到低排序 + allResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); + log.info("重排精排完成,结果数: {}", allResults.size()); + + } catch (Exception e) { + log.error("重排精排执行失败,已跳过重排步骤: {}", e.getMessage(), e); } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java deleted file mode 100644 index 1086df8e..00000000 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/DashScopeScoringModel.java +++ /dev/null @@ -1,98 +0,0 @@ -package org.ruoyi.service.knowledge.rerank; - -import com.alibaba.dashscope.exception.ApiException; -import com.alibaba.dashscope.exception.InputRequiredException; -import com.alibaba.dashscope.exception.NoApiKeyException; -import com.alibaba.dashscope.rerank.TextReRank; -import com.alibaba.dashscope.rerank.TextReRankParam; -import com.alibaba.dashscope.rerank.TextReRankResult; -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.scoring.ScoringModel; -import lombok.Builder; -import lombok.extern.slf4j.Slf4j; - -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; - -import static dev.langchain4j.internal.Utils.isNullOrEmpty; - -/** - * DashScope 重排模型实现 (GTE-Rerank) - * 包装了阿里云 DashScope 的 TextReRank API,使其符合 LangChain4j 的 ScoringModel 标准。 - */ -@Slf4j -public class DashScopeScoringModel implements ScoringModel { - - private final String apiKey; - private final String modelName; - private final TextReRank rerank; - - @Builder - public DashScopeScoringModel(String apiKey, String modelName) { - if (isNullOrEmpty(apiKey)) { - throw new IllegalArgumentException("DashScope API Key 不能为空"); - } - this.apiKey = apiKey; - this.modelName = isNullOrEmpty(modelName) ? "gte-rerank" : modelName; - this.rerank = new TextReRank(); - } - - @Override - public Response> scoreAll(List segments, String query) { - if (isNullOrEmpty(segments)) { - return Response.from(new ArrayList<>()); - } - - // 提取文本列表供阿里 SDK 使用 - List texts = segments.stream() - .map(TextSegment::text) - .collect(Collectors.toList()); - - try { - TextReRankParam param = TextReRankParam.builder() - .apiKey(apiKey) - .model(modelName) - .query(query) - .documents(texts) - .topN(texts.size()) - .returnDocuments(false) - .build(); - - TextReRankResult result = rerank.call(param); - - // 初始化分数组,默认值为 0.0 - Double[] scores = new Double[texts.size()]; - for (int i = 0; i < texts.size(); i++) { - scores[i] = 0.0; - } - - // 根据返回结果填充对应的分数值(返回结果中包含原文索引) - result.getOutput().getResults().forEach(item -> { - if (item.getIndex() != null && item.getIndex() < texts.size()) { - scores[item.getIndex()] = item.getRelevanceScore(); - } - }); - - List scoreList = new ArrayList<>(); - for (Double s : scores) { - scoreList.add(s); - } - - return Response.from(scoreList); - - } catch (ApiException | NoApiKeyException | InputRequiredException e) { - log.error("DashScope 重排处理出错: {}", e.getMessage(), e); - throw new RuntimeException("调用 DashScope 重排服务失败", e); - } - } - - @Override - public Response score(TextSegment segment, String query) { - List segments = new ArrayList<>(); - segments.add(segment); - Response> response = scoreAll(segments, query); - return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason()); - } -} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java deleted file mode 100644 index a3844d66..00000000 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/ScoringModelFactory.java +++ /dev/null @@ -1,54 +0,0 @@ -package org.ruoyi.service.knowledge.rerank; - -import dev.langchain4j.model.scoring.ScoringModel; -import lombok.extern.slf4j.Slf4j; -import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; -import org.springframework.stereotype.Component; - -/** - * 重排模型提供商工厂 - * 用于将来无缝拓展硅基流动、百炼等支持重排的模型厂商 - * - * @author RobustH - */ -@Slf4j -@Component -public class ScoringModelFactory { - - /** - * 根据后台传递的模型配置创建具体的重排模型 - * - * @param rerankModelConfig 重排模型的配置 (例如其 providerCode, apiUrl, apiKey 等) - * @return 标准的 LangChain4j ScoringModel - */ - public ScoringModel createScoringModel(ChatModelVo rerankModelConfig) { - if (rerankModelConfig == null) { - return null; - } - - String providerCode = rerankModelConfig.getProviderCode(); - log.info("初始化重排模型,供应商代码: {}, 模型名称: {}", providerCode, rerankModelConfig.getModelName()); - - try { - if ("alibailian".equalsIgnoreCase(providerCode)) { - return DashScopeScoringModel.builder() - .apiKey(rerankModelConfig.getApiKey()) - .modelName(rerankModelConfig.getModelName()) - .build(); - } - - if ("siliconflow".equalsIgnoreCase(providerCode)) { - return SiliconFlowScoringModel.builder() - .apiKey(rerankModelConfig.getApiKey()) - .modelName(rerankModelConfig.getModelName()) - // 如果后台配置了不同的 API Host,可以在此传递,否则使用默认值 - .baseUrl(rerankModelConfig.getApiHost()) - .build(); - } - } catch (Exception e) { - log.error("创建重排模型失败: {}", e.getMessage(), e); - } - - return null; - } -} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java deleted file mode 100644 index ceae578e..00000000 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/rerank/SiliconFlowScoringModel.java +++ /dev/null @@ -1,155 +0,0 @@ -package org.ruoyi.service.knowledge.rerank; - -import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.model.output.Response; -import dev.langchain4j.model.scoring.ScoringModel; -import lombok.Builder; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import okhttp3.*; -import org.ruoyi.common.json.utils.JsonUtils; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; - -import static dev.langchain4j.internal.Utils.isNullOrEmpty; - -/** - * SiliconFlow 重排模型实现 - * 适配硅基流动的 /v1/rerank 接口 - */ -@Slf4j -public class SiliconFlowScoringModel implements ScoringModel { - - private final String apiKey; - private final String modelName; - private final String baseUrl; - private final OkHttpClient client; - - @Builder - public SiliconFlowScoringModel(String apiKey, String modelName, String baseUrl) { - if (isNullOrEmpty(apiKey)) { - throw new IllegalArgumentException("SiliconFlow API Key 不能为空"); - } - this.apiKey = apiKey; - this.modelName = isNullOrEmpty(modelName) ? "BAAI/bge-reranker-v2-m3" : modelName; - - // 鲁棒性处理:自动补全 /rerank 路径 - String finalUrl = baseUrl; - if (isNullOrEmpty(finalUrl)) { - finalUrl = "https://api.siliconflow.cn/v1/rerank"; - } else { - // 如果用户只填了基础路径 https://api.siliconflow.cn/v1,自动补全成 https://api.siliconflow.cn/v1/rerank - if (finalUrl.endsWith("/v1")) { - finalUrl = finalUrl + "/rerank"; - } else if (!finalUrl.endsWith("/rerank")) { - // 如果没有以 /rerank 结尾也不以斜杠结尾,尝试拼接 - finalUrl = finalUrl.endsWith("/") ? finalUrl + "rerank" : finalUrl + "/rerank"; - } - } - this.baseUrl = finalUrl; - log.info("初始化 SiliconFlow 重排模型: URL=[{}], Model=[{}]", this.baseUrl, this.modelName); - - this.client = new OkHttpClient.Builder() - .connectTimeout(60, TimeUnit.SECONDS) - .readTimeout(60, TimeUnit.SECONDS) - .build(); - } - - @Override - public Response> scoreAll(List segments, String query) { - if (isNullOrEmpty(segments)) { - return Response.from(new ArrayList<>()); - } - - List texts = segments.stream() - .map(TextSegment::text) - .collect(Collectors.toList()); - - RerankRequest requestBody = new RerankRequest(); - requestBody.setModel(modelName); - requestBody.setQuery(query); - requestBody.setDocuments(texts); - requestBody.setTop_n(texts.size()); - requestBody.setReturn_documents(false); - - String json = JsonUtils.toJsonString(requestBody); - RequestBody body = RequestBody.create(json, MediaType.parse("application/json; charset=utf-8")); - - Request request = new Request.Builder() - .url(baseUrl) - .header("Authorization", "Bearer " + apiKey) - .post(body) - .build(); - - try (okhttp3.Response response = client.newCall(request).execute()) { - if (!response.isSuccessful()) { - String errorBody = response.body() != null ? response.body().string() : "unknown error"; - log.error("SiliconFlow Rerank API 调用失败: code={}, body={}", response.code(), errorBody); - throw new RuntimeException("SiliconFlow Rerank API 调用失败: " + response.code()); - } - - String responseBody = response.body().string(); - RerankResponse rerankResponse = JsonUtils.parseObject(responseBody, RerankResponse.class); - - if (rerankResponse == null || rerankResponse.getResults() == null) { - return Response.from(new ArrayList<>()); - } - - // 初始化分数组,默认值为 0.0 - Double[] scores = new Double[texts.size()]; - for (int i = 0; i < texts.size(); i++) { - scores[i] = 0.0; - } - - // 填充得分 - rerankResponse.getResults().forEach(item -> { - if (item.getIndex() != null && item.getIndex() < texts.size()) { - scores[item.getIndex()] = item.getRelevance_score(); - } - }); - - List scoreList = new ArrayList<>(); - for (Double s : scores) { - scoreList.add(s); - } - - return Response.from(scoreList); - - } catch (IOException e) { - log.error("SiliconFlow Rerank 网络请求异常", e); - throw new RuntimeException("SiliconFlow Rerank 网络请求异常", e); - } - } - - @Override - public Response score(TextSegment segment, String query) { - List segments = new ArrayList<>(); - segments.add(segment); - Response> response = scoreAll(segments, query); - return Response.from(response.content().get(0)); - } - - @Data - public static class RerankRequest { - private String model; - private String query; - private List documents; - private Integer top_n; - private Boolean return_documents; - } - - @Data - public static class RerankResponse { - private List results; - } - - @Data - public static class RerankResultItem { - private Integer index; - private Double relevance_score; - } -} diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/rerank/impl/SiliconFlowRerankModelService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/rerank/impl/SiliconFlowRerankModelService.java new file mode 100644 index 00000000..49c5aecf --- /dev/null +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/rerank/impl/SiliconFlowRerankModelService.java @@ -0,0 +1,174 @@ +package org.ruoyi.service.rerank.impl; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import okhttp3.*; +import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; +import org.ruoyi.domain.bo.rerank.RerankRequest; +import org.ruoyi.domain.bo.rerank.RerankResult; +import org.ruoyi.service.rerank.RerankModelService; +import org.springframework.stereotype.Component; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +/** + * 硅基流动重排序模型实现 + * 适配硅基流动的 /v1/rerank 接口 + * + * @author RobustH + * @date 2026-04-21 + */ +@Slf4j +@Component("siliconflowRerank") +public class SiliconFlowRerankModelService implements RerankModelService { + + private static final String DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1/rerank"; + + private final OkHttpClient okHttpClient; + private final ObjectMapper objectMapper = new ObjectMapper() + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + private ChatModelVo chatModelVo; + + public SiliconFlowRerankModelService() { + this.okHttpClient = new OkHttpClient.Builder() + .connectTimeout(30, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .writeTimeout(30, TimeUnit.SECONDS) + .build(); + } + + @Override + public void configure(ChatModelVo config) { + this.chatModelVo = config; + } + + @Override + public RerankResult rerank(RerankRequest rerankRequest) { + long startTime = System.currentTimeMillis(); + + try { + String url = buildUrl(); + String requestJson = buildRequestJson(rerankRequest); + + RequestBody body = RequestBody.create(requestJson, MediaType.get("application/json")); + Request httpRequest = new Request.Builder() + .url(url) + .addHeader("Authorization", "Bearer " + chatModelVo.getApiKey()) + .addHeader("Content-Type", "application/json") + .post(body) + .build(); + + log.info("硅基流动重排序请求: model={}, url={}", chatModelVo.getModelName(), url); + + try (Response response = okHttpClient.newCall(httpRequest).execute()) { + if (!response.isSuccessful()) { + String err = response.body() != null ? response.body().string() : "无错误信息"; + throw new IllegalArgumentException("硅基流动 Rerank API 调用失败: " + response.code() + " - " + err); + } + + ResponseBody responseBody = response.body(); + if (responseBody == null) { + throw new IllegalArgumentException("响应体为空"); + } + + SiliconFlowRerankResponse rerankResponse = objectMapper.readValue( + responseBody.string(), SiliconFlowRerankResponse.class); + + return buildRerankResult(rerankResponse, rerankRequest.getDocuments(), + System.currentTimeMillis() - startTime); + } + + } catch (Exception e) { + log.error("硅基流动重排序失败: {}", e.getMessage(), e); + throw new RuntimeException("重排序服务调用失败: " + e.getMessage(), e); + } + } + + /** + * 构建请求 URL,鲁棒处理 API Host 末尾路径 + */ + private String buildUrl() { + String apiHost = chatModelVo.getApiHost(); + if (apiHost == null || apiHost.isBlank()) { + return DEFAULT_BASE_URL; + } + if (apiHost.endsWith("/rerank")) { + return apiHost; + } + if (apiHost.endsWith("/v1")) { + return apiHost + "/rerank"; + } + return apiHost.endsWith("/") ? apiHost + "rerank" : apiHost + "/rerank"; + } + + /** + * 构建请求体 JSON + */ + private String buildRequestJson(RerankRequest rerankRequest) throws IOException { + SiliconFlowRerankRequest request = new SiliconFlowRerankRequest(); + request.setModel(chatModelVo.getModelName()); + request.setQuery(rerankRequest.getQuery()); + request.setDocuments(rerankRequest.getDocuments()); + request.setTop_n(rerankRequest.getTopN() != null ? rerankRequest.getTopN() : rerankRequest.getDocuments().size()); + request.setReturn_documents(rerankRequest.getReturnDocuments() != null ? rerankRequest.getReturnDocuments() : false); + return objectMapper.writeValueAsString(request); + } + + /** + * 构建标准 RerankResult + */ + private RerankResult buildRerankResult(SiliconFlowRerankResponse response, + List originalDocuments, long durationMs) { + Double[] scores = new Double[originalDocuments.size()]; + for (int i = 0; i < scores.length; i++) { + scores[i] = 0.0; + } + + List docs = new ArrayList<>(); + if (response != null && response.getResults() != null) { + response.getResults().forEach(item -> { + if (item.getIndex() != null && item.getIndex() < originalDocuments.size()) { + scores[item.getIndex()] = item.getRelevance_score(); + docs.add(RerankResult.RerankDocument.builder() + .index(item.getIndex()) + .relevanceScore(item.getRelevance_score()) + .document(originalDocuments.get(item.getIndex())) + .build()); + } + }); + } + + return RerankResult.builder() + .documents(docs) + .totalDocuments(originalDocuments.size()) + .durationMs(durationMs) + .build(); + } + + // ==================== 内部 DTO ==================== + + @Data + static class SiliconFlowRerankRequest { + private String model; + private String query; + private List documents; + private Integer top_n; + private Boolean return_documents; + } + + @Data + static class SiliconFlowRerankResponse { + private List results; + } + + @Data + static class SiliconFlowRerankResultItem { + private Integer index; + private Double relevance_score; + } +} From 058a4aee2a1f0c48785fadcca38df1fa6b1e59e5 Mon Sep 17 00:00:00 2001 From: RobustH <1511209518@qq.com> Date: Tue, 21 Apr 2026 22:54:11 +0800 Subject: [PATCH 7/8] =?UTF-8?q?feat(rag):=20=E6=96=B0=E5=A2=9E=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E9=85=8D=E7=BD=AE=E5=BA=94=E7=94=A8=E7=9A=84=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java | 5 +++++ .../org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java | 5 +++++ .../java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java | 6 ++++++ 3 files changed, 16 insertions(+) diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java index 2ecc548b..b0045a1d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/knowledge/KnowledgeInfoBo.java @@ -62,6 +62,11 @@ public class KnowledgeInfoBo extends BaseEntity { */ private Long retrieveLimit; + /** + * 相似度阈值 + */ + private Double similarityThreshold; + /** * 文本块大小 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java index 5ed96c15..95bd9286 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/entity/knowledge/KnowledgeInfo.java @@ -63,6 +63,11 @@ public class KnowledgeInfo extends BaseEntity { */ private Long retrieveLimit; + /** + * 相似度阈值 + */ + private Double similarityThreshold; + /** * 文本块大小 */ diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java index 1ac81565..d742fad0 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/vo/knowledge/KnowledgeInfoVo.java @@ -76,6 +76,12 @@ public class KnowledgeInfoVo implements Serializable { @ExcelProperty(value = "知识库中检索的条数") private Integer retrieveLimit; + /** + * 相似度阈值 + */ + @ExcelProperty(value = "相似度阈值") + private Double similarityThreshold; + /** * 文本块大小 */ From b8d16b7669b3c575e651ff4d4ba6dc4e79802ca0 Mon Sep 17 00:00:00 2001 From: RobustH <1511209518@qq.com> Date: Thu, 23 Apr 2026 00:52:53 +0800 Subject: [PATCH 8/8] =?UTF-8?q?feat(rag):=20=E5=AF=B9=E6=8E=A5=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E7=AB=AF=E7=94=A8=E6=88=B7=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=EF=BC=8C=E9=9B=86=E6=88=90=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E5=BA=93=E9=85=8D=E7=BD=AE=E5=BA=94=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/script/sql/update/updat-0423.sql | 14 ++ .../ruoyi/domain/bo/vector/QueryVectorBo.java | 18 ++ .../service/chat/impl/ChatServiceFacade.java | 78 +++--- .../impl/KnowledgeFragmentServiceImpl.java | 168 ++----------- .../retriever/CustomVectorRetriever.java | 29 ++- .../retrieval/KnowledgeRetrievalService.java | 12 +- .../impl/KnowledgeRetrievalServiceImpl.java | 233 +++++++++++++----- 7 files changed, 301 insertions(+), 251 deletions(-) create mode 100644 docs/script/sql/update/updat-0423.sql diff --git a/docs/script/sql/update/updat-0423.sql b/docs/script/sql/update/updat-0423.sql new file mode 100644 index 00000000..4ed14430 --- /dev/null +++ b/docs/script/sql/update/updat-0423.sql @@ -0,0 +1,14 @@ +-- 为知识库信息表新增检索配置字段 (剔除了已存在的重排字段) +ALTER TABLE knowledge_info + ADD COLUMN similarity_threshold DOUBLE DEFAULT 0.5 COMMENT '相似度阈值' +AFTER retrieve_limit; + +ALTER TABLE knowledge_info ADD COLUMN enable_hybrid tinyint(1) DEFAULT 0 COMMENT '是否启用混合检索'; +ALTER TABLE knowledge_info ADD COLUMN hybrid_alpha double DEFAULT 0.5 COMMENT '混合检索权重比例 (0.0=纯向量, 1.0=纯关键词)'; + +-- 为知识片段表增加全文索引及关联ID +ALTER TABLE knowledge_fragment ADD COLUMN knowledge_id bigint COMMENT '知识库ID'; +ALTER TABLE knowledge_fragment ADD FULLTEXT INDEX ft_content (content) WITH PARSER ngram; + +-- 为知识库附件表增加解析状态字段 +ALTER TABLE `knowledge_attach` ADD COLUMN `status` TINYINT DEFAULT 0 COMMENT '解析状态: 0待解析, 1解析中, 2已解析, 3解析失败'; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java index bb5634a3..6f0b9352 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/domain/bo/vector/QueryVectorBo.java @@ -77,4 +77,22 @@ public class QueryVectorBo { */ private Double rerankScoreThreshold; + // ========== 混合检索与阈值相关参数 ========== + + /** + * 相似度阈值 (0.0-1.0) + * 应用于向量搜索阶段 + */ + private Double similarityThreshold; + + /** + * 是否启用混合检索 + */ + private Boolean enableHybrid = false; + + /** + * 混合检索权重 (0.0-1.0) + */ + private Double hybridAlpha; + } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java index 3bd0876e..16e0750d 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/chat/impl/ChatServiceFacade.java @@ -20,6 +20,11 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.service.tool.ToolProvider; import dev.langchain4j.skills.shell.ShellSkills; +import dev.langchain4j.rag.AugmentationRequest; +import dev.langchain4j.rag.AugmentationResult; +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.query.Metadata; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; @@ -55,6 +60,7 @@ import org.ruoyi.service.chat.IChatMessageService; import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore; import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.service.retrieval.KnowledgeRetrievalService; +import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -412,16 +418,49 @@ public class ChatServiceFacade implements IChatService { /** * 构建上下文消息列表 - * 消息顺序:历史消息 → 当前用户消息(确保 AI 正确理解对话上下文) * * @param chatRequest 聊天请求 * @return 上下文消息列表 */ private List buildContextMessages(ChatRequest chatRequest) { - List messages = new ArrayList<>(); + List messages = new ArrayList<>(); - // 从数据库查询历史对话消息(放在前面) + // 1. 初始化当前用户消息 + UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent()); + + // 2. 知识库检索增强 (RAG) + if (chatRequest.getKnowledgeId() != null) { + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId())); + if (knowledgeInfoVo != null) { + ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); + if (chatModel != null) { + log.info("执行高级 RAG 流程: kid={}", chatRequest.getKnowledgeId()); + + // 构建自定义检索器 + CustomVectorRetriever retriever = new CustomVectorRetriever( + knowledgeRetrievalService, knowledgeInfoVo, chatModel); + + // 构建增强流水线 + RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder() + .contentRetriever(retriever) + .build(); + + // 执行增强:编织上下文到 UserMessage + Metadata metadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>()); + AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); + AugmentationResult result = augmentor.augment(augmentationRequest); + + ChatMessage augmented = result.chatMessage(); + if (augmented instanceof UserMessage) { + userMessage = (UserMessage) augmented; + log.debug("RAG 增强完成,UserMessage 已注入背景知识"); + } + } + } + } + + // 3. 从数据库查询历史对话消息(放在前面) if (chatRequest.getSessionId() != null) { MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId()); if (memory != null) { @@ -433,38 +472,7 @@ public class ChatServiceFacade implements IChatService { } } - // 从向量库查询相关历史消息(知识库内容作为上下文) - if (chatRequest.getKnowledgeId() != null) { - // 查询知识库信息 - KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId())); - if (knowledgeInfoVo == null) { - log.warn("知识库信息不存在,kid: {}", chatRequest.getKnowledgeId()); - // 继续添加当前用户消息 - messages.add(UserMessage.userMessage(chatRequest.getContent())); - return messages; - } - - // 查询向量模型配置信息 - ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); - if (chatModel == null) { - log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel()); - messages.add(UserMessage.userMessage(chatRequest.getContent())); - return messages; - } - - // 构建向量查询参数 - QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel); - - // 使用知识库检索服务(支持重排序) - List nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo); - for (String prompt : nearestList) { - // 知识库内容作为系统上下文添加 - messages.add(new AiMessage(prompt)); - } - } - - // 构建当前用户消息(放在最后) - UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent()); + // 4. 添加经过增强的用户消息(放在最后) messages.add(userMessage); return messages; diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java index 8ca9e647..da17f9b9 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/impl/KnowledgeFragmentServiceImpl.java @@ -12,25 +12,18 @@ import org.ruoyi.common.core.utils.StringUtils; import org.ruoyi.common.mybatis.core.page.PageQuery; import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; -import org.ruoyi.domain.bo.rerank.RerankRequest; -import org.ruoyi.domain.bo.rerank.RerankResult; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; -import org.ruoyi.factory.RerankModelFactory; import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.ruoyi.service.knowledge.IKnowledgeInfoService; -import org.ruoyi.service.rerank.RerankModelService; -import org.ruoyi.service.vector.VectorStoreService; +import org.ruoyi.service.retrieval.KnowledgeRetrievalService; import org.springframework.stereotype.Service; import java.util.*; -import java.util.concurrent.CompletableFuture; -import java.util.stream.Collectors; - /** * 知识片段Service业务层处理 @@ -46,8 +39,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { private final KnowledgeFragmentMapper baseMapper; private final IKnowledgeInfoService knowledgeInfoService; private final IChatModelService chatModelService; - private final VectorStoreService vectorStoreService; - private final RerankModelFactory rerankModelFactory; + private final KnowledgeRetrievalService knowledgeRetrievalService; /** * 查询知识片段 @@ -87,7 +79,6 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { } private LambdaQueryWrapper buildQueryWrapper(KnowledgeFragmentBo bo) { - Map params = bo.getParams(); LambdaQueryWrapper lqw = Wrappers.lambdaQuery(); lqw.orderByAsc(KnowledgeFragment::getId); lqw.eq(bo.getDocId() != null, KnowledgeFragment::getDocId, bo.getDocId()); @@ -149,7 +140,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { } /** - * 检索测试核心实现 + * 检索测试核心实现 - 委托给统一的 KnowledgeRetrievalService */ @Override public List retrieval(KnowledgeFragmentBo bo) { @@ -157,7 +148,7 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { return new ArrayList<>(); } - // 1. 获取知识库及模型配置 + // 1. 获取知识库及模型配置(为了获取 API Key/Host 等模型参数) KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId()); if (knowledgeInfoVo == null) { return new ArrayList<>(); @@ -169,151 +160,28 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { return new ArrayList<>(); } - // 2. 构造向量检索参数 + // 2. 构造通用的参数对象 QueryVectorBo queryVectorBo = new QueryVectorBo(); queryVectorBo.setQuery(bo.getQuery()); queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId())); - queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit()); - queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); - queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); queryVectorBo.setApiKey(chatModel.getApiKey()); queryVectorBo.setBaseUrl(chatModel.getApiHost()); + queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); + queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); - // 3. 执行搜索 (向量搜索 + 关键词搜索) - List allResults; + // 使用前端传入的实时测试参数,若无则使用知识库默认参数 + queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit()); + queryVectorBo.setSimilarityThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getSimilarityThreshold()); - boolean hybridEnabled = Boolean.TRUE.equals(bo.getEnableHybrid()) || - Integer.valueOf(1).equals(knowledgeInfoVo.getEnableHybrid()); - - if (hybridEnabled) { - log.info("执行混合检索: kid={}, query={}", bo.getKnowledgeId(), bo.getQuery()); - try { - // 并行执行向量搜索 - CompletableFuture> vectorFuture = CompletableFuture.supplyAsync(() -> - vectorStoreService.search(queryVectorBo)); - - // 执行关键词搜索 (MySQL) - int limit = bo.getTopK() != null ? bo.getTopK() : 50; - List keywordFragments = baseMapper.searchByKeyword(bo.getKnowledgeId(), bo.getQuery(), limit); - List keywordResults = keywordFragments.stream().map(f -> { - KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo(); - vo.setId(f.getId().toString()); - vo.setContent(f.getContent()); - vo.setDocId(f.getDocId()); - vo.setIdx(f.getIdx()); - vo.setKnowledgeId(f.getKnowledgeId()); - vo.setScore(10.0); // 初始分,后续由 RRF 重新打分 - return vo; - }).collect(Collectors.toList()); - - List vectorResults = vectorFuture.get(); - log.info("抽取混合结果成功: Vector命中={}条, Keyword命中={}条", vectorResults.size(), keywordResults.size()); + queryVectorBo.setEnableHybrid(bo.getEnableHybrid() != null ? bo.getEnableHybrid() : Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1)); + queryVectorBo.setHybridAlpha(bo.getHybridAlpha() != null ? bo.getHybridAlpha() : knowledgeInfoVo.getHybridAlpha()); - double alpha = bo.getHybridAlpha() != null ? bo.getHybridAlpha() : - (knowledgeInfoVo.getHybridAlpha() != null ? knowledgeInfoVo.getHybridAlpha() : 0.5); - - allResults = calculateRRF(vectorResults, keywordResults, alpha); - } catch (Exception e) { - log.error("混合检索执行或合并失败,已自动降级回退到纯向量检索", e); - allResults = vectorStoreService.search(queryVectorBo); - } - } else { - allResults = vectorStoreService.search(queryVectorBo); - } + queryVectorBo.setEnableRerank(bo.getEnableRerank() != null ? bo.getEnableRerank() : Objects.equals(knowledgeInfoVo.getEnableRerank(), 1)); + queryVectorBo.setRerankModelName(StringUtils.isNotBlank(bo.getRerankModel()) ? bo.getRerankModel() : knowledgeInfoVo.getRerankModel()); + queryVectorBo.setRerankTopN(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRerankTopN()); + queryVectorBo.setRerankScoreThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getRerankScoreThreshold()); - // 初始化原始排名 - for (int i = 0; i < allResults.size(); i++) { - allResults.get(i).setOriginalIndex(i); - } - - // 4. 执行重排逻辑 (如果请求启用重排且配置了重排模型) - if (Boolean.TRUE.equals(bo.getEnableRerank()) && StringUtils.isNotBlank(bo.getRerankModel())) { - log.info("开始重排精排,模型: [{}]", bo.getRerankModel()); - try { - RerankModelService rerankModel = rerankModelFactory.createModel(bo.getRerankModel()); - - List contents = allResults.stream() - .map(KnowledgeRetrievalVo::getContent) - .collect(Collectors.toList()); - - RerankRequest rerankRequest = RerankRequest.builder() - .query(bo.getQuery()) - .documents(contents) - .topN(contents.size()) - .returnDocuments(false) - .build(); - - RerankResult rerankResult = rerankModel.rerank(rerankRequest); - - // 将重排分数写回,并记录原始分数供前端对比 - for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) { - if (doc.getIndex() != null && doc.getIndex() < allResults.size()) { - KnowledgeRetrievalVo resultVo = allResults.get(doc.getIndex()); - resultVo.setRawScore(resultVo.getScore()); - resultVo.setScore(doc.getRelevanceScore()); - } - } - - // 按重排后的分数从高到低排序 - allResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); - log.info("重排精排完成,结果数: {}", allResults.size()); - - } catch (Exception e) { - log.error("重排精排执行失败,已跳过重排步骤: {}", e.getMessage(), e); - } - } - - // 5. 根据阈值过滤 - double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0; - return allResults.stream() - .filter(res -> res.getScore() >= threshold) - .collect(Collectors.toList()); - } - - /** - * RRF (Reciprocal Rank Fusion) 融合算法 - * 公式: Score = (1-alpha) * (1 / (k + rank_vector)) + alpha * (1 / (k + rank_keyword)) - */ - private List calculateRRF(List vectorList, List keywordList, double alpha) { - Map allMap = new HashMap<>(); - Map vectorScores = new HashMap<>(); - Map keywordScores = new HashMap<>(); - - int k = 60; // 常用 RRF 常数 - - for (int i = 0; i < vectorList.size(); i++) { - KnowledgeRetrievalVo vo = vectorList.get(i); - allMap.put(vo.getId(), vo); - vectorScores.put(vo.getId(), 1.0 / (k + i + 1)); - } - - for (int i = 0; i < keywordList.size(); i++) { - KnowledgeRetrievalVo vo = keywordList.get(i); - if (!allMap.containsKey(vo.getId())) { - allMap.put(vo.getId(), vo); - } - keywordScores.put(vo.getId(), 1.0 / (k + i + 1)); - } - - // 重新计算得分 - List fusedResults = new ArrayList<>(); - for (Map.Entry entry : allMap.entrySet()) { - String id = entry.getKey(); - double vScore = vectorScores.getOrDefault(id, 0.0); - double kScore = keywordScores.getOrDefault(id, 0.0); - - // 混合分值 - double finalScore = (1 - alpha) * vScore + alpha * kScore; - - // 分值归一化/缩放:将 RRF 分值放大到 0-1 范围 - // 理论单路最大得分为 1/61 ≈ 0.016,乘以 60 使其处于相似度常用区间 - KnowledgeRetrievalVo vo = entry.getValue(); - vo.setScore(finalScore * 60.0); - fusedResults.add(vo); - } - - // 按融合分数从高到低排序 - fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); - return fusedResults; + // 3. 执行统一检索 + return knowledgeRetrievalService.retrieve(queryVectorBo); } } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java index 6d876710..f79206bc 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/knowledge/retriever/CustomVectorRetriever.java @@ -9,14 +9,15 @@ import lombok.extern.slf4j.Slf4j; import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo; import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo; -import org.ruoyi.service.vector.VectorStoreService; +import org.ruoyi.service.retrieval.KnowledgeRetrievalService; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; /** - * 自定义向量检索器:适配 LangChain4j ContentRetriever 接口 - * 桥接现有的 VectorStoreService 获取检索结果 + * 自定义检索器:适配 LangChain4j ContentRetriever 接口 + * 桥接统一的 KnowledgeRetrievalService,支持配置化的混合检索、阈值过滤等功能 * * @author RobustH */ @@ -24,15 +25,15 @@ import java.util.stream.Collectors; @RequiredArgsConstructor public class CustomVectorRetriever implements ContentRetriever { - private final VectorStoreService vectorStoreService; + private final KnowledgeRetrievalService knowledgeRetrievalService; private final KnowledgeInfoVo knowledgeInfoVo; private final ChatModelVo chatModelVo; @Override public List retrieve(Query query) { - log.info("执行自定义向量检索,关键字: {}", query.text()); + log.info("执行自定义检索,关键字: {}", query.text()); - // 构建内部查询参数 + // 构建增强后的查询参数 QueryVectorBo queryVectorBo = new QueryVectorBo(); queryVectorBo.setQuery(query.text()); queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId())); @@ -40,11 +41,21 @@ public class CustomVectorRetriever implements ContentRetriever { queryVectorBo.setBaseUrl(chatModelVo.getApiHost()); queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel()); queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel()); - // 如果接入了重排,这里的 retrieveLimit 也就是 MaxResults 应当被放大,后续留给 Aggregator 截断 + + // 应用知识库配置参数 queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit()); + queryVectorBo.setSimilarityThreshold(knowledgeInfoVo.getSimilarityThreshold()); + queryVectorBo.setEnableHybrid(Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1)); + queryVectorBo.setHybridAlpha(knowledgeInfoVo.getHybridAlpha()); - // 执行底层的多种向量库策略检索 - List nearestList = vectorStoreService.getQueryVector(queryVectorBo); + // 设置重排序参数 (如果 retriever 阶段也想做初步重排,可以在此设置) + queryVectorBo.setEnableRerank(Objects.equals(knowledgeInfoVo.getEnableRerank(), 1)); + queryVectorBo.setRerankModelName(knowledgeInfoVo.getRerankModel()); + queryVectorBo.setRerankTopN(knowledgeInfoVo.getRerankTopN()); + queryVectorBo.setRerankScoreThreshold(knowledgeInfoVo.getRerankScoreThreshold()); + + // 通过统一服务执行检索 + List nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo); // 将结果包装为标准的 Content 返回 return nearestList.stream() diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java index 9c42dd1a..3e0a6cab 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/KnowledgeRetrievalService.java @@ -1,12 +1,13 @@ package org.ruoyi.service.retrieval; import org.ruoyi.domain.bo.vector.QueryVectorBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import java.util.List; /** * 知识库检索服务接口 - * 整合粗召回(向量检索)和重排序流程 + * 整合粗召回(向量检索/关键词检索)和重排序流程 * * @author yang * @date 2026-04-19 @@ -21,4 +22,13 @@ public interface KnowledgeRetrievalService { * @return 文本内容列表 */ List retrieveTexts(QueryVectorBo queryVectorBo); + + /** + * 执行知识库检索,返回详细结果对象(包含分数、文档ID等) + * 支持混合检索和重排序 + * + * @param queryVectorBo 查询参数 + * @return 检索结果列表 + */ + List retrieve(QueryVectorBo queryVectorBo); } diff --git a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java index 42f6cf68..b6841ef1 100644 --- a/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java +++ b/ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/service/retrieval/impl/KnowledgeRetrievalServiceImpl.java @@ -2,21 +2,26 @@ package org.ruoyi.service.retrieval.impl; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.utils.StringUtils; import org.ruoyi.domain.bo.rerank.RerankRequest; import org.ruoyi.domain.bo.rerank.RerankResult; import org.ruoyi.domain.bo.vector.QueryVectorBo; +import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; +import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo; import org.ruoyi.factory.RerankModelFactory; +import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; import org.ruoyi.service.rerank.RerankModelService; import org.ruoyi.service.retrieval.KnowledgeRetrievalService; import org.ruoyi.service.vector.VectorStoreService; import org.springframework.stereotype.Service; -import java.util.ArrayList; -import java.util.List; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; /** * 知识库检索服务实现 - * 整合粗召回(向量检索)和重排序流程 + * 整合粗召回(向量检索/关键词检索)、RRF融合和重排序流程 * * @author yang * @date 2026-04-19 @@ -28,6 +33,7 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService private final VectorStoreService vectorStoreService; private final RerankModelFactory rerankModelFactory; + private final KnowledgeFragmentMapper fragmentMapper; /** * 粗召回默认扩大倍数 @@ -37,99 +43,214 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService @Override public List retrieveTexts(QueryVectorBo queryVectorBo) { + List results = retrieve(queryVectorBo); + return results.stream() + .map(KnowledgeRetrievalVo::getContent) + .collect(Collectors.toList()); + } + + @Override + public List retrieve(QueryVectorBo queryVectorBo) { log.info("开始知识库检索, kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery()); - // 1. 粗召回阶段 - 向量检索 - List coarseResults = coarseRetrieval(queryVectorBo); + // 1. 粗召回阶段 (向量检索 + 关键词搜索) + List coarseResults = performCoarseRetrieval(queryVectorBo); log.debug("粗召回返回 {} 条结果", coarseResults.size()); if (coarseResults.isEmpty()) { return coarseResults; } - // 2. 重排序阶段(可选) - if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && - queryVectorBo.getRerankModelName() != null) { - return rerank(queryVectorBo, coarseResults); + // 2. 初始化原始索引 + for (int i = 0; i < coarseResults.size(); i++) { + coarseResults.get(i).setOriginalIndex(i); } - return coarseResults; + // 3. 重排序阶段 (可选) + List finalResults = coarseResults; + if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && + StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) { + finalResults = performRerank(queryVectorBo, coarseResults); + } + + // 4. 应用分值阈值过滤 (重排分值或 RRF 分值) + double threshold = queryVectorBo.getRerankScoreThreshold() != null ? + queryVectorBo.getRerankScoreThreshold() : 0.0; + + return finalResults.stream() + .filter(res -> res.getScore() >= threshold) + .collect(Collectors.toList()); } /** - * 粗召回阶段 - 向量检索 + * 粗召回阶段:根据配置执行向量搜索或混合搜索 */ - private List coarseRetrieval(QueryVectorBo queryVectorBo) { - // 如果启用重排序,扩大粗召回数量 - int originalMaxResults = queryVectorBo.getMaxResults(); - int expandedResults = originalMaxResults; - if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && - queryVectorBo.getRerankModelName() != null) { - expandedResults = originalMaxResults * RERANK_EXPANSION_FACTOR; - log.debug("启用重排序,粗召回数量从 {} 扩大到 {}", originalMaxResults, expandedResults); + private List performCoarseRetrieval(QueryVectorBo queryVectorBo) { + // 如果启用重排序,适当扩大召回数量 + int originalMaxResults = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10; + int targetMaxResults = originalMaxResults; + if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) && + StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) { + targetMaxResults = originalMaxResults * RERANK_EXPANSION_FACTOR; } - // 临时修改查询数量 - queryVectorBo.setMaxResults(expandedResults); + // 如果未启用混合检索,直接走向量搜索 + if (!Boolean.TRUE.equals(queryVectorBo.getEnableHybrid())) { + QueryVectorBo vectorQuery = copyOf(queryVectorBo, targetMaxResults); + List results = vectorStoreService.search(vectorQuery); + + // 应用基础相似度阈值过滤(如果有) + if (queryVectorBo.getSimilarityThreshold() != null) { + results = results.stream() + .filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold()) + .collect(Collectors.toList()); + } + return results; + } + + // 混合检索逻辑 + log.info("执行混合检索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery()); try { - return vectorStoreService.getQueryVector(queryVectorBo); - } finally { - // 恢复原始值 - queryVectorBo.setMaxResults(originalMaxResults); + // A. 并行执行向量搜索 + int finalTargetMaxResults = targetMaxResults; + CompletableFuture> vectorFuture = CompletableFuture.supplyAsync(() -> { + QueryVectorBo vectorQuery = copyOf(queryVectorBo, finalTargetMaxResults); + List results = vectorStoreService.search(vectorQuery); + // 向量层初步过滤 + if (queryVectorBo.getSimilarityThreshold() != null) { + return results.stream() + .filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold()) + .collect(Collectors.toList()); + } + return results; + }); + + // B. 并行执行关键词搜索 (MySQL Fulltext) + CompletableFuture> keywordFuture = CompletableFuture.supplyAsync(() -> { + try { + Long kid = Long.valueOf(queryVectorBo.getKid()); + List fragments = fragmentMapper.searchByKeyword(kid, queryVectorBo.getQuery(), finalTargetMaxResults); + return fragments.stream().map(f -> { + KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo(); + vo.setId(f.getId().toString()); + vo.setContent(f.getContent()); + vo.setDocId(f.getDocId()); + vo.setIdx(f.getIdx()); + vo.setKnowledgeId(f.getKnowledgeId()); + vo.setScore(10.0); // RRF 初始占位分 + return vo; + }).collect(Collectors.toList()); + } catch (Exception e) { + log.error("关键词检索失败: {}", e.getMessage()); + return new ArrayList<>(); + } + }); + + List vectorResults = vectorFuture.get(); + List keywordResults = keywordFuture.get(); + + // C. RRF 融合 + double alpha = queryVectorBo.getHybridAlpha() != null ? queryVectorBo.getHybridAlpha() : 0.5; + return calculateRRF(vectorResults, keywordResults, alpha); + + } catch (Exception e) { + log.error("混合检索执行失败,回退到纯向量检索: {}", e.getMessage(), e); + return vectorStoreService.search(copyOf(queryVectorBo, targetMaxResults)); } } /** * 重排序阶段 */ - private List rerank(QueryVectorBo queryVectorBo, List coarseResults) { - long startTime = System.currentTimeMillis(); - + private List performRerank(QueryVectorBo queryVectorBo, List coarseResults) { try { - // 1. 通过工厂获取重排序模型 RerankModelService rerankModel = rerankModelFactory.createModel(queryVectorBo.getRerankModelName()); + + List contents = coarseResults.stream() + .map(KnowledgeRetrievalVo::getContent) + .collect(Collectors.toList()); - // 2. 构建重排序请求 - int topN = queryVectorBo.getRerankTopN() != null ? - queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults(); + // topN 默认为 maxResults + int topN = queryVectorBo.getRerankTopN() != null ? queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults(); RerankRequest rerankRequest = RerankRequest.builder() .query(queryVectorBo.getQuery()) - .documents(coarseResults) + .documents(contents) .topN(topN) - .returnDocuments(true) .build(); - log.info("执行重排序, model={}, documents={}, topN={}", - queryVectorBo.getRerankModelName(), coarseResults.size(), topN); - - // 3. 执行重排序 RerankResult rerankResult = rerankModel.rerank(rerankRequest); - // 4. 转换重排序结果 - List finalResults = new ArrayList<>(); + // 写回分数并记录原始分 for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) { - // 应用分数阈值过滤 - if (queryVectorBo.getRerankScoreThreshold() != null && - doc.getRelevanceScore() < queryVectorBo.getRerankScoreThreshold()) { - continue; - } - - if (doc.getDocument() != null) { - finalResults.add(doc.getDocument()); + if (doc.getIndex() != null && doc.getIndex() < coarseResults.size()) { + KnowledgeRetrievalVo vo = coarseResults.get(doc.getIndex()); + vo.setRawScore(vo.getScore()); + vo.setScore(doc.getRelevanceScore()); } } - long duration = System.currentTimeMillis() - startTime; - log.info("重排序完成, 返回 {} 条结果, 耗时 {}ms", finalResults.size(), duration); - - return finalResults; + // 按新分排序 + coarseResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); + + // 截断到 topN + return coarseResults.subList(0, Math.min(topN, coarseResults.size())); } catch (Exception e) { - log.error("重排序失败: {}", e.getMessage(), e); - // 重排序失败时返回原始粗召回结果(截取到期望数量) - int limit = Math.min(queryVectorBo.getMaxResults(), coarseResults.size()); - return new ArrayList<>(coarseResults.subList(0, limit)); + log.error("重排序流程失败: {}", e.getMessage()); + int limit = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10; + return coarseResults.subList(0, Math.min(limit, coarseResults.size())); } } + + /** + * RRF (Reciprocal Rank Fusion) 融合计算 + */ + private List calculateRRF(List vectorList, List keywordList, double alpha) { + Map allMap = new LinkedHashMap<>(); + Map vectorScores = new HashMap<>(); + Map keywordScores = new HashMap<>(); + + int k = 60; // RRF 常数 + + for (int i = 0; i < vectorList.size(); i++) { + KnowledgeRetrievalVo vo = vectorList.get(i); + allMap.put(vo.getId(), vo); + vectorScores.put(vo.getId(), 1.0 / (k + i + 1)); + } + + for (int i = 0; i < keywordList.size(); i++) { + KnowledgeRetrievalVo vo = keywordList.get(i); + if (!allMap.containsKey(vo.getId())) { + allMap.put(vo.getId(), vo); + } + keywordScores.put(vo.getId(), 1.0 / (k + i + 1)); + } + + List fusedResults = new ArrayList<>(); + for (Map.Entry entry : allMap.entrySet()) { + String id = entry.getKey(); + double finalScore = (1 - alpha) * vectorScores.getOrDefault(id, 0.0) + + alpha * keywordScores.getOrDefault(id, 0.0); + + KnowledgeRetrievalVo vo = entry.getValue(); + vo.setScore(finalScore * 60.0); // 归一化缩放 + fusedResults.add(vo); + } + + fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore())); + return fusedResults; + } + + private QueryVectorBo copyOf(QueryVectorBo original, int maxResults) { + QueryVectorBo copy = new QueryVectorBo(); + copy.setQuery(original.getQuery()); + copy.setKid(original.getKid()); + copy.setMaxResults(maxResults); + copy.setVectorModelName(original.getVectorModelName()); + copy.setEmbeddingModelName(original.getEmbeddingModelName()); + copy.setApiKey(original.getApiKey()); + copy.setBaseUrl(original.getBaseUrl()); + return copy; + } }