mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-23 08:43:40 +00:00
14
docs/script/sql/update/updat-0423.sql
Normal file
14
docs/script/sql/update/updat-0423.sql
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
-- 为知识库信息表新增检索配置字段 (剔除了已存在的重排字段)
|
||||||
|
ALTER TABLE knowledge_info
|
||||||
|
ADD COLUMN similarity_threshold DOUBLE DEFAULT 0.5 COMMENT '相似度阈值'
|
||||||
|
AFTER retrieve_limit;
|
||||||
|
|
||||||
|
ALTER TABLE knowledge_info ADD COLUMN enable_hybrid tinyint(1) DEFAULT 0 COMMENT '是否启用混合检索';
|
||||||
|
ALTER TABLE knowledge_info ADD COLUMN hybrid_alpha double DEFAULT 0.5 COMMENT '混合检索权重比例 (0.0=纯向量, 1.0=纯关键词)';
|
||||||
|
|
||||||
|
-- 为知识片段表增加全文索引及关联ID
|
||||||
|
ALTER TABLE knowledge_fragment ADD COLUMN knowledge_id bigint COMMENT '知识库ID';
|
||||||
|
ALTER TABLE knowledge_fragment ADD FULLTEXT INDEX ft_content (content) WITH PARSER ngram;
|
||||||
|
|
||||||
|
-- 为知识库附件表增加解析状态字段
|
||||||
|
ALTER TABLE `knowledge_attach` ADD COLUMN `status` TINYINT DEFAULT 0 COMMENT '解析状态: 0待解析, 1解析中, 2已解析, 3解析失败';
|
||||||
@@ -10,6 +10,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
|
|||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.core.task.VirtualThreadTaskExecutor;
|
import org.springframework.core.task.VirtualThreadTaskExecutor;
|
||||||
|
|
||||||
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
import java.util.concurrent.*;
|
import java.util.concurrent.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -22,6 +23,12 @@ import java.util.concurrent.*;
|
|||||||
@EnableConfigurationProperties(ThreadPoolProperties.class)
|
@EnableConfigurationProperties(ThreadPoolProperties.class)
|
||||||
public class ThreadPoolConfig {
|
public class ThreadPoolConfig {
|
||||||
|
|
||||||
|
private final ThreadPoolProperties properties;
|
||||||
|
|
||||||
|
public ThreadPoolConfig(ThreadPoolProperties properties) {
|
||||||
|
this.properties = properties;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 核心线程数 = cpu 核心数 + 1
|
* 核心线程数 = cpu 核心数 + 1
|
||||||
*/
|
*/
|
||||||
@@ -54,6 +61,22 @@ public class ThreadPoolConfig {
|
|||||||
return scheduledThreadPoolExecutor;
|
return scheduledThreadPoolExecutor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库解析专用异步线程池
|
||||||
|
*/
|
||||||
|
@Bean(name = "knowledgeParseExecutor")
|
||||||
|
public ThreadPoolTaskExecutor knowledgeParseExecutor() {
|
||||||
|
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
||||||
|
executor.setCorePoolSize(core);
|
||||||
|
executor.setMaxPoolSize(core * 2);
|
||||||
|
executor.setQueueCapacity(properties.getQueueCapacity());
|
||||||
|
executor.setKeepAliveSeconds(properties.getKeepAliveSeconds());
|
||||||
|
executor.setThreadNamePrefix("knowledge-parse-pool-");
|
||||||
|
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
|
||||||
|
executor.initialize();
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 销毁事件
|
* 销毁事件
|
||||||
* 停止线程池
|
* 停止线程池
|
||||||
|
|||||||
@@ -110,6 +110,17 @@ public class KnowledgeAttachController extends BaseController {
|
|||||||
@PostMapping(value = "/upload")
|
@PostMapping(value = "/upload")
|
||||||
public R<String> upload(KnowledgeInfoUploadBo bo){
|
public R<String> upload(KnowledgeInfoUploadBo bo){
|
||||||
knowledgeAttachService.upload(bo);
|
knowledgeAttachService.upload(bo);
|
||||||
return R.ok("上传知识库附件成功!");
|
return R.ok("上传成功!");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 手动解析附件内容
|
||||||
|
*
|
||||||
|
* @param id 附件ID
|
||||||
|
*/
|
||||||
|
@PostMapping("/parse/{id}")
|
||||||
|
public R<Void> parse(@PathVariable Long id) {
|
||||||
|
knowledgeAttachService.parse(id);
|
||||||
|
return R.ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import jakarta.validation.constraints.*;
|
|||||||
import cn.dev33.satoken.annotation.SaCheckPermission;
|
import cn.dev33.satoken.annotation.SaCheckPermission;
|
||||||
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
import org.springframework.validation.annotation.Validated;
|
import org.springframework.validation.annotation.Validated;
|
||||||
@@ -102,4 +103,12 @@ public class KnowledgeFragmentController extends BaseController {
|
|||||||
@PathVariable Long[] ids) {
|
@PathVariable Long[] ids) {
|
||||||
return toAjax(knowledgeFragmentService.deleteWithValidByIds(List.of(ids), true));
|
return toAjax(knowledgeFragmentService.deleteWithValidByIds(List.of(ids), true));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索测试
|
||||||
|
*/
|
||||||
|
@PostMapping("/retrieval")
|
||||||
|
public R<List<KnowledgeRetrievalVo>> retrieval(@RequestBody KnowledgeFragmentBo bo) {
|
||||||
|
return R.ok(knowledgeFragmentService.retrieval(bo));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,5 +49,44 @@ public class KnowledgeFragmentBo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库ID
|
||||||
|
*/
|
||||||
|
private Long knowledgeId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索内容
|
||||||
|
*/
|
||||||
|
private String query;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 返回条数
|
||||||
|
*/
|
||||||
|
private Integer topK;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 相似度阈值
|
||||||
|
*/
|
||||||
|
private Double threshold;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用重排
|
||||||
|
*/
|
||||||
|
private Boolean enableRerank;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重排模型名称
|
||||||
|
*/
|
||||||
|
private String rerankModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索
|
||||||
|
*/
|
||||||
|
private Boolean enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,6 +62,11 @@ public class KnowledgeInfoBo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private Long retrieveLimit;
|
private Long retrieveLimit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 相似度阈值
|
||||||
|
*/
|
||||||
|
private Double similarityThreshold;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 文本块大小
|
* 文本块大小
|
||||||
*/
|
*/
|
||||||
@@ -98,12 +103,19 @@ public class KnowledgeInfoBo extends BaseEntity {
|
|||||||
private Double rerankScoreThreshold;
|
private Double rerankScoreThreshold;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索(0 否 1是)
|
||||||
|
*/
|
||||||
|
private Integer enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 备注
|
* 备注
|
||||||
*/
|
*/
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,11 @@ public class KnowledgeInfoUploadBo {
|
|||||||
|
|
||||||
private MultipartFile file;
|
private MultipartFile file;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否自动解析 (true: 立即解析, false: 仅上传)
|
||||||
|
*/
|
||||||
|
private Boolean autoParse;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 生效时间, 为空则立即生效
|
* 生效时间, 为空则立即生效
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -77,4 +77,22 @@ public class QueryVectorBo {
|
|||||||
*/
|
*/
|
||||||
private Double rerankScoreThreshold;
|
private Double rerankScoreThreshold;
|
||||||
|
|
||||||
|
// ========== 混合检索与阈值相关参数 ==========
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 相似度阈值 (0.0-1.0)
|
||||||
|
* 应用于向量搜索阶段
|
||||||
|
*/
|
||||||
|
private Double similarityThreshold;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索
|
||||||
|
*/
|
||||||
|
private Boolean enableHybrid = false;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,5 +57,10 @@ public class KnowledgeAttach extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析状态: 0待解析, 1解析中, 2已解析, 3解析失败
|
||||||
|
*/
|
||||||
|
private Integer status;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,5 +47,10 @@ public class KnowledgeFragment extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库ID
|
||||||
|
*/
|
||||||
|
private Long knowledgeId;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,6 +63,11 @@ public class KnowledgeInfo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private Long retrieveLimit;
|
private Long retrieveLimit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 相似度阈值
|
||||||
|
*/
|
||||||
|
private Double similarityThreshold;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 文本块大小
|
* 文本块大小
|
||||||
*/
|
*/
|
||||||
@@ -98,6 +103,16 @@ public class KnowledgeInfo extends BaseEntity {
|
|||||||
*/
|
*/
|
||||||
private Double rerankScoreThreshold;
|
private Double rerankScoreThreshold;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索(0 否 1是)
|
||||||
|
*/
|
||||||
|
private Integer enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 备注
|
* 备注
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package org.ruoyi.domain.vo.knowledge;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档分块数统计 VO(用于 GROUP BY 查询结果接收)
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class DocFragmentCountVo {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档ID(关联 knowledge_attach.doc_id)
|
||||||
|
*/
|
||||||
|
private String docId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 该文档下的分块数量
|
||||||
|
*/
|
||||||
|
private Integer fragmentCount;
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
|||||||
|
|
||||||
import java.io.Serial;
|
import java.io.Serial;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.util.Date;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -68,5 +69,22 @@ public class KnowledgeAttachVo implements Serializable {
|
|||||||
@ExcelProperty(value = "备注")
|
@ExcelProperty(value = "备注")
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 上传时间(来自 BaseEntity.createTime)
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "上传时间")
|
||||||
|
private Date createTime;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析状态: 0待解析, 1解析中, 2已解析, 3解析失败
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "解析状态")
|
||||||
|
private Integer status;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分块数(统计字段,非数据库列)
|
||||||
|
*/
|
||||||
|
private Integer fragmentCount;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ public class KnowledgeFragmentVo implements Serializable {
|
|||||||
* 片段索引下标
|
* 片段索引下标
|
||||||
*/
|
*/
|
||||||
@ExcelProperty(value = "片段索引下标")
|
@ExcelProperty(value = "片段索引下标")
|
||||||
private Long idx;
|
private Integer idx;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 文档内容
|
* 文档内容
|
||||||
@@ -53,5 +53,10 @@ public class KnowledgeFragmentVo implements Serializable {
|
|||||||
@ExcelProperty(value = "备注")
|
@ExcelProperty(value = "备注")
|
||||||
private String remark;
|
private String remark;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库ID
|
||||||
|
*/
|
||||||
|
private Long knowledgeId;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,6 +76,12 @@ public class KnowledgeInfoVo implements Serializable {
|
|||||||
@ExcelProperty(value = "知识库中检索的条数")
|
@ExcelProperty(value = "知识库中检索的条数")
|
||||||
private Integer retrieveLimit;
|
private Integer retrieveLimit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 相似度阈值
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "相似度阈值")
|
||||||
|
private Double similarityThreshold;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 文本块大小
|
* 文本块大小
|
||||||
*/
|
*/
|
||||||
@@ -118,6 +124,24 @@ public class KnowledgeInfoVo implements Serializable {
|
|||||||
@ExcelProperty(value = "重排序分数阈值")
|
@ExcelProperty(value = "重排序分数阈值")
|
||||||
private Double rerankScoreThreshold;
|
private Double rerankScoreThreshold;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用混合检索(0 否 1是)
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "是否启用混合检索")
|
||||||
|
private Integer enableHybrid;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 混合检索权重 (0.0-1.0)
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "混合检索权重")
|
||||||
|
private Double hybridAlpha;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档数量
|
||||||
|
*/
|
||||||
|
@ExcelProperty(value = "文档数量")
|
||||||
|
private Integer documentCount;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 备注
|
* 备注
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package org.ruoyi.domain.vo.knowledge;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.Serial;
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识检索测试结果视图对象
|
||||||
|
*
|
||||||
|
* @author RobustH
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class KnowledgeRetrievalVo implements Serializable {
|
||||||
|
|
||||||
|
@Serial
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 片段ID
|
||||||
|
*/
|
||||||
|
private String id;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档ID
|
||||||
|
*/
|
||||||
|
private String docId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库ID
|
||||||
|
*/
|
||||||
|
private Long knowledgeId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分片索引
|
||||||
|
*/
|
||||||
|
private Integer idx;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 片段内容
|
||||||
|
*/
|
||||||
|
private String content;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 相似度得分
|
||||||
|
*/
|
||||||
|
private Double score;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 原始检索排名 (重排前)
|
||||||
|
*/
|
||||||
|
private Integer originalIndex;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 原始检索得分 (重排前)
|
||||||
|
*/
|
||||||
|
private Double rawScore;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 来源文档名称
|
||||||
|
*/
|
||||||
|
private String sourceName;
|
||||||
|
}
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
package org.ruoyi.enums;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识库附件解析状态枚举
|
||||||
|
*
|
||||||
|
* @author RobustH
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
@AllArgsConstructor
|
||||||
|
public enum KnowledgeAttachStatus {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 待解析
|
||||||
|
*/
|
||||||
|
WAITING(0, "待解析"),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析中
|
||||||
|
*/
|
||||||
|
PARSING(1, "解析中"),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 已解析
|
||||||
|
*/
|
||||||
|
COMPLETED(2, "已解析"),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析失败
|
||||||
|
*/
|
||||||
|
FAILED(3, "解析失败");
|
||||||
|
|
||||||
|
private final Integer code;
|
||||||
|
private final String info;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,5 +1,8 @@
|
|||||||
package org.ruoyi.mapper.knowledge;
|
package org.ruoyi.mapper.knowledge;
|
||||||
|
|
||||||
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
|
import org.apache.ibatis.annotations.Param;
|
||||||
|
import org.apache.ibatis.annotations.Select;
|
||||||
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeAttachVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeAttachVo;
|
||||||
import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus;
|
import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus;
|
||||||
@@ -10,6 +13,12 @@ import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus;
|
|||||||
* @author ageerle
|
* @author ageerle
|
||||||
* @date 2025-12-17
|
* @date 2025-12-17
|
||||||
*/
|
*/
|
||||||
|
@Mapper
|
||||||
public interface KnowledgeAttachMapper extends BaseMapperPlus<KnowledgeAttach, KnowledgeAttachVo> {
|
public interface KnowledgeAttachMapper extends BaseMapperPlus<KnowledgeAttach, KnowledgeAttachVo> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 统计指定知识库下的文档数量
|
||||||
|
*/
|
||||||
|
@Select("SELECT COUNT(*) FROM knowledge_attach WHERE knowledge_id = #{knowledgeId}")
|
||||||
|
int countByKnowledgeId(@Param("knowledgeId") Long knowledgeId);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,45 @@
|
|||||||
package org.ruoyi.mapper.knowledge;
|
package org.ruoyi.mapper.knowledge;
|
||||||
|
|
||||||
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
|
import org.apache.ibatis.annotations.Param;
|
||||||
|
import org.apache.ibatis.annotations.Select;
|
||||||
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.DocFragmentCountVo;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||||
import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus;
|
import org.ruoyi.common.mybatis.core.mapper.BaseMapperPlus;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识片段Mapper接口
|
* 知识片段Mapper接口
|
||||||
*
|
*
|
||||||
* @author ageerle
|
* @author ageerle
|
||||||
* @date 2025-12-17
|
* @date 2025-12-17
|
||||||
*/
|
*/
|
||||||
|
@Mapper
|
||||||
public interface KnowledgeFragmentMapper extends BaseMapperPlus<KnowledgeFragment, KnowledgeFragmentVo> {
|
public interface KnowledgeFragmentMapper extends BaseMapperPlus<KnowledgeFragment, KnowledgeFragmentVo> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 批量统计各文档的分块数(强类型接收,避免 Map key 大小写问题)
|
||||||
|
*
|
||||||
|
* @param docIds 文档 ID 列表
|
||||||
|
* @return 每个 docId 对应的分块数列表
|
||||||
|
*/
|
||||||
|
@Select("<script>" +
|
||||||
|
"SELECT doc_id AS docId, COUNT(*) AS fragmentCount " +
|
||||||
|
"FROM knowledge_fragment " +
|
||||||
|
"WHERE doc_id IN " +
|
||||||
|
"<foreach collection='docIds' item='id' open='(' separator=',' close=')'>#{id}</foreach> " +
|
||||||
|
"GROUP BY doc_id" +
|
||||||
|
"</script>")
|
||||||
|
List<DocFragmentCountVo> selectFragmentCountByDocIds(@Param("docIds") List<String> docIds);
|
||||||
|
@Select("<script>" +
|
||||||
|
"SELECT id, doc_id AS docId, content, idx, knowledge_id AS knowledgeId " +
|
||||||
|
"FROM knowledge_fragment " +
|
||||||
|
"WHERE knowledge_id = #{knowledgeId} " +
|
||||||
|
"AND MATCH (content) AGAINST (#{query} IN NATURAL LANGUAGE MODE) " +
|
||||||
|
"ORDER BY MATCH (content) AGAINST (#{query} IN NATURAL LANGUAGE MODE) DESC " +
|
||||||
|
"LIMIT #{limit}" +
|
||||||
|
"</script>")
|
||||||
|
List<KnowledgeFragmentVo> searchByKeyword(@Param("knowledgeId") Long knowledgeId, @Param("query") String query, @Param("limit") Integer limit);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,11 @@ import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
|
|||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||||
import dev.langchain4j.service.tool.ToolProvider;
|
import dev.langchain4j.service.tool.ToolProvider;
|
||||||
import dev.langchain4j.skills.shell.ShellSkills;
|
import dev.langchain4j.skills.shell.ShellSkills;
|
||||||
|
import dev.langchain4j.rag.AugmentationRequest;
|
||||||
|
import dev.langchain4j.rag.AugmentationResult;
|
||||||
|
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
|
||||||
|
import dev.langchain4j.rag.RetrievalAugmentor;
|
||||||
|
import dev.langchain4j.rag.query.Metadata;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.SneakyThrows;
|
import lombok.SneakyThrows;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -55,6 +60,7 @@ import org.ruoyi.service.chat.IChatMessageService;
|
|||||||
import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore;
|
import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore;
|
||||||
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||||
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
|
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
|
||||||
|
import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever;
|
||||||
import org.ruoyi.service.vector.VectorStoreService;
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
@@ -412,16 +418,49 @@ public class ChatServiceFacade implements IChatService {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 构建上下文消息列表
|
* 构建上下文消息列表
|
||||||
|
|
||||||
* 消息顺序:历史消息 → 当前用户消息(确保 AI 正确理解对话上下文)
|
* 消息顺序:历史消息 → 当前用户消息(确保 AI 正确理解对话上下文)
|
||||||
*
|
*
|
||||||
* @param chatRequest 聊天请求
|
* @param chatRequest 聊天请求
|
||||||
* @return 上下文消息列表
|
* @return 上下文消息列表
|
||||||
*/
|
*/
|
||||||
private List<ChatMessage> buildContextMessages(ChatRequest chatRequest) {
|
private List<ChatMessage> buildContextMessages(ChatRequest chatRequest) {
|
||||||
List<ChatMessage> messages = new ArrayList<>();
|
List<ChatMessage> messages = new ArrayList<>();
|
||||||
|
|
||||||
// 从数据库查询历史对话消息(放在前面)
|
// 1. 初始化当前用户消息
|
||||||
|
UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent());
|
||||||
|
|
||||||
|
// 2. 知识库检索增强 (RAG)
|
||||||
|
if (chatRequest.getKnowledgeId() != null) {
|
||||||
|
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId()));
|
||||||
|
if (knowledgeInfoVo != null) {
|
||||||
|
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
if (chatModel != null) {
|
||||||
|
log.info("执行高级 RAG 流程: kid={}", chatRequest.getKnowledgeId());
|
||||||
|
|
||||||
|
// 构建自定义检索器
|
||||||
|
CustomVectorRetriever retriever = new CustomVectorRetriever(
|
||||||
|
knowledgeRetrievalService, knowledgeInfoVo, chatModel);
|
||||||
|
|
||||||
|
// 构建增强流水线
|
||||||
|
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
|
||||||
|
.contentRetriever(retriever)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// 执行增强:编织上下文到 UserMessage
|
||||||
|
Metadata metadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>());
|
||||||
|
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
|
||||||
|
AugmentationResult result = augmentor.augment(augmentationRequest);
|
||||||
|
|
||||||
|
ChatMessage augmented = result.chatMessage();
|
||||||
|
if (augmented instanceof UserMessage) {
|
||||||
|
userMessage = (UserMessage) augmented;
|
||||||
|
log.debug("RAG 增强完成,UserMessage 已注入背景知识");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 从数据库查询历史对话消息(放在前面)
|
||||||
if (chatRequest.getSessionId() != null) {
|
if (chatRequest.getSessionId() != null) {
|
||||||
MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId());
|
MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId());
|
||||||
if (memory != null) {
|
if (memory != null) {
|
||||||
@@ -433,38 +472,7 @@ public class ChatServiceFacade implements IChatService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从向量库查询相关历史消息(知识库内容作为上下文)
|
// 4. 添加经过增强的用户消息(放在最后)
|
||||||
if (chatRequest.getKnowledgeId() != null) {
|
|
||||||
// 查询知识库信息
|
|
||||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId()));
|
|
||||||
if (knowledgeInfoVo == null) {
|
|
||||||
log.warn("知识库信息不存在,kid: {}", chatRequest.getKnowledgeId());
|
|
||||||
// 继续添加当前用户消息
|
|
||||||
messages.add(UserMessage.userMessage(chatRequest.getContent()));
|
|
||||||
return messages;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查询向量模型配置信息
|
|
||||||
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
|
||||||
if (chatModel == null) {
|
|
||||||
log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel());
|
|
||||||
messages.add(UserMessage.userMessage(chatRequest.getContent()));
|
|
||||||
return messages;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建向量查询参数
|
|
||||||
QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel);
|
|
||||||
|
|
||||||
// 使用知识库检索服务(支持重排序)
|
|
||||||
List<String> nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo);
|
|
||||||
for (String prompt : nearestList) {
|
|
||||||
// 知识库内容作为系统上下文添加
|
|
||||||
messages.add(new AiMessage(prompt));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建当前用户消息(放在最后)
|
|
||||||
UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent());
|
|
||||||
messages.add(userMessage);
|
messages.add(userMessage);
|
||||||
|
|
||||||
return messages;
|
return messages;
|
||||||
|
|||||||
@@ -72,4 +72,11 @@ public interface IKnowledgeAttachService {
|
|||||||
* 上传附件
|
* 上传附件
|
||||||
*/
|
*/
|
||||||
void upload(KnowledgeInfoUploadBo bo);
|
void upload(KnowledgeInfoUploadBo bo);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析附件知识片段
|
||||||
|
*
|
||||||
|
* @param id 附件ID
|
||||||
|
*/
|
||||||
|
void parse(Long id);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
|||||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||||
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -65,4 +66,12 @@ public interface IKnowledgeFragmentService {
|
|||||||
* @return 是否删除成功
|
* @return 是否删除成功
|
||||||
*/
|
*/
|
||||||
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
|
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索测试
|
||||||
|
*
|
||||||
|
* @param bo 检索参数
|
||||||
|
* @return 检索结果
|
||||||
|
*/
|
||||||
|
List<KnowledgeRetrievalVo> retrieval(KnowledgeFragmentBo bo);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,24 +2,27 @@ package org.ruoyi.service.knowledge.impl;
|
|||||||
|
|
||||||
import cn.hutool.core.collection.CollUtil;
|
import cn.hutool.core.collection.CollUtil;
|
||||||
import cn.hutool.core.util.RandomUtil;
|
import cn.hutool.core.util.RandomUtil;
|
||||||
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
|
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
||||||
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||||
|
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
||||||
|
import org.ruoyi.enums.KnowledgeAttachStatus;
|
||||||
import org.ruoyi.common.core.domain.dto.OssDTO;
|
import org.ruoyi.common.core.domain.dto.OssDTO;
|
||||||
import org.ruoyi.common.core.service.OssService;
|
import org.ruoyi.common.core.service.OssService;
|
||||||
import org.ruoyi.common.core.utils.MapstructUtils;
|
import org.ruoyi.common.core.utils.MapstructUtils;
|
||||||
|
import org.ruoyi.common.core.utils.SpringUtils;
|
||||||
import org.ruoyi.common.core.utils.StringUtils;
|
import org.ruoyi.common.core.utils.StringUtils;
|
||||||
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
|
||||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
||||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
|
||||||
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.ruoyi.domain.bo.knowledge.KnowledgeAttachBo;
|
import org.ruoyi.domain.bo.knowledge.KnowledgeAttachBo;
|
||||||
import org.ruoyi.domain.bo.knowledge.KnowledgeInfoUploadBo;
|
import org.ruoyi.domain.bo.knowledge.KnowledgeInfoUploadBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||||
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.DocFragmentCountVo;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeAttachVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeAttachVo;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||||
import org.ruoyi.factory.ResourceLoaderFactory;
|
import org.ruoyi.factory.ResourceLoaderFactory;
|
||||||
@@ -29,11 +32,15 @@ import org.ruoyi.service.knowledge.IKnowledgeAttachService;
|
|||||||
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||||
import org.ruoyi.service.knowledge.ResourceLoader;
|
import org.ruoyi.service.knowledge.ResourceLoader;
|
||||||
import org.ruoyi.service.vector.VectorStoreService;
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
|
import org.springframework.scheduling.annotation.Async;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.web.multipart.MultipartFile;
|
import org.springframework.web.multipart.MultipartFile;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
import java.net.URL;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识库附件Service业务层处理
|
* 知识库附件Service业务层处理
|
||||||
@@ -47,57 +54,51 @@ import java.util.*;
|
|||||||
public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService {
|
public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService {
|
||||||
|
|
||||||
private final KnowledgeAttachMapper baseMapper;
|
private final KnowledgeAttachMapper baseMapper;
|
||||||
|
|
||||||
private final IKnowledgeInfoService knowledgeInfoService;
|
private final IKnowledgeInfoService knowledgeInfoService;
|
||||||
|
|
||||||
private final KnowledgeFragmentMapper knowledgeFragmentMapper;
|
private final KnowledgeFragmentMapper knowledgeFragmentMapper;
|
||||||
|
|
||||||
private final IChatModelService chatModelService;
|
private final IChatModelService chatModelService;
|
||||||
|
|
||||||
private final ResourceLoaderFactory resourceLoaderFactory;
|
private final ResourceLoaderFactory resourceLoaderFactory;
|
||||||
|
|
||||||
private final VectorStoreService vectorStoreService;
|
private final VectorStoreService vectorStoreService;
|
||||||
|
|
||||||
private final OssService ossService;
|
private final OssService ossService;
|
||||||
/**
|
|
||||||
* 查询知识库附件
|
|
||||||
*
|
|
||||||
* @param id 主键
|
|
||||||
* @return 知识库附件
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public KnowledgeAttachVo queryById(Long id){
|
public KnowledgeAttachVo queryById(Long id) {
|
||||||
return baseMapper.selectVoById(id);
|
return baseMapper.selectVoById(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 分页查询知识库附件列表
|
|
||||||
*
|
|
||||||
* @param bo 查询条件
|
|
||||||
* @param pageQuery 分页参数
|
|
||||||
* @return 知识库附件分页列表
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public TableDataInfo<KnowledgeAttachVo> queryPageList(KnowledgeAttachBo bo, PageQuery pageQuery) {
|
public TableDataInfo<KnowledgeAttachVo> queryPageList(KnowledgeAttachBo bo, PageQuery pageQuery) {
|
||||||
LambdaQueryWrapper<KnowledgeAttach> lqw = buildQueryWrapper(bo);
|
LambdaQueryWrapper<KnowledgeAttach> lqw = buildQueryWrapper(bo);
|
||||||
Page<KnowledgeAttachVo> result = baseMapper.selectVoPage(pageQuery.build(), lqw);
|
Page<KnowledgeAttachVo> result = baseMapper.selectVoPage(pageQuery.build(), lqw);
|
||||||
|
fillFragmentCount(result.getRecords());
|
||||||
return TableDataInfo.build(result);
|
return TableDataInfo.build(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 查询符合条件的知识库附件列表
|
|
||||||
*
|
|
||||||
* @param bo 查询条件
|
|
||||||
* @return 知识库附件列表
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public List<KnowledgeAttachVo> queryList(KnowledgeAttachBo bo) {
|
public List<KnowledgeAttachVo> queryList(KnowledgeAttachBo bo) {
|
||||||
LambdaQueryWrapper<KnowledgeAttach> lqw = buildQueryWrapper(bo);
|
LambdaQueryWrapper<KnowledgeAttach> lqw = buildQueryWrapper(bo);
|
||||||
return baseMapper.selectVoList(lqw);
|
List<KnowledgeAttachVo> list = baseMapper.selectVoList(lqw);
|
||||||
|
fillFragmentCount(list);
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void fillFragmentCount(List<KnowledgeAttachVo> records) {
|
||||||
|
if (records == null || records.isEmpty()) return;
|
||||||
|
List<String> docIds = records.stream()
|
||||||
|
.map(KnowledgeAttachVo::getDocId)
|
||||||
|
.filter(StringUtils::isNotBlank)
|
||||||
|
.distinct()
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
if (docIds.isEmpty()) return;
|
||||||
|
List<DocFragmentCountVo> countList = knowledgeFragmentMapper.selectFragmentCountByDocIds(docIds);
|
||||||
|
Map<String, Integer> countMap = countList.stream()
|
||||||
|
.collect(Collectors.toMap(DocFragmentCountVo::getDocId, DocFragmentCountVo::getFragmentCount, (k1, k2) -> k1));
|
||||||
|
for (KnowledgeAttachVo vo : records) {
|
||||||
|
vo.setFragmentCount(countMap.getOrDefault(vo.getDocId(), 0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private LambdaQueryWrapper<KnowledgeAttach> buildQueryWrapper(KnowledgeAttachBo bo) {
|
private LambdaQueryWrapper<KnowledgeAttach> buildQueryWrapper(KnowledgeAttachBo bo) {
|
||||||
Map<String, Object> params = bo.getParams();
|
|
||||||
LambdaQueryWrapper<KnowledgeAttach> lqw = Wrappers.lambdaQuery();
|
LambdaQueryWrapper<KnowledgeAttach> lqw = Wrappers.lambdaQuery();
|
||||||
lqw.orderByAsc(KnowledgeAttach::getId);
|
lqw.orderByAsc(KnowledgeAttach::getId);
|
||||||
lqw.eq(bo.getKnowledgeId() != null, KnowledgeAttach::getKnowledgeId, bo.getKnowledgeId());
|
lqw.eq(bo.getKnowledgeId() != null, KnowledgeAttach::getKnowledgeId, bo.getKnowledgeId());
|
||||||
@@ -107,16 +108,9 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService {
|
|||||||
return lqw;
|
return lqw;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 新增知识库附件
|
|
||||||
*
|
|
||||||
* @param bo 知识库附件
|
|
||||||
* @return 是否新增成功
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public Boolean insertByBo(KnowledgeAttachBo bo) {
|
public Boolean insertByBo(KnowledgeAttachBo bo) {
|
||||||
KnowledgeAttach add = MapstructUtils.convert(bo, KnowledgeAttach.class);
|
KnowledgeAttach add = MapstructUtils.convert(bo, KnowledgeAttach.class);
|
||||||
validEntityBeforeSave(add);
|
|
||||||
boolean flag = baseMapper.insert(add) > 0;
|
boolean flag = baseMapper.insert(add) > 0;
|
||||||
if (flag) {
|
if (flag) {
|
||||||
bo.setId(add.getId());
|
bo.setId(add.getId());
|
||||||
@@ -124,98 +118,109 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService {
|
|||||||
return flag;
|
return flag;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 修改知识库附件
|
|
||||||
*
|
|
||||||
* @param bo 知识库附件
|
|
||||||
* @return 是否修改成功
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public Boolean updateByBo(KnowledgeAttachBo bo) {
|
public Boolean updateByBo(KnowledgeAttachBo bo) {
|
||||||
KnowledgeAttach update = MapstructUtils.convert(bo, KnowledgeAttach.class);
|
KnowledgeAttach update = MapstructUtils.convert(bo, KnowledgeAttach.class);
|
||||||
validEntityBeforeSave(update);
|
|
||||||
return baseMapper.updateById(update) > 0;
|
return baseMapper.updateById(update) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 保存前的数据校验
|
|
||||||
*/
|
|
||||||
private void validEntityBeforeSave(KnowledgeAttach entity){
|
|
||||||
//TODO 做一些数据校验,如唯一约束
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 校验并批量删除知识库附件信息
|
|
||||||
*
|
|
||||||
* @param ids 待删除的主键集合
|
|
||||||
* @param isValid 是否进行有效性校验
|
|
||||||
* @return 是否删除成功
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid) {
|
public Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid) {
|
||||||
if(isValid){
|
|
||||||
//TODO 做一些业务上的校验,判断是否需要校验
|
|
||||||
}
|
|
||||||
return baseMapper.deleteByIds(ids) > 0;
|
return baseMapper.deleteByIds(ids) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void upload(KnowledgeInfoUploadBo bo) {
|
public void upload(KnowledgeInfoUploadBo bo) {
|
||||||
MultipartFile file = bo.getFile();
|
MultipartFile file = bo.getFile();
|
||||||
// 保存文件信息
|
|
||||||
OssDTO ossDTO = ossService.uploadFile(file);
|
OssDTO ossDTO = ossService.uploadFile(file);
|
||||||
Long knowledgeId = bo.getKnowledgeId();
|
|
||||||
List<String> chunkList = new ArrayList<>();
|
|
||||||
KnowledgeAttach knowledgeAttach = new KnowledgeAttach();
|
KnowledgeAttach knowledgeAttach = new KnowledgeAttach();
|
||||||
knowledgeAttach.setKnowledgeId(bo.getKnowledgeId());
|
knowledgeAttach.setKnowledgeId(bo.getKnowledgeId());
|
||||||
String docId = RandomUtil.randomString(10);
|
|
||||||
knowledgeAttach.setOssId(ossDTO.getOssId());
|
knowledgeAttach.setOssId(ossDTO.getOssId());
|
||||||
knowledgeAttach.setDocId(docId);
|
knowledgeAttach.setDocId(RandomUtil.randomString(10));
|
||||||
knowledgeAttach.setName(ossDTO.getOriginalName());
|
knowledgeAttach.setName(ossDTO.getOriginalName());
|
||||||
knowledgeAttach.setType(ossDTO.getFileSuffix());
|
knowledgeAttach.setType(ossDTO.getFileSuffix());
|
||||||
String content = "";
|
knowledgeAttach.setStatus(KnowledgeAttachStatus.WAITING.getCode()); // 待解析
|
||||||
ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(knowledgeAttach.getType());
|
|
||||||
// 文档分段入库
|
baseMapper.insert(knowledgeAttach);
|
||||||
List<String> fids = new ArrayList<>();
|
|
||||||
|
if (Boolean.TRUE.equals(bo.getAutoParse())) {
|
||||||
|
// 通过 SpringUtils 获取代理对象,确保 @Async 生效
|
||||||
|
SpringUtils.getBean(IKnowledgeAttachService.class).parse(knowledgeAttach.getId());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Async("knowledgeParseExecutor")
|
||||||
|
@Override
|
||||||
|
public void parse(Long id) {
|
||||||
|
KnowledgeAttach attach = baseMapper.selectById(id);
|
||||||
|
if (attach == null || (!KnowledgeAttachStatus.WAITING.getCode().equals(attach.getStatus()) && !KnowledgeAttachStatus.FAILED.getCode().equals(attach.getStatus()))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
content = resourceLoader.getContent(file.getInputStream());
|
attach.setStatus(KnowledgeAttachStatus.PARSING.getCode()); // 解析中
|
||||||
chunkList = resourceLoader.getChunkList(content, String.valueOf(knowledgeId));
|
baseMapper.updateById(attach);
|
||||||
|
|
||||||
|
log.info("开始解析知识库文档... id: {}, docId: {}", id, attach.getDocId());
|
||||||
|
|
||||||
|
Long knowledgeId = attach.getKnowledgeId();
|
||||||
|
String docId = attach.getDocId();
|
||||||
|
|
||||||
|
// 获取文件信息并下载
|
||||||
|
List<OssDTO> ossDTOs = ossService.selectByIds(String.valueOf(attach.getOssId()));
|
||||||
|
if (ossDTOs == null || ossDTOs.isEmpty()) {
|
||||||
|
throw new RuntimeException("未找到对应的 OSS 文件信息");
|
||||||
|
}
|
||||||
|
OssDTO ossDTO = ossDTOs.get(0);
|
||||||
|
String content;
|
||||||
|
ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(attach.getType());
|
||||||
|
try (InputStream inputStream = new URL(ossDTO.getUrl()).openStream()) {
|
||||||
|
content = resourceLoader.getContent(inputStream);
|
||||||
|
}
|
||||||
|
List<String> chunkList = resourceLoader.getChunkList(content, String.valueOf(knowledgeId));
|
||||||
|
|
||||||
|
List<String> fids = new ArrayList<>();
|
||||||
List<KnowledgeFragment> knowledgeFragmentList = new ArrayList<>();
|
List<KnowledgeFragment> knowledgeFragmentList = new ArrayList<>();
|
||||||
if (CollUtil.isNotEmpty(chunkList)) {
|
if (CollUtil.isNotEmpty(chunkList)) {
|
||||||
for (int i = 0; i < chunkList.size(); i++) {
|
for (int i = 0; i < chunkList.size(); i++) {
|
||||||
// 生成知识片段ID
|
|
||||||
String fid = RandomUtil.randomString(10);
|
String fid = RandomUtil.randomString(10);
|
||||||
fids.add(fid);
|
fids.add(fid);
|
||||||
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
|
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
|
||||||
|
knowledgeFragment.setKnowledgeId(knowledgeId);
|
||||||
knowledgeFragment.setDocId(docId);
|
knowledgeFragment.setDocId(docId);
|
||||||
knowledgeFragment.setIdx(i);
|
knowledgeFragment.setIdx(i);
|
||||||
knowledgeFragment.setContent(chunkList.get(i));
|
knowledgeFragment.setContent(chunkList.get(i));
|
||||||
knowledgeFragment.setCreateTime(new Date());
|
knowledgeFragment.setCreateTime(new Date());
|
||||||
knowledgeFragmentList.add(knowledgeFragment);
|
knowledgeFragmentList.add(knowledgeFragment);
|
||||||
}
|
}
|
||||||
|
knowledgeFragmentMapper.delete(Wrappers.<KnowledgeFragment>lambdaQuery().eq(KnowledgeFragment::getDocId, docId));
|
||||||
|
knowledgeFragmentMapper.insertBatch(knowledgeFragmentList);
|
||||||
|
log.info("文档切片并入库完成,共计 {} 个片段。id: {}", chunkList.size(), id);
|
||||||
}
|
}
|
||||||
knowledgeFragmentMapper.insertBatch(knowledgeFragmentList);
|
|
||||||
} catch (IOException e) {
|
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(knowledgeId);
|
||||||
log.error("保存知识库信息失败!{}", e.getMessage());
|
ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
|
||||||
|
StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
|
||||||
|
storeEmbeddingBo.setKid(String.valueOf(knowledgeId));
|
||||||
|
storeEmbeddingBo.setDocId(docId);
|
||||||
|
storeEmbeddingBo.setFids(fids);
|
||||||
|
storeEmbeddingBo.setChunkList(chunkList);
|
||||||
|
storeEmbeddingBo.setVectorStoreName(knowledgeInfoVo.getVectorModel());
|
||||||
|
storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
storeEmbeddingBo.setApiKey(chatModelVo.getApiKey());
|
||||||
|
storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost());
|
||||||
|
vectorStoreService.storeEmbeddings(storeEmbeddingBo);
|
||||||
|
|
||||||
|
attach.setStatus(KnowledgeAttachStatus.COMPLETED.getCode()); // 已完成
|
||||||
|
baseMapper.updateById(attach);
|
||||||
|
log.info("知识库文档解析、向量化并入库成功!id: {}", id);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("解析文档失败!id: {}, error: {}", id, e.getMessage(), e);
|
||||||
|
attach.setStatus(KnowledgeAttachStatus.FAILED.getCode()); // 失败
|
||||||
|
attach.setRemark(StringUtils.substring(e.getMessage(), 0, 255)); // 保存错误原因,截取防止溢出
|
||||||
|
baseMapper.updateById(attach);
|
||||||
}
|
}
|
||||||
baseMapper.insert(knowledgeAttach);
|
|
||||||
|
|
||||||
// 查询知识库信息
|
|
||||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(knowledgeId);
|
|
||||||
|
|
||||||
// 查询向量模信息
|
|
||||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
|
||||||
|
|
||||||
StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
|
|
||||||
storeEmbeddingBo.setKid(String.valueOf(knowledgeId));
|
|
||||||
storeEmbeddingBo.setDocId(docId);
|
|
||||||
storeEmbeddingBo.setFids(fids);
|
|
||||||
storeEmbeddingBo.setChunkList(chunkList);
|
|
||||||
storeEmbeddingBo.setVectorStoreName(knowledgeInfoVo.getVectorModel());
|
|
||||||
storeEmbeddingBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
|
||||||
storeEmbeddingBo.setApiKey(chatModelVo.getApiKey());
|
|
||||||
storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost());
|
|
||||||
vectorStoreService.storeEmbeddings(storeEmbeddingBo);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,24 +1,29 @@
|
|||||||
package org.ruoyi.service.knowledge.impl;
|
package org.ruoyi.service.knowledge.impl;
|
||||||
|
|
||||||
import org.ruoyi.common.core.utils.MapstructUtils;
|
|
||||||
import org.ruoyi.common.core.utils.StringUtils;
|
|
||||||
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
|
||||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
|
||||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
|
||||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
||||||
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||||
|
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
||||||
|
import org.ruoyi.common.core.utils.MapstructUtils;
|
||||||
|
import org.ruoyi.common.core.utils.StringUtils;
|
||||||
|
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||||
|
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
||||||
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
|
||||||
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
|
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
|
||||||
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
|
||||||
|
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||||
|
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.*;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Collection;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识片段Service业务层处理
|
* 知识片段Service业务层处理
|
||||||
@@ -32,6 +37,9 @@ import java.util.Collection;
|
|||||||
public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
||||||
|
|
||||||
private final KnowledgeFragmentMapper baseMapper;
|
private final KnowledgeFragmentMapper baseMapper;
|
||||||
|
private final IKnowledgeInfoService knowledgeInfoService;
|
||||||
|
private final IChatModelService chatModelService;
|
||||||
|
private final KnowledgeRetrievalService knowledgeRetrievalService;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 查询知识片段
|
* 查询知识片段
|
||||||
@@ -71,7 +79,6 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private LambdaQueryWrapper<KnowledgeFragment> buildQueryWrapper(KnowledgeFragmentBo bo) {
|
private LambdaQueryWrapper<KnowledgeFragment> buildQueryWrapper(KnowledgeFragmentBo bo) {
|
||||||
Map<String, Object> params = bo.getParams();
|
|
||||||
LambdaQueryWrapper<KnowledgeFragment> lqw = Wrappers.lambdaQuery();
|
LambdaQueryWrapper<KnowledgeFragment> lqw = Wrappers.lambdaQuery();
|
||||||
lqw.orderByAsc(KnowledgeFragment::getId);
|
lqw.orderByAsc(KnowledgeFragment::getId);
|
||||||
lqw.eq(bo.getDocId() != null, KnowledgeFragment::getDocId, bo.getDocId());
|
lqw.eq(bo.getDocId() != null, KnowledgeFragment::getDocId, bo.getDocId());
|
||||||
@@ -131,4 +138,50 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
|
|||||||
}
|
}
|
||||||
return baseMapper.deleteByIds(ids) > 0;
|
return baseMapper.deleteByIds(ids) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检索测试核心实现 - 委托给统一的 KnowledgeRetrievalService
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> retrieval(KnowledgeFragmentBo bo) {
|
||||||
|
if (bo.getKnowledgeId() == null || StringUtils.isBlank(bo.getQuery())) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 获取知识库及模型配置(为了获取 API Key/Host 等模型参数)
|
||||||
|
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId());
|
||||||
|
if (knowledgeInfoVo == null) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
if (chatModel == null) {
|
||||||
|
log.warn("未找到对应的向量模型配置: {}", knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 构造通用的参数对象
|
||||||
|
QueryVectorBo queryVectorBo = new QueryVectorBo();
|
||||||
|
queryVectorBo.setQuery(bo.getQuery());
|
||||||
|
queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId()));
|
||||||
|
queryVectorBo.setApiKey(chatModel.getApiKey());
|
||||||
|
queryVectorBo.setBaseUrl(chatModel.getApiHost());
|
||||||
|
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
|
||||||
|
|
||||||
|
// 使用前端传入的实时测试参数,若无则使用知识库默认参数
|
||||||
|
queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit());
|
||||||
|
queryVectorBo.setSimilarityThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getSimilarityThreshold());
|
||||||
|
|
||||||
|
queryVectorBo.setEnableHybrid(bo.getEnableHybrid() != null ? bo.getEnableHybrid() : Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1));
|
||||||
|
queryVectorBo.setHybridAlpha(bo.getHybridAlpha() != null ? bo.getHybridAlpha() : knowledgeInfoVo.getHybridAlpha());
|
||||||
|
|
||||||
|
queryVectorBo.setEnableRerank(bo.getEnableRerank() != null ? bo.getEnableRerank() : Objects.equals(knowledgeInfoVo.getEnableRerank(), 1));
|
||||||
|
queryVectorBo.setRerankModelName(StringUtils.isNotBlank(bo.getRerankModel()) ? bo.getRerankModel() : knowledgeInfoVo.getRerankModel());
|
||||||
|
queryVectorBo.setRerankTopN(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRerankTopN());
|
||||||
|
queryVectorBo.setRerankScoreThreshold(bo.getThreshold() != null ? bo.getThreshold() : knowledgeInfoVo.getRerankScoreThreshold());
|
||||||
|
|
||||||
|
// 3. 执行统一检索
|
||||||
|
return knowledgeRetrievalService.retrieve(queryVectorBo);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.ruoyi.domain.bo.knowledge.KnowledgeInfoBo;
|
import org.ruoyi.domain.bo.knowledge.KnowledgeInfoBo;
|
||||||
import org.ruoyi.domain.entity.knowledge.KnowledgeInfo;
|
import org.ruoyi.domain.entity.knowledge.KnowledgeInfo;
|
||||||
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||||
|
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
|
||||||
import org.ruoyi.mapper.knowledge.KnowledgeInfoMapper;
|
import org.ruoyi.mapper.knowledge.KnowledgeInfoMapper;
|
||||||
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -33,6 +34,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
|||||||
|
|
||||||
private final KnowledgeInfoMapper baseMapper;
|
private final KnowledgeInfoMapper baseMapper;
|
||||||
|
|
||||||
|
private final KnowledgeAttachMapper knowledgeAttachMapper;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 查询知识库
|
* 查询知识库
|
||||||
*
|
*
|
||||||
@@ -55,6 +58,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
|||||||
public TableDataInfo<KnowledgeInfoVo> queryPageList(KnowledgeInfoBo bo, PageQuery pageQuery) {
|
public TableDataInfo<KnowledgeInfoVo> queryPageList(KnowledgeInfoBo bo, PageQuery pageQuery) {
|
||||||
LambdaQueryWrapper<KnowledgeInfo> lqw = buildQueryWrapper(bo);
|
LambdaQueryWrapper<KnowledgeInfo> lqw = buildQueryWrapper(bo);
|
||||||
Page<KnowledgeInfoVo> result = baseMapper.selectVoPage(pageQuery.build(), lqw);
|
Page<KnowledgeInfoVo> result = baseMapper.selectVoPage(pageQuery.build(), lqw);
|
||||||
|
// 批量填充文档数
|
||||||
|
fillDocumentCount(result.getRecords());
|
||||||
return TableDataInfo.build(result);
|
return TableDataInfo.build(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,6 +92,17 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
|||||||
return lqw;
|
return lqw;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 批量填充知识库列表每一条记录的文档数(documentCount)
|
||||||
|
*/
|
||||||
|
private void fillDocumentCount(List<KnowledgeInfoVo> records) {
|
||||||
|
if (records == null || records.isEmpty()) return;
|
||||||
|
for (KnowledgeInfoVo vo : records) {
|
||||||
|
int count = knowledgeAttachMapper.countByKnowledgeId(vo.getId());
|
||||||
|
vo.setDocumentCount(count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 新增知识库
|
* 新增知识库
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
package org.ruoyi.service.knowledge.retriever;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.rag.content.Content;
|
||||||
|
import dev.langchain4j.rag.content.retriever.ContentRetriever;
|
||||||
|
import dev.langchain4j.rag.query.Query;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||||
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
|
||||||
|
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 自定义检索器:适配 LangChain4j ContentRetriever 接口
|
||||||
|
* 桥接统一的 KnowledgeRetrievalService,支持配置化的混合检索、阈值过滤等功能
|
||||||
|
*
|
||||||
|
* @author RobustH
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class CustomVectorRetriever implements ContentRetriever {
|
||||||
|
|
||||||
|
private final KnowledgeRetrievalService knowledgeRetrievalService;
|
||||||
|
private final KnowledgeInfoVo knowledgeInfoVo;
|
||||||
|
private final ChatModelVo chatModelVo;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Content> retrieve(Query query) {
|
||||||
|
log.info("执行自定义检索,关键字: {}", query.text());
|
||||||
|
|
||||||
|
// 构建增强后的查询参数
|
||||||
|
QueryVectorBo queryVectorBo = new QueryVectorBo();
|
||||||
|
queryVectorBo.setQuery(query.text());
|
||||||
|
queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId()));
|
||||||
|
queryVectorBo.setApiKey(chatModelVo.getApiKey());
|
||||||
|
queryVectorBo.setBaseUrl(chatModelVo.getApiHost());
|
||||||
|
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
|
||||||
|
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
|
||||||
|
|
||||||
|
// 应用知识库配置参数
|
||||||
|
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
|
||||||
|
queryVectorBo.setSimilarityThreshold(knowledgeInfoVo.getSimilarityThreshold());
|
||||||
|
queryVectorBo.setEnableHybrid(Objects.equals(knowledgeInfoVo.getEnableHybrid(), 1));
|
||||||
|
queryVectorBo.setHybridAlpha(knowledgeInfoVo.getHybridAlpha());
|
||||||
|
|
||||||
|
// 设置重排序参数 (如果 retriever 阶段也想做初步重排,可以在此设置)
|
||||||
|
queryVectorBo.setEnableRerank(Objects.equals(knowledgeInfoVo.getEnableRerank(), 1));
|
||||||
|
queryVectorBo.setRerankModelName(knowledgeInfoVo.getRerankModel());
|
||||||
|
queryVectorBo.setRerankTopN(knowledgeInfoVo.getRerankTopN());
|
||||||
|
queryVectorBo.setRerankScoreThreshold(knowledgeInfoVo.getRerankScoreThreshold());
|
||||||
|
|
||||||
|
// 通过统一服务执行检索
|
||||||
|
List<String> nearestList = knowledgeRetrievalService.retrieveTexts(queryVectorBo);
|
||||||
|
|
||||||
|
// 将结果包装为标准的 Content 返回
|
||||||
|
return nearestList.stream()
|
||||||
|
.map(text -> Content.from(TextSegment.from(text)))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,174 @@
|
|||||||
|
package org.ruoyi.service.rerank.impl;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import okhttp3.*;
|
||||||
|
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
|
||||||
|
import org.ruoyi.domain.bo.rerank.RerankRequest;
|
||||||
|
import org.ruoyi.domain.bo.rerank.RerankResult;
|
||||||
|
import org.ruoyi.service.rerank.RerankModelService;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 硅基流动重排序模型实现
|
||||||
|
* 适配硅基流动的 /v1/rerank 接口
|
||||||
|
*
|
||||||
|
* @author RobustH
|
||||||
|
* @date 2026-04-21
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Component("siliconflowRerank")
|
||||||
|
public class SiliconFlowRerankModelService implements RerankModelService {
|
||||||
|
|
||||||
|
private static final String DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1/rerank";
|
||||||
|
|
||||||
|
private final OkHttpClient okHttpClient;
|
||||||
|
private final ObjectMapper objectMapper = new ObjectMapper()
|
||||||
|
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||||
|
private ChatModelVo chatModelVo;
|
||||||
|
|
||||||
|
public SiliconFlowRerankModelService() {
|
||||||
|
this.okHttpClient = new OkHttpClient.Builder()
|
||||||
|
.connectTimeout(30, TimeUnit.SECONDS)
|
||||||
|
.readTimeout(60, TimeUnit.SECONDS)
|
||||||
|
.writeTimeout(30, TimeUnit.SECONDS)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void configure(ChatModelVo config) {
|
||||||
|
this.chatModelVo = config;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RerankResult rerank(RerankRequest rerankRequest) {
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
|
||||||
|
try {
|
||||||
|
String url = buildUrl();
|
||||||
|
String requestJson = buildRequestJson(rerankRequest);
|
||||||
|
|
||||||
|
RequestBody body = RequestBody.create(requestJson, MediaType.get("application/json"));
|
||||||
|
Request httpRequest = new Request.Builder()
|
||||||
|
.url(url)
|
||||||
|
.addHeader("Authorization", "Bearer " + chatModelVo.getApiKey())
|
||||||
|
.addHeader("Content-Type", "application/json")
|
||||||
|
.post(body)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
log.info("硅基流动重排序请求: model={}, url={}", chatModelVo.getModelName(), url);
|
||||||
|
|
||||||
|
try (Response response = okHttpClient.newCall(httpRequest).execute()) {
|
||||||
|
if (!response.isSuccessful()) {
|
||||||
|
String err = response.body() != null ? response.body().string() : "无错误信息";
|
||||||
|
throw new IllegalArgumentException("硅基流动 Rerank API 调用失败: " + response.code() + " - " + err);
|
||||||
|
}
|
||||||
|
|
||||||
|
ResponseBody responseBody = response.body();
|
||||||
|
if (responseBody == null) {
|
||||||
|
throw new IllegalArgumentException("响应体为空");
|
||||||
|
}
|
||||||
|
|
||||||
|
SiliconFlowRerankResponse rerankResponse = objectMapper.readValue(
|
||||||
|
responseBody.string(), SiliconFlowRerankResponse.class);
|
||||||
|
|
||||||
|
return buildRerankResult(rerankResponse, rerankRequest.getDocuments(),
|
||||||
|
System.currentTimeMillis() - startTime);
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("硅基流动重排序失败: {}", e.getMessage(), e);
|
||||||
|
throw new RuntimeException("重排序服务调用失败: " + e.getMessage(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建请求 URL,鲁棒处理 API Host 末尾路径
|
||||||
|
*/
|
||||||
|
private String buildUrl() {
|
||||||
|
String apiHost = chatModelVo.getApiHost();
|
||||||
|
if (apiHost == null || apiHost.isBlank()) {
|
||||||
|
return DEFAULT_BASE_URL;
|
||||||
|
}
|
||||||
|
if (apiHost.endsWith("/rerank")) {
|
||||||
|
return apiHost;
|
||||||
|
}
|
||||||
|
if (apiHost.endsWith("/v1")) {
|
||||||
|
return apiHost + "/rerank";
|
||||||
|
}
|
||||||
|
return apiHost.endsWith("/") ? apiHost + "rerank" : apiHost + "/rerank";
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建请求体 JSON
|
||||||
|
*/
|
||||||
|
private String buildRequestJson(RerankRequest rerankRequest) throws IOException {
|
||||||
|
SiliconFlowRerankRequest request = new SiliconFlowRerankRequest();
|
||||||
|
request.setModel(chatModelVo.getModelName());
|
||||||
|
request.setQuery(rerankRequest.getQuery());
|
||||||
|
request.setDocuments(rerankRequest.getDocuments());
|
||||||
|
request.setTop_n(rerankRequest.getTopN() != null ? rerankRequest.getTopN() : rerankRequest.getDocuments().size());
|
||||||
|
request.setReturn_documents(rerankRequest.getReturnDocuments() != null ? rerankRequest.getReturnDocuments() : false);
|
||||||
|
return objectMapper.writeValueAsString(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建标准 RerankResult
|
||||||
|
*/
|
||||||
|
private RerankResult buildRerankResult(SiliconFlowRerankResponse response,
|
||||||
|
List<String> originalDocuments, long durationMs) {
|
||||||
|
Double[] scores = new Double[originalDocuments.size()];
|
||||||
|
for (int i = 0; i < scores.length; i++) {
|
||||||
|
scores[i] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<RerankResult.RerankDocument> docs = new ArrayList<>();
|
||||||
|
if (response != null && response.getResults() != null) {
|
||||||
|
response.getResults().forEach(item -> {
|
||||||
|
if (item.getIndex() != null && item.getIndex() < originalDocuments.size()) {
|
||||||
|
scores[item.getIndex()] = item.getRelevance_score();
|
||||||
|
docs.add(RerankResult.RerankDocument.builder()
|
||||||
|
.index(item.getIndex())
|
||||||
|
.relevanceScore(item.getRelevance_score())
|
||||||
|
.document(originalDocuments.get(item.getIndex()))
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return RerankResult.builder()
|
||||||
|
.documents(docs)
|
||||||
|
.totalDocuments(originalDocuments.size())
|
||||||
|
.durationMs(durationMs)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 内部 DTO ====================
|
||||||
|
|
||||||
|
@Data
|
||||||
|
static class SiliconFlowRerankRequest {
|
||||||
|
private String model;
|
||||||
|
private String query;
|
||||||
|
private List<String> documents;
|
||||||
|
private Integer top_n;
|
||||||
|
private Boolean return_documents;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
static class SiliconFlowRerankResponse {
|
||||||
|
private List<SiliconFlowRerankResultItem> results;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
static class SiliconFlowRerankResultItem {
|
||||||
|
private Integer index;
|
||||||
|
private Double relevance_score;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,12 +1,13 @@
|
|||||||
package org.ruoyi.service.retrieval;
|
package org.ruoyi.service.retrieval;
|
||||||
|
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识库检索服务接口
|
* 知识库检索服务接口
|
||||||
* 整合粗召回(向量检索)和重排序流程
|
* 整合粗召回(向量检索/关键词检索)和重排序流程
|
||||||
*
|
*
|
||||||
* @author yang
|
* @author yang
|
||||||
* @date 2026-04-19
|
* @date 2026-04-19
|
||||||
@@ -21,4 +22,13 @@ public interface KnowledgeRetrievalService {
|
|||||||
* @return 文本内容列表
|
* @return 文本内容列表
|
||||||
*/
|
*/
|
||||||
List<String> retrieveTexts(QueryVectorBo queryVectorBo);
|
List<String> retrieveTexts(QueryVectorBo queryVectorBo);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行知识库检索,返回详细结果对象(包含分数、文档ID等)
|
||||||
|
* 支持混合检索和重排序
|
||||||
|
*
|
||||||
|
* @param queryVectorBo 查询参数
|
||||||
|
* @return 检索结果列表
|
||||||
|
*/
|
||||||
|
List<KnowledgeRetrievalVo> retrieve(QueryVectorBo queryVectorBo);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,21 +2,26 @@ package org.ruoyi.service.retrieval.impl;
|
|||||||
|
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.ruoyi.common.core.utils.StringUtils;
|
||||||
import org.ruoyi.domain.bo.rerank.RerankRequest;
|
import org.ruoyi.domain.bo.rerank.RerankRequest;
|
||||||
import org.ruoyi.domain.bo.rerank.RerankResult;
|
import org.ruoyi.domain.bo.rerank.RerankResult;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.factory.RerankModelFactory;
|
import org.ruoyi.factory.RerankModelFactory;
|
||||||
|
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
|
||||||
import org.ruoyi.service.rerank.RerankModelService;
|
import org.ruoyi.service.rerank.RerankModelService;
|
||||||
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
|
import org.ruoyi.service.retrieval.KnowledgeRetrievalService;
|
||||||
import org.ruoyi.service.vector.VectorStoreService;
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.List;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识库检索服务实现
|
* 知识库检索服务实现
|
||||||
* 整合粗召回(向量检索)和重排序流程
|
* 整合粗召回(向量检索/关键词检索)、RRF融合和重排序流程
|
||||||
*
|
*
|
||||||
* @author yang
|
* @author yang
|
||||||
* @date 2026-04-19
|
* @date 2026-04-19
|
||||||
@@ -28,6 +33,7 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService
|
|||||||
|
|
||||||
private final VectorStoreService vectorStoreService;
|
private final VectorStoreService vectorStoreService;
|
||||||
private final RerankModelFactory rerankModelFactory;
|
private final RerankModelFactory rerankModelFactory;
|
||||||
|
private final KnowledgeFragmentMapper fragmentMapper;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 粗召回默认扩大倍数
|
* 粗召回默认扩大倍数
|
||||||
@@ -37,99 +43,214 @@ public class KnowledgeRetrievalServiceImpl implements KnowledgeRetrievalService
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<String> retrieveTexts(QueryVectorBo queryVectorBo) {
|
public List<String> retrieveTexts(QueryVectorBo queryVectorBo) {
|
||||||
|
List<KnowledgeRetrievalVo> results = retrieve(queryVectorBo);
|
||||||
|
return results.stream()
|
||||||
|
.map(KnowledgeRetrievalVo::getContent)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> retrieve(QueryVectorBo queryVectorBo) {
|
||||||
log.info("开始知识库检索, kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
|
log.info("开始知识库检索, kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
|
||||||
|
|
||||||
// 1. 粗召回阶段 - 向量检索
|
// 1. 粗召回阶段 (向量检索 + 关键词搜索)
|
||||||
List<String> coarseResults = coarseRetrieval(queryVectorBo);
|
List<KnowledgeRetrievalVo> coarseResults = performCoarseRetrieval(queryVectorBo);
|
||||||
log.debug("粗召回返回 {} 条结果", coarseResults.size());
|
log.debug("粗召回返回 {} 条结果", coarseResults.size());
|
||||||
|
|
||||||
if (coarseResults.isEmpty()) {
|
if (coarseResults.isEmpty()) {
|
||||||
return coarseResults;
|
return coarseResults;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 重排序阶段(可选)
|
// 2. 初始化原始索引
|
||||||
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
|
for (int i = 0; i < coarseResults.size(); i++) {
|
||||||
queryVectorBo.getRerankModelName() != null) {
|
coarseResults.get(i).setOriginalIndex(i);
|
||||||
return rerank(queryVectorBo, coarseResults);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return coarseResults;
|
// 3. 重排序阶段 (可选)
|
||||||
|
List<KnowledgeRetrievalVo> finalResults = coarseResults;
|
||||||
|
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
|
||||||
|
StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) {
|
||||||
|
finalResults = performRerank(queryVectorBo, coarseResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 应用分值阈值过滤 (重排分值或 RRF 分值)
|
||||||
|
double threshold = queryVectorBo.getRerankScoreThreshold() != null ?
|
||||||
|
queryVectorBo.getRerankScoreThreshold() : 0.0;
|
||||||
|
|
||||||
|
return finalResults.stream()
|
||||||
|
.filter(res -> res.getScore() >= threshold)
|
||||||
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 粗召回阶段 - 向量检索
|
* 粗召回阶段:根据配置执行向量搜索或混合搜索
|
||||||
*/
|
*/
|
||||||
private List<String> coarseRetrieval(QueryVectorBo queryVectorBo) {
|
private List<KnowledgeRetrievalVo> performCoarseRetrieval(QueryVectorBo queryVectorBo) {
|
||||||
// 如果启用重排序,扩大粗召回数量
|
// 如果启用重排序,适当扩大召回数量
|
||||||
int originalMaxResults = queryVectorBo.getMaxResults();
|
int originalMaxResults = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10;
|
||||||
int expandedResults = originalMaxResults;
|
int targetMaxResults = originalMaxResults;
|
||||||
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
|
if (Boolean.TRUE.equals(queryVectorBo.getEnableRerank()) &&
|
||||||
queryVectorBo.getRerankModelName() != null) {
|
StringUtils.isNotBlank(queryVectorBo.getRerankModelName())) {
|
||||||
expandedResults = originalMaxResults * RERANK_EXPANSION_FACTOR;
|
targetMaxResults = originalMaxResults * RERANK_EXPANSION_FACTOR;
|
||||||
log.debug("启用重排序,粗召回数量从 {} 扩大到 {}", originalMaxResults, expandedResults);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 临时修改查询数量
|
// 如果未启用混合检索,直接走向量搜索
|
||||||
queryVectorBo.setMaxResults(expandedResults);
|
if (!Boolean.TRUE.equals(queryVectorBo.getEnableHybrid())) {
|
||||||
|
QueryVectorBo vectorQuery = copyOf(queryVectorBo, targetMaxResults);
|
||||||
|
List<KnowledgeRetrievalVo> results = vectorStoreService.search(vectorQuery);
|
||||||
|
|
||||||
|
// 应用基础相似度阈值过滤(如果有)
|
||||||
|
if (queryVectorBo.getSimilarityThreshold() != null) {
|
||||||
|
results = results.stream()
|
||||||
|
.filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 混合检索逻辑
|
||||||
|
log.info("执行混合检索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
|
||||||
try {
|
try {
|
||||||
return vectorStoreService.getQueryVector(queryVectorBo);
|
// A. 并行执行向量搜索
|
||||||
} finally {
|
int finalTargetMaxResults = targetMaxResults;
|
||||||
// 恢复原始值
|
CompletableFuture<List<KnowledgeRetrievalVo>> vectorFuture = CompletableFuture.supplyAsync(() -> {
|
||||||
queryVectorBo.setMaxResults(originalMaxResults);
|
QueryVectorBo vectorQuery = copyOf(queryVectorBo, finalTargetMaxResults);
|
||||||
|
List<KnowledgeRetrievalVo> results = vectorStoreService.search(vectorQuery);
|
||||||
|
// 向量层初步过滤
|
||||||
|
if (queryVectorBo.getSimilarityThreshold() != null) {
|
||||||
|
return results.stream()
|
||||||
|
.filter(r -> r.getScore() >= queryVectorBo.getSimilarityThreshold())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
});
|
||||||
|
|
||||||
|
// B. 并行执行关键词搜索 (MySQL Fulltext)
|
||||||
|
CompletableFuture<List<KnowledgeRetrievalVo>> keywordFuture = CompletableFuture.supplyAsync(() -> {
|
||||||
|
try {
|
||||||
|
Long kid = Long.valueOf(queryVectorBo.getKid());
|
||||||
|
List<KnowledgeFragmentVo> fragments = fragmentMapper.searchByKeyword(kid, queryVectorBo.getQuery(), finalTargetMaxResults);
|
||||||
|
return fragments.stream().map(f -> {
|
||||||
|
KnowledgeRetrievalVo vo = new KnowledgeRetrievalVo();
|
||||||
|
vo.setId(f.getId().toString());
|
||||||
|
vo.setContent(f.getContent());
|
||||||
|
vo.setDocId(f.getDocId());
|
||||||
|
vo.setIdx(f.getIdx());
|
||||||
|
vo.setKnowledgeId(f.getKnowledgeId());
|
||||||
|
vo.setScore(10.0); // RRF 初始占位分
|
||||||
|
return vo;
|
||||||
|
}).collect(Collectors.toList());
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("关键词检索失败: {}", e.getMessage());
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
List<KnowledgeRetrievalVo> vectorResults = vectorFuture.get();
|
||||||
|
List<KnowledgeRetrievalVo> keywordResults = keywordFuture.get();
|
||||||
|
|
||||||
|
// C. RRF 融合
|
||||||
|
double alpha = queryVectorBo.getHybridAlpha() != null ? queryVectorBo.getHybridAlpha() : 0.5;
|
||||||
|
return calculateRRF(vectorResults, keywordResults, alpha);
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("混合检索执行失败,回退到纯向量检索: {}", e.getMessage(), e);
|
||||||
|
return vectorStoreService.search(copyOf(queryVectorBo, targetMaxResults));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 重排序阶段
|
* 重排序阶段
|
||||||
*/
|
*/
|
||||||
private List<String> rerank(QueryVectorBo queryVectorBo, List<String> coarseResults) {
|
private List<KnowledgeRetrievalVo> performRerank(QueryVectorBo queryVectorBo, List<KnowledgeRetrievalVo> coarseResults) {
|
||||||
long startTime = System.currentTimeMillis();
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 1. 通过工厂获取重排序模型
|
|
||||||
RerankModelService rerankModel = rerankModelFactory.createModel(queryVectorBo.getRerankModelName());
|
RerankModelService rerankModel = rerankModelFactory.createModel(queryVectorBo.getRerankModelName());
|
||||||
|
|
||||||
|
List<String> contents = coarseResults.stream()
|
||||||
|
.map(KnowledgeRetrievalVo::getContent)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
// 2. 构建重排序请求
|
// topN 默认为 maxResults
|
||||||
int topN = queryVectorBo.getRerankTopN() != null ?
|
int topN = queryVectorBo.getRerankTopN() != null ? queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults();
|
||||||
queryVectorBo.getRerankTopN() : queryVectorBo.getMaxResults();
|
|
||||||
|
|
||||||
RerankRequest rerankRequest = RerankRequest.builder()
|
RerankRequest rerankRequest = RerankRequest.builder()
|
||||||
.query(queryVectorBo.getQuery())
|
.query(queryVectorBo.getQuery())
|
||||||
.documents(coarseResults)
|
.documents(contents)
|
||||||
.topN(topN)
|
.topN(topN)
|
||||||
.returnDocuments(true)
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
log.info("执行重排序, model={}, documents={}, topN={}",
|
|
||||||
queryVectorBo.getRerankModelName(), coarseResults.size(), topN);
|
|
||||||
|
|
||||||
// 3. 执行重排序
|
|
||||||
RerankResult rerankResult = rerankModel.rerank(rerankRequest);
|
RerankResult rerankResult = rerankModel.rerank(rerankRequest);
|
||||||
|
|
||||||
// 4. 转换重排序结果
|
// 写回分数并记录原始分
|
||||||
List<String> finalResults = new ArrayList<>();
|
|
||||||
for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) {
|
for (RerankResult.RerankDocument doc : rerankResult.getDocuments()) {
|
||||||
// 应用分数阈值过滤
|
if (doc.getIndex() != null && doc.getIndex() < coarseResults.size()) {
|
||||||
if (queryVectorBo.getRerankScoreThreshold() != null &&
|
KnowledgeRetrievalVo vo = coarseResults.get(doc.getIndex());
|
||||||
doc.getRelevanceScore() < queryVectorBo.getRerankScoreThreshold()) {
|
vo.setRawScore(vo.getScore());
|
||||||
continue;
|
vo.setScore(doc.getRelevanceScore());
|
||||||
}
|
|
||||||
|
|
||||||
if (doc.getDocument() != null) {
|
|
||||||
finalResults.add(doc.getDocument());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
long duration = System.currentTimeMillis() - startTime;
|
// 按新分排序
|
||||||
log.info("重排序完成, 返回 {} 条结果, 耗时 {}ms", finalResults.size(), duration);
|
coarseResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
|
||||||
|
|
||||||
return finalResults;
|
// 截断到 topN
|
||||||
|
return coarseResults.subList(0, Math.min(topN, coarseResults.size()));
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("重排序失败: {}", e.getMessage(), e);
|
log.error("重排序流程失败: {}", e.getMessage());
|
||||||
// 重排序失败时返回原始粗召回结果(截取到期望数量)
|
int limit = queryVectorBo.getMaxResults() != null ? queryVectorBo.getMaxResults() : 10;
|
||||||
int limit = Math.min(queryVectorBo.getMaxResults(), coarseResults.size());
|
return coarseResults.subList(0, Math.min(limit, coarseResults.size()));
|
||||||
return new ArrayList<>(coarseResults.subList(0, limit));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RRF (Reciprocal Rank Fusion) 融合计算
|
||||||
|
*/
|
||||||
|
private List<KnowledgeRetrievalVo> calculateRRF(List<KnowledgeRetrievalVo> vectorList, List<KnowledgeRetrievalVo> keywordList, double alpha) {
|
||||||
|
Map<String, KnowledgeRetrievalVo> allMap = new LinkedHashMap<>();
|
||||||
|
Map<String, Double> vectorScores = new HashMap<>();
|
||||||
|
Map<String, Double> keywordScores = new HashMap<>();
|
||||||
|
|
||||||
|
int k = 60; // RRF 常数
|
||||||
|
|
||||||
|
for (int i = 0; i < vectorList.size(); i++) {
|
||||||
|
KnowledgeRetrievalVo vo = vectorList.get(i);
|
||||||
|
allMap.put(vo.getId(), vo);
|
||||||
|
vectorScores.put(vo.getId(), 1.0 / (k + i + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < keywordList.size(); i++) {
|
||||||
|
KnowledgeRetrievalVo vo = keywordList.get(i);
|
||||||
|
if (!allMap.containsKey(vo.getId())) {
|
||||||
|
allMap.put(vo.getId(), vo);
|
||||||
|
}
|
||||||
|
keywordScores.put(vo.getId(), 1.0 / (k + i + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
List<KnowledgeRetrievalVo> fusedResults = new ArrayList<>();
|
||||||
|
for (Map.Entry<String, KnowledgeRetrievalVo> entry : allMap.entrySet()) {
|
||||||
|
String id = entry.getKey();
|
||||||
|
double finalScore = (1 - alpha) * vectorScores.getOrDefault(id, 0.0) +
|
||||||
|
alpha * keywordScores.getOrDefault(id, 0.0);
|
||||||
|
|
||||||
|
KnowledgeRetrievalVo vo = entry.getValue();
|
||||||
|
vo.setScore(finalScore * 60.0); // 归一化缩放
|
||||||
|
fusedResults.add(vo);
|
||||||
|
}
|
||||||
|
|
||||||
|
fusedResults.sort((a, b) -> b.getScore().compareTo(a.getScore()));
|
||||||
|
return fusedResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
private QueryVectorBo copyOf(QueryVectorBo original, int maxResults) {
|
||||||
|
QueryVectorBo copy = new QueryVectorBo();
|
||||||
|
copy.setQuery(original.getQuery());
|
||||||
|
copy.setKid(original.getKid());
|
||||||
|
copy.setMaxResults(maxResults);
|
||||||
|
copy.setVectorModelName(original.getVectorModelName());
|
||||||
|
copy.setEmbeddingModelName(original.getEmbeddingModelName());
|
||||||
|
copy.setApiKey(original.getApiKey());
|
||||||
|
copy.setBaseUrl(original.getBaseUrl());
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package org.ruoyi.service.vector;
|
|||||||
import org.ruoyi.common.core.exception.ServiceException;
|
import org.ruoyi.common.core.exception.ServiceException;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -17,6 +18,11 @@ public interface VectorStoreService {
|
|||||||
|
|
||||||
List<String> getQueryVector(QueryVectorBo queryVectorBo);
|
List<String> getQueryVector(QueryVectorBo queryVectorBo);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 带分数及元数据的检索(用于测试检索功能)
|
||||||
|
*/
|
||||||
|
List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo);
|
||||||
|
|
||||||
void createSchema(String kid, String embeddingModelName);
|
void createSchema(String kid, String embeddingModelName);
|
||||||
|
|
||||||
void removeById(String id, String modelName) throws ServiceException;
|
void removeById(String id, String modelName) throws ServiceException;
|
||||||
|
|||||||
@@ -37,6 +37,24 @@ public abstract class AbstractVectorStoreStrategy implements VectorStoreService
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 向量 L2 归一化 (单位化)
|
||||||
|
*/
|
||||||
|
protected static float[] normalize(float[] vector) {
|
||||||
|
if (vector == null) return null;
|
||||||
|
double sum = 0;
|
||||||
|
for (float v : vector) {
|
||||||
|
sum += v * v;
|
||||||
|
}
|
||||||
|
float norm = (float) Math.sqrt(sum);
|
||||||
|
if (norm > 1e-9) {
|
||||||
|
for (int i = 0; i < vector.length; i++) {
|
||||||
|
vector[i] /= norm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vector;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取向量模型
|
* 获取向量模型
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -19,7 +19,11 @@ import org.ruoyi.common.chat.service.chat.IChatModelService;
|
|||||||
import org.ruoyi.config.VectorStoreProperties;
|
import org.ruoyi.config.VectorStoreProperties;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.factory.EmbeddingModelFactory;
|
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||||
|
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
|
||||||
|
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -32,10 +36,14 @@ import java.util.stream.IntStream;
|
|||||||
@Component
|
@Component
|
||||||
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||||
|
|
||||||
|
private final KnowledgeAttachMapper knowledgeAttachMapper;
|
||||||
|
|
||||||
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||||
IChatModelService chatModelService,
|
IChatModelService chatModelService,
|
||||||
EmbeddingModelFactory embeddingModelFactory) {
|
EmbeddingModelFactory embeddingModelFactory,
|
||||||
|
KnowledgeAttachMapper knowledgeAttachMapper) {
|
||||||
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
||||||
|
this.knowledgeAttachMapper = knowledgeAttachMapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
|
// 缓存不同集合与 autoFlush 配置的 Milvus 连接
|
||||||
@@ -51,7 +59,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
.collectionName(collectionName)
|
.collectionName(collectionName)
|
||||||
.dimension(dimension)
|
.dimension(dimension)
|
||||||
.indexType(IndexType.IVF_FLAT)
|
.indexType(IndexType.IVF_FLAT)
|
||||||
.metricType(MetricType.L2)
|
.metricType(MetricType.COSINE)
|
||||||
.autoFlushOnInsert(autoFlushOnInsert)
|
.autoFlushOnInsert(autoFlushOnInsert)
|
||||||
.idFieldName("id")
|
.idFieldName("id")
|
||||||
.textFieldName("text")
|
.textFieldName("text")
|
||||||
@@ -104,7 +112,10 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
|
|
||||||
TextSegment textSegment = TextSegment.from(text, metadata);
|
TextSegment textSegment = TextSegment.from(text, metadata);
|
||||||
Embedding embedding = embeddingModel.embed(text).content();
|
Embedding embedding = embeddingModel.embed(text).content();
|
||||||
embeddingStore.add(embedding, textSegment);
|
// 单位化处理
|
||||||
|
float[] vector = embedding.vector();
|
||||||
|
normalize(vector);
|
||||||
|
embeddingStore.add(Embedding.from(vector), textSegment);
|
||||||
});
|
});
|
||||||
long endTime = System.currentTimeMillis();
|
long endTime = System.currentTimeMillis();
|
||||||
log.info("Milvus向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000);
|
log.info("Milvus向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000);
|
||||||
@@ -136,6 +147,55 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
return resultList;
|
return resultList;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||||
|
int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName());
|
||||||
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
|
|
||||||
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
|
// 查询向量单位化处理
|
||||||
|
float[] queryVector = queryEmbedding.vector();
|
||||||
|
normalize(queryVector);
|
||||||
|
|
||||||
|
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
|
||||||
|
|
||||||
|
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, dimension, true);
|
||||||
|
|
||||||
|
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||||
|
.queryEmbedding(Embedding.from(queryVector))
|
||||||
|
.maxResults(queryVectorBo.getMaxResults())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(request).matches();
|
||||||
|
List<KnowledgeRetrievalVo> resultList = new ArrayList<>();
|
||||||
|
|
||||||
|
for (EmbeddingMatch<TextSegment> match : matches) {
|
||||||
|
TextSegment segment = match.embedded();
|
||||||
|
if (segment == null) continue;
|
||||||
|
|
||||||
|
String docId = segment.metadata().getString("docId");
|
||||||
|
String sourceName = "未知来源";
|
||||||
|
if (docId != null) {
|
||||||
|
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
|
||||||
|
.eq(KnowledgeAttach::getDocId, docId)
|
||||||
|
.last("limit 1"));
|
||||||
|
if (attach != null) {
|
||||||
|
sourceName = attach.getName();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取内容、评分及来源
|
||||||
|
double score = match.score();
|
||||||
|
|
||||||
|
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
|
||||||
|
.content(segment.text())
|
||||||
|
.score(score)
|
||||||
|
.sourceName(sourceName)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
return resultList;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
public void removeById(String id, String modelName) {
|
public void removeById(String id, String modelName) {
|
||||||
|
|||||||
@@ -24,7 +24,11 @@ import org.ruoyi.common.core.exception.ServiceException;
|
|||||||
import org.ruoyi.config.VectorStoreProperties;
|
import org.ruoyi.config.VectorStoreProperties;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.factory.EmbeddingModelFactory;
|
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||||
|
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||||
|
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import static io.qdrant.client.VectorInputFactory.vectorInput;
|
import static io.qdrant.client.VectorInputFactory.vectorInput;
|
||||||
@@ -47,10 +51,14 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
private static final String METADATA_KID_KEY = "kid";
|
private static final String METADATA_KID_KEY = "kid";
|
||||||
private static final String METADATA_DOC_ID_KEY = "doc_id";
|
private static final String METADATA_DOC_ID_KEY = "doc_id";
|
||||||
|
|
||||||
|
private final KnowledgeAttachMapper knowledgeAttachMapper;
|
||||||
|
|
||||||
public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||||
IChatModelService chatModelService,
|
IChatModelService chatModelService,
|
||||||
EmbeddingModelFactory embeddingModelFactory) {
|
EmbeddingModelFactory embeddingModelFactory,
|
||||||
|
KnowledgeAttachMapper knowledgeAttachMapper) {
|
||||||
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
||||||
|
this.knowledgeAttachMapper = knowledgeAttachMapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
private EmbeddingStore<TextSegment> getQdrantStore(String collectionName) {
|
private EmbeddingStore<TextSegment> getQdrantStore(String collectionName) {
|
||||||
@@ -129,7 +137,10 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
metadata.put(METADATA_DOC_ID_KEY, docId);
|
metadata.put(METADATA_DOC_ID_KEY, docId);
|
||||||
TextSegment textSegment = TextSegment.from(text, metadata);
|
TextSegment textSegment = TextSegment.from(text, metadata);
|
||||||
Embedding embedding = embeddingModel.embed(text).content();
|
Embedding embedding = embeddingModel.embed(text).content();
|
||||||
embeddingStore.add(embedding, textSegment);
|
// 单位化处理
|
||||||
|
float[] vector = embedding.vector();
|
||||||
|
normalize(vector);
|
||||||
|
embeddingStore.add(Embedding.from(vector), textSegment);
|
||||||
});
|
});
|
||||||
|
|
||||||
long endTime = System.currentTimeMillis();
|
long endTime = System.currentTimeMillis();
|
||||||
@@ -140,18 +151,22 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
|
// 查询向量单位化处理
|
||||||
|
float[] queryVector = queryEmbedding.vector();
|
||||||
|
normalize(queryVector);
|
||||||
|
|
||||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
|
||||||
|
|
||||||
List<Float> vector = new ArrayList<>();
|
List<Float> vectorList = new ArrayList<>();
|
||||||
for (float f : queryEmbedding.vector()) {
|
for (float f : queryVector) {
|
||||||
vector.add(f);
|
vectorList.add(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
try (QdrantClient client = buildQdrantClient()) {
|
try (QdrantClient client = buildQdrantClient()) {
|
||||||
QueryPoints request = QueryPoints.newBuilder()
|
QueryPoints request = QueryPoints.newBuilder()
|
||||||
.setCollectionName(collectionName)
|
.setCollectionName(collectionName)
|
||||||
.setQuery(Query.newBuilder()
|
.setQuery(Query.newBuilder()
|
||||||
.setNearest(vectorInput(vector))
|
.setNearest(vectorInput(vectorList))
|
||||||
.build())
|
.build())
|
||||||
.setLimit(queryVectorBo.getMaxResults())
|
.setLimit(queryVectorBo.getMaxResults())
|
||||||
.setWithPayload(enable(true))
|
.setWithPayload(enable(true))
|
||||||
@@ -172,6 +187,69 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||||
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
|
// 查询向量单位化处理
|
||||||
|
float[] queryVector = queryEmbedding.vector();
|
||||||
|
normalize(queryVector);
|
||||||
|
|
||||||
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
|
||||||
|
|
||||||
|
List<Float> vectorList = new ArrayList<>();
|
||||||
|
for (float f : queryVector) {
|
||||||
|
vectorList.add(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
try (QdrantClient client = buildQdrantClient()) {
|
||||||
|
QueryPoints request = QueryPoints.newBuilder()
|
||||||
|
.setCollectionName(collectionName)
|
||||||
|
.setQuery(Query.newBuilder()
|
||||||
|
.setNearest(vectorInput(vectorList))
|
||||||
|
.build())
|
||||||
|
.setLimit(queryVectorBo.getMaxResults())
|
||||||
|
.setWithPayload(enable(true))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<ScoredPoint> results = client.queryAsync(request).get();
|
||||||
|
List<org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo> resultList = new ArrayList<>();
|
||||||
|
for (ScoredPoint point : results) {
|
||||||
|
String content = "";
|
||||||
|
JsonWithInt.Value textValue = point.getPayloadMap().get(TEXT_SEGMENT_KEY);
|
||||||
|
if (textValue != null && textValue.hasStringValue()) {
|
||||||
|
content = textValue.getStringValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
String docId = null;
|
||||||
|
JsonWithInt.Value docIdValue = point.getPayloadMap().get(METADATA_DOC_ID_KEY);
|
||||||
|
if (docIdValue != null && docIdValue.hasStringValue()) {
|
||||||
|
docId = docIdValue.getStringValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
String sourceName = "未知来源";
|
||||||
|
if (docId != null) {
|
||||||
|
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
|
||||||
|
.eq(KnowledgeAttach::getDocId, docId)
|
||||||
|
.last("limit 1"));
|
||||||
|
if (attach != null) {
|
||||||
|
sourceName = attach.getName();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
|
||||||
|
.content(content)
|
||||||
|
.score((double) point.getScore())
|
||||||
|
.sourceName(sourceName)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
return resultList;
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Qdrant检索失败: {}", collectionName, e);
|
||||||
|
throw new ServiceException("Qdrant向量检索失败");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void removeById(String id, String modelName) {
|
public void removeById(String id, String modelName) {
|
||||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id;
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id;
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import lombok.RequiredArgsConstructor;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.factory.VectorStoreStrategyFactory;
|
import org.ruoyi.factory.VectorStoreStrategyFactory;
|
||||||
import org.ruoyi.service.vector.VectorStoreService;
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
import org.springframework.context.annotation.Primary;
|
import org.springframework.context.annotation.Primary;
|
||||||
@@ -54,6 +55,13 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|||||||
return strategy.getQueryVector(queryVectorBo);
|
return strategy.getQueryVector(queryVectorBo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||||
|
log.info("执行测试搜索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
|
||||||
|
VectorStoreService strategy = getCurrentStrategy();
|
||||||
|
return strategy.search(queryVectorBo);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void removeById(String id, String modelName) {
|
public void removeById(String id, String modelName) {
|
||||||
log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName);
|
log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName);
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import org.ruoyi.common.core.exception.ServiceException;
|
|||||||
import org.ruoyi.config.VectorStoreProperties;
|
import org.ruoyi.config.VectorStoreProperties;
|
||||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
|
||||||
import org.ruoyi.factory.EmbeddingModelFactory;
|
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import io.weaviate.client.Config;
|
import io.weaviate.client.Config;
|
||||||
@@ -24,6 +25,9 @@ import io.weaviate.client.v1.graphql.model.GraphQLResponse;
|
|||||||
import io.weaviate.client.v1.schema.model.Property;
|
import io.weaviate.client.v1.schema.model.Property;
|
||||||
import io.weaviate.client.v1.schema.model.Schema;
|
import io.weaviate.client.v1.schema.model.Schema;
|
||||||
import io.weaviate.client.v1.schema.model.WeaviateClass;
|
import io.weaviate.client.v1.schema.model.WeaviateClass;
|
||||||
|
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
|
||||||
|
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@@ -40,11 +44,14 @@ import java.util.Map;
|
|||||||
public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||||
|
|
||||||
private WeaviateClient client;
|
private WeaviateClient client;
|
||||||
|
private final KnowledgeAttachMapper knowledgeAttachMapper;
|
||||||
|
|
||||||
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||||
IChatModelService chatModelService,
|
IChatModelService chatModelService,
|
||||||
EmbeddingModelFactory embeddingModelFactory) {
|
EmbeddingModelFactory embeddingModelFactory,
|
||||||
|
KnowledgeAttachMapper knowledgeAttachMapper) {
|
||||||
super(vectorStoreProperties, embeddingModelFactory,chatModelService);
|
super(vectorStoreProperties, embeddingModelFactory,chatModelService);
|
||||||
|
this.knowledgeAttachMapper = knowledgeAttachMapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -110,9 +117,12 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
"kid", kid,
|
"kid", kid,
|
||||||
"docId", docId
|
"docId", docId
|
||||||
);
|
);
|
||||||
Float[] vector = toObjectArray(embedding.vector());
|
float[] vectorArray = embedding.vector();
|
||||||
|
normalize(vectorArray);
|
||||||
|
Float[] vector = toObjectArray(vectorArray);
|
||||||
|
|
||||||
client.data().creator()
|
client.data().creator()
|
||||||
.withClassName("LocalKnowledge" + kid)
|
.withClassName(vectorStoreProperties.getWeaviate().getClassname() + kid)
|
||||||
.withProperties(properties)
|
.withProperties(properties)
|
||||||
.withVector(vector)
|
.withVector(vector)
|
||||||
.run();
|
.run();
|
||||||
@@ -128,6 +138,9 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
float[] vector = queryEmbedding.vector();
|
float[] vector = queryEmbedding.vector();
|
||||||
|
// 查询向量单位化处理
|
||||||
|
normalize(vector);
|
||||||
|
|
||||||
List<String> vectorStrings = new ArrayList<>();
|
List<String> vectorStrings = new ArrayList<>();
|
||||||
for (float v : vector) {
|
for (float v : vector) {
|
||||||
vectorStrings.add(String.valueOf(v));
|
vectorStrings.add(String.valueOf(v));
|
||||||
@@ -178,6 +191,77 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
|
||||||
|
createSchema(queryVectorBo.getKid(), queryVectorBo.getEmbeddingModelName());
|
||||||
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
|
float[] vector = queryEmbedding.vector();
|
||||||
|
// 查询向量单位化处理
|
||||||
|
normalize(vector);
|
||||||
|
List<String> vectorStrings = new ArrayList<>();
|
||||||
|
for (float v : vector) {
|
||||||
|
vectorStrings.add(String.valueOf(v));
|
||||||
|
}
|
||||||
|
String vectorStr = String.join(",", vectorStrings);
|
||||||
|
String className = vectorStoreProperties.getWeaviate().getClassname();
|
||||||
|
|
||||||
|
String graphQLQuery = String.format(
|
||||||
|
"{\n" +
|
||||||
|
" Get {\n" +
|
||||||
|
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
|
||||||
|
" text\n" +
|
||||||
|
" docId\n" +
|
||||||
|
" _additional {\n" +
|
||||||
|
" distance\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
"}",
|
||||||
|
className + queryVectorBo.getKid(),
|
||||||
|
vectorStr,
|
||||||
|
queryVectorBo.getMaxResults()
|
||||||
|
);
|
||||||
|
|
||||||
|
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
|
||||||
|
List<org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo> resultList = new ArrayList<>();
|
||||||
|
|
||||||
|
if (result != null && !result.hasErrors()) {
|
||||||
|
Object data = result.getResult().getData();
|
||||||
|
JSONObject entries = new JSONObject(data);
|
||||||
|
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
|
||||||
|
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
|
||||||
|
|
||||||
|
for (Object obj : objects) {
|
||||||
|
Map<String, Object> map = (Map<String, Object>) obj;
|
||||||
|
String content = (String) map.get("text");
|
||||||
|
String docId = (String) map.get("docId");
|
||||||
|
|
||||||
|
Map<String, Object> additional = (Map<String, Object>) map.get("_additional");
|
||||||
|
Double distance = Double.valueOf(String.valueOf(additional.get("distance")));
|
||||||
|
// 转换距离为得分 (Weaviate 0 是最相近,1 是最远;余弦距离下 1-dist 即为相似度)
|
||||||
|
double score = 1.0 - distance;
|
||||||
|
|
||||||
|
String sourceName = "未知来源";
|
||||||
|
if (docId != null) {
|
||||||
|
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
|
||||||
|
.eq(KnowledgeAttach::getDocId, docId)
|
||||||
|
.last("limit 1"));
|
||||||
|
if (attach != null) {
|
||||||
|
sourceName = attach.getName();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
|
||||||
|
.content(content)
|
||||||
|
.score(score)
|
||||||
|
.sourceName(sourceName)
|
||||||
|
.build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resultList;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
public void removeById(String id, String modelName) {
|
public void removeById(String id, String modelName) {
|
||||||
|
|||||||
Reference in New Issue
Block a user