mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-03-28 20:13:43 +08:00
Merge pull request #281 from Anush008/main
feat: Adds support for Qdrant vector search
This commit is contained in:
@@ -275,7 +275,7 @@ warm-flow:
|
||||
|
||||
# 向量库配置
|
||||
vector-store:
|
||||
# 向量存储类型 可选(weaviate/milvus)
|
||||
# 向量存储类型 可选(weaviate/milvus/qdrant)
|
||||
# 如需修改向量库类型,请修改此配置值!
|
||||
type: milvus
|
||||
# Weaviate配置
|
||||
@@ -287,3 +287,10 @@ vector-store:
|
||||
milvus:
|
||||
url: http://localhost:19530
|
||||
collectionname: LocalKnowledge
|
||||
# Qdrant配置
|
||||
qdrant:
|
||||
host: localhost
|
||||
port: 6334
|
||||
collectionname: LocalKnowledge
|
||||
api-key:
|
||||
use-tls: false
|
||||
|
||||
@@ -91,6 +91,12 @@
|
||||
<version>${langchain4j.community.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-qdrant</artifactId>
|
||||
<version>${langchain4j.community.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-mcp</artifactId>
|
||||
|
||||
@@ -59,4 +59,37 @@ public class VectorStoreProperties {
|
||||
*/
|
||||
private String collectionname;
|
||||
}
|
||||
|
||||
/**
|
||||
* Qdrant配置
|
||||
*/
|
||||
private Qdrant qdrant = new Qdrant();
|
||||
|
||||
@Data
|
||||
public static class Qdrant {
|
||||
/**
|
||||
* 主机地址
|
||||
*/
|
||||
private String host = "localhost";
|
||||
|
||||
/**
|
||||
* gRPC端口
|
||||
*/
|
||||
private int port = 6334;
|
||||
|
||||
/**
|
||||
* 集合名称
|
||||
*/
|
||||
private String collectionname = "LocalKnowledge";
|
||||
|
||||
/**
|
||||
* API密钥(可选)
|
||||
*/
|
||||
private String apiKey;
|
||||
|
||||
/**
|
||||
* 是否启用TLS
|
||||
*/
|
||||
private boolean useTls = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.config.VectorStoreProperties;
|
||||
import org.ruoyi.service.vector.VectorStoreService;
|
||||
import org.ruoyi.service.vector.impl.MilvusVectorStoreStrategy;
|
||||
import org.ruoyi.service.vector.impl.QdrantVectorStoreStrategy;
|
||||
import org.ruoyi.service.vector.impl.WeaviateVectorStoreStrategy;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@@ -27,6 +28,7 @@ public class VectorStoreStrategyFactory {
|
||||
private final VectorStoreProperties vectorStoreProperties;
|
||||
private final WeaviateVectorStoreStrategy weaviateStrategy;
|
||||
private final MilvusVectorStoreStrategy milvusStrategy;
|
||||
private final QdrantVectorStoreStrategy qdrantStrategy;
|
||||
|
||||
private Map<String, VectorStoreService> strategies;
|
||||
|
||||
@@ -35,6 +37,7 @@ public class VectorStoreStrategyFactory {
|
||||
strategies = new HashMap<>();
|
||||
strategies.put("weaviate", weaviateStrategy);
|
||||
strategies.put("milvus", milvusStrategy);
|
||||
strategies.put("qdrant", qdrantStrategy);
|
||||
log.info("向量库策略工厂初始化完成,支持的策略: {}", strategies.keySet());
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
package org.ruoyi.service.vector.impl;
|
||||
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
|
||||
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
|
||||
import io.qdrant.client.QdrantClient;
|
||||
import io.qdrant.client.QdrantGrpcClient;
|
||||
import io.qdrant.client.grpc.Collections.Distance;
|
||||
import io.qdrant.client.grpc.Collections.VectorParams;
|
||||
import io.qdrant.client.grpc.JsonWithInt;
|
||||
import io.qdrant.client.grpc.Points.DenseVector;
|
||||
import io.qdrant.client.grpc.Points.Query;
|
||||
import io.qdrant.client.grpc.Points.QueryPoints;
|
||||
import io.qdrant.client.grpc.Points.ScoredPoint;
|
||||
import io.qdrant.client.grpc.Points.VectorInput;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.chat.service.chat.IChatModelService;
|
||||
import org.ruoyi.common.core.exception.ServiceException;
|
||||
import org.ruoyi.config.VectorStoreProperties;
|
||||
import org.ruoyi.domain.bo.vector.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.vector.StoreEmbeddingBo;
|
||||
import org.ruoyi.factory.EmbeddingModelFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static io.qdrant.client.VectorInputFactory.vectorInput;
|
||||
import static io.qdrant.client.WithPayloadSelectorFactory.enable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Qdrant向量库策略实现
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class QdrantVectorStoreStrategy extends AbstractVectorStoreStrategy {
|
||||
|
||||
private static final String VECTOR_STORE_TYPE = "qdrant";
|
||||
private static final String TEXT_SEGMENT_KEY = "text_segment";
|
||||
private static final String METADATA_FID_KEY = "fid";
|
||||
private static final String METADATA_KID_KEY = "kid";
|
||||
private static final String METADATA_DOC_ID_KEY = "doc_id";
|
||||
|
||||
public QdrantVectorStoreStrategy(VectorStoreProperties vectorStoreProperties,
|
||||
IChatModelService chatModelService,
|
||||
EmbeddingModelFactory embeddingModelFactory) {
|
||||
super(vectorStoreProperties, embeddingModelFactory, chatModelService);
|
||||
}
|
||||
|
||||
private EmbeddingStore<TextSegment> getQdrantStore(String collectionName) {
|
||||
VectorStoreProperties.Qdrant cfg = vectorStoreProperties.getQdrant();
|
||||
QdrantEmbeddingStore.Builder builder = QdrantEmbeddingStore.builder()
|
||||
.host(cfg.getHost())
|
||||
.port(cfg.getPort())
|
||||
.collectionName(collectionName)
|
||||
.useTls(cfg.isUseTls());
|
||||
if (cfg.getApiKey() != null && !cfg.getApiKey().isEmpty()) {
|
||||
builder.apiKey(cfg.getApiKey());
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
private QdrantClient buildQdrantClient() {
|
||||
VectorStoreProperties.Qdrant cfg = vectorStoreProperties.getQdrant();
|
||||
QdrantGrpcClient.Builder grpcBuilder = QdrantGrpcClient.newBuilder(cfg.getHost(), cfg.getPort(), cfg.isUseTls());
|
||||
if (cfg.getApiKey() != null && !cfg.getApiKey().isEmpty()) {
|
||||
grpcBuilder.withApiKey(cfg.getApiKey());
|
||||
}
|
||||
return new QdrantClient(grpcBuilder.build());
|
||||
}
|
||||
|
||||
private int getModelDimension(String modelName) {
|
||||
return chatModelService.selectModelByName(modelName).getModelDimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getVectorStoreType() {
|
||||
return VECTOR_STORE_TYPE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createSchema(String kid, String modelName) {
|
||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + kid;
|
||||
int dimension = getModelDimension(modelName);
|
||||
try (QdrantClient client = buildQdrantClient()) {
|
||||
Boolean exists = client.collectionExistsAsync(collectionName).get();
|
||||
if (!exists) {
|
||||
VectorParams params = VectorParams.newBuilder()
|
||||
.setSize(dimension)
|
||||
.setDistance(Distance.Cosine)
|
||||
.build();
|
||||
client.createCollectionAsync(collectionName, params).get();
|
||||
log.info("Qdrant集合创建成功: {}, dimension: {}", collectionName, dimension);
|
||||
} else {
|
||||
log.info("Qdrant集合已存在: {}", collectionName);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("Qdrant集合创建失败: {}", collectionName, e);
|
||||
throw new ServiceException("Qdrant集合创建失败: " + collectionName);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName());
|
||||
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
||||
List<String> fidList = storeEmbeddingBo.getFids();
|
||||
String kid = storeEmbeddingBo.getKid();
|
||||
String docId = storeEmbeddingBo.getDocId();
|
||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + kid;
|
||||
|
||||
EmbeddingStore<TextSegment> embeddingStore = getQdrantStore(collectionName);
|
||||
|
||||
log.info("Qdrant向量存储条数记录: {}", chunkList.size());
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
IntStream.range(0, chunkList.size()).forEach(i -> {
|
||||
String text = chunkList.get(i);
|
||||
String fid = fidList.get(i);
|
||||
Metadata metadata = new Metadata();
|
||||
metadata.put(METADATA_FID_KEY, fid);
|
||||
metadata.put(METADATA_KID_KEY, kid);
|
||||
metadata.put(METADATA_DOC_ID_KEY, docId);
|
||||
TextSegment textSegment = TextSegment.from(text, metadata);
|
||||
Embedding embedding = embeddingModel.embed(text).content();
|
||||
embeddingStore.add(embedding, textSegment);
|
||||
});
|
||||
|
||||
long endTime = System.currentTimeMillis();
|
||||
log.info("Qdrant向量存储完成消耗时间:{}秒", (endTime - startTime) / 1000);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName());
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + queryVectorBo.getKid();
|
||||
|
||||
List<Float> vector = new ArrayList<>();
|
||||
for (float f : queryEmbedding.vector()) {
|
||||
vector.add(f);
|
||||
}
|
||||
|
||||
try (QdrantClient client = buildQdrantClient()) {
|
||||
QueryPoints request = QueryPoints.newBuilder()
|
||||
.setCollectionName(collectionName)
|
||||
.setQuery(Query.newBuilder()
|
||||
.setNearest(vectorInput(vector))
|
||||
.build())
|
||||
.setLimit(queryVectorBo.getMaxResults())
|
||||
.setWithPayload(enable(true))
|
||||
.build();
|
||||
|
||||
List<ScoredPoint> results = client.queryAsync(request).get();
|
||||
List<String> resultList = new ArrayList<>();
|
||||
for (ScoredPoint point : results) {
|
||||
JsonWithInt.Value textValue = point.getPayloadMap().get(TEXT_SEGMENT_KEY);
|
||||
if (textValue != null && textValue.hasStringValue()) {
|
||||
resultList.add(textValue.getStringValue());
|
||||
}
|
||||
}
|
||||
return resultList;
|
||||
} catch (Exception e) {
|
||||
log.error("Qdrant查询失败: {}", collectionName, e);
|
||||
throw new ServiceException("Qdrant向量查询失败");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeById(String id, String modelName) {
|
||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + id;
|
||||
try (QdrantClient client = buildQdrantClient()) {
|
||||
client.deleteCollectionAsync(collectionName).get();
|
||||
log.info("Qdrant成功删除集合: {}", collectionName);
|
||||
} catch (Exception e) {
|
||||
log.error("Qdrant删除集合失败: {}", collectionName, e);
|
||||
throw new ServiceException("失败删除向量数据!");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByDocId(String docId, String kid) {
|
||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + kid;
|
||||
EmbeddingStore<TextSegment> embeddingStore = getQdrantStore(collectionName);
|
||||
Filter filter = MetadataFilterBuilder.metadataKey(METADATA_DOC_ID_KEY).isEqualTo(docId);
|
||||
embeddingStore.removeAll(filter);
|
||||
log.info("Qdrant成功删除 docId={} 的所有向量数据", docId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByFid(String fid, String kid) {
|
||||
String collectionName = vectorStoreProperties.getQdrant().getCollectionname() + kid;
|
||||
EmbeddingStore<TextSegment> embeddingStore = getQdrantStore(collectionName);
|
||||
Filter filter = MetadataFilterBuilder.metadataKey(METADATA_FID_KEY).isEqualTo(fid);
|
||||
embeddingStore.removeAll(filter);
|
||||
log.info("Qdrant成功删除 fid={} 的所有向量数据", fid);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user