Merge pull request #222 from Cyclones-Y/main

通过策略模式扩展milvus向量库
This commit is contained in:
ageerle
2025-10-12 19:04:35 +08:00
committed by GitHub
13 changed files with 860 additions and 226 deletions

View File

@@ -328,3 +328,17 @@ spring:
servers-configuration: classpath:mcp-server.json
request-timeout: 300s
--- # 向量库配置
vector-store:
# 向量存储类型 (weaviate/milvus)
type: weaviate
# Weaviate配置
weaviate:
protocol: http
host: 127.0.0.1:6038
classname: LocalKnowledge
# Milvus配置
milvus:
url: http://localhost:19530
collectionname: LocalKnowledge

View File

@@ -0,0 +1,62 @@
package org.ruoyi.common.core.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
/**
* 向量库配置属性
*
* @author ageer
*/
@Data
@Component
@ConfigurationProperties(prefix = "vector-store")
public class VectorStoreProperties {
/**
* 向量库类型
*/
private String type = "weaviate";
/**
* Weaviate配置
*/
private Weaviate weaviate = new Weaviate();
/**
* Milvus配置
*/
private Milvus milvus = new Milvus();
@Data
public static class Weaviate {
/**
* 协议
*/
private String protocol = "http";
/**
* 主机地址
*/
private String host = "localhost:8080";
/**
* 类名
*/
private String classname = "Document";
}
@Data
public static class Milvus {
/**
* 连接URL
*/
private String url = "http://localhost:19530";
/**
* 集合名称
*/
private String collectionname = "knowledge_base";
}
}

View File

@@ -74,6 +74,12 @@
<version>1.19.6</version>
</dependency>
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>2.6.4</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>

View File

@@ -1,5 +1,6 @@
package org.ruoyi.service;
import org.ruoyi.common.core.exception.ServiceException;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
@@ -11,15 +12,15 @@ import java.util.List;
*/
public interface VectorStoreService {
void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo);
void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) throws ServiceException;
List<String> getQueryVector(QueryVectorBo queryVectorBo);
void createSchema(String kid,String modelName);
void createSchema(String vectorModelName, String kid,String modelName);
void removeById(String id,String modelName);
void removeById(String id,String modelName) throws ServiceException;
void removeByDocId(String docId, String kid);
void removeByDocId(String docId, String kid) throws ServiceException;
void removeByFid(String fid, String kid);
void removeByFid(String fid, String kid) throws ServiceException;
}

View File

@@ -1,36 +1,14 @@
package org.ruoyi.service.impl;
import cn.hutool.json.JSONObject;
import com.google.protobuf.ServiceException;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
import io.weaviate.client.v1.batch.model.BatchDeleteResponse;
import io.weaviate.client.v1.filters.Operator;
import io.weaviate.client.v1.filters.WhereFilter;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.schema.model.Property;
import io.weaviate.client.v1.schema.model.Schema;
import io.weaviate.client.v1.schema.model.WeaviateClass;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.embedding.BaseEmbedModelService;
import org.ruoyi.embedding.EmbeddingModelFactory;
import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.strategy.VectorStoreStrategy;
import org.ruoyi.service.strategy.VectorStoreStrategyFactory;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;
@@ -41,210 +19,62 @@ import java.util.stream.Collectors;
* @author ageer
*/
@Service
@Primary
@Slf4j
@RequiredArgsConstructor
public class VectorStoreServiceImpl implements VectorStoreService {
private final ConfigService configService;
private final VectorStoreStrategyFactory strategyFactory;
// private EmbeddingStore<TextSegment> embeddingStore;
private WeaviateClient client;
private final EmbeddingModelFactory embeddingModelFactory;
/**
* 获取当前配置的向量库策略
*/
private VectorStoreStrategy getCurrentStrategy() {
return strategyFactory.getStrategy();
}
@Override
public void createSchema(String kid, String modelName) {
String protocol = configService.getConfigValue("weaviate", "protocol");
String host = configService.getConfigValue("weaviate", "host");
String className = configService.getConfigValue("weaviate", "classname")+kid;
// 创建 Weaviate 客户端
client= new WeaviateClient(new Config(protocol, host));
// 检查类是否存在,如果不存在就创建 schema
Result<Schema> schemaResult = client.schema().getter().run();
Schema schema = schemaResult.getResult();
boolean classExists = false;
for (WeaviateClass weaviateClass : schema.getClasses()) {
if (weaviateClass.getClassName().equals(className)) {
classExists = true;
break;
}
}
if (!classExists) {
// 类不存在,创建 schema
WeaviateClass build = WeaviateClass.builder()
.className(className)
.vectorizer("none")
.properties(
List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(),
Property.builder().name("fid").dataType(Collections.singletonList("text")).build(),
Property.builder().name("kid").dataType(Collections.singletonList("text")).build(),
Property.builder().name("docId").dataType(Collections.singletonList("text")).build())
)
.build();
Result<Boolean> createResult = client.schema().classCreator().withClass(build).run();
if (createResult.hasErrors()) {
log.error("Schema 创建失败: {}", createResult.getError());
} else {
log.info("Schema 创建成功: {}", className);
}
}
// embeddingStore = WeaviateEmbeddingStore.builder()
// .scheme(protocol)
// .host(host)
// .objectClass(className)
// .scheme(protocol)
// .avoidDups(true)
// .consistencyLevel("ALL")
// .build();
public void createSchema(String vectorModelName, String kid, String modelName) {
log.info("创建向量库schema: vectorModelName={}, kid={}, modelName={}", vectorModelName, kid, modelName);
VectorStoreStrategy strategy = getCurrentStrategy();
strategy.createSchema(vectorModelName, kid, modelName);
}
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
BaseEmbedModelService model = embeddingModelFactory.createModel(storeEmbeddingBo.getEmbeddingModelId());
List<String> chunkList = storeEmbeddingBo.getChunkList();
List<String> fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
String docId = storeEmbeddingBo.getDocId();
long startTime = System.currentTimeMillis();
for (int i = 0; i < chunkList.size(); i++) {
String text = chunkList.get(i);
String fid = fidList.get(i);
Embedding embedding = model.embed(text).content();
Map<String, Object> properties = Map.of(
"text", text,
"fid",fid,
"kid", kid,
"docId", docId
);
Float[] vector = toObjectArray(embedding.vector());
client.data().creator()
.withClassName("LocalKnowledge" + kid) // 注意替换成实际类名
.withProperties(properties)
.withVector(vector)
.run();
}
long endTime = System.currentTimeMillis();
log.info("向量存储完成消耗时间:"+ (endTime-startTime)/1000+"");
log.info("存储向量数据: kid={}, docId={}, 数据条数={}",
storeEmbeddingBo.getKid(), storeEmbeddingBo.getDocId(), storeEmbeddingBo.getChunkList().size());
VectorStoreStrategy strategy = getCurrentStrategy();
strategy.storeEmbeddings(storeEmbeddingBo);
}
private static Float[] toObjectArray(float[] primitive) {
Float[] result = new Float[primitive.length];
for (int i = 0; i < primitive.length; i++) {
result[i] = primitive[i]; // 自动装箱
}
return result;
}
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
BaseEmbedModelService model = embeddingModelFactory.createModel(queryVectorBo.getEmbeddingModelId());
Embedding queryEmbedding = model.embed(queryVectorBo.getQuery()).content();
float[] vector = queryEmbedding.vector();
List<String> vectorStrings = new ArrayList<>();
for (float v : vector) {
vectorStrings.add(String.valueOf(v));
}
String vectorStr = String.join(",", vectorStrings);
String className = configService.getConfigValue("weaviate", "classname") ;
// 构建 GraphQL 查询
String graphQLQuery = String.format(
"{\n" +
" Get {\n" +
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
" text\n" +
" fid\n" +
" kid\n" +
" docId\n" +
" _additional {\n" +
" distance\n" +
" id\n" +
" }\n" +
" }\n" +
" }\n" +
"}",
className+ queryVectorBo.getKid(),
vectorStr,
queryVectorBo.getMaxResults()
);
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
List<String> resultList = new ArrayList<>();
if (result != null && !result.hasErrors()) {
Object data = result.getResult().getData();
JSONObject entries = new JSONObject(data);
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
if(objects.isEmpty()){
return resultList;
}
for (Object object : objects) {
Map<String, String> map = (Map<String, String>) object;
String content = map.get("text");
resultList.add( content);
}
return resultList;
} else {
log.error("GraphQL 查询失败: {}", result.getError());
return resultList;
}
log.info("查询向量数据: kid={}, query={}, maxResults={}",
queryVectorBo.getKid(), queryVectorBo.getQuery(), queryVectorBo.getMaxResults());
VectorStoreStrategy strategy = getCurrentStrategy();
return strategy.getQueryVector(queryVectorBo);
}
@Override
@SneakyThrows
public void removeById(String id, String modelName) {
String protocol = configService.getConfigValue("weaviate", "protocol");
String host = configService.getConfigValue("weaviate", "host");
String className = configService.getConfigValue("weaviate", "classname");
String finalClassName = className + id;
WeaviateClient client = new WeaviateClient(new Config(protocol, host));
Result<Boolean> result = client.schema().classDeleter().withClassName(finalClassName).run();
if (result.hasErrors()) {
log.error("失败删除向量: " + result.getError());
throw new ServiceException("失败删除向量数据!");
} else {
log.info("成功删除向量数据: " + result.getResult());
}
log.info("根据ID删除向量数据: id={}, modelName={}", id, modelName);
VectorStoreStrategy strategy = getCurrentStrategy();
strategy.removeById(id, modelName);
}
@Override
public void removeByDocId(String docId, String kid) {
String className = configService.getConfigValue("weaviate", "classname") + kid;
// 构建 Where 条件
WhereFilter whereFilter = WhereFilter.builder()
.path("docId")
.operator(Operator.Equal)
.valueText(docId)
.build();
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
Result<BatchDeleteResponse> result = deleter.withClassName(className)
.withWhere(whereFilter)
.run();
if (result != null && !result.hasErrors()) {
log.info("成功删除 docId={} 的所有向量数据", docId);
} else {
log.error("删除失败: {}", result.getError());
}
log.info("根据docId删除向量数据: docId={}, kid={}", docId, kid);
VectorStoreStrategy strategy = getCurrentStrategy();
strategy.removeByDocId(docId, kid);
}
@Override
public void removeByFid(String fid, String kid) {
String className = configService.getConfigValue("weaviate", "classname") + kid;
// 构建 Where 条件
WhereFilter whereFilter = WhereFilter.builder()
.path("fid")
.operator(Operator.Equal)
.valueText(fid)
.build();
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
Result<BatchDeleteResponse> result = deleter.withClassName(className)
.withWhere(whereFilter)
.run();
if (result != null && !result.hasErrors()) {
log.info("成功删除 fid={} 的所有向量数据", fid);
} else {
log.error("删除失败: {}", result.getError());
}
log.info("根据fid删除向量数据: fid={}, kid={}", fid, kid);
VectorStoreStrategy strategy = getCurrentStrategy();
strategy.removeByFid(fid, kid);
}
}

View File

@@ -0,0 +1,62 @@
package org.ruoyi.service.strategy;
import org.ruoyi.common.core.exception.ServiceException;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
/**
* 向量库策略抽象基类
* 提供公共的方法实现如embedding模型获取等
*
* @author ageer
*/
@Slf4j
@RequiredArgsConstructor
public abstract class AbstractVectorStoreStrategy implements VectorStoreStrategy {
protected final VectorStoreProperties vectorStoreProperties;
/**
* 获取向量模型
*/
@SneakyThrows
protected EmbeddingModel getEmbeddingModel(String modelName, String apiKey, String baseUrl) {
EmbeddingModel embeddingModel;
if ("quentinz/bge-large-zh-v1.5".equals(modelName)) {
embeddingModel = OllamaEmbeddingModel.builder()
.baseUrl(baseUrl)
.modelName(modelName)
.build();
} else if ("baai/bge-m3".equals(modelName)) {
embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(apiKey)
.baseUrl(baseUrl)
.modelName(modelName)
.build();
} else {
throw new ServiceException("未找到对应向量化模型!");
}
return embeddingModel;
}
/**
* 将float数组转换为Float对象数组
*/
protected static Float[] toObjectArray(float[] primitive) {
Float[] result = new Float[primitive.length];
for (int i = 0; i < primitive.length; i++) {
result[i] = primitive[i]; // 自动装箱
}
return result;
}
/**
* 获取向量库类型标识
*/
public abstract String getVectorStoreType();
}

View File

@@ -0,0 +1,18 @@
package org.ruoyi.service.strategy;
import org.ruoyi.service.VectorStoreService;
/**
* 向量库策略接口
* 继承VectorStoreService以避免重复定义相同的方法
*
* @author ageer
*/
public interface VectorStoreStrategy extends VectorStoreService {
/**
* 获取向量库类型标识
* @return 向量库类型weaviate, milvus
*/
String getVectorStoreType();
}

View File

@@ -0,0 +1,74 @@
package org.ruoyi.service.strategy;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.service.strategy.impl.MilvusVectorStoreStrategy;
import org.ruoyi.service.strategy.impl.WeaviateVectorStoreStrategy;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
/**
* 向量库策略工厂
* 根据配置动态选择向量库实现
*
* @author ageer
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class VectorStoreStrategyFactory {
private final VectorStoreProperties vectorStoreProperties;
private final WeaviateVectorStoreStrategy weaviateStrategy;
private final MilvusVectorStoreStrategy milvusStrategy;
private Map<String, VectorStoreStrategy> strategies;
@PostConstruct
public void init() {
strategies = new HashMap<>();
strategies.put("weaviate", weaviateStrategy);
strategies.put("milvus", milvusStrategy);
log.info("向量库策略工厂初始化完成,支持的策略: {}", strategies.keySet());
}
/**
* 获取当前配置的向量库策略
*/
public VectorStoreStrategy getStrategy() {
String vectorStoreType = vectorStoreProperties.getType();
if (vectorStoreType == null || vectorStoreType.trim().isEmpty()) {
vectorStoreType = "weaviate"; // 默认使用weaviate
}
VectorStoreStrategy strategy = strategies.get(vectorStoreType.toLowerCase());
if (strategy == null) {
log.warn("未找到向量库策略: {}, 使用默认策略: weaviate", vectorStoreType);
strategy = strategies.get("weaviate");
}
log.debug("使用向量库策略: {}", vectorStoreType);
return strategy;
}
/**
* 根据类型获取向量库策略
*/
public VectorStoreStrategy getStrategy(String type) {
if (type == null || type.trim().isEmpty()) {
return getStrategy();
}
VectorStoreStrategy strategy = strategies.get(type.toLowerCase());
if (strategy == null) {
log.warn("未找到向量库策略: {}, 使用默认策略", type);
return getStrategy();
}
return strategy;
}
}

View File

@@ -0,0 +1,337 @@
package org.ruoyi.service.strategy.impl;
import org.ruoyi.common.core.exception.ServiceException;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.grpc.*;
import io.milvus.param.*;
import io.milvus.param.collection.*;
import io.milvus.param.dml.DeleteParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.param.index.DescribeIndexParam;
import io.milvus.response.DescCollResponseWrapper;
import io.milvus.response.SearchResultsWrapper;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
import org.springframework.stereotype.Component;
import java.util.*;
/**
* Milvus向量库策略实现
*
* @author ageer
*/
@Slf4j
@Component
public class MilvusVectorStoreStrategy extends AbstractVectorStoreStrategy {
private MilvusServiceClient milvusClient;
public MilvusVectorStoreStrategy(VectorStoreProperties vectorStoreProperties) {
super(vectorStoreProperties);
}
@Override
public String getVectorStoreType() {
return "milvus";
}
@Override
public void createSchema(String vectorModelName, String kid, String modelName) {
String url = vectorStoreProperties.getMilvus().getUrl();
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
// 创建Milvus客户端连接
ConnectParam connectParam = ConnectParam.newBuilder()
.withUri(url)
.build();
milvusClient = new MilvusServiceClient(connectParam);
// 检查集合是否存在
HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build();
R<Boolean> hasCollectionResponse = milvusClient.hasCollection(hasCollectionParam);
if (hasCollectionResponse.getStatus() != R.Status.Success.getCode()) {
log.error("检查集合是否存在失败: {}", hasCollectionResponse.getMessage());
return;
}
if (!hasCollectionResponse.getData()) {
// 创建字段
List<FieldType> fields = new ArrayList<>();
// ID字段 (主键)
fields.add(FieldType.newBuilder()
.withName("id")
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(true)
.build());
// 文本字段
fields.add(FieldType.newBuilder()
.withName("text")
.withDataType(DataType.VarChar)
.withMaxLength(65535)
.build());
// fid字段
fields.add(FieldType.newBuilder()
.withName("fid")
.withDataType(DataType.VarChar)
.withMaxLength(255)
.build());
// kid字段
fields.add(FieldType.newBuilder()
.withName("kid")
.withDataType(DataType.VarChar)
.withMaxLength(255)
.build());
// docId字段
fields.add(FieldType.newBuilder()
.withName("docId")
.withDataType(DataType.VarChar)
.withMaxLength(255)
.build());
// 向量字段
fields.add(FieldType.newBuilder()
.withName("vector")
.withDataType(DataType.FloatVector)
.withDimension(1024) // 根据实际embedding维度调整
.build());
// 创建集合
CreateCollectionParam createCollectionParam = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName)
.withDescription("Knowledge base collection for " + kid)
.withShardsNum(2)
.withFieldTypes(fields)
.build();
R<RpcStatus> createCollectionResponse = milvusClient.createCollection(createCollectionParam);
if (createCollectionResponse.getStatus() != R.Status.Success.getCode()) {
log.error("创建集合失败: {}", createCollectionResponse.getMessage());
return;
}
// 创建索引
CreateIndexParam createIndexParam = CreateIndexParam.newBuilder()
.withCollectionName(collectionName)
.withFieldName("vector")
.withIndexType(IndexType.IVF_FLAT)
.withMetricType(MetricType.L2)
.withExtraParam("{\"nlist\":1024}")
.build();
R<RpcStatus> createIndexResponse = milvusClient.createIndex(createIndexParam);
if (createIndexResponse.getStatus() != R.Status.Success.getCode()) {
log.error("创建索引失败: {}", createIndexResponse.getMessage());
} else {
log.info("Milvus集合和索引创建成功: {}", collectionName);
}
} else {
log.info("Milvus集合已存在: {}", collectionName);
}
}
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
createSchema(storeEmbeddingBo.getVectorModelName(), storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
List<String> chunkList = storeEmbeddingBo.getChunkList();
List<String> fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
String docId = storeEmbeddingBo.getDocId();
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
log.info("Milvus向量存储条数记录: " + chunkList.size());
long startTime = System.currentTimeMillis();
// 准备批量插入数据
List<InsertParam.Field> fields = new ArrayList<>();
List<String> textList = new ArrayList<>();
List<String> fidListData = new ArrayList<>();
List<String> kidList = new ArrayList<>();
List<String> docIdList = new ArrayList<>();
List<List<Float>> vectorList = new ArrayList<>();
for (int i = 0; i < chunkList.size(); i++) {
String text = chunkList.get(i);
String fid = fidList.get(i);
Embedding embedding = embeddingModel.embed(text).content();
textList.add(text);
fidListData.add(fid);
kidList.add(kid);
docIdList.add(docId);
List<Float> vector = new ArrayList<>();
for (float f : embedding.vector()) {
vector.add(f);
}
vectorList.add(vector);
}
// 构建字段数据
fields.add(new InsertParam.Field("text", textList));
fields.add(new InsertParam.Field("fid", fidListData));
fields.add(new InsertParam.Field("kid", kidList));
fields.add(new InsertParam.Field("docId", docIdList));
fields.add(new InsertParam.Field("vector", vectorList));
// 执行插入
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(collectionName)
.withFields(fields)
.build();
R<MutationResult> insertResponse = milvusClient.insert(insertParam);
if (insertResponse.getStatus() != R.Status.Success.getCode()) {
log.error("Milvus向量存储失败: {}", insertResponse.getMessage());
throw new ServiceException("Milvus向量存储失败");
} else {
log.info("Milvus向量存储成功插入条数: {}", insertResponse.getData().getInsertCnt());
}
long endTime = System.currentTimeMillis();
log.info("Milvus向量存储完成消耗时间" + (endTime - startTime) / 1000 + "");
}
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
createSchema(queryVectorBo.getVectorModelName(), queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + queryVectorBo.getKid();
List<String> resultList = new ArrayList<>();
// 加载集合到内存
LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build();
milvusClient.loadCollection(loadCollectionParam);
// 准备查询向量
List<List<Float>> searchVectors = new ArrayList<>();
List<Float> queryVector = new ArrayList<>();
for (float f : queryEmbedding.vector()) {
queryVector.add(f);
}
searchVectors.add(queryVector);
// 构建搜索参数
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(collectionName)
.withMetricType(MetricType.L2)
.withOutFields(Arrays.asList("text", "fid", "kid", "docId"))
.withTopK(queryVectorBo.getMaxResults())
.withVectors(searchVectors)
.withVectorFieldName("vector")
.withParams("{\"nprobe\":10}")
.build();
R<SearchResults> searchResponse = milvusClient.search(searchParam);
if (searchResponse.getStatus() != R.Status.Success.getCode()) {
log.error("Milvus查询失败: {}", searchResponse.getMessage());
return resultList;
}
SearchResultsWrapper wrapper = new SearchResultsWrapper(searchResponse.getData().getResults());
// 遍历搜索结果
for (int i = 0; i < wrapper.getIDScore(0).size(); i++) {
SearchResultsWrapper.IDScore idScore = wrapper.getIDScore(0).get(i);
// 获取text字段数据
List<?> textFieldData = wrapper.getFieldData("text", 0);
if (textFieldData != null && i < textFieldData.size()) {
Object textObj = textFieldData.get(i);
if (textObj != null) {
resultList.add(textObj.toString());
log.debug("找到相似文本ID: {}, 距离: {}, 内容: {}",
idScore.getLongID(), idScore.getScore(), textObj.toString());
}
}
}
return resultList;
}
@Override
@SneakyThrows
public void removeById(String id, String modelName) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + id;
// 删除整个集合
DropCollectionParam dropCollectionParam = DropCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build();
R<RpcStatus> dropResponse = milvusClient.dropCollection(dropCollectionParam);
if (dropResponse.getStatus() != R.Status.Success.getCode()) {
log.error("Milvus集合删除失败: {}", dropResponse.getMessage());
throw new ServiceException("Milvus集合删除失败");
} else {
log.info("Milvus集合删除成功: {}", collectionName);
}
}
@Override
public void removeByDocId(String docId, String kid) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
String expr = "docId == \"" + docId + "\"";
DeleteParam deleteParam = DeleteParam.newBuilder()
.withCollectionName(collectionName)
.withExpr(expr)
.build();
R<MutationResult> deleteResponse = milvusClient.delete(deleteParam);
if (deleteResponse.getStatus() != R.Status.Success.getCode()) {
log.error("Milvus删除失败: {}", deleteResponse.getMessage());
throw new ServiceException("Milvus删除失败");
} else {
log.info("Milvus成功删除 docId={} 的所有向量数据,删除条数: {}", docId, deleteResponse.getData().getDeleteCnt());
}
}
@Override
public void removeByFid(String fid, String kid) {
String collectionName = vectorStoreProperties.getMilvus().getCollectionname() + kid;
String expr = "fid == \"" + fid + "\"";
DeleteParam deleteParam = DeleteParam.newBuilder()
.withCollectionName(collectionName)
.withExpr(expr)
.build();
R<MutationResult> deleteResponse = milvusClient.delete(deleteParam);
if (deleteResponse.getStatus() != R.Status.Success.getCode()) {
log.error("Milvus删除失败: {}", deleteResponse.getMessage());
throw new ServiceException("Milvus删除失败");
} else {
log.info("Milvus成功删除 fid={} 的所有向量数据,删除条数: {}", fid, deleteResponse.getData().getDeleteCnt());
}
}
}

View File

@@ -0,0 +1,233 @@
package org.ruoyi.service.strategy.impl;
import cn.hutool.json.JSONObject;
import org.ruoyi.common.core.exception.ServiceException;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
import io.weaviate.client.v1.batch.model.BatchDeleteResponse;
import io.weaviate.client.v1.filters.Operator;
import io.weaviate.client.v1.filters.WhereFilter;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.schema.model.Property;
import io.weaviate.client.v1.schema.model.Schema;
import io.weaviate.client.v1.schema.model.WeaviateClass;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.core.config.VectorStoreProperties;
import org.ruoyi.domain.bo.QueryVectorBo;
import org.ruoyi.domain.bo.StoreEmbeddingBo;
import org.ruoyi.service.strategy.AbstractVectorStoreStrategy;
import org.springframework.stereotype.Component;
import java.util.*;
/**
* Weaviate向量库策略实现
*
* @author ageer
*/
@Slf4j
@Component
public class WeaviateVectorStoreStrategy extends AbstractVectorStoreStrategy {
private WeaviateClient client;
public WeaviateVectorStoreStrategy(VectorStoreProperties vectorStoreProperties) {
super(vectorStoreProperties);
}
@Override
public String getVectorStoreType() {
return "weaviate";
}
@Override
public void createSchema(String vectorModelName, String kid, String modelName) {
String protocol = vectorStoreProperties.getWeaviate().getProtocol();
String host = vectorStoreProperties.getWeaviate().getHost();
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
// 创建 Weaviate 客户端
client = new WeaviateClient(new Config(protocol, host));
// 检查类是否存在,如果不存在就创建 schema
Result<Schema> schemaResult = client.schema().getter().run();
Schema schema = schemaResult.getResult();
boolean classExists = false;
for (WeaviateClass weaviateClass : schema.getClasses()) {
if (weaviateClass.getClassName().equals(className)) {
classExists = true;
break;
}
}
if (!classExists) {
// 类不存在,创建 schema
WeaviateClass build = WeaviateClass.builder()
.className(className)
.vectorizer("none")
.properties(
List.of(Property.builder().name("text").dataType(Collections.singletonList("text")).build(),
Property.builder().name("fid").dataType(Collections.singletonList("text")).build(),
Property.builder().name("kid").dataType(Collections.singletonList("text")).build(),
Property.builder().name("docId").dataType(Collections.singletonList("text")).build())
)
.build();
Result<Boolean> createResult = client.schema().classCreator().withClass(build).run();
if (createResult.hasErrors()) {
log.error("Schema 创建失败: {}", createResult.getError());
} else {
log.info("Schema 创建成功: {}", className);
}
}
}
@Override
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
createSchema(storeEmbeddingBo.getVectorModelName(), storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
List<String> chunkList = storeEmbeddingBo.getChunkList();
List<String> fidList = storeEmbeddingBo.getFids();
String kid = storeEmbeddingBo.getKid();
String docId = storeEmbeddingBo.getDocId();
log.info("向量存储条数记录: " + chunkList.size());
long startTime = System.currentTimeMillis();
for (int i = 0; i < chunkList.size(); i++) {
String text = chunkList.get(i);
String fid = fidList.get(i);
Embedding embedding = embeddingModel.embed(text).content();
Map<String, Object> properties = Map.of(
"text", text,
"fid", fid,
"kid", kid,
"docId", docId
);
Float[] vector = toObjectArray(embedding.vector());
client.data().creator()
.withClassName("LocalKnowledge" + kid)
.withProperties(properties)
.withVector(vector)
.run();
}
long endTime = System.currentTimeMillis();
log.info("向量存储完成消耗时间:" + (endTime - startTime) / 1000 + "");
}
@Override
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
createSchema(queryVectorBo.getVectorModelName(), queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getEmbeddingModelName(),
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
float[] vector = queryEmbedding.vector();
List<String> vectorStrings = new ArrayList<>();
for (float v : vector) {
vectorStrings.add(String.valueOf(v));
}
String vectorStr = String.join(",", vectorStrings);
String className = vectorStoreProperties.getWeaviate().getClassname();
// 构建 GraphQL 查询
String graphQLQuery = String.format(
"{\n" +
" Get {\n" +
" %s(nearVector: {vector: [%s]} limit: %d) {\n" +
" text\n" +
" fid\n" +
" kid\n" +
" docId\n" +
" _additional {\n" +
" distance\n" +
" id\n" +
" }\n" +
" }\n" +
" }\n" +
"}",
className + queryVectorBo.getKid(),
vectorStr,
queryVectorBo.getMaxResults()
);
Result<GraphQLResponse> result = client.graphQL().raw().withQuery(graphQLQuery).run();
List<String> resultList = new ArrayList<>();
if (result != null && !result.hasErrors()) {
Object data = result.getResult().getData();
JSONObject entries = new JSONObject(data);
Map<String, cn.hutool.json.JSONArray> entriesMap = entries.get("Get", Map.class);
cn.hutool.json.JSONArray objects = entriesMap.get(className + queryVectorBo.getKid());
if (objects.isEmpty()) {
return resultList;
}
for (Object object : objects) {
Map<String, String> map = (Map<String, String>) object;
String content = map.get("text");
resultList.add(content);
}
return resultList;
} else {
log.error("GraphQL 查询失败: {}", result.getError());
return resultList;
}
}
@Override
@SneakyThrows
public void removeById(String id, String modelName) {
String protocol = vectorStoreProperties.getWeaviate().getProtocol();
String host = vectorStoreProperties.getWeaviate().getHost();
String className = vectorStoreProperties.getWeaviate().getClassname();
String finalClassName = className + id;
WeaviateClient client = new WeaviateClient(new Config(protocol, host));
Result<Boolean> result = client.schema().classDeleter().withClassName(finalClassName).run();
if (result.hasErrors()) {
log.error("失败删除向量: " + result.getError());
throw new ServiceException("失败删除向量数据!");
} else {
log.info("成功删除向量数据: " + result.getResult());
}
}
@Override
public void removeByDocId(String docId, String kid) {
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
// 构建 Where 条件
WhereFilter whereFilter = WhereFilter.builder()
.path("docId")
.operator(Operator.Equal)
.valueText(docId)
.build();
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
Result<BatchDeleteResponse> result = deleter.withClassName(className)
.withWhere(whereFilter)
.run();
if (result != null && !result.hasErrors()) {
log.info("成功删除 docId={} 的所有向量数据", docId);
} else {
log.error("删除失败: {}", result.getError());
}
}
@Override
public void removeByFid(String fid, String kid) {
String className = vectorStoreProperties.getWeaviate().getClassname() + kid;
// 构建 Where 条件
WhereFilter whereFilter = WhereFilter.builder()
.path("fid")
.operator(Operator.Equal)
.valueText(fid)
.build();
ObjectsBatchDeleter deleter = client.batch().objectsBatchDeleter();
Result<BatchDeleteResponse> result = deleter.withClassName(className)
.withWhere(whereFilter)
.run();
if (result != null && !result.hasErrors()) {
log.info("成功删除 fid={} 的所有向量数据", fid);
} else {
log.error("删除失败: {}", result.getError());
}
}
}

View File

@@ -45,6 +45,18 @@ public class ChatMessageController extends BaseController {
return chatMessageService.queryPageList(bo, pageQuery);
}
/**
* 根据会话ID查询聊天消息列表
*/
@GetMapping("/listBySession/{sessionId}")
public TableDataInfo<ChatMessageVo> listBySession(@NotNull(message = "会话ID不能为空")
@PathVariable Long sessionId,
PageQuery pageQuery) {
ChatMessageBo bo = new ChatMessageBo();
bo.setSessionId(sessionId);
return chatMessageService.queryPageList(bo, pageQuery);
}
/**
* 导出聊天消息列表
*/

View File

@@ -100,24 +100,7 @@ public class SseServiceImpl implements ISseService {
// 设置用户id
chatRequest.setUserId(LoginHelper.getUserId());
//待优化的地方 这里请前端提交send的时候传递uuid进来或者sessionId
//待优化的地方 这里请前端提交send的时候传递uuid进来或者sessionId
//待优化的地方 这里请前端提交send的时候传递uuid进来或者sessionId
{
// 设置会话id
if (chatRequest.getUuid() == null) {
//暂时随机生成会话id
chatRequest.setSessionId(System.currentTimeMillis());
} else {
//这里或许需要修改一下这里应该用uuid 或者 前端传递 sessionId
chatRequest.setSessionId(chatRequest.getUuid());
}
}
chatRequest.setUserId(chatCostService.getUserId());
// 设置会话id
if (chatRequest.getSessionId() == null) {
ChatSessionBo chatSessionBo = new ChatSessionBo();
chatSessionBo.setUserId(chatCostService.getUserId());

View File

@@ -216,7 +216,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
}
baseMapper.insert(knowledgeInfo);
if (knowledgeInfo != null) {
vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()),
vectorStoreService.createSchema(knowledgeInfo.getVectorModelName(),String.valueOf(knowledgeInfo.getId()),
bo.getVectorModelName());
}
} else {
@@ -258,6 +258,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
knowledgeAttach.setDocType(fileName.substring(fileName.lastIndexOf(".") + 1));
String content = "";
ResourceLoader resourceLoader = resourceLoaderFactory.getLoaderByFileType(knowledgeAttach.getDocType());
// 文档分段入库
List<String> fids = new ArrayList<>();
try {
content = resourceLoader.getContent(file.getInputStream());
@@ -265,6 +266,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
List<KnowledgeFragment> knowledgeFragmentList = new ArrayList<>();
if (CollUtil.isNotEmpty(chunkList)) {
for (int i = 0; i < chunkList.size(); i++) {
// 生成知识片段ID
String fid = RandomUtil.randomString(10);
fids.add(fid);
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();