mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-16 21:33:40 +00:00
Compare commits
4 Commits
731f6ceb6e
...
980df20752
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
980df20752 | ||
|
|
aa92d232bb | ||
|
|
81c0bb5738 | ||
|
|
1a645c6e10 |
@@ -16,8 +16,21 @@
|
||||
<maven.compiler.source>17</maven.compiler.source>
|
||||
<maven.compiler.target>17</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<langchain4j.version>1.0.0-beta4</langchain4j.version>
|
||||
</properties>
|
||||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-bom</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
<type>pom</type>
|
||||
<scope>import</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<!-- pdf解析器 -->
|
||||
@@ -34,17 +47,60 @@
|
||||
<version>1.0.79</version>
|
||||
</dependency>
|
||||
|
||||
<!-- milvus java sdk -->
|
||||
<dependency>
|
||||
<groupId>io.milvus</groupId>
|
||||
<artifactId>milvus-sdk-java</artifactId>
|
||||
<version>2.3.2</version>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-weaviate</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>io.weaviate</groupId>
|
||||
<artifactId>client</artifactId>
|
||||
<version>4.0.0</version>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
|
||||
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>weaviate</artifactId>
|
||||
<version>1.19.6</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-ollama</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-milvus</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>milvus</artifactId>
|
||||
<version>1.19.6</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-qdrant</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>qdrant</artifactId>
|
||||
<version>1.19.6</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package org.ruoyi.domain.bo;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 查询向量所需参数
|
||||
* @author ageer
|
||||
*/
|
||||
@Data
|
||||
public class QueryVectorBo {
|
||||
|
||||
/**
|
||||
* 查询内容
|
||||
*/
|
||||
private String query;
|
||||
|
||||
/**
|
||||
* 知识库kid
|
||||
*/
|
||||
private String kid;
|
||||
|
||||
/**
|
||||
* 查询向量返回条数
|
||||
*/
|
||||
private Integer maxResults;
|
||||
|
||||
/**
|
||||
* 模型名称
|
||||
*/
|
||||
private String modelName;
|
||||
|
||||
/**
|
||||
* 请求key
|
||||
*/
|
||||
private String apiKey;
|
||||
|
||||
/**
|
||||
* 请求地址
|
||||
*/
|
||||
private String baseUrl;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package org.ruoyi.domain.bo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 保存向量所需参数
|
||||
* @author ageer
|
||||
*/
|
||||
@Data
|
||||
public class StoreEmbeddingBo {
|
||||
|
||||
/**
|
||||
* 切分文本块列表
|
||||
*/
|
||||
private List<String> chunkList;
|
||||
|
||||
/**
|
||||
* 知识库kid
|
||||
*/
|
||||
private String kid;
|
||||
|
||||
/**
|
||||
* 文档id
|
||||
*/
|
||||
private String docId;
|
||||
|
||||
/**
|
||||
* 知识块id列表
|
||||
*/
|
||||
private List<String> fids;
|
||||
|
||||
/**
|
||||
* 模型名称
|
||||
*/
|
||||
private String modelName;
|
||||
|
||||
/**
|
||||
* 请求key
|
||||
*/
|
||||
private String apiKey;
|
||||
|
||||
/**
|
||||
* 请求地址
|
||||
*/
|
||||
private String baseUrl;
|
||||
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package org.ruoyi.service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface EmbeddingService {
|
||||
|
||||
void storeEmbeddings(List<String> chunkList, String kid, String docId,List<String> fidList);
|
||||
|
||||
void removeByDocId(String kid,String docId);
|
||||
|
||||
void removeByKid(String kid);
|
||||
|
||||
List<Double> getQueryVector(String query, String kid);
|
||||
|
||||
void createSchema(String kid);
|
||||
|
||||
void removeByKidAndFid(String kid, String fid);
|
||||
|
||||
void saveFragment(String kid, String docId, String fid, String content);
|
||||
}
|
||||
@@ -1,23 +1,26 @@
|
||||
package org.ruoyi.service;
|
||||
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 向量存储
|
||||
* 向量库管理
|
||||
* @author ageer
|
||||
*/
|
||||
public interface VectorStoreService {
|
||||
|
||||
void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList);
|
||||
void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo);
|
||||
|
||||
void removeByDocId(String kid, String docId);
|
||||
void removeByDocId(String kid,String docId);
|
||||
|
||||
void removeByKid(String kid);
|
||||
|
||||
List<String> nearest(List<Double> queryVector, String kid);
|
||||
List<String> getQueryVector(QueryVectorBo queryVectorBo);
|
||||
|
||||
List<String> nearest(String query, String kid);
|
||||
|
||||
void newSchema(String kid);
|
||||
void createSchema(String kid,String modelName);
|
||||
|
||||
void removeByKidAndFid(String kid, String fid);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package org.ruoyi.service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 文本向量化
|
||||
*/
|
||||
public interface VectorizationService {
|
||||
|
||||
List<List<Double>> batchVectorization(List<String> chunkList, String kid);
|
||||
|
||||
List<Double> singleVectorization(String chunk, String kid);
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
package org.ruoyi.service.impl;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import org.ruoyi.service.EmbeddingService;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.ruoyi.service.VectorizationService;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Service
|
||||
@AllArgsConstructor
|
||||
public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
|
||||
private final VectorStoreService vectorStore;
|
||||
private final VectorizationService vectorization;
|
||||
|
||||
/**
|
||||
* 保存向量数据库
|
||||
* @param chunkList 文档按行切分的片段
|
||||
* @param kid 知识库ID
|
||||
* @param docId 文档ID
|
||||
*/
|
||||
@Override
|
||||
public void storeEmbeddings(List<String> chunkList, String kid, String docId,List<String> fidList) {
|
||||
List<List<Double>> vectorList = vectorization.batchVectorization(chunkList, kid);
|
||||
vectorStore.storeEmbeddings(chunkList,vectorList,kid,docId,fidList);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByDocId(String kid,String docId) {
|
||||
vectorStore.removeByDocId(kid,docId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByKid(String kid) {
|
||||
vectorStore.removeByKid(kid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> getQueryVector(String query, String kid) {
|
||||
return vectorization.singleVectorization(query,kid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createSchema(String kid) {
|
||||
vectorStore.newSchema(kid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByKidAndFid(String kid, String fid) {
|
||||
vectorStore.removeByKidAndFid(kid,fid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveFragment(String kid, String docId, String fid, String content) {
|
||||
List<String> chunkList = new ArrayList<>();
|
||||
List<String> fidList = new ArrayList<>();
|
||||
chunkList.add(content);
|
||||
fidList.add(fid);
|
||||
storeEmbeddings(chunkList,kid,docId,fidList);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package org.ruoyi.service.impl;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.qdrant.QdrantEmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.service.ConfigService;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import static dev.langchain4j.model.openai.OpenAiEmbeddingModelName.TEXT_EMBEDDING_3_SMALL;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 向量库管理
|
||||
* @author ageer
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class VectorStoreServiceImpl implements VectorStoreService {
|
||||
|
||||
private EmbeddingStore<TextSegment> embeddingStore;
|
||||
|
||||
private final ConfigService configService;
|
||||
|
||||
@Override
|
||||
@PostConstruct
|
||||
public void createSchema(String kid,String modelName) {
|
||||
if(modelName.equals("weaviate")){
|
||||
String protocol = configService.getConfigValue("weaviate", "protocol");
|
||||
String host = configService.getConfigValue("weaviate", "host");
|
||||
String className = configService.getConfigValue("weaviate", "classname");
|
||||
this.embeddingStore = WeaviateEmbeddingStore.builder()
|
||||
.scheme(protocol)
|
||||
.host(host)
|
||||
.objectClass(className+kid)
|
||||
.scheme(protocol)
|
||||
.avoidDups(true)
|
||||
.consistencyLevel("ALL")
|
||||
.build();
|
||||
}else if(modelName.equals("milvus")){
|
||||
String uri = configService.getConfigValue("milvus", "host");
|
||||
String collection = configService.getConfigValue("milvus", "collection");
|
||||
String dimension = configService.getConfigValue("milvus", "dimension");
|
||||
this.embeddingStore = MilvusEmbeddingStore.builder()
|
||||
.uri(uri)
|
||||
.collectionName(collection+kid)
|
||||
.dimension(Integer.parseInt(dimension))
|
||||
.build();
|
||||
}else if(modelName.equals("qdrant")){
|
||||
String host = configService.getConfigValue("qdrant", "host");
|
||||
String port = configService.getConfigValue("qdrant", "port");
|
||||
String collectionName = configService.getConfigValue("qdrant", "collectionName");
|
||||
this.embeddingStore = QdrantEmbeddingStore.builder()
|
||||
.host(host)
|
||||
.port(Integer.parseInt(port))
|
||||
.collectionName(collectionName)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getModelName(),
|
||||
storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
|
||||
for (int i = 0; i < storeEmbeddingBo.getChunkList().size(); i++) {
|
||||
Map<String, Object> dataSchema = new HashMap<>();
|
||||
dataSchema.put("kid", storeEmbeddingBo.getKid());
|
||||
dataSchema.put("docId", storeEmbeddingBo.getKid());
|
||||
dataSchema.put("fid", storeEmbeddingBo.getFids().get(i));
|
||||
Response<Embedding> response = embeddingModel.embed(storeEmbeddingBo.getChunkList().get(i));
|
||||
Embedding embedding = response.content();
|
||||
TextSegment segment = TextSegment.from(storeEmbeddingBo.getChunkList().get(i));
|
||||
segment.metadata().putAll(dataSchema);
|
||||
embeddingStore.add(embedding,segment);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
||||
EmbeddingModel embeddingModel = getEmbeddingModel(queryVectorBo.getModelName(),
|
||||
queryVectorBo.getApiKey(), queryVectorBo.getBaseUrl());
|
||||
Filter simpleFilter = new IsEqualTo("kid", queryVectorBo.getKid());
|
||||
Embedding queryEmbedding = embeddingModel.embed(queryVectorBo.getQuery()).content();
|
||||
EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(queryEmbedding)
|
||||
.maxResults(queryVectorBo.getMaxResults())
|
||||
// 添加过滤条件
|
||||
.filter(simpleFilter)
|
||||
.build();
|
||||
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(embeddingSearchRequest).matches();
|
||||
|
||||
List<String> results = new ArrayList<>();
|
||||
|
||||
matches.forEach(embeddingMatch -> {
|
||||
results.add(embeddingMatch.embedded().text());
|
||||
});
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void removeByKid(String kid) {
|
||||
// 根据条件删除向量数据
|
||||
Filter simpleFilter = new IsEqualTo("kid", kid);
|
||||
embeddingStore.removeAll(simpleFilter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByDocId(String kid, String docId) {
|
||||
// 根据条件删除向量数据
|
||||
Filter simpleFilterByDocId = new IsEqualTo("docId", docId);
|
||||
embeddingStore.removeAll(simpleFilterByDocId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByKidAndFid(String kid, String fid) {
|
||||
// 根据条件删除向量数据
|
||||
Filter simpleFilterByKid = new IsEqualTo("kid", kid);
|
||||
Filter simpleFilterFid = new IsEqualTo("fid", fid);
|
||||
Filter simpleFilterByAnd = Filter.and(simpleFilterFid, simpleFilterByKid);
|
||||
embeddingStore.removeAll(simpleFilterByAnd);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取向量模型
|
||||
*/
|
||||
public EmbeddingModel getEmbeddingModel(String modelName,String apiKey,String baseUrl) {
|
||||
EmbeddingModel embeddingModel = OpenAiEmbeddingModel.builder().build();
|
||||
if(TEXT_EMBEDDING_3_SMALL.toString().equals(modelName)) {
|
||||
embeddingModel = OpenAiEmbeddingModel.builder()
|
||||
.apiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(TEXT_EMBEDDING_3_SMALL)
|
||||
.build();
|
||||
// TODO 添加枚举
|
||||
}else if("quentinz/bge-large-zh-v1.5".equals(modelName)) {
|
||||
embeddingModel = OllamaEmbeddingModel.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.modelName(modelName)
|
||||
.build();
|
||||
}
|
||||
return embeddingModel;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,402 +0,0 @@
|
||||
package org.ruoyi.service.impl;
|
||||
|
||||
import cn.hutool.core.lang.UUID;
|
||||
import cn.hutool.json.JSONObject;
|
||||
import com.google.gson.internal.LinkedTreeMap;
|
||||
import io.weaviate.client.Config;
|
||||
import io.weaviate.client.WeaviateClient;
|
||||
import io.weaviate.client.base.Result;
|
||||
import io.weaviate.client.v1.data.model.WeaviateObject;
|
||||
import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
|
||||
import io.weaviate.client.v1.filters.Operator;
|
||||
import io.weaviate.client.v1.filters.WhereFilter;
|
||||
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
|
||||
import io.weaviate.client.v1.graphql.query.argument.NearTextArgument;
|
||||
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
|
||||
import io.weaviate.client.v1.graphql.query.fields.Field;
|
||||
import io.weaviate.client.v1.misc.model.Meta;
|
||||
import io.weaviate.client.v1.misc.model.ReplicationConfig;
|
||||
import io.weaviate.client.v1.misc.model.ShardingConfig;
|
||||
import io.weaviate.client.v1.misc.model.VectorIndexConfig;
|
||||
import io.weaviate.client.v1.schema.model.DataType;
|
||||
import io.weaviate.client.v1.schema.model.Property;
|
||||
import io.weaviate.client.v1.schema.model.Schema;
|
||||
import io.weaviate.client.v1.schema.model.WeaviateClass;
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.ruoyi.common.core.service.ConfigService;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class WeaviateVectorStoreImpl implements VectorStoreService {
|
||||
|
||||
private volatile String protocol;
|
||||
private volatile String host;
|
||||
private volatile String className;
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private ConfigService configService;
|
||||
|
||||
@PostConstruct
|
||||
public void loadConfig() {
|
||||
this.protocol = configService.getConfigValue("weaviate", "protocol");
|
||||
this.host = configService.getConfigValue("weaviate", "host");
|
||||
this.className = configService.getConfigValue("weaviate", "classname");
|
||||
}
|
||||
|
||||
public WeaviateClient getClient() {
|
||||
Config config = new Config(protocol, host);
|
||||
WeaviateClient client = new WeaviateClient(config);
|
||||
return client;
|
||||
}
|
||||
|
||||
public Result<Meta> getMeta() {
|
||||
WeaviateClient client = getClient();
|
||||
Result<Meta> meta = client.misc().metaGetter().run();
|
||||
if (meta.getError() == null) {
|
||||
System.out.printf("meta.hostname: %s\n", meta.getResult().getHostname());
|
||||
System.out.printf("meta.version: %s\n", meta.getResult().getVersion());
|
||||
System.out.printf("meta.modules: %s\n", meta.getResult().getModules());
|
||||
} else {
|
||||
System.out.printf("Error: %s\n", meta.getError().getMessages());
|
||||
}
|
||||
return meta;
|
||||
}
|
||||
|
||||
public Result<Schema> getSchemas() {
|
||||
WeaviateClient client = getClient();
|
||||
Result<Schema> result = client.schema().getter().run();
|
||||
if (result.hasErrors()) {
|
||||
System.out.println(result.getError());
|
||||
} else {
|
||||
System.out.println(result.getResult());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
public Result<Boolean> createSchema(String kid) {
|
||||
WeaviateClient client = getClient();
|
||||
|
||||
VectorIndexConfig vectorIndexConfig = VectorIndexConfig.builder()
|
||||
.distance("cosine")
|
||||
.cleanupIntervalSeconds(300)
|
||||
.efConstruction(128)
|
||||
.maxConnections(64)
|
||||
.vectorCacheMaxObjects(500000L)
|
||||
.ef(-1)
|
||||
.skip(false)
|
||||
.dynamicEfFactor(8)
|
||||
.dynamicEfMax(500)
|
||||
.dynamicEfMin(100)
|
||||
.flatSearchCutoff(40000)
|
||||
.build();
|
||||
|
||||
ShardingConfig shardingConfig = ShardingConfig.builder()
|
||||
.desiredCount(3)
|
||||
.desiredVirtualCount(128)
|
||||
.function("murmur3")
|
||||
.key("_id")
|
||||
.strategy("hash")
|
||||
.virtualPerPhysical(128)
|
||||
.build();
|
||||
|
||||
ReplicationConfig replicationConfig = ReplicationConfig.builder()
|
||||
.factor(1)
|
||||
.build();
|
||||
|
||||
JSONObject classModuleConfigValue = new JSONObject();
|
||||
classModuleConfigValue.put("vectorizeClassName", false);
|
||||
JSONObject classModuleConfig = new JSONObject();
|
||||
classModuleConfig.put("text2vec-transformers", classModuleConfigValue);
|
||||
|
||||
JSONObject propertyModuleConfigValueSkipTrue = new JSONObject();
|
||||
propertyModuleConfigValueSkipTrue.put("vectorizePropertyName", false);
|
||||
propertyModuleConfigValueSkipTrue.put("skip", true);
|
||||
JSONObject propertyModuleConfigSkipTrue = new JSONObject();
|
||||
propertyModuleConfigSkipTrue.put("text2vec-transformers", propertyModuleConfigValueSkipTrue);
|
||||
|
||||
JSONObject propertyModuleConfigValueSkipFalse = new JSONObject();
|
||||
propertyModuleConfigValueSkipFalse.put("vectorizePropertyName", false);
|
||||
propertyModuleConfigValueSkipFalse.put("skip", false);
|
||||
JSONObject propertyModuleConfigSkipFalse = new JSONObject();
|
||||
propertyModuleConfigSkipFalse.put("text2vec-transformers", propertyModuleConfigValueSkipFalse);
|
||||
|
||||
WeaviateClass clazz = WeaviateClass.builder()
|
||||
.className(className + kid)
|
||||
.description("local knowledge")
|
||||
.vectorIndexType("hnsw")
|
||||
.vectorizer("text2vec-transformers")
|
||||
.shardingConfig(shardingConfig)
|
||||
.vectorIndexConfig(vectorIndexConfig)
|
||||
.replicationConfig(replicationConfig)
|
||||
.moduleConfig(classModuleConfig)
|
||||
.properties(new ArrayList() {
|
||||
{
|
||||
add(Property.builder()
|
||||
.dataType(new ArrayList() {
|
||||
{
|
||||
add(DataType.TEXT);
|
||||
}
|
||||
})
|
||||
.name("content")
|
||||
.description("The content of the local knowledge,for search")
|
||||
.moduleConfig(propertyModuleConfigSkipFalse)
|
||||
.build());
|
||||
add(Property.builder()
|
||||
.dataType(new ArrayList() {
|
||||
{
|
||||
add(DataType.TEXT);
|
||||
}
|
||||
})
|
||||
.name("kid")
|
||||
.description("The knowledge id of the local knowledge,for search")
|
||||
.moduleConfig(propertyModuleConfigSkipTrue)
|
||||
.build());
|
||||
add(Property.builder()
|
||||
.dataType(new ArrayList() {
|
||||
{
|
||||
add(DataType.TEXT);
|
||||
}
|
||||
})
|
||||
.name("docId")
|
||||
.description("The doc id of the local knowledge,for search")
|
||||
.moduleConfig(propertyModuleConfigSkipTrue)
|
||||
.build());
|
||||
add(Property.builder()
|
||||
.dataType(new ArrayList() {
|
||||
{
|
||||
add(DataType.TEXT);
|
||||
}
|
||||
})
|
||||
.name("fid")
|
||||
.description("The fragment id of the local knowledge,for search")
|
||||
.moduleConfig(propertyModuleConfigSkipTrue)
|
||||
.build());
|
||||
add(Property.builder()
|
||||
.dataType(new ArrayList() {
|
||||
{
|
||||
add(DataType.TEXT);
|
||||
}
|
||||
})
|
||||
.name("uuid")
|
||||
.description("The uuid id of the local knowledge fragment(same with id properties),for search")
|
||||
.moduleConfig(propertyModuleConfigSkipTrue)
|
||||
.build());
|
||||
} })
|
||||
.build();
|
||||
|
||||
Result<Boolean> result = client.schema().classCreator().withClass(clazz).run();
|
||||
if (result.hasErrors()) {
|
||||
System.out.println(result.getError());
|
||||
}
|
||||
System.out.println(result.getResult());
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void newSchema(String kid) {
|
||||
createSchema(kid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByKidAndFid(String kid, String fid) {
|
||||
List<String> resultList = new ArrayList<>();
|
||||
WeaviateClient client = getClient();
|
||||
Field fieldId = Field.builder().name("uuid").build();
|
||||
WhereFilter where = WhereFilter.builder()
|
||||
.path(new String[]{"fid"})
|
||||
.operator(Operator.Equal)
|
||||
.valueString(fid)
|
||||
.build();
|
||||
Result<GraphQLResponse> result = client.graphQL().get()
|
||||
.withClassName(className + kid)
|
||||
.withFields(fieldId)
|
||||
.withWhere(where)
|
||||
.run();
|
||||
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
|
||||
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
|
||||
ArrayList<LinkedTreeMap> m = l.get(className + kid);
|
||||
for (LinkedTreeMap linkedTreeMap : m) {
|
||||
String uuid = linkedTreeMap.get("uuid").toString();
|
||||
resultList.add(uuid);
|
||||
}
|
||||
for (String uuid : resultList) {
|
||||
Result<Boolean> deleteResult = client.data().deleter()
|
||||
.withID(uuid)
|
||||
.withClassName(className + kid)
|
||||
.withConsistencyLevel(ConsistencyLevel.ALL) // default QUORUM
|
||||
.run();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList) {
|
||||
WeaviateClient client = getClient();
|
||||
|
||||
for (int i = 0; i < Math.min(chunkList.size(), vectorList.size()); i++) {
|
||||
List<Double> vector = vectorList.get(i);
|
||||
Float[] vf = vector.stream().map(Double::floatValue).toArray(Float[]::new);
|
||||
|
||||
Map<String, Object> dataSchema = new HashMap<>();
|
||||
dataSchema.put("content", chunkList.get(i));
|
||||
dataSchema.put("kid", kid);
|
||||
dataSchema.put("docId", docId);
|
||||
dataSchema.put("fid", fidList.get(i));
|
||||
String uuid = UUID.randomUUID().toString();
|
||||
dataSchema.put("uuid", uuid);
|
||||
|
||||
Result<WeaviateObject> result = client.data().creator()
|
||||
.withClassName(className + kid)
|
||||
.withID(uuid)
|
||||
.withVector(vf)
|
||||
.withProperties(dataSchema)
|
||||
.run();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByDocId(String kid, String docId) {
|
||||
List<String> resultList = new ArrayList<>();
|
||||
WeaviateClient client = getClient();
|
||||
Field fieldId = Field.builder().name("uuid").build();
|
||||
WhereFilter where = WhereFilter.builder()
|
||||
.path(new String[]{"docId"})
|
||||
.operator(Operator.Equal)
|
||||
.valueString(docId)
|
||||
.build();
|
||||
Result<GraphQLResponse> result = client.graphQL().get()
|
||||
.withClassName(className + kid)
|
||||
.withFields(fieldId)
|
||||
.withWhere(where)
|
||||
.run();
|
||||
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
|
||||
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
|
||||
ArrayList<LinkedTreeMap> m = l.get(className + kid);
|
||||
for (LinkedTreeMap linkedTreeMap : m) {
|
||||
String uuid = linkedTreeMap.get("uuid").toString();
|
||||
resultList.add(uuid);
|
||||
}
|
||||
for (String uuid : resultList) {
|
||||
Result<Boolean> deleteResult = client.data().deleter()
|
||||
.withID(uuid)
|
||||
.withClassName(className + kid)
|
||||
.withConsistencyLevel(ConsistencyLevel.ALL) // default QUORUM
|
||||
.run();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByKid(String kid) {
|
||||
WeaviateClient client = getClient();
|
||||
Result<Boolean> result = client.schema().classDeleter().withClassName(className + kid).run();
|
||||
if (result.hasErrors()) {
|
||||
System.out.println("删除schema失败" + result.getError());
|
||||
} else {
|
||||
System.out.println("删除schema成功" + result.getResult());
|
||||
}
|
||||
log.info("drop schema by kid, result = {}", result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> nearest(List<Double> queryVector, String kid) {
|
||||
if (StringUtils.isBlank(kid)) {
|
||||
return new ArrayList<String>();
|
||||
}
|
||||
List<String> resultList = new ArrayList<>();
|
||||
Float[] vf = new Float[queryVector.size()];
|
||||
for (int j = 0; j < queryVector.size(); j++) {
|
||||
Double value = queryVector.get(j);
|
||||
vf[j] = value.floatValue();
|
||||
}
|
||||
WeaviateClient client = getClient();
|
||||
Field contentField = Field.builder().name("content").build();
|
||||
Field _additional = Field.builder()
|
||||
.name("_additional")
|
||||
.fields(new Field[]{
|
||||
Field.builder().name("distance").build()
|
||||
}).build();
|
||||
NearVectorArgument nearVector = NearVectorArgument.builder()
|
||||
.vector(vf)
|
||||
.distance(1.6f) // certainty = 1f - distance /2f
|
||||
.build();
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
Result<GraphQLResponse> result = client.graphQL().get()
|
||||
.withClassName(className + kid)
|
||||
.withFields(contentField, _additional)
|
||||
.withNearVector(nearVector)
|
||||
.withLimit(knowledgeInfoVo.getRetrieveLimit())
|
||||
.run();
|
||||
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
|
||||
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
|
||||
ArrayList<LinkedTreeMap> m = l.get(className + kid);
|
||||
for (LinkedTreeMap linkedTreeMap : m) {
|
||||
String content = linkedTreeMap.get("content").toString();
|
||||
resultList.add(content);
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> nearest(String query, String kid) {
|
||||
if (StringUtils.isBlank(kid)) {
|
||||
return new ArrayList<String>();
|
||||
}
|
||||
List<String> resultList = new ArrayList<>();
|
||||
WeaviateClient client = getClient();
|
||||
Field contentField = Field.builder().name("content").build();
|
||||
Field _additional = Field.builder()
|
||||
.name("_additional")
|
||||
.fields(new Field[]{
|
||||
Field.builder().name("distance").build()
|
||||
}).build();
|
||||
NearTextArgument nearText = client.graphQL().arguments().nearTextArgBuilder()
|
||||
.concepts(new String[]{query})
|
||||
.distance(1.6f) // certainty = 1f - distance /2f
|
||||
.build();
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
Result<GraphQLResponse> result = client.graphQL().get()
|
||||
.withClassName(className + kid)
|
||||
.withFields(contentField, _additional)
|
||||
.withNearText(nearText)
|
||||
.withLimit(knowledgeInfoVo.getRetrieveLimit())
|
||||
.run();
|
||||
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
|
||||
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
|
||||
ArrayList<LinkedTreeMap> m = l.get(className + kid);
|
||||
for (LinkedTreeMap linkedTreeMap : m) {
|
||||
String content = linkedTreeMap.get("content").toString();
|
||||
resultList.add(content);
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
|
||||
public Result<Boolean> deleteSchema(String kid) {
|
||||
WeaviateClient client = getClient();
|
||||
Result<Boolean> result = client.schema().classDeleter().withClassName(className + kid).run();
|
||||
if (result.hasErrors()) {
|
||||
System.out.println(result.getError());
|
||||
} else {
|
||||
System.out.println(result.getResult());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package org.ruoyi.chat.factory;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import org.ruoyi.chat.service.knowledge.BgeLargeVectorizationImpl;
|
||||
import org.ruoyi.chat.service.knowledge.OpenAiVectorizationImpl;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.VectorizationService;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* 文本向量化
|
||||
* @author huangkh
|
||||
*/
|
||||
@Component
|
||||
@Slf4j
|
||||
public class VectorizationFactory {
|
||||
|
||||
private final OpenAiVectorizationImpl openAiVectorization;
|
||||
|
||||
private final BgeLargeVectorizationImpl bgeLargeVectorization;
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
public VectorizationFactory(OpenAiVectorizationImpl openAiVectorization, BgeLargeVectorizationImpl bgeLargeVectorization) {
|
||||
this.openAiVectorization = openAiVectorization;
|
||||
this.bgeLargeVectorization = bgeLargeVectorization;
|
||||
}
|
||||
|
||||
public VectorizationService getEmbedding(String kid){
|
||||
String vectorModel = "text-embedding-3-small";
|
||||
if (StrUtil.isNotEmpty(kid)) {
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
if (knowledgeInfoVo != null && StrUtil.isNotEmpty(knowledgeInfoVo.getVectorModel())) {
|
||||
vectorModel = knowledgeInfoVo.getVectorModel();
|
||||
}
|
||||
}
|
||||
return switch (vectorModel) {
|
||||
case "quentinz/bge-large-zh-v1.5" -> bgeLargeVectorization;
|
||||
default -> openAiVectorization;
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -24,13 +24,12 @@ import org.ruoyi.common.core.utils.StringUtils;
|
||||
import org.ruoyi.common.core.utils.file.FileUtils;
|
||||
import org.ruoyi.common.core.utils.file.MimeTypeUtils;
|
||||
import org.ruoyi.common.redis.utils.RedisUtils;
|
||||
import org.ruoyi.domain.ChatSession;
|
||||
import org.ruoyi.domain.bo.ChatSessionBo;
|
||||
import org.ruoyi.domain.bo.QueryVectorBo;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.service.EmbeddingService;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.ruoyi.service.IChatSessionService;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.core.io.InputStreamResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.http.MediaType;
|
||||
@@ -56,9 +55,7 @@ public class SseServiceImpl implements ISseService {
|
||||
|
||||
private final OpenAiStreamClient openAiStreamClient;
|
||||
|
||||
private final EmbeddingService embeddingService;
|
||||
|
||||
private final VectorStoreService vectorStore;
|
||||
private final VectorStoreService vectorStoreService;
|
||||
|
||||
private final IChatCostService chatCostService;
|
||||
|
||||
@@ -170,7 +167,10 @@ public class SseServiceImpl implements ISseService {
|
||||
// 获取对话消息列表
|
||||
List<Message> messages = chatRequest.getMessages();
|
||||
String sysPrompt = chatModelVo.getSystemPrompt();
|
||||
|
||||
|
||||
if(StringUtils.isEmpty(sysPrompt)){
|
||||
// TODO 系统默认提示词,后续会增加提示词管理
|
||||
sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手,名字叫熊猫助手。你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。" +
|
||||
"当前时间:"+ DateUtils.getDate()+
|
||||
"#注意:回复之前注意结合上下文和工具返回内容进行回复。";
|
||||
@@ -184,13 +184,20 @@ public class SseServiceImpl implements ISseService {
|
||||
if(StringUtils.isNotEmpty(chatRequest.getKid())){
|
||||
List<Message> knMessages = new ArrayList<>();
|
||||
String content = messages.get(messages.size() - 1).getContent().toString();
|
||||
List<String> nearestList;
|
||||
List<Double> queryVector = embeddingService.getQueryVector(content, chatRequest.getKid());
|
||||
nearestList = vectorStore.nearest(queryVector, chatRequest.getKid());
|
||||
QueryVectorBo queryVectorBo = new QueryVectorBo();
|
||||
queryVectorBo.setQuery(content);
|
||||
queryVectorBo.setKid(chatRequest.getKid());
|
||||
queryVectorBo.setApiKey(chatModelVo.getApiKey());
|
||||
queryVectorBo.setBaseUrl(chatModelVo.getApiHost());
|
||||
queryVectorBo.setModelName(chatModelVo.getModelName());
|
||||
// TODO 查询向量返回条数,这里应该查询知识库配置
|
||||
queryVectorBo.setMaxResults(3);
|
||||
List<String> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
|
||||
for (String prompt : nearestList) {
|
||||
Message userMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
|
||||
knMessages.add(userMessage);
|
||||
}
|
||||
// TODO 提示词,这里应该查询知识库配置
|
||||
Message userMessage = Message.builder().content(content + (!nearestList.isEmpty() ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "")).role(Message.Role.USER).build();
|
||||
knMessages.add(userMessage);
|
||||
messages.addAll(knMessages);
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
package org.ruoyi.chat.service.knowledge;
|
||||
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.core.exception.ServiceException;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.VectorizationService;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author ageer
|
||||
*/
|
||||
@Component
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class BgeLargeVectorizationImpl implements VectorizationService {
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private final IChatModelService chatModelService;
|
||||
|
||||
@Override
|
||||
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
|
||||
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getVectorModel());
|
||||
|
||||
OllamaAPI api = new OllamaAPI(chatModelVo.getApiHost());
|
||||
|
||||
List<Double> doubleVector;
|
||||
List<List<Double>> vectorList = new ArrayList<>();
|
||||
try {
|
||||
for (String chunk : chunkList) {
|
||||
doubleVector = api.generateEmbeddings(new OllamaEmbeddingsRequestModel(knowledgeInfoVo.getVectorModel(), chunk));
|
||||
vectorList.add(doubleVector);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new ServiceException("文本向量化异常:"+e.getMessage());
|
||||
}
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> singleVectorization(String chunk, String kid) {
|
||||
List<String> chunkList = new ArrayList<>();
|
||||
chunkList.add(chunk);
|
||||
List<List<Double>> vectorList = batchVectorization(chunkList, kid);
|
||||
return vectorList.get(0);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package org.ruoyi.chat.service.knowledge;
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.util.RandomUtil;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
@@ -14,17 +15,23 @@ import org.ruoyi.common.core.utils.StringUtils;
|
||||
import org.ruoyi.common.satoken.utils.LoginHelper;
|
||||
import org.ruoyi.core.page.PageQuery;
|
||||
import org.ruoyi.core.page.TableDataInfo;
|
||||
import org.ruoyi.domain.ChatModel;
|
||||
import org.ruoyi.domain.KnowledgeAttach;
|
||||
import org.ruoyi.domain.KnowledgeFragment;
|
||||
import org.ruoyi.domain.KnowledgeInfo;
|
||||
import org.ruoyi.domain.bo.KnowledgeInfoBo;
|
||||
import org.ruoyi.domain.bo.KnowledgeInfoUploadBo;
|
||||
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.mapper.KnowledgeAttachMapper;
|
||||
import org.ruoyi.mapper.KnowledgeFragmentMapper;
|
||||
import org.ruoyi.mapper.KnowledgeInfoMapper;
|
||||
import org.ruoyi.service.EmbeddingService;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
@@ -42,9 +49,10 @@ import java.util.*;
|
||||
@Service
|
||||
public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(KnowledgeInfoServiceImpl.class);
|
||||
private final KnowledgeInfoMapper baseMapper;
|
||||
|
||||
private final EmbeddingService embeddingService;
|
||||
private final VectorStoreService vectorStoreService;
|
||||
|
||||
private final ResourceLoaderFactory resourceLoaderFactory;
|
||||
|
||||
@@ -52,6 +60,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
|
||||
private final KnowledgeAttachMapper attachMapper;
|
||||
|
||||
private final IChatModelService chatModelService;
|
||||
|
||||
/**
|
||||
* 查询知识库
|
||||
*/
|
||||
@@ -150,7 +160,9 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
knowledgeInfo.setUid(LoginHelper.getLoginUser().getUserId());
|
||||
}
|
||||
baseMapper.insert(knowledgeInfo);
|
||||
embeddingService.createSchema(String.valueOf(knowledgeInfo.getId()));
|
||||
if (knowledgeInfo != null) {
|
||||
vectorStoreService.createSchema(String.valueOf(knowledgeInfo.getId()),bo.getVector());
|
||||
}
|
||||
}else {
|
||||
baseMapper.updateById(knowledgeInfo);
|
||||
}
|
||||
@@ -165,7 +177,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
check(knowledgeInfoList);
|
||||
// 删除向量库信息
|
||||
knowledgeInfoList.forEach(knowledgeInfoVo -> {
|
||||
embeddingService.removeByKid(String.valueOf(knowledgeInfoVo.getId()));
|
||||
vectorStoreService.removeByKid(String.valueOf(knowledgeInfoVo.getId()));
|
||||
});
|
||||
// 删除附件和知识片段
|
||||
fragmentMapper.deleteByMap(map);
|
||||
@@ -197,7 +209,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
List<KnowledgeFragment> knowledgeFragmentList = new ArrayList<>();
|
||||
if (CollUtil.isNotEmpty(chunkList)) {
|
||||
for (int i = 0; i < chunkList.size(); i++) {
|
||||
String fid = RandomUtil.randomString(16);
|
||||
String fid = RandomUtil.randomString(10);
|
||||
fids.add(fid);
|
||||
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
|
||||
knowledgeFragment.setKid(kid);
|
||||
@@ -211,15 +223,36 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||
}
|
||||
fragmentMapper.insertBatch(knowledgeFragmentList);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
log.error("保存知识库信息失败!{}", e.getMessage());
|
||||
}
|
||||
knowledgeAttach.setContent(content);
|
||||
knowledgeAttach.setCreateTime(new Date());
|
||||
attachMapper.insert(knowledgeAttach);
|
||||
embeddingService.storeEmbeddings(chunkList,kid,docId,fids);
|
||||
|
||||
// 通过kid查询知识库信息
|
||||
KnowledgeInfoVo knowledgeInfoVo = baseMapper.selectVoOne(Wrappers.<KnowledgeInfo>lambdaQuery()
|
||||
.eq(KnowledgeInfo::getKid, kid));
|
||||
|
||||
// 通过向量模型查询模型信息
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getVectorModel());
|
||||
|
||||
StoreEmbeddingBo storeEmbeddingBo = new StoreEmbeddingBo();
|
||||
storeEmbeddingBo.setKid(kid);
|
||||
storeEmbeddingBo.setDocId(docId);
|
||||
storeEmbeddingBo.setFids(fids);
|
||||
storeEmbeddingBo.setChunkList(chunkList);
|
||||
storeEmbeddingBo.setModelName(knowledgeInfoVo.getVectorModel());
|
||||
storeEmbeddingBo.setApiKey(chatModelVo.getApiKey());
|
||||
storeEmbeddingBo.setBaseUrl(chatModelVo.getApiHost());
|
||||
vectorStoreService.storeEmbeddings(storeEmbeddingBo);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 检查用户是否有删除知识库权限
|
||||
*
|
||||
* @param knowledgeInfoList 知识库列表
|
||||
*/
|
||||
public void check(List<KnowledgeInfoVo> knowledgeInfoList){
|
||||
LoginUser loginUser = LoginHelper.getLoginUser();
|
||||
for (KnowledgeInfoVo knowledgeInfoVo : knowledgeInfoList) {
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
package org.ruoyi.chat.service.knowledge;
|
||||
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.config.ChatConfig;
|
||||
import org.ruoyi.common.chat.entity.embeddings.Embedding;
|
||||
import org.ruoyi.common.chat.entity.embeddings.EmbeddingResponse;
|
||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.VectorizationService;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class OpenAiVectorizationImpl implements VectorizationService {
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private IChatModelService chatModelService;
|
||||
|
||||
@Getter
|
||||
private OpenAiStreamClient openAiStreamClient;
|
||||
|
||||
private final ChatConfig chatConfig;
|
||||
|
||||
@Override
|
||||
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
|
||||
List<List<Double>> vectorList;
|
||||
// 获取知识库信息
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
if(knowledgeInfoVo == null){
|
||||
log.warn("知识库不存在:请查检ID {}",kid);
|
||||
vectorList=new ArrayList<>();
|
||||
vectorList.add(new ArrayList<>());
|
||||
return vectorList;
|
||||
}
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(knowledgeInfoVo.getVectorModel());
|
||||
String apiHost= chatModelVo.getApiHost();
|
||||
String apiKey= chatModelVo.getApiKey();
|
||||
openAiStreamClient = ChatConfig.createOpenAiStreamClient(apiHost,apiKey);
|
||||
Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo);
|
||||
EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
|
||||
// 处理 OpenAI 返回的嵌入数据
|
||||
vectorList = processOpenAiEmbeddings(embeddings);
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Embedding 对象
|
||||
*/
|
||||
private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) {
|
||||
return Embedding.builder()
|
||||
.input(chunkList)
|
||||
.model(knowledgeInfoVo.getVectorModel())
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 OpenAI 返回的嵌入数据
|
||||
*/
|
||||
private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) {
|
||||
List<List<Double>> vectorList = new ArrayList<>();
|
||||
|
||||
embeddings.getData().forEach(data -> {
|
||||
List<BigDecimal> vector = data.getEmbedding();
|
||||
List<Double> doubleVector = convertToDoubleList(vector);
|
||||
vectorList.add(doubleVector);
|
||||
});
|
||||
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 BigDecimal 转换为 Double 列表
|
||||
*/
|
||||
private List<Double> convertToDoubleList(List<BigDecimal> vector) {
|
||||
return vector.stream()
|
||||
.map(BigDecimal::doubleValue)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<Double> singleVectorization(String chunk, String kid) {
|
||||
List<String> chunkList = new ArrayList<>();
|
||||
chunkList.add(chunk);
|
||||
List<List<Double>> vectorList = batchVectorization(chunkList, kid);
|
||||
return vectorList.get(0);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package org.ruoyi.chat.service.knowledge;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.factory.VectorizationFactory;
|
||||
import org.ruoyi.service.VectorizationService;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@Primary
|
||||
@AllArgsConstructor
|
||||
public class VectorizationWrapper implements VectorizationService {
|
||||
|
||||
private final VectorizationFactory vectorizationFactory;
|
||||
@Override
|
||||
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
|
||||
VectorizationService embedding = vectorizationFactory.getEmbedding(kid);
|
||||
return embedding.batchVectorization(chunkList, kid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> singleVectorization(String chunk, String kid) {
|
||||
VectorizationService embedding = vectorizationFactory.getEmbedding(kid);
|
||||
return embedding.singleVectorization(chunk, kid);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user