feat: 新增检索测试相关接口

- 实现向量 L2 归一化,统一 Milvus/Qdrant/Weaviate 检索评分为 [0, 1] 空间
This commit is contained in:
RobustH
2026-04-13 23:33:56 +08:00
parent 0fa25032a3
commit 06a63c377e
14 changed files with 548 additions and 36 deletions

View File

@@ -8,6 +8,7 @@ import jakarta.validation.constraints.*;
import cn.dev33.satoken.annotation.SaCheckPermission; import cn.dev33.satoken.annotation.SaCheckPermission;
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
@@ -102,4 +103,12 @@ public class KnowledgeFragmentController extends BaseController {
@PathVariable Long[] ids) { @PathVariable Long[] ids) {
return toAjax(knowledgeFragmentService.deleteWithValidByIds(List.of(ids), true)); return toAjax(knowledgeFragmentService.deleteWithValidByIds(List.of(ids), true));
} }
/**
* 检索测试
*/
@PostMapping("/retrieval")
public R<List<KnowledgeRetrievalVo>> retrieval(@RequestBody KnowledgeFragmentBo bo) {
return R.ok(knowledgeFragmentService.retrieval(bo));
}
} }

View File

@@ -49,5 +49,24 @@ public class KnowledgeFragmentBo extends BaseEntity {
*/ */
private String remark; private String remark;
/**
* 知识库ID
*/
private Long knowledgeId;
/**
* 检索内容
*/
private String query;
/**
* 返回条数
*/
private Integer topK;
/**
* 相似度阈值
*/
private Double threshold;
} }

View File

@@ -0,0 +1,35 @@
package org.ruoyi.domain.vo.knowledge;
import lombok.Builder;
import lombok.Data;
import java.io.Serial;
import java.io.Serializable;
/**
* 知识检索测试结果视图对象
*
* @author RobustH
*/
@Data
@Builder
public class KnowledgeRetrievalVo implements Serializable {
@Serial
private static final long serialVersionUID = 1L;
/**
* 片段内容
*/
private String content;
/**
* 相似度得分
*/
private Double score;
/**
* 来源文档名称
*/
private String sourceName;
}

View File

@@ -16,6 +16,15 @@ import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
import dev.langchain4j.rag.content.aggregator.DefaultContentAggregator;
import dev.langchain4j.rag.content.aggregator.ReRankingContentAggregator;
import dev.langchain4j.model.scoring.ScoringModel;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.tool.ToolProvider; import dev.langchain4j.service.tool.ToolProvider;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows; import lombok.SneakyThrows;
@@ -46,6 +55,8 @@ import org.ruoyi.mcp.service.core.ToolProviderFactory;
import org.ruoyi.service.chat.AbstractChatService; import org.ruoyi.service.chat.AbstractChatService;
import org.ruoyi.service.chat.IChatMessageService; import org.ruoyi.service.chat.IChatMessageService;
import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore; import org.ruoyi.service.chat.impl.memory.PersistentChatMemoryStore;
import org.ruoyi.service.knowledge.retriever.CustomVectorRetriever;
import org.ruoyi.service.knowledge.rerank.ScoringModelFactory;
import org.ruoyi.service.knowledge.IKnowledgeInfoService; import org.ruoyi.service.knowledge.IKnowledgeInfoService;
import org.ruoyi.service.vector.VectorStoreService; import org.ruoyi.service.vector.VectorStoreService;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -89,6 +100,8 @@ public class ChatServiceFacade implements IChatService {
private final ToolProviderFactory toolProviderFactory; private final ToolProviderFactory toolProviderFactory;
private final ScoringModelFactory scoringModelFactory;
/** /**
* 内存实例缓存,避免同一会话重复创建 * 内存实例缓存,避免同一会话重复创建
* Key: sessionId, Value: MessageWindowChatMemory实例 * Key: sessionId, Value: MessageWindowChatMemory实例
@@ -119,7 +132,9 @@ public class ChatServiceFacade implements IChatService {
// 2. 构建上下文消息列表 // 2. 构建上下文消息列表
List<ChatMessage> contextMessages = buildContextMessages(chatRequest); List<ChatMessage> contextMessages = buildContextMessages(chatRequest);
// 3. 处理特殊聊天模式(工作流、人机交互恢复、思考模式) // 注意buildContextMessages() 最后返回的列表中,最新的带有增强知识的 UserMessage 在最后。
// 对于有些模型API非langchain4j的代理它们可能不识别增强后的复杂文本取决于供应商适配度
// 但是通过标准流,它被解析为 String。
SseEmitter specialResult = handleSpecialChatModes(chatRequest, contextMessages, chatModelVo, emitter); SseEmitter specialResult = handleSpecialChatModes(chatRequest, contextMessages, chatModelVo, emitter);
if (specialResult != null) { if (specialResult != null) {
return specialResult; return specialResult;
@@ -346,39 +361,63 @@ public class ChatServiceFacade implements IChatService {
* @return 上下文消息列表 * @return 上下文消息列表
*/ */
private List<ChatMessage> buildContextMessages(ChatRequest chatRequest) { private List<ChatMessage> buildContextMessages(ChatRequest chatRequest) {
List<ChatMessage> messages = new ArrayList<>(); List<ChatMessage> messages = new ArrayList<>();
// 构建用户消息
// 初始化用户消息
UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent()); UserMessage userMessage = UserMessage.userMessage(chatRequest.getContent());
messages.add(userMessage);
// 从向量库查询相关历史消息 // 使用 LangChain4j 的 RetrievalAugmentor 进行检索增强
if (chatRequest.getKnowledgeId() != null) { if (chatRequest.getKnowledgeId() != null) {
// 查询知识库信息
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId())); KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(chatRequest.getKnowledgeId()));
if (knowledgeInfoVo == null) { if (knowledgeInfoVo != null) {
log.warn("知识库信息不存在kid: {}", chatRequest.getKnowledgeId()); ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
return messages; if (chatModel != null) {
}
// 查询向量模型配置信息 // 1. 构建适配器Retriever
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel()); CustomVectorRetriever retriever = new CustomVectorRetriever(
if (chatModel == null) { vectorStoreService, knowledgeInfoVo, chatModel);
log.warn("向量模型配置不存在,模型名称: {}", knowledgeInfoVo.getEmbeddingModel());
return messages;
}
// 构建向量查询参数 // 2. 获取和构建重排模型聚合器Aggregator
QueryVectorBo queryVectorBo = buildQueryVectorBo(chatRequest, knowledgeInfoVo, chatModel); // 假设已在 KnowledgeInfoVo 等加入 getRerankModelConfig/getRerankModel 等,这里演示通用逻辑
// 若无重排需求,使用 DefaultContentAggregator 或无 ScoringModel 的聚合器
ContentAggregator contentAggregator;
// TODO: 一旦实体类实现了重排模型的支持,此处可以从数据库读出:
// ChatModelVo scoringModelConfig = chatModelService.selectModelByName(knowledgeInfoVo.getRerankModel());
ChatModelVo scoringModelConfig = null; // 当前暂无对应配置字段
// 获取向量查询结果 ScoringModel scoringModel = scoringModelFactory.createScoringModel(scoringModelConfig);
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo); if (scoringModel != null) {
for (String prompt : nearestList) { contentAggregator = ReRankingContentAggregator.builder()
// 知识库内容作为系统上下文添加 .scoringModel(scoringModel)
messages.add( new AiMessage(prompt)); // .maxResults(3) 这个数字将来从配置取
.build();
} else {
contentAggregator = new DefaultContentAggregator();
}
// 3. 构造流水线
RetrievalAugmentor augmentor = DefaultRetrievalAugmentor.builder()
.contentRetriever(retriever)
.contentAggregator(contentAggregator)
.build();
// 4. 执行 Augmentor 增强:将检索到的知识内容编织进 UserMessage 中
Metadata ragMetadata = Metadata.from(userMessage, chatRequest.getSessionId(), new ArrayList<>());
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, ragMetadata);
AugmentationResult augmentedResult = augmentor.augment(augmentationRequest);
ChatMessage augmentedMessage = augmentedResult.chatMessage();
if (augmentedMessage instanceof UserMessage) {
userMessage = (UserMessage) augmentedMessage;
}
log.info("RAG 增强完成: UserMessage 已重构并附加上下文背景。");
}
} }
} }
// 从数据库查询历史对话消息 // 从数据库查询历史对话消息(历史消息应放在当前提问前)
if (chatRequest.getSessionId() != null) { if (chatRequest.getSessionId() != null) {
MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId()); MessageWindowChatMemory memory = createChatMemory(chatRequest.getSessionId());
if (memory != null) { if (memory != null) {
@@ -390,6 +429,9 @@ public class ChatServiceFacade implements IChatService {
} }
} }
// 注入本次用户提问(经过 RAG 增强后的 UserMessage
messages.add(userMessage);
return messages; return messages;
} }

View File

@@ -4,6 +4,7 @@ import org.ruoyi.common.mybatis.core.page.TableDataInfo;
import org.ruoyi.common.mybatis.core.page.PageQuery; import org.ruoyi.common.mybatis.core.page.PageQuery;
import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo; import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
@@ -65,4 +66,12 @@ public interface IKnowledgeFragmentService {
* @return 是否删除成功 * @return 是否删除成功
*/ */
Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid); Boolean deleteWithValidByIds(Collection<Long> ids, Boolean isValid);
/**
* 检索测试
*
* @param bo 检索参数
* @return 检索结果
*/
List<KnowledgeRetrievalVo> retrieval(KnowledgeFragmentBo bo);
} }

View File

@@ -13,8 +13,17 @@ import org.ruoyi.domain.bo.knowledge.KnowledgeFragmentBo;
import org.ruoyi.domain.entity.knowledge.KnowledgeFragment; import org.ruoyi.domain.entity.knowledge.KnowledgeFragment;
import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo; import org.ruoyi.domain.vo.knowledge.KnowledgeFragmentVo;
import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper; import org.ruoyi.mapper.knowledge.KnowledgeFragmentMapper;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.service.knowledge.IKnowledgeFragmentService; import org.ruoyi.service.knowledge.IKnowledgeFragmentService;
import org.ruoyi.service.knowledge.IKnowledgeInfoService;
import org.ruoyi.common.chat.service.chat.IChatModelService;
import org.ruoyi.service.vector.VectorStoreService;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.stream.Collectors;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -32,6 +41,9 @@ import java.util.Collection;
public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService { public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
private final KnowledgeFragmentMapper baseMapper; private final KnowledgeFragmentMapper baseMapper;
private final IKnowledgeInfoService knowledgeInfoService;
private final IChatModelService chatModelService;
private final VectorStoreService vectorStoreService;
/** /**
* 查询知识片段 * 查询知识片段
@@ -131,4 +143,45 @@ public class KnowledgeFragmentServiceImpl implements IKnowledgeFragmentService {
} }
return baseMapper.deleteByIds(ids) > 0; return baseMapper.deleteByIds(ids) > 0;
} }
/**
* 检索测试核心实现
*/
@Override
public List<KnowledgeRetrievalVo> retrieval(KnowledgeFragmentBo bo) {
if (bo.getKnowledgeId() == null || StringUtils.isBlank(bo.getQuery())) {
return new ArrayList<>();
}
// 1. 获取知识库及模型配置
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(bo.getKnowledgeId());
if (knowledgeInfoVo == null) {
return new ArrayList<>();
}
ChatModelVo chatModel = chatModelService.selectModelByName(knowledgeInfoVo.getEmbeddingModel());
if (chatModel == null) {
log.warn("未找到对应的向量模型配置: {}", knowledgeInfoVo.getEmbeddingModel());
return new ArrayList<>();
}
// 2. 构造向量检索参数
QueryVectorBo queryVectorBo = new QueryVectorBo();
queryVectorBo.setQuery(bo.getQuery());
queryVectorBo.setKid(String.valueOf(bo.getKnowledgeId()));
queryVectorBo.setMaxResults(bo.getTopK() != null ? bo.getTopK() : knowledgeInfoVo.getRetrieveLimit());
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
queryVectorBo.setApiKey(chatModel.getApiKey());
queryVectorBo.setBaseUrl(chatModel.getApiHost());
// 3. 执行物理检索
List<KnowledgeRetrievalVo> allResults = vectorStoreService.search(queryVectorBo);
// 4. 根据阈值过滤 (LangChain4j 结果 score 通常 0-1)
double threshold = bo.getThreshold() != null ? bo.getThreshold() : 0.0;
return allResults.stream()
.filter(res -> res.getScore() >= threshold)
.collect(Collectors.toList());
}
} }

View File

@@ -0,0 +1,37 @@
package org.ruoyi.service.knowledge.rerank;
import dev.langchain4j.model.scoring.ScoringModel;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
import org.springframework.stereotype.Component;
/**
* 重排模型提供商工厂
* 用于将来无缝拓展硅基流动、百炼等支持重排的模型厂商
*
* @author RobustH
*/
@Slf4j
@Component
public class ScoringModelFactory {
/**
* 根据后台传递的模型配置创建具体的重排模型
*
* @param rerankModelConfig 重排模型的配置 (例如其 providerCode, apiUrl, apiKey 等)
* @return 标准的 LangChain4j ScoringModel
*/
public ScoringModel createScoringModel(ChatModelVo rerankModelConfig) {
if (rerankModelConfig == null) {
return null;
}
String providerCode = rerankModelConfig.getProviderCode();
log.info("初始化重排模型,供应商代码: {}", providerCode);
// TODO: 在这里通过 switch 或反射具体实例化支持的各种 ScoringModel (例如 CohereScoringModel, DascScope 等)
// 目前返回 null 代表暂时没有加载特定的重排底座这不会影响流程Aggregator 会忽略它返回原样结果
return null;
}
}

View File

@@ -0,0 +1,54 @@
package org.ruoyi.service.knowledge.retriever;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.chat.domain.vo.chat.ChatModelVo;
import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeInfoVo;
import org.ruoyi.service.vector.VectorStoreService;
import java.util.List;
import java.util.stream.Collectors;
/**
* 自定义向量检索器:适配 LangChain4j ContentRetriever 接口
* 桥接现有的 VectorStoreService 获取检索结果
*
* @author RobustH
*/
@Slf4j
@RequiredArgsConstructor
public class CustomVectorRetriever implements ContentRetriever {
private final VectorStoreService vectorStoreService;
private final KnowledgeInfoVo knowledgeInfoVo;
private final ChatModelVo chatModelVo;
@Override
public List<Content> retrieve(Query query) {
log.info("执行自定义向量检索,关键字: {}", query.text());
// 构建内部查询参数
QueryVectorBo queryVectorBo = new QueryVectorBo();
queryVectorBo.setQuery(query.text());
queryVectorBo.setKid(String.valueOf(knowledgeInfoVo.getId()));
queryVectorBo.setApiKey(chatModelVo.getApiKey());
queryVectorBo.setBaseUrl(chatModelVo.getApiHost());
queryVectorBo.setVectorModelName(knowledgeInfoVo.getVectorModel());
queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModel());
// 如果接入了重排,这里的 retrieveLimit 也就是 MaxResults 应当被放大,后续留给 Aggregator 截断
queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
// 执行底层的多种向量库策略检索
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
// 将结果包装为标准的 Content 返回
return nearestList.stream()
.map(text -> Content.from(TextSegment.from(text)))
.collect(Collectors.toList());
}
}

View File

@@ -3,6 +3,7 @@ package org.ruoyi.service.vector;
import org.ruoyi.common.core.exception.ServiceException; import org.ruoyi.common.core.exception.ServiceException;
import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import java.util.List; import java.util.List;
@@ -17,6 +18,11 @@ public interface VectorStoreService {
List<String> getQueryVector(QueryVectorBo queryVectorBo); List<String> getQueryVector(QueryVectorBo queryVectorBo);
/**
* 带分数及元数据的检索(用于测试检索功能)
*/
List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo);
void createSchema(String kid, String embeddingModelName); void createSchema(String kid, String embeddingModelName);
void removeById(String id, String modelName) throws ServiceException; void removeById(String id, String modelName) throws ServiceException;

View File

@@ -37,6 +37,24 @@ public abstract class AbstractVectorStoreStrategy implements VectorStoreService
return result; return result;
} }
/**
* 向量 L2 归一化 (单位化)
*/
protected static float[] normalize(float[] vector) {
if (vector == null) return null;
double sum = 0;
for (float v : vector) {
sum += v * v;
}
float norm = (float) Math.sqrt(sum);
if (norm > 1e-9) {
for (int i = 0; i < vector.length; i++) {
vector[i] /= norm;
}
}
return vector;
}
/** /**
* 获取向量模型 * 获取向量模型
*/ */

View File

@@ -19,7 +19,11 @@ import org.ruoyi.common.chat.service.chat.IChatModelService;
import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.config.VectorStoreProperties;
import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import org.ruoyi.factory.EmbeddingModelFactory; import org.ruoyi.factory.EmbeddingModelFactory;
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.ArrayList; import java.util.ArrayList;
@@ -32,10 +36,14 @@ import java.util.stream.IntStream;
@Component @Component
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy { public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
private final KnowledgeAttachMapper knowledgeAttachMapper;
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
IChatModelService chatModelService, IChatModelService chatModelService,
EmbeddingModelFactory embeddingModelFactory) { EmbeddingModelFactory embeddingModelFactory,
KnowledgeAttachMapper knowledgeAttachMapper) {
super(vectorStoreProperties, embeddingModelFactory, chatModelService); super(vectorStoreProperties, embeddingModelFactory, chatModelService);
this.knowledgeAttachMapper = knowledgeAttachMapper;
} }
// 缓存不同集合与 autoFlush 配置的 Milvus 连接 // 缓存不同集合与 autoFlush 配置的 Milvus 连接
@@ -51,7 +59,7 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
.collectionName(collectionName) .collectionName(collectionName)
.dimension(dimension) .dimension(dimension)
.indexType(IndexType.IVF_FLAT) .indexType(IndexType.IVF_FLAT)
.metricType(MetricType.L2) .metricType(MetricType.COSINE)
.autoFlushOnInsert(autoFlushOnInsert) .autoFlushOnInsert(autoFlushOnInsert)
.idFieldName("id") .idFieldName("id")
.textFieldName("text") .textFieldName("text")
@@ -104,7 +112,10 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
TextSegment textSegment = TextSegment.from(text, metadata); TextSegment textSegment = TextSegment.from(text, metadata);
Embedding embedding = embeddingModel.embed(text).content(); Embedding embedding = embeddingModel.embed(text).content();
embeddingStore.add(embedding, textSegment); // 单位化处理
float[] vector = embedding.vector();
normalize(vector);
embeddingStore.add(Embedding.from(vector), textSegment);
}); });
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
log.info("Milvus向量存储完成消耗时间{}秒", (endTime - startTime) / 1000); log.info("Milvus向量存储完成消耗时间{}秒", (endTime - startTime) / 1000);
@@ -136,6 +147,55 @@ public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
return resultList; return resultList;
} }
@Override
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
int dimension = getModelDimension(queryVectorBo.getEmbeddingModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
// 查询向量单位化处理
float[] queryVector = queryEmbedding.vector();
normalize(queryVector);
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
EmbeddingStore<TextSegment> embeddingStore = getMilvusStore(collectionName, dimension, true);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(Embedding.from(queryVector))
.maxResults(queryVectorBo.getMaxResults())
.build();
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(request).matches();
List<KnowledgeRetrievalVo> resultList = new ArrayList<>();
for (EmbeddingMatch<TextSegment> match : matches) {
TextSegment segment = match.embedded();
if (segment == null) continue;
String docId = segment.metadata().getString("docId");
String sourceName = "未知来源";
if (docId != null) {
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
.eq(KnowledgeAttach::getDocId, docId)
.last("limit 1"));
if (attach != null) {
sourceName = attach.getName();
}
}
// 提取内容、评分及来源
double score = match.score();
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
.content(segment.text())
.score(score)
.sourceName(sourceName)
.build());
}
return resultList;
}
@Override @Override
@SneakyThrows @SneakyThrows
public void removeById(String id, String modelName) { public void removeById(String id, String modelName) {

View File

@@ -24,7 +24,11 @@ import org.ruoyi.common.core.exception.ServiceException;
import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.config.VectorStoreProperties;
import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import org.ruoyi.factory.EmbeddingModelFactory; import org.ruoyi.factory.EmbeddingModelFactory;
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import static io.qdrant.client.VectorInputFactory.vectorInput; import static io.qdrant.client.VectorInputFactory.vectorInput;
@@ -47,10 +51,14 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
private static final String METADATA_KID_KEY = "kid"; private static final String METADATA_KID_KEY = "kid";
private static final String METADATA_DOC_ID_KEY = "doc_id"; private static final String METADATA_DOC_ID_KEY = "doc_id";
private final KnowledgeAttachMapper knowledgeAttachMapper;
public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
IChatModelService chatModelService, IChatModelService chatModelService,
EmbeddingModelFactory embeddingModelFactory) { EmbeddingModelFactory embeddingModelFactory,
KnowledgeAttachMapper knowledgeAttachMapper) {
super(vectorStoreProperties, embeddingModelFactory, chatModelService); super(vectorStoreProperties, embeddingModelFactory, chatModelService);
this.knowledgeAttachMapper = knowledgeAttachMapper;
} }
private EmbeddingStore<TextSegment> getQdrantStore(String collectionName) { private EmbeddingStore<TextSegment> getQdrantStore(String collectionName) {
@@ -129,7 +137,10 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
metadata.put(METADATA_DOC_ID_KEY, docId); metadata.put(METADATA_DOC_ID_KEY, docId);
TextSegment textSegment = TextSegment.from(text, metadata); TextSegment textSegment = TextSegment.from(text, metadata);
Embedding embedding = embeddingModel.embed(text).content(); Embedding embedding = embeddingModel.embed(text).content();
embeddingStore.add(embedding, textSegment); // 单位化处理
float[] vector = embedding.vector();
normalize(vector);
embeddingStore.add(Embedding.from(vector), textSegment);
}); });
long endTime = System.currentTimeMillis(); long endTime = System.currentTimeMillis();
@@ -140,18 +151,22 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
public List<String> getQueryVector(QueryVectorBo queryVectorBo) { public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
// 查询向量单位化处理
float[] queryVector = queryEmbedding.vector();
normalize(queryVector);
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid(); String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
List<Float> vector = new ArrayList<>(); List<Float> vectorList = new ArrayList<>();
for (float f : queryEmbedding.vector()) { for (float f : queryVector) {
vector.add(f); vectorList.add(f);
} }
try (QdrantClient client = buildQdrantClient()) { try (QdrantClient client = buildQdrantClient()) {
QueryPoints request = QueryPoints.newBuilder() QueryPoints request = QueryPoints.newBuilder()
.setCollectionName(collectionName) .setCollectionName(collectionName)
.setQuery(Query.newBuilder() .setQuery(Query.newBuilder()
.setNearest(vectorInput(vector)) .setNearest(vectorInput(vectorList))
.build()) .build())
.setLimit(queryVectorBo.getMaxResults()) .setLimit(queryVectorBo.getMaxResults())
.setWithPayload(enable(true)) .setWithPayload(enable(true))
@@ -172,6 +187,69 @@ public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
} }
} }
@Override
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
// 查询向量单位化处理
float[] queryVector = queryEmbedding.vector();
normalize(queryVector);
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
List<Float> vectorList = new ArrayList<>();
for (float f : queryVector) {
vectorList.add(f);
}
try (QdrantClient client = buildQdrantClient()) {
QueryPoints request = QueryPoints.newBuilder()
.setCollectionName(collectionName)
.setQuery(Query.newBuilder()
.setNearest(vectorInput(vectorList))
.build())
.setLimit(queryVectorBo.getMaxResults())
.setWithPayload(enable(true))
.build();
List<ScoredPoint> results = client.queryAsync(request).get();
List<org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo> resultList = new ArrayList<>();
for (ScoredPoint point : results) {
String content = "";
JsonWithInt.Value textValue = point.getPayloadMap().get(TEXT_SEGMENT_KEY);
if (textValue != null && textValue.hasStringValue()) {
content = textValue.getStringValue();
}
String docId = null;
JsonWithInt.Value docIdValue = point.getPayloadMap().get(METADATA_DOC_ID_KEY);
if (docIdValue != null && docIdValue.hasStringValue()) {
docId = docIdValue.getStringValue();
}
String sourceName = "未知来源";
if (docId != null) {
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
.eq(KnowledgeAttach::getDocId, docId)
.last("limit 1"));
if (attach != null) {
sourceName = attach.getName();
}
}
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
.content(content)
.score((double) point.getScore())
.sourceName(sourceName)
.build());
}
return resultList;
} catch (Exception e) {
log.error("Qdrant检索失败: {}", collectionName, e);
throw new ServiceException("Qdrant向量检索失败");
}
}
@Override @Override
public void removeById(String id, String modelName) { public void removeById(String id, String modelName) {
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id; String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id;

View File

@@ -4,6 +4,7 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import org.ruoyi.factory.VectorStoreStrategyFactory; import org.ruoyi.factory.VectorStoreStrategyFactory;
import org.ruoyi.service.vector.VectorStoreService; import org.ruoyi.service.vector.VectorStoreService;
import org.springframework.context.annotation.Primary; import org.springframework.context.annotation.Primary;
@@ -54,6 +55,13 @@ public class VectorStoreServiceImpl implements VectorStoreService {
return strategy.getQueryVector(queryVectorBo); return strategy.getQueryVector(queryVectorBo);
} }
@Override
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
log.info("执行测试搜索: kid={}, query={}", queryVectorBo.getKid(), queryVectorBo.getQuery());
VectorStoreService strategy = getCurrentStrategy();
return strategy.search(queryVectorBo);
}
@Override @Override
public void removeById(String id, String modelName) { public void removeById(String id, String modelName) {
log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName); log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName);

View File

@@ -12,6 +12,7 @@ import org.ruoyi.common.core.exception.ServiceException;
import org.ruoyi.config.VectorStoreProperties; import org.ruoyi.config.VectorStoreProperties;
import org.ruoyi.domain.bo.vector.QueryVectorBo; import org.ruoyi.domain.bo.vector.QueryVectorBo;
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo; import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
import org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo;
import org.ruoyi.factory.EmbeddingModelFactory; import org.ruoyi.factory.EmbeddingModelFactory;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import io.weaviate.client.Config; import io.weaviate.client.Config;
@@ -24,6 +25,9 @@ import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.schema.model.Property; import io.weaviate.client.v1.schema.model.Property;
import io.weaviate.client.v1.schema.model.Schema; import io.weaviate.client.v1.schema.model.Schema;
import io.weaviate.client.v1.schema.model.WeaviateClass; import io.weaviate.client.v1.schema.model.WeaviateClass;
import org.ruoyi.domain.entity.knowledge.KnowledgeAttach;
import org.ruoyi.mapper.knowledge.KnowledgeAttachMapper;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@@ -40,11 +44,14 @@ import java.util.Map;
public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy { public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
private WeaviateClient client; private WeaviateClient client;
private final KnowledgeAttachMapper knowledgeAttachMapper;
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties, public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
IChatModelService chatModelService, IChatModelService chatModelService,
EmbeddingModelFactory embeddingModelFactory) { EmbeddingModelFactory embeddingModelFactory,
KnowledgeAttachMapper knowledgeAttachMapper) {
super(vectorStoreProperties, embeddingModelFactory,chatModelService); super(vectorStoreProperties, embeddingModelFactory,chatModelService);
this.knowledgeAttachMapper = knowledgeAttachMapper;
} }
@Override @Override
@@ -110,9 +117,12 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
"kid", kid, "kid", kid,
"docId", docId "docId", docId
); );
Float[] vector = toObjectArray(embedding.vector()); float[] vectorArray = embedding.vector();
normalize(vectorArray);
Float[] vector = toObjectArray(vectorArray);
client.data().creator() client.data().creator()
.withClassName("LocalKnowledge" + kid) .withClassName(vectorStoreProperties.getWeaviate().getClassname() + kid)
.withProperties(properties) .withProperties(properties)
.withVector(vector) .withVector(vector)
.run(); .run();
@@ -128,6 +138,9 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName()); EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content(); Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
float[] vector = queryEmbedding.vector(); float[] vector = queryEmbedding.vector();
// 查询向量单位化处理
normalize(vector);
List<String> vectorStrings = new ArrayList<>(); List<String> vectorStrings = new ArrayList<>();
for (float v : vector) { for (float v : vector) {
vectorStrings.add(String.valueOf(v)); vectorStrings.add(String.valueOf(v));
@@ -178,6 +191,77 @@ public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
} }
} }
@Override
public List<KnowledgeRetrievalVo> search(QueryVectorBo queryVectorBo) {
createSchema(queryVectorBo.getKid(), queryVectorBo.getEmbeddingModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
float[] vector = queryEmbedding.vector();
// 查询向量单位化处理
normalize(vector);
List<String> vectorStrings = new ArrayList<>();
for (float v : vector) {
vectorStrings.add(String.valueOf(v));
}
String vectorStr = String.join(",", vectorStrings);
String className = vectorStoreProperties.getWeaviate().getClassname();
String graphQLQuery = String.format(
"{\n" +
" Get {\n" +
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
" text\n" +
" docId\n" +
" _additional {\n" +
" distance\n" +
" }\n" +
" }\n" +
" }\n" +
"}",
className + queryVectorBo.getKid(),
vectorStr,
queryVectorBo.getMaxResults()
);
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
List<org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo> resultList = new ArrayList<>();
if (result != null && !result.hasErrors()) {
Object data = result.getResult().getData();
JSONObject entries = new JSONObject(data);
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
for (Object obj : objects) {
Map<String, Object> map = (Map<String, Object>) obj;
String content = (String) map.get("text");
String docId = (String) map.get("docId");
Map<String, Object> additional = (Map<String, Object>) map.get("_additional");
Double distance = Double.valueOf(String.valueOf(additional.get("distance")));
// 转换距离为得分 (Weaviate 0 是最相近1 是最远;余弦距离下 1-dist 即为相似度)
double score = 1.0 - distance;
String sourceName = "未知来源";
if (docId != null) {
KnowledgeAttach attach = knowledgeAttachMapper.selectOne(new LambdaQueryWrapper<KnowledgeAttach>()
.eq(KnowledgeAttach::getDocId, docId)
.last("limit 1"));
if (attach != null) {
sourceName = attach.getName();
}
}
resultList.add(org.ruoyi.domain.vo.knowledge.KnowledgeRetrievalVo.builder()
.content(content)
.score(score)
.sourceName(sourceName)
.build());
}
}
return resultList;
}
@Override @Override
@SneakyThrows @SneakyThrows
public void removeById(String id, String modelName) { public void removeById(String id, String modelName) {