mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-03-18 23:23:43 +08:00
feat: Weaviate操作向量库功能优化
This commit is contained in:
@@ -0,0 +1,43 @@
|
||||
package org.ruoyi.domain.bo;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 查询向量所需参数
|
||||
* @author ageer
|
||||
*/
|
||||
@Data
|
||||
public class QueryVectorBo {
|
||||
|
||||
/**
|
||||
* 查询内容
|
||||
*/
|
||||
private String query;
|
||||
|
||||
/**
|
||||
* 知识库kid
|
||||
*/
|
||||
private String kid;
|
||||
|
||||
/**
|
||||
* 查询向量返回条数
|
||||
*/
|
||||
private Integer maxResults;
|
||||
|
||||
/**
|
||||
* 模型名称
|
||||
*/
|
||||
private String modelName;
|
||||
|
||||
/**
|
||||
* 请求key
|
||||
*/
|
||||
private String apiKey;
|
||||
|
||||
/**
|
||||
* 请求地址
|
||||
*/
|
||||
private String baseUrl;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package org.ruoyi.domain.bo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 保存向量所需参数
|
||||
* @author ageer
|
||||
*/
|
||||
@Data
|
||||
public class StoreEmbeddingBo {
|
||||
|
||||
/**
|
||||
* 切分文本块列表
|
||||
*/
|
||||
private List<String> chunkList;
|
||||
|
||||
/**
|
||||
* 知识库kid
|
||||
*/
|
||||
private String kid;
|
||||
|
||||
/**
|
||||
* 文档id
|
||||
*/
|
||||
private String docId;
|
||||
|
||||
/**
|
||||
* 知识块id列表
|
||||
*/
|
||||
private List<String> fids;
|
||||
|
||||
/**
|
||||
* 模型名称
|
||||
*/
|
||||
private String modelName;
|
||||
|
||||
/**
|
||||
* 请求key
|
||||
*/
|
||||
private String apiKey;
|
||||
|
||||
/**
|
||||
* 请求地址
|
||||
*/
|
||||
private String baseUrl;
|
||||
|
||||
}
|
||||
@@ -1,20 +1,23 @@
|
||||
package org.ruoyi.service;
|
||||
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author ageer
|
||||
* 向量库管理
|
||||
* @author ageer
|
||||
*/
|
||||
public interface VectorStoreService {
|
||||
|
||||
void storeEmbeddings(List<String> chunkList, String kid,String docId,List<String> fids);
|
||||
void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo);
|
||||
|
||||
void removeByDocId(String kid,String docId);
|
||||
|
||||
void removeByKid(String kid);
|
||||
|
||||
List<String> getQueryVector(String query, String kid);
|
||||
List<String> getQueryVector(QueryVectorBo queryVectorBo);
|
||||
|
||||
void createSchema(String kid);
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package org.ruoyi.service.impl;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
@@ -11,9 +12,12 @@ import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.service.ConfigService;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -23,9 +27,11 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Weaviate向量库管理
|
||||
* @author ageer
|
||||
* Weaviate 向量库管理
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -37,38 +43,7 @@ public class WeaviateVectorStoreImpl implements VectorStoreService {
|
||||
private final ConfigService configService;
|
||||
|
||||
@Override
|
||||
public List<String> getQueryVector(String query, String kid) {
|
||||
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey("sk-xxx")
|
||||
.baseUrl("https://api.pandarobot.chat/v1/")
|
||||
.modelName(TEXT_EMBEDDING_3_SMALL)
|
||||
.build();
|
||||
|
||||
// Filter simpleFilter = new IsEqualTo("kid", kid);
|
||||
|
||||
// createSchema(kid);
|
||||
|
||||
Embedding queryEmbedding = embeddingModel.embed("聊天补全模型").content();
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(queryEmbedding)
|
||||
.maxResults(2)
|
||||
// 添加过滤条件
|
||||
// .filter(simpleFilter)
|
||||
.build();
|
||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
|
||||
|
||||
|
||||
|
||||
List<String> results = new ArrayList<>();
|
||||
|
||||
matches.forEach(embeddingMatch -> {
|
||||
results.add(embeddingMatch.embedded().text());
|
||||
});
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
@Override
|
||||
@PostConstruct
|
||||
public void createSchema(String kid) {
|
||||
String protocol = configService.getConfigValue("weaviate", "protocol");
|
||||
String host = configService.getConfigValue("weaviate", "host");
|
||||
@@ -84,24 +59,42 @@ public class WeaviateVectorStoreImpl implements VectorStoreService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(List<String> chunkList,String kid,String docId,List<String> fids) {
|
||||
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey("sk-xxxx")
|
||||
.baseUrl("https://api.pandarobot.chat/v1/")
|
||||
.modelName(TEXT_EMBEDDING_3_SMALL)
|
||||
.build();
|
||||
|
||||
chunkList.forEach(chunk -> {
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getModelName(),
|
||||
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
|
||||
for (int i = 0; i < storeEmbeddingBo.getChunkList().size(); i++) {
|
||||
Map<String, Object> dataSchema = new HashMap<>();
|
||||
dataSchema.put("kid", kid);
|
||||
dataSchema.put("docId", docId);
|
||||
dataSchema.put("fid", fids.get(0));
|
||||
Response<Embedding> response = embeddingModel.embed(chunk);
|
||||
dataSchema.put("kid", storeEmbeddingBo.getKid());
|
||||
dataSchema.put("docId", storeEmbeddingBo.getKid());
|
||||
dataSchema.put("fid", storeEmbeddingBo.getFids().get(i));
|
||||
Response<Embedding> response = embeddingModel.embed(storeEmbeddingBo.getChunkList().get(i));
|
||||
Embedding embedding = response.content();
|
||||
TextSegment segment = TextSegment.from(chunk);
|
||||
TextSegment segment = TextSegment.from(storeEmbeddingBo.getChunkList().get(i));
|
||||
segment.metadata().putAll(dataSchema);
|
||||
embeddingStore.add(embedding,segment);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getModelName(),
|
||||
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
|
||||
Filter simpleFilter = new IsEqualTo("kid", queryVectorBo.getKid());
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(queryEmbedding)
|
||||
.maxResults(queryVectorBo.getMaxResults())
|
||||
// 添加过滤条件
|
||||
.filter(simpleFilter)
|
||||
.build();
|
||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
|
||||
|
||||
List<String> results = new ArrayList<>();
|
||||
|
||||
matches.forEach(embeddingMatch -> {
|
||||
results.add(embeddingMatch.embedded().text());
|
||||
});
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
@@ -128,4 +121,25 @@ public class WeaviateVectorStoreImpl implements VectorStoreService {
|
||||
embeddingStore.removeAll(simpleFilterByAnd);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取向量模型
|
||||
*/
|
||||
public EmbeddingModel getEmbeddingModel(String modelName,String apiKey,String baseUrl) {
|
||||
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder().build();
|
||||
if(TEXT_EMBEDDING_3_SMALL.toString().equals(modelName)) {
|
||||
embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(TEXT_EMBEDDING_3_SMALL)
|
||||
.build();
|
||||
// TODO 添加枚举
|
||||
}else if("quentinz/bge-large-zh-v1.5".equals(modelName)) {
|
||||
embeddingModel = OllamaEmbeddingModel.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
}
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user