mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-11 10:37:20 +00:00
Merge pull request #281 from Anush008/main
feat: Adds support for Qdrant vector search
This commit is contained in:
@@ -275,7 +275,7 @@ warm-flow:
|
|||||||
|
|
||||||
# 向量库配置
|
# 向量库配置
|
||||||
vector-store:
|
vector-store:
|
||||||
# 向量存储类型 可选(weaviate/milvus)
|
# 向量存储类型 可选(weaviate/milvus/qdrant)
|
||||||
# 如需修改向量库类型,请修改此配置值!
|
# 如需修改向量库类型,请修改此配置值!
|
||||||
type: milvus
|
type: milvus
|
||||||
# Weaviate配置
|
# Weaviate配置
|
||||||
@@ -287,3 +287,10 @@ vector-store:
|
|||||||
milvus:
|
milvus:
|
||||||
url: http://localhost:19530
|
url: http://localhost:19530
|
||||||
collectionname: LocalKnowledge
|
collectionname: LocalKnowledge
|
||||||
|
# Qdrant配置
|
||||||
|
qdrant:
|
||||||
|
host: localhost
|
||||||
|
port: 6334
|
||||||
|
collectionname: LocalKnowledge
|
||||||
|
api-key:
|
||||||
|
use-tls: false
|
||||||
|
|||||||
@@ -91,6 +91,12 @@
|
|||||||
<version>${langchain4j.community.version}</version>
|
<version>${langchain4j.community.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-qdrant</artifactId>
|
||||||
|
<version>${langchain4j.community.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-mcp</artifactId>
|
<artifactId>langchain4j-mcp</artifactId>
|
||||||
|
|||||||
@@ -59,4 +59,37 @@ public class VectorStoreProperties {
|
|||||||
*/
|
*/
|
||||||
private String collectionname;
|
private String collectionname;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Qdrant配置
|
||||||
|
*/
|
||||||
|
private Qdrant qdrant = new Qdrant();
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class Qdrant {
|
||||||
|
/**
|
||||||
|
* 主机地址
|
||||||
|
*/
|
||||||
|
private String host = "localhost";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* gRPC端口
|
||||||
|
*/
|
||||||
|
private int port = 6334;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 集合名称
|
||||||
|
*/
|
||||||
|
private String collectionname = "LocalKnowledge";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* API密钥(可选)
|
||||||
|
*/
|
||||||
|
private String apiKey;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用TLS
|
||||||
|
*/
|
||||||
|
private boolean useTls = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.ruoyi.config.VectorStoreProperties;
|
import org.ruoyi.config.VectorStoreProperties;
|
||||||
import org.ruoyi.service.vector.VectorStoreService;
|
import org.ruoyi.service.vector.VectorStoreService;
|
||||||
import org.ruoyi.service.vector.impl.MilvusVectorStoreStrategy;
|
import org.ruoyi.service.vector.impl.MilvusVectorStoreStrategy;
|
||||||
|
import org.ruoyi.service.vector.impl.QdrantVectorStoreStrategy;
|
||||||
import org.ruoyi.service.vector.impl.WeaviateVectorStoreStrategy;
|
import org.ruoyi.service.vector.impl.WeaviateVectorStoreStrategy;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
@@ -27,6 +28,7 @@ public class VectorStoreStrategyFactory {
|
|||||||
private final VectorStoreProperties vectorStoreProperties;
|
private final VectorStoreProperties vectorStoreProperties;
|
||||||
private final WeaviateVectorStoreStrategy weaviateStrategy;
|
private final WeaviateVectorStoreStrategy weaviateStrategy;
|
||||||
private final MilvusVectorStoreStrategy milvusStrategy;
|
private final MilvusVectorStoreStrategy milvusStrategy;
|
||||||
|
private final QdrantVectorStoreStrategy qdrantStrategy;
|
||||||
|
|
||||||
private Map<String, VectorStoreService> strategies;
|
private Map<String, VectorStoreService> strategies;
|
||||||
|
|
||||||
@@ -35,6 +37,7 @@ public class VectorStoreStrategyFactory {
|
|||||||
strategies = new HashMap<>();
|
strategies = new HashMap<>();
|
||||||
strategies.put("weaviate", weaviateStrategy);
|
strategies.put("weaviate", weaviateStrategy);
|
||||||
strategies.put("milvus", milvusStrategy);
|
strategies.put("milvus", milvusStrategy);
|
||||||
|
strategies.put("qdrant", qdrantStrategy);
|
||||||
log.info("向量库策略工厂初始化完成,支持的策略: {}", strategies.keySet());
|
log.info("向量库策略工厂初始化完成,支持的策略: {}", strategies.keySet());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,204 @@
|
|||||||
|
package org.ruoyi.service.vector.impl;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.document.Metadata;
|
||||||
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
import dev.langchain4j.store.embedding.filter.Filter;
|
||||||
|
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
|
||||||
|
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
|
||||||
|
import io.qdrant.client.QdrantClient;
|
||||||
|
import io.qdrant.client.QdrantGrpcClient;
|
||||||
|
import io.qdrant.client.grpc.Collections.Distance;
|
||||||
|
import io.qdrant.client.grpc.Collections.VectorParams;
|
||||||
|
import io.qdrant.client.grpc.JsonWithInt;
|
||||||
|
import io.qdrant.client.grpc.Points.DenseVector;
|
||||||
|
import io.qdrant.client.grpc.Points.Query;
|
||||||
|
import io.qdrant.client.grpc.Points.QueryPoints;
|
||||||
|
import io.qdrant.client.grpc.Points.ScoredPoint;
|
||||||
|
import io.qdrant.client.grpc.Points.VectorInput;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
||||||
|
import org.ruoyi.common.core.exception.ServiceException;
|
||||||
|
import org.ruoyi.config.VectorStoreProperties;
|
||||||
|
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||||
|
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||||
|
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import static io.qdrant.client.VectorInputFactory.vectorInput;
|
||||||
|
import static io.qdrant.client.WithPayloadSelectorFactory.enable;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Qdrant向量库策略实现
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||||
|
|
||||||
|
private static final String VECTOR_STORE_TYPE = "qdrant";
|
||||||
|
private static final String TEXT_SEGMENT_KEY = "text_segment";
|
||||||
|
private static final String METADATA_FID_KEY = "fid";
|
||||||
|
private static final String METADATA_KID_KEY = "kid";
|
||||||
|
private static final String METADATA_DOC_ID_KEY = "doc_id";
|
||||||
|
|
||||||
|
public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||||
|
IChatModelService chatModelService,
|
||||||
|
EmbeddingModelFactory embeddingModelFactory) {
|
||||||
|
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
||||||
|
}
|
||||||
|
|
||||||
|
private EmbeddingStore<TextSegment> getQdrantStore(String collectionName) {
|
||||||
|
VectorStoreProperties.Qdrant cfg = vectorStoreProperties.getQdrant();
|
||||||
|
QdrantEmbeddingStore.Builder builder = QdrantEmbeddingStore.builder()
|
||||||
|
.host(cfg.getHost())
|
||||||
|
.port(cfg.getPort())
|
||||||
|
.collectionName(collectionName)
|
||||||
|
.useTls(cfg.isUseTls());
|
||||||
|
if (cfg.getApiKey() != null && !cfg.getApiKey().isEmpty()) {
|
||||||
|
builder.apiKey(cfg.getApiKey());
|
||||||
|
}
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private QdrantClient buildQdrantClient() {
|
||||||
|
VectorStoreProperties.Qdrant cfg = vectorStoreProperties.getQdrant();
|
||||||
|
QdrantGrpcClient.Builder grpcBuilder = QdrantGrpcClient.newBuilder(cfg.getHost(), cfg.getPort(), cfg.isUseTls());
|
||||||
|
if (cfg.getApiKey() != null && !cfg.getApiKey().isEmpty()) {
|
||||||
|
grpcBuilder.withApiKey(cfg.getApiKey());
|
||||||
|
}
|
||||||
|
return new QdrantClient(grpcBuilder.build());
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getModelDimension(String modelName) {
|
||||||
|
return chatModelService.selectModelByName(modelName).getModelDimension();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getVectorStoreType() {
|
||||||
|
return VECTOR_STORE_TYPE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void createSchema(String kid, String modelName) {
|
||||||
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + kid;
|
||||||
|
int dimension = getModelDimension(modelName);
|
||||||
|
try (QdrantClient client = buildQdrantClient()) {
|
||||||
|
Boolean exists = client.collectionExistsAsync(collectionName).get();
|
||||||
|
if (!exists) {
|
||||||
|
VectorParams params = VectorParams.newBuilder()
|
||||||
|
.setSize(dimension)
|
||||||
|
.setDistance(Distance.Cosine)
|
||||||
|
.build();
|
||||||
|
client.createCollectionAsync(collectionName, params).get();
|
||||||
|
log.info("Qdrant集合创建成功: {}, dimension: {}", collectionName, dimension);
|
||||||
|
} else {
|
||||||
|
log.info("Qdrant集合已存在: {}", collectionName);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Qdrant集合创建失败: {}", collectionName, e);
|
||||||
|
throw new ServiceException("Qdrant集合创建失败: " + collectionName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||||
|
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName());
|
||||||
|
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||||
|
List<String> fidList = storeEmbeddingBo.getFids();
|
||||||
|
String kid = storeEmbeddingBo.getKid();
|
||||||
|
String docId = storeEmbeddingBo.getDocId();
|
||||||
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + kid;
|
||||||
|
|
||||||
|
EmbeddingStore<TextSegment> embeddingStore = getQdrantStore(collectionName);
|
||||||
|
|
||||||
|
log.info("Qdrant向量存储条数记录: {}", chunkList.size());
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
|
||||||
|
IntStream.range(0, chunkList.size()).forEach(i -> {
|
||||||
|
String text = chunkList.get(i);
|
||||||
|
String fid = fidList.get(i);
|
||||||
|
Metadata metadata = new Metadata();
|
||||||
|
metadata.put(METADATA_FID_KEY, fid);
|
||||||
|
metadata.put(METADATA_KID_KEY, kid);
|
||||||
|
metadata.put(METADATA_DOC_ID_KEY, docId);
|
||||||
|
TextSegment textSegment = TextSegment.from(text, metadata);
|
||||||
|
Embedding embedding = embeddingModel.embed(text).content();
|
||||||
|
embeddingStore.add(embedding, textSegment);
|
||||||
|
});
|
||||||
|
|
||||||
|
long endTime = System.currentTimeMillis();
|
||||||
|
log.info("Qdrant向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||||
|
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||||
|
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||||
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
|
||||||
|
|
||||||
|
List<Float> vector = new ArrayList<>();
|
||||||
|
for (float f : queryEmbedding.vector()) {
|
||||||
|
vector.add(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
try (QdrantClient client = buildQdrantClient()) {
|
||||||
|
QueryPoints request = QueryPoints.newBuilder()
|
||||||
|
.setCollectionName(collectionName)
|
||||||
|
.setQuery(Query.newBuilder()
|
||||||
|
.setNearest(vectorInput(vector))
|
||||||
|
.build())
|
||||||
|
.setLimit(queryVectorBo.getMaxResults())
|
||||||
|
.setWithPayload(enable(true))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<ScoredPoint> results = client.queryAsync(request).get();
|
||||||
|
List<String> resultList = new ArrayList<>();
|
||||||
|
for (ScoredPoint point : results) {
|
||||||
|
JsonWithInt.Value textValue = point.getPayloadMap().get(TEXT_SEGMENT_KEY);
|
||||||
|
if (textValue != null && textValue.hasStringValue()) {
|
||||||
|
resultList.add(textValue.getStringValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resultList;
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Qdrant查询失败: {}", collectionName, e);
|
||||||
|
throw new ServiceException("Qdrant向量查询失败");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void removeById(String id, String modelName) {
|
||||||
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id;
|
||||||
|
try (QdrantClient client = buildQdrantClient()) {
|
||||||
|
client.deleteCollectionAsync(collectionName).get();
|
||||||
|
log.info("Qdrant成功删除集合: {}", collectionName);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Qdrant删除集合失败: {}", collectionName, e);
|
||||||
|
throw new ServiceException("失败删除向量数据!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void removeByDocId(String docId, String kid) {
|
||||||
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + kid;
|
||||||
|
EmbeddingStore<TextSegment> embeddingStore = getQdrantStore(collectionName);
|
||||||
|
Filter filter = MetadataFilterBuilder.metadataKey(METADATA_DOC_ID_KEY).isEqualTo(docId);
|
||||||
|
embeddingStore.removeAll(filter);
|
||||||
|
log.info("Qdrant成功删除 docId={} 的所有向量数据", docId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void removeByFid(String fid, String kid) {
|
||||||
|
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + kid;
|
||||||
|
EmbeddingStore<TextSegment> embeddingStore = getQdrantStore(collectionName);
|
||||||
|
Filter filter = MetadataFilterBuilder.metadataKey(METADATA_FID_KEY).isEqualTo(fid);
|
||||||
|
embeddingStore.removeAll(filter);
|
||||||
|
log.info("Qdrant成功删除 fid={} 的所有向量数据", fid);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user