feat: 接入langchain4j操作向量库

This commit is contained in:
ageerle
2025-05-07 17:33:22 +08:00
parent 731f6ceb6e
commit 1a645c6e10
7 changed files with 133 additions and 441 deletions

View File

@@ -1,20 +0,0 @@
package org.ruoyi.service;
import java.util.List;
public interface EmbeddingService {
void storeEmbeddings(List<String> chunkList, String kid, String docId,List<String> fidList);
void removeByDocId(String kid,String docId);
void removeByKid(String kid);
List<Double> getQueryVector(String query, String kid);
void createSchema(String kid);
void removeByKidAndFid(String kid, String fid);
void saveFragment(String kid, String docId, String fid, String content);
}

View File

@@ -2,22 +2,18 @@ package org.ruoyi.service;
import java.util.List;
/**
* 向量存储
*/
public interface VectorStoreService {
void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList);
void storeEmbeddings(List<String> chunkList, String kid);
void removeByDocId(String kid, String docId);
void removeByDocId(String kid,String docId);
void removeByKid(String kid);
List<String> nearest(List<Double> queryVector, String kid);
List<String> getQueryVector(String query, String kid);
List<String> nearest(String query, String kid);
void newSchema(String kid);
void createSchema(String kid);
void removeByKidAndFid(String kid, String fid);
}

View File

@@ -1,64 +0,0 @@
package org.ruoyi.service.impl;
import lombok.AllArgsConstructor;
import org.ruoyi.service.EmbeddingService;
import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.VectorizationService;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
@Service
@AllArgsConstructor
public class EmbeddingServiceImpl implements EmbeddingService {
private final VectorStoreService vectorStore;
private final VectorizationService vectorization;
/**
* 保存向量数据库
* @param chunkList 文档按行切分的片段
* @param kid 知识库ID
* @param docId 文档ID
*/
@Override
public void storeEmbeddings(List<String> chunkList, String kid, String docId,List<String> fidList) {
List<List<Double>> vectorList = vectorization.batchVectorization(chunkList, kid);
vectorStore.storeEmbeddings(chunkList,vectorList,kid,docId,fidList);
}
@Override
public void removeByDocId(String kid,String docId) {
vectorStore.removeByDocId(kid,docId);
}
@Override
public void removeByKid(String kid) {
vectorStore.removeByKid(kid);
}
@Override
public List<Double> getQueryVector(String query, String kid) {
return vectorization.singleVectorization(query,kid);
}
@Override
public void createSchema(String kid) {
vectorStore.newSchema(kid);
}
@Override
public void removeByKidAndFid(String kid, String fid) {
vectorStore.removeByKidAndFid(kid,fid);
}
@Override
public void saveFragment(String kid, String docId, String fid, String content) {
List<String> chunkList = new ArrayList<>();
List<String> fidList = new ArrayList<>();
chunkList.add(content);
fidList.add(fid);
storeEmbeddings(chunkList,kid,docId,fidList);
}
}

View File

@@ -1,37 +1,25 @@
package org.ruoyi.service.impl;
import cn.hutool.core.lang.UUID;
import cn.hutool.json.JSONObject;
import com.google.gson.internal.LinkedTreeMap;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
import io.weaviate.client.v1.filters.Operator;
import io.weaviate.client.v1.filters.WhereFilter;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.graphql.query.argument.NearTextArgument;
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
import io.weaviate.client.v1.graphql.query.fields.Field;
import io.weaviate.client.v1.misc.model.Meta;
import io.weaviate.client.v1.misc.model.ReplicationConfig;
import io.weaviate.client.v1.misc.model.ShardingConfig;
import io.weaviate.client.v1.misc.model.VectorIndexConfig;
import io.weaviate.client.v1.schema.model.DataType;
import io.weaviate.client.v1.schema.model.Property;
import io.weaviate.client.v1.schema.model.Schema;
import io.weaviate.client.v1.schema.model.WeaviateClass;
import cn.hutool.core.util.RandomUtil;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.domain.vo.KnowledgeInfoVo;
import org.ruoyi.service.IKnowledgeInfoService;
import org.ruoyi.service.VectorStoreService;
import org.ruoyi.service.IKnowledgeInfoService;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.testcontainers.weaviate.WeaviateContainer;
import java.util.ArrayList;
import java.util.HashMap;
@@ -54,6 +42,8 @@ public class WeaviateVectorStoreImpl implements VectorStoreService {
@Resource
private ConfigService configService;
private EmbeddingStore<TextSegment> embeddingStore;
@PostConstruct
public void loadConfig() {
this.protocol = configService.getConfigValue("weaviate", "protocol");
@@ -61,342 +51,94 @@ public class WeaviateVectorStoreImpl implements VectorStoreService {
this.className = configService.getConfigValue("weaviate", "classname");
}
public WeaviateClient getClient() {
Config config = new Config(protocol, host);
WeaviateClient client = new WeaviateClient(config);
return client;
}
public Result<Meta> getMeta() {
WeaviateClient client = getClient();
Result<Meta> meta = client.misc().metaGetter().run();
if (meta.getError() == null) {
System.out.printf("meta.hostname: %s\n", meta.getResult().getHostname());
System.out.printf("meta.version: %s\n", meta.getResult().getVersion());
System.out.printf("meta.modules: %s\n", meta.getResult().getModules());
} else {
System.out.printf("Error: %s\n", meta.getError().getMessages());
}
return meta;
}
public Result<Schema> getSchemas() {
WeaviateClient client = getClient();
Result<Schema> result = client.schema().getter().run();
if (result.hasErrors()) {
System.out.println(result.getError());
} else {
System.out.println(result.getResult());
}
return result;
}
public Result<Boolean> createSchema(String kid) {
WeaviateClient client = getClient();
VectorIndexConfig vectorIndexConfig = VectorIndexConfig.builder()
.distance("cosine")
.cleanupIntervalSeconds(300)
.efConstruction(128)
.maxConnections(64)
.vectorCacheMaxObjects(500000L)
.ef(-1)
.skip(false)
.dynamicEfFactor(8)
.dynamicEfMax(500)
.dynamicEfMin(100)
.flatSearchCutoff(40000)
@Override
public List<String> getQueryVector(String query, String kid) {
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.modelName("text-embedding-3-small")
.build();
ShardingConfig shardingConfig = ShardingConfig.builder()
.desiredCount(3)
.desiredVirtualCount(128)
.function("murmur3")
.key("_id")
.strategy("hash")
.virtualPerPhysical(128)
Filter simpleFilter = new IsEqualTo("kid", kid);
Embedding queryEmbedding = embeddingModel.embed("What is your favourite sport?").content();
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
.maxResults(3)
// 添加过滤条件
.filter(simpleFilter)
.build();
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
ReplicationConfig replicationConfig = ReplicationConfig.builder()
.factor(1)
.build();
List<String> results = new ArrayList<>();
JSONObject classModuleConfigValue = new JSONObject();
classModuleConfigValue.put("vectorizeClassName", false);
JSONObject classModuleConfig = new JSONObject();
classModuleConfig.put("text2vec-transformers", classModuleConfigValue);
matches.forEach(embeddingMatch -> {
results.add(embeddingMatch.embedded().text());
});
JSONObject propertyModuleConfigValueSkipTrue = new JSONObject();
propertyModuleConfigValueSkipTrue.put("vectorizePropertyName", false);
propertyModuleConfigValueSkipTrue.put("skip", true);
JSONObject propertyModuleConfigSkipTrue = new JSONObject();
propertyModuleConfigSkipTrue.put("text2vec-transformers", propertyModuleConfigValueSkipTrue);
JSONObject propertyModuleConfigValueSkipFalse = new JSONObject();
propertyModuleConfigValueSkipFalse.put("vectorizePropertyName", false);
propertyModuleConfigValueSkipFalse.put("skip", false);
JSONObject propertyModuleConfigSkipFalse = new JSONObject();
propertyModuleConfigSkipFalse.put("text2vec-transformers", propertyModuleConfigValueSkipFalse);
WeaviateClass clazz = WeaviateClass.builder()
.className(className + kid)
.description("local knowledge")
.vectorIndexType("hnsw")
.vectorizer("text2vec-transformers")
.shardingConfig(shardingConfig)
.vectorIndexConfig(vectorIndexConfig)
.replicationConfig(replicationConfig)
.moduleConfig(classModuleConfig)
.properties(new ArrayList() {
{
add(Property.builder()
.dataType(new ArrayList() {
{
add(DataType.TEXT);
}
})
.name("content")
.description("The content of the local knowledge,for search")
.moduleConfig(propertyModuleConfigSkipFalse)
.build());
add(Property.builder()
.dataType(new ArrayList() {
{
add(DataType.TEXT);
}
})
.name("kid")
.description("The knowledge id of the local knowledge,for search")
.moduleConfig(propertyModuleConfigSkipTrue)
.build());
add(Property.builder()
.dataType(new ArrayList() {
{
add(DataType.TEXT);
}
})
.name("docId")
.description("The doc id of the local knowledge,for search")
.moduleConfig(propertyModuleConfigSkipTrue)
.build());
add(Property.builder()
.dataType(new ArrayList() {
{
add(DataType.TEXT);
}
})
.name("fid")
.description("The fragment id of the local knowledge,for search")
.moduleConfig(propertyModuleConfigSkipTrue)
.build());
add(Property.builder()
.dataType(new ArrayList() {
{
add(DataType.TEXT);
}
})
.name("uuid")
.description("The uuid id of the local knowledge fragment(same with id properties),for search")
.moduleConfig(propertyModuleConfigSkipTrue)
.build());
} })
.build();
Result<Boolean> result = client.schema().classCreator().withClass(clazz).run();
if (result.hasErrors()) {
System.out.println(result.getError());
}
System.out.println(result.getResult());
return result;
return results;
}
@Override
public void newSchema(String kid) {
createSchema(kid);
}
@Override
public void removeByKidAndFid(String kid, String fid) {
List<String> resultList = new ArrayList<>();
WeaviateClient client = getClient();
Field fieldId = Field.builder().name("uuid").build();
WhereFilter where = WhereFilter.builder()
.path(new String[]{"fid"})
.operator(Operator.Equal)
.valueString(fid)
public void createSchema(String kid) {
WeaviateContainer weaviate = new WeaviateContainer(protocol);
weaviate.start();
this.embeddingStore = WeaviateEmbeddingStore.builder()
.scheme("http")
.host(host)
.objectClass(className+kid)
.scheme(protocol)
.avoidDups(true)
.consistencyLevel("ALL")
.build();
Result<GraphQLResponse> result = client.graphQL().get()
.withClassName(className + kid)
.withFields(fieldId)
.withWhere(where)
.run();
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
ArrayList<LinkedTreeMap> m = l.get(className + kid);
for (LinkedTreeMap linkedTreeMap : m) {
String uuid = linkedTreeMap.get("uuid").toString();
resultList.add(uuid);
}
for (String uuid : resultList) {
Result<Boolean> deleteResult = client.data().deleter()
.withID(uuid)
.withClassName(className + kid)
.withConsistencyLevel(ConsistencyLevel.ALL) // default QUORUM
.run();
}
}
@Override
public void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList) {
WeaviateClient client = getClient();
for (int i = 0; i < Math.min(chunkList.size(), vectorList.size()); i++) {
List<Double> vector = vectorList.get(i);
Float[] vf = vector.stream().map(Double::floatValue).toArray(Float[]::new);
public void storeEmbeddings(List<String> chunkList,String kid) {
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.modelName("text-embedding-3-small")
.build();
// 生成文档id
String docId = RandomUtil.randomString(10);
chunkList.forEach(chunk -> {
// 生成知识块id
String fid = RandomUtil.randomString(10);
Map<String, Object> dataSchema = new HashMap<>();
dataSchema.put("content", chunkList.get(i));
dataSchema.put("kid", kid);
dataSchema.put("docId", docId);
dataSchema.put("fid", fidList.get(i));
String uuid = UUID.randomUUID().toString();
dataSchema.put("uuid", uuid);
dataSchema.put("fid", fid);
TextSegment segment = TextSegment.from(chunk);
segment.metadata().putAll(dataSchema);
Embedding content = embeddingModel.embed(segment).content();
embeddingStore.add(content);
});
}
Result<WeaviateObject> result = client.data().creator()
.withClassName(className + kid)
.withID(uuid)
.withVector(vf)
.withProperties(dataSchema)
.run();
}
@Override
public void removeByKid(String kid) {
// 根据条件删除向量数据
Filter simpleFilter = new IsEqualTo("kid", kid);
embeddingStore.removeAll(simpleFilter);
}
@Override
public void removeByDocId(String kid, String docId) {
List<String> resultList = new ArrayList<>();
WeaviateClient client = getClient();
Field fieldId = Field.builder().name("uuid").build();
WhereFilter where = WhereFilter.builder()
.path(new String[]{"docId"})
.operator(Operator.Equal)
.valueString(docId)
.build();
Result<GraphQLResponse> result = client.graphQL().get()
.withClassName(className + kid)
.withFields(fieldId)
.withWhere(where)
.run();
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
ArrayList<LinkedTreeMap> m = l.get(className + kid);
for (LinkedTreeMap linkedTreeMap : m) {
String uuid = linkedTreeMap.get("uuid").toString();
resultList.add(uuid);
}
for (String uuid : resultList) {
Result<Boolean> deleteResult = client.data().deleter()
.withID(uuid)
.withClassName(className + kid)
.withConsistencyLevel(ConsistencyLevel.ALL) // default QUORUM
.run();
}
// 根据条件删除向量数据
Filter simpleFilterByDocId = new IsEqualTo("docId", docId);
embeddingStore.removeAll(simpleFilterByDocId);
}
@Override
public void removeByKid(String kid) {
WeaviateClient client = getClient();
Result<Boolean> result = client.schema().classDeleter().withClassName(className + kid).run();
if (result.hasErrors()) {
System.out.println("删除schema失败" + result.getError());
} else {
System.out.println("删除schema成功" + result.getResult());
}
log.info("drop schema by kid, result = {}", result);
public void removeByKidAndFid(String kid, String fid) {
// 根据条件删除向量数据
Filter simpleFilterByKid = new IsEqualTo("kid", kid);
Filter simpleFilterFid = new IsEqualTo("fid", fid);
Filter simpleFilterByAnd = Filter.and(simpleFilterFid, simpleFilterByKid);
embeddingStore.removeAll(simpleFilterByAnd);
}
@Override
public List<String> nearest(List<Double> queryVector, String kid) {
if (StringUtils.isBlank(kid)) {
return new ArrayList<String>();
}
List<String> resultList = new ArrayList<>();
Float[] vf = new Float[queryVector.size()];
for (int j = 0; j < queryVector.size(); j++) {
Double value = queryVector.get(j);
vf[j] = value.floatValue();
}
WeaviateClient client = getClient();
Field contentField = Field.builder().name("content").build();
Field _additional = Field.builder()
.name("_additional")
.fields(new Field[]{
Field.builder().name("distance").build()
}).build();
NearVectorArgument nearVector = NearVectorArgument.builder()
.vector(vf)
.distance(1.6f) // certainty = 1f - distance /2f
.build();
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
Result<GraphQLResponse> result = client.graphQL().get()
.withClassName(className + kid)
.withFields(contentField, _additional)
.withNearVector(nearVector)
.withLimit(knowledgeInfoVo.getRetrieveLimit())
.run();
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
ArrayList<LinkedTreeMap> m = l.get(className + kid);
for (LinkedTreeMap linkedTreeMap : m) {
String content = linkedTreeMap.get("content").toString();
resultList.add(content);
}
return resultList;
}
@Override
public List<String> nearest(String query, String kid) {
if (StringUtils.isBlank(kid)) {
return new ArrayList<String>();
}
List<String> resultList = new ArrayList<>();
WeaviateClient client = getClient();
Field contentField = Field.builder().name("content").build();
Field _additional = Field.builder()
.name("_additional")
.fields(new Field[]{
Field.builder().name("distance").build()
}).build();
NearTextArgument nearText = client.graphQL().arguments().nearTextArgBuilder()
.concepts(new String[]{query})
.distance(1.6f) // certainty = 1f - distance /2f
.build();
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
Result<GraphQLResponse> result = client.graphQL().get()
.withClassName(className + kid)
.withFields(contentField, _additional)
.withNearText(nearText)
.withLimit(knowledgeInfoVo.getRetrieveLimit())
.run();
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
ArrayList<LinkedTreeMap> m = l.get(className + kid);
for (LinkedTreeMap linkedTreeMap : m) {
String content = linkedTreeMap.get("content").toString();
resultList.add(content);
}
return resultList;
}
public Result<Boolean> deleteSchema(String kid) {
WeaviateClient client = getClient();
Result<Boolean> result = client.schema().classDeleter().withClassName(className + kid).run();
if (result.hasErrors()) {
System.out.println(result.getError());
} else {
System.out.println(result.getResult());
}
return result;
}
}