feat: 调整知识库模块

This commit is contained in:
ageerle
2025-04-09 17:41:29 +08:00
parent be6d027cad
commit 3be9005f95
424 changed files with 1584 additions and 10005 deletions

View File

@@ -5,11 +5,12 @@ import jakarta.servlet.http.HttpServletRequest;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.chat.service.chat.ISseService;
import org.ruoyi.common.chat.domain.request.ChatRequest;
import org.ruoyi.common.chat.domain.request.Dall3Request;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
import org.ruoyi.common.chat.entity.images.Item;
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
import org.ruoyi.common.core.domain.R;
import org.ruoyi.common.core.domain.model.LoginUser;
@@ -17,8 +18,10 @@ import org.ruoyi.common.core.exception.base.BaseException;
import org.ruoyi.common.mybatis.core.page.PageQuery;
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
import org.ruoyi.common.satoken.utils.LoginHelper;
import org.ruoyi.system.domain.request.translation.TranslationRequest;
import org.ruoyi.system.service.ISseService;
import org.ruoyi.domain.bo.ChatMessageBo;
import org.ruoyi.domain.vo.ChatMessageVo;
import org.ruoyi.service.IChatMessageService;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
@@ -26,7 +29,6 @@ import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.List;
/**
* 描述:聊天管理
@@ -56,7 +58,6 @@ public class ChatController {
return sseService.sseChat(chatRequest,request);
}
/**
* 上传文件
*/
@@ -90,22 +91,6 @@ public class ChatController {
return sseService.textToSpeed(textToSpeech);
}
/**
* 文本翻译
*
* @param
*/
@PostMapping("/translation")
@ResponseBody
public String translation(@RequestBody TranslationRequest translationRequest) {
return sseService.translation(translationRequest);
}
@PostMapping("/dall3")
@ResponseBody
public R<List<Item>> dall3(@RequestBody @Valid Dall3Request request) {
return R.ok(sseService.dall3(request));
}
/**
* 聊天记录

View File

@@ -7,7 +7,7 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Request;
import org.apache.commons.lang3.math.NumberUtils;
import org.ruoyi.chat.dto.*;
import org.ruoyi.chat.domain.dto.*;
import org.ruoyi.chat.enums.ActionType;
import org.ruoyi.chat.util.MjOkHttpUtil;
import org.springframework.web.bind.annotation.PostMapping;

View File

@@ -7,7 +7,7 @@ import io.swagger.annotations.ApiParam;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Request;
import org.ruoyi.chat.dto.TaskConditionDTO;
import org.ruoyi.chat.domain.dto.TaskConditionDTO;
import org.ruoyi.chat.util.MjOkHttpUtil;
import org.springframework.web.bind.annotation.*;

View File

@@ -1,4 +1,4 @@
package org.ruoyi.chat.dto;
package org.ruoyi.chat.domain.dto;
import io.swagger.annotations.ApiModelProperty;
import lombok.Getter;

View File

@@ -1,4 +1,4 @@
package org.ruoyi.chat.dto;
package org.ruoyi.chat.domain.dto;
import io.swagger.annotations.ApiModel;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package org.ruoyi.chat.dto;
package org.ruoyi.chat.domain.dto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;

View File

@@ -1,4 +1,4 @@
package org.ruoyi.chat.dto;
package org.ruoyi.chat.domain.dto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;

View File

@@ -1,4 +1,4 @@
package org.ruoyi.chat.dto;
package org.ruoyi.chat.domain.dto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;

View File

@@ -1,4 +1,4 @@
package org.ruoyi.chat.dto;
package org.ruoyi.chat.domain.dto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;

View File

@@ -1,4 +1,4 @@
package org.ruoyi.chat.dto;
package org.ruoyi.chat.domain.dto;
import io.swagger.annotations.ApiModel;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package org.ruoyi.chat.dto;
package org.ruoyi.chat.domain.dto;
import io.swagger.annotations.ApiModel;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package org.ruoyi.chat.dto;
package org.ruoyi.chat.domain.dto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;

View File

@@ -1,4 +1,4 @@
package org.ruoyi.chat.dto;
package org.ruoyi.chat.domain.dto;
import io.swagger.annotations.ApiModel;
import lombok.Data;

View File

@@ -0,0 +1,107 @@
package org.ruoyi.chat.listener;
import cn.hutool.core.collection.CollectionUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.core.utils.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
import java.util.Objects;
/**
* 描述OpenAIEventSourceListener
*
* @author https:www.unfbx.com
* @date 2023-02-22
*/
@Slf4j
@RequiredArgsConstructor
@Component
public class SSEEventSourceListener extends EventSourceListener {
@Autowired(required = false)
public SSEEventSourceListener(ResponseBodyEmitter emitter) {
this.emitter = emitter;
}
private ResponseBodyEmitter emitter;
private StringBuilder stringBuffer;
private String modelName;
/**
* {@inheritDoc}
*/
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("OpenAI建立sse连接...");
}
/**
* {@inheritDoc}
*/
@SneakyThrows
@Override
public void onEvent(@NotNull EventSource eventSource, String id, String type, String data) {
try {
if ("[DONE]".equals(data)) {
//成功响应
emitter.complete();
// 扣除费用 (消耗字符 模型名称)
return;
}
// 解析返回内容
ObjectMapper mapper = new ObjectMapper();
ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class);
if(completionResponse == null || CollectionUtil.isEmpty(completionResponse.getChoices())){
return;
}
Object content = completionResponse.getChoices().get(0).getDelta().getContent();
if(content == null){
content = completionResponse.getChoices().get(0).getDelta().getReasoningContent();
if(content == null) return;
}
if(StringUtils.isEmpty(modelName)){
modelName = completionResponse.getModel();
}
stringBuffer.append(content);
emitter.send(data);
} catch (Exception e) {
log.error("sse信息推送失败{}内容:{}",e.getMessage(),data);
eventSource.cancel();
}
}
@Override
public void onClosed(EventSource eventSource) {
log.info("OpenAI关闭sse连接...");
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if (Objects.isNull(response)) {
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
log.error("OpenAI sse连接异常data{},异常:{}", body.string(), t);
} else {
log.error("OpenAI sse连接异常data{},异常:{}", response, t);
}
eventSource.cancel();
}
}

View File

@@ -0,0 +1,43 @@
package org.ruoyi.chat.service.chat;
import org.ruoyi.domain.bo.ChatMessageBo;
/**
* 计费管理Service接口
*
* @author ageerle
* @date 2025-04-08
*/
public interface IChatCostService {
/**
* 根据消耗的tokens扣除余额
*
* @param chatMessageBo
* @return 结果
*/
void deductToken(ChatMessageBo chatMessageBo);
/**
* 扣除用户的余额
*
*/
void deductUserBalance(Long userId, Double numberCost);
/**
* 扣除任务费用并且保存记录
*
* @param type 任务类型
* @param prompt 任务描述
* @param cost 扣除费用
*/
void taskDeduct(String type,String prompt, double cost);
/**
* 判断用户是否付费
*/
void checkUserGrade();
}

View File

@@ -0,0 +1,65 @@
package org.ruoyi.chat.service.chat;
import jakarta.servlet.http.HttpServletRequest;
import org.ruoyi.common.chat.domain.request.ChatRequest;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* 用户聊天管理Service接口
*
* @author ageerle
* @date 2025-04-08
*/
public interface ISseService {
/**
* 客户端发送消息到服务端
* @param chatRequest 请求对象
*/
SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request);
/**
* 语音转文字
* @param file 语音文件
*/
WhisperResponse speechToTextTranscriptionsV2(MultipartFile file);
/**
* 文字转语音
*
* @param textToSpeech 文本信息
* @return 流式语音
*/
ResponseEntity<Resource> textToSpeed(TextToSpeech textToSpeech);
/**
* 上传文件到api服务器
*
* @param file 文件信息
* @return 返回文件信息
*/
UploadFileResponse upload(MultipartFile file);
/**
* 使用ollama调用本地模型
* @param chatRequest 对话信息
* @return 流式输出返回内容
*/
SseEmitter ollamaChat(ChatRequest chatRequest);
/**
* 企业应用回复
* @param prompt 提示词
* @return 回复内容
*/
String wxCpChat(String prompt);
}

View File

@@ -0,0 +1,292 @@
package org.ruoyi.chat.service.chat.impl;
import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.extra.servlet.ServletUtil;
import com.google.protobuf.ServiceException;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.models.chat.OllamaChatMessage;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
import io.github.ollama4j.models.chat.OllamaChatRequestModel;
import io.github.ollama4j.models.generate.OllamaStreamHandler;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.ruoyi.chat.listener.SSEEventSourceListener;
import org.ruoyi.chat.service.chat.ISseService;
import org.ruoyi.common.chat.config.ChatConfig;
import org.ruoyi.common.chat.domain.request.ChatRequest;
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.common.core.utils.file.FileUtils;
import org.ruoyi.common.core.utils.file.MimeTypeUtils;
import org.ruoyi.common.redis.utils.RedisUtils;
import org.ruoyi.domain.vo.ChatModelVo;
import org.ruoyi.service.IChatModelService;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.Resource;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
@Service
@Slf4j
@RequiredArgsConstructor
public class SseServiceImpl implements ISseService {
private OpenAiStreamClient openAiStreamClient;
private final ChatConfig chatConfig;
private final IChatModelService chatModelService;
@Override
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
SseEmitter sseEmitter = new SseEmitter(0L);
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
// 获取对话消息列表
List<Message> messages = chatRequest.getMessages();
try {
if (StpUtil.isLogin()) {
// 通过模型名称查询模型信息
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
// 构建api请求客户端
openAiStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
// 模型设置默认提示词
// 是否开启联网查询
}else {
// 未登录用户限制对话次数,默认5次
String clientIp = ServletUtil.getClientIP((javax.servlet.http.HttpServletRequest) request,"X-Forwarded-For");
int timeWindowInSeconds = 5;
String redisKey = "visitor:" + clientIp;
int count = 0;
if (RedisUtils.getCacheObject(redisKey) == null) {
// 当前访问次数
RedisUtils.setCacheObject(redisKey, count, Duration.ofSeconds(86400));
}else {
count = RedisUtils.getCacheObject(redisKey);
if (count >= timeWindowInSeconds) {
throw new ServiceException("当日免费次数已用完");
}
count++;
RedisUtils.setCacheObject(redisKey, count);
}
}
ChatCompletion completion = ChatCompletion
.builder()
.messages(messages)
.model(chatRequest.getModel())
.temperature(chatRequest.getTemperature())
.topP(chatRequest.getTop_p())
.stream(true)
.build();
openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
// 保存消息记录 并扣除费用
} catch (Exception e) {
String message = e.getMessage();
sendErrorEvent(sseEmitter, message);
return sseEmitter;
}
return sseEmitter;
}
/**
* 发送SSE错误事件的封装方法
*
* @param sseEmitter
* @param errorMessage
*/
private void sendErrorEvent(SseEmitter sseEmitter, String errorMessage) {
SseEmitter.SseEventBuilder event = SseEmitter.event()
.name("error")
.data(errorMessage);
try {
sseEmitter.send(event);
} catch (IOException e) {
log.error("发送事件失败: {}", e.getMessage());
}
sseEmitter.complete();
}
/**
* 文字转语音
*/
@Override
public ResponseEntity<Resource> textToSpeed(TextToSpeech textToSpeech) {
ResponseBody body = openAiStreamClient.textToSpeech(textToSpeech);
if (body != null) {
// 将ResponseBody转换为InputStreamResource
InputStreamResource resource = new InputStreamResource(body.byteStream());
// 创建并返回ResponseEntity
return ResponseEntity.ok()
.contentType(MediaType.parseMediaType("audio/mpeg"))
.body(resource);
} else {
// 如果ResponseBody为空返回404状态码
return ResponseEntity.notFound().build();
}
}
/**
* 语音转文字
*/
@Override
public WhisperResponse speechToTextTranscriptionsV2(MultipartFile file) {
// 确保文件不为空
if (file.isEmpty()) {
throw new IllegalStateException("Cannot convert an empty MultipartFile");
}
if (!FileUtils.isValidFileExtention(file, MimeTypeUtils.AUDIO__EXTENSION)) {
throw new IllegalStateException("File Extention not supported");
}
// 创建一个文件对象
File fileA = new File(System.getProperty("java.io.tmpdir") + File.separator + file.getOriginalFilename());
try {
// 将 MultipartFile 的内容写入文件
file.transferTo(fileA);
} catch (IOException e) {
throw new RuntimeException("Failed to convert MultipartFile to File", e);
}
return openAiStreamClient.speechToTextTranscriptions(fileA);
}
@Override
public UploadFileResponse upload(MultipartFile file) {
if (file.isEmpty()) {
throw new IllegalStateException("Cannot upload an empty MultipartFile");
}
if (!FileUtils.isValidFileExtention(file, MimeTypeUtils.DEFAULT_ALLOWED_EXTENSION)) {
throw new IllegalStateException("File Extention not supported");
}
openAiStreamClient = chatConfig.getOpenAiStreamClient();
return openAiStreamClient.uploadFile("fine-tune", convertMultiPartToFile(file));
}
private File convertMultiPartToFile(MultipartFile multipartFile) {
File file = null;
try {
// 获取原始文件名
String originalFileName = multipartFile.getOriginalFilename();
// 默认扩展名
String extension = ".tmp";
// 尝试从原始文件名中获取扩展名
if (originalFileName != null && originalFileName.contains(".")) {
extension = originalFileName.substring(originalFileName.lastIndexOf("."));
}
// 使用原始文件的扩展名创建临时文件
Path tempFile = Files.createTempFile(null, extension);
file = tempFile.toFile();
// 将MultipartFile的内容写入文件
try (InputStream inputStream = multipartFile.getInputStream();
FileOutputStream outputStream = new FileOutputStream(file)) {
int read;
byte[] bytes = new byte[1024];
while ((read = inputStream.read(bytes)) != -1) {
outputStream.write(bytes, 0, read);
}
} catch (IOException e) {
// 处理文件写入异常
e.printStackTrace();
}
} catch (IOException e) {
// 处理临时文件创建异常
e.printStackTrace();
}
return file;
}
@Override
public SseEmitter ollamaChat(ChatRequest chatRequest) {
String[] parts = chatRequest.getModel().split("ollama-");
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
final SseEmitter emitter = new SseEmitter();
String host = chatModelVo.getApiHost();
List<Message> msgList = chatRequest.getMessages();
List<OllamaChatMessage> messages = new ArrayList<>();
for (Message message : msgList) {
OllamaChatMessage ollamaChatMessage = new OllamaChatMessage();
ollamaChatMessage.setRole(OllamaChatMessageRole.USER);
ollamaChatMessage.setContent(message.getContent().toString());
messages.add(ollamaChatMessage);
}
OllamaAPI api = new OllamaAPI(host);
api.setRequestTimeoutSeconds(100);
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(parts[1]);
OllamaChatRequestModel requestModel = builder
.withMessages(messages)
.build();
// 异步执行 OllAma API 调用
CompletableFuture.runAsync(() -> {
try {
StringBuilder response = new StringBuilder();
OllamaStreamHandler streamHandler = (s) -> {
String substr = s.substring(response.length());
response.append(substr);
System.out.println(substr);
try {
emitter.send(substr);
} catch (IOException e) {
sendErrorEvent(emitter, e.getMessage());
}
};
api.chat(requestModel, streamHandler);
emitter.complete();
} catch (Exception e) {
sendErrorEvent(emitter, e.getMessage());
}
});
return emitter;
}
@Override
public String wxCpChat(String prompt) {
List<Message> messageList = new ArrayList<>();
Message message = Message.builder().role(Message.Role.USER).content(prompt).build();
messageList.add(message);
ChatCompletion chatCompletion = ChatCompletion
.builder()
.messages(messageList)
.model("gpt-4o-mini")
.stream(false)
.build();
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
}
}

View File

@@ -0,0 +1,64 @@
package org.ruoyi.chat.service.knowledge.vectorizer;
import com.google.gson.Gson;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
import jakarta.annotation.Resource;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.domain.vo.KnowledgeInfoVo;
import org.ruoyi.service.IKnowledgeInfoService;
import org.ruoyi.service.VectorizationService;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@Component
@Slf4j
@RequiredArgsConstructor
public class BgeLargeVectorization implements VectorizationService {
String host = "http://localhost:11434/";
@Lazy
@Resource
private IKnowledgeInfoService knowledgeInfoService;
@Override
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
OllamaAPI ollamaAPI = new OllamaAPI(host);
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
List<Double> doubleVector;
try {
doubleVector = ollamaAPI.generateEmbeddings(new OllamaEmbeddingsRequestModel(knowledgeInfoVo.getVectorModel(), new Gson().toJson(chunkList)));
} catch (Exception e) {
throw new RuntimeException(e);
}
List<List<Double>> vectorList = new ArrayList<>();
vectorList.add(doubleVector);
return vectorList;
}
@Override
public List<Double> singleVectorization(String chunk, String kid) {
List<String> chunkList = new ArrayList<>();
chunkList.add(chunk);
List<List<Double>> vectorList = batchVectorization(chunkList, kid);
return vectorList.get(0);
}
public static void main(String[] args) {
OllamaAPI ollamaAPI = new OllamaAPI("http://localhost:11434/");
List<String> chunkList = Arrays.asList("天很蓝", "海很深");
List<Double> doubleVector;
try {
doubleVector = ollamaAPI.generateEmbeddings(new OllamaEmbeddingsRequestModel("quentinz/bge-large-zh-v1.5", new Gson().toJson(chunkList)));
} catch (Exception e) {
throw new RuntimeException(e);
}
System.out.println("=== " + doubleVector + " 1===");
}
}

View File

@@ -0,0 +1,110 @@
package org.ruoyi.chat.service.knowledge.vectorizer;
import jakarta.annotation.Resource;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.common.chat.config.ChatConfig;
import org.ruoyi.common.chat.entity.embeddings.Embedding;
import org.ruoyi.common.chat.entity.embeddings.EmbeddingResponse;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.domain.vo.KnowledgeInfoVo;
import org.ruoyi.service.IKnowledgeInfoService;
import org.ruoyi.service.VectorizationService;
import org.ruoyi.system.domain.SysModel;
import org.ruoyi.system.service.ISysModelService;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Component;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@Component
@Slf4j
@RequiredArgsConstructor
public class OpenAiVectorization implements VectorizationService {
@Lazy
@Resource
private IKnowledgeInfoService knowledgeInfoService;
@Lazy
@Resource
private ISysModelService sysModelService;
@Getter
private OpenAiStreamClient openAiStreamClient;
private final ChatConfig chatConfig;
@Override
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
List<List<Double>> vectorList;
// 获取知识库信息
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
if(knowledgeInfoVo == null){
log.warn("知识库不存在:请查检ID {}",kid);
vectorList=new ArrayList<>();
vectorList.add(new ArrayList<>());
return vectorList;
}
SysModel sysModel = sysModelService.selectModelByName(knowledgeInfoVo.getVectorModel());
String apiHost= sysModel.getApiHost();
String apiKey= sysModel.getApiKey();
openAiStreamClient = chatConfig.createOpenAiStreamClient(apiHost,apiKey);
Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo);
EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
// 处理 OpenAI 返回的嵌入数据
vectorList = processOpenAiEmbeddings(embeddings);
return vectorList;
}
/**
* 构建 Embedding 对象
*/
private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) {
return Embedding.builder()
.input(chunkList)
.model(knowledgeInfoVo.getVectorModel())
.build();
}
/**
* 处理 OpenAI 返回的嵌入数据
*/
private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) {
List<List<Double>> vectorList = new ArrayList<>();
embeddings.getData().forEach(data -> {
List<BigDecimal> vector = data.getEmbedding();
List<Double> doubleVector = convertToDoubleList(vector);
vectorList.add(doubleVector);
});
return vectorList;
}
/**
* 将 BigDecimal 转换为 Double 列表
*/
private List<Double> convertToDoubleList(List<BigDecimal> vector) {
return vector.stream()
.map(BigDecimal::doubleValue)
.collect(Collectors.toList());
}
@Override
public List<Double> singleVectorization(String chunk, String kid) {
List<String> chunkList = new ArrayList<>();
chunkList.add(chunk);
List<List<Double>> vectorList = batchVectorization(chunkList, kid);
return vectorList.get(0);
}
}

View File

@@ -0,0 +1,47 @@
package org.ruoyi.chat.service.knowledge.vectorizer;
import cn.hutool.core.util.StrUtil;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.domain.vo.KnowledgeInfoVo;
import org.ruoyi.service.IKnowledgeInfoService;
import org.ruoyi.service.VectorizationService;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Component;
/**
* 文本向量化
* @author huangkh
*/
@Component
@Slf4j
public class VectorizationFactory {
private final OpenAiVectorization openAiVectorization;
private final BgeLargeVectorization bgeLargeVectorization;
@Lazy
@Resource
private IKnowledgeInfoService knowledgeInfoService;
public VectorizationFactory(OpenAiVectorization openAiVectorization, BgeLargeVectorization bgeLargeVectorization) {
this.openAiVectorization = openAiVectorization;
this.bgeLargeVectorization = bgeLargeVectorization;
}
public VectorizationService getEmbedding(String kid){
String vectorModel = "text-embedding-3-small";
if (StrUtil.isNotEmpty(kid)) {
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
if (knowledgeInfoVo != null && StrUtil.isNotEmpty(knowledgeInfoVo.getVectorModel())) {
vectorModel = knowledgeInfoVo.getVectorModel();
}
}
return switch (vectorModel) {
case "quentinz/bge-large-zh-v1.5" -> bgeLargeVectorization;
default -> openAiVectorization;
};
}
}

View File

@@ -0,0 +1,397 @@
package org.ruoyi.chat.service.knowledge.vectorstore;
import io.milvus.client.MilvusServiceClient;
import io.milvus.grpc.DataType;
import io.milvus.grpc.DescribeIndexResponse;
import io.milvus.grpc.MutationResult;
import io.milvus.grpc.SearchResults;
import io.milvus.param.*;
import io.milvus.param.collection.*;
import io.milvus.param.dml.DeleteParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.param.index.DescribeIndexParam;
import io.milvus.param.partition.CreatePartitionParam;
import io.milvus.response.QueryResultsWrapper;
import io.milvus.response.SearchResultsWrapper;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.service.VectorStoreService;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@Service
@Slf4j
public class MilvusVectorStore implements VectorStoreService {
private volatile Integer dimension;
private volatile String collectionName;
private MilvusServiceClient milvusServiceClient;
@Resource
private ConfigService configService;
@PostConstruct
public void loadConfig() {
this.dimension = Integer.parseInt(configService.getConfigValue("milvus", "dimension"));
this.collectionName = configService.getConfigValue("milvus", "collection");
}
@PostConstruct
public void init() {
String milvusHost = configService.getConfigValue("milvus", "host");
String milvausPort = configService.getConfigValue("milvus", "port");
milvusServiceClient = new MilvusServiceClient(
ConnectParam.newBuilder()
.withHost(milvusHost)
.withPort(Integer.parseInt(milvausPort))
.withDatabaseName("default")
.build()
);
}
private void createSchema(String kid) {
FieldType primaryField = FieldType.newBuilder()
.withName("row_id")
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(true)
.build();
FieldType contentField = FieldType.newBuilder()
.withName("content")
.withDataType(DataType.VarChar)
.withMaxLength(1000)
.build();
FieldType kidField = FieldType.newBuilder()
.withName("kid")
.withDataType(DataType.VarChar)
.withMaxLength(20)
.build();
FieldType docIdField = FieldType.newBuilder()
.withName("docId")
.withDataType(DataType.VarChar)
.withMaxLength(20)
.build();
FieldType fidField = FieldType.newBuilder()
.withName("fid")
.withDataType(DataType.VarChar)
.withMaxLength(20)
.build();
FieldType vectorField = FieldType.newBuilder()
.withName("fv")
.withDataType(DataType.FloatVector)
.withDimension(dimension)
.build();
CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName + kid)
.withDescription("local knowledge")
.addFieldType(primaryField)
.addFieldType(contentField)
.addFieldType(kidField)
.addFieldType(docIdField)
.addFieldType(fidField)
.addFieldType(vectorField)
.build();
milvusServiceClient.createCollection(createCollectionReq);
// 创建向量的索引
IndexType INDEX_TYPE = IndexType.IVF_FLAT;
String INDEX_PARAM = "{\"nlist\":1024}";
milvusServiceClient.createIndex(
CreateIndexParam.newBuilder()
.withCollectionName(collectionName + kid)
.withFieldName("fv")
.withIndexType(INDEX_TYPE)
.withMetricType(MetricType.IP)
.withExtraParam(INDEX_PARAM)
.withSyncMode(Boolean.FALSE)
.build()
);
}
@Override
public void newSchema(String kid) {
createSchema(kid);
}
@Override
public void removeByKidAndFid(String kid, String fid) {
milvusServiceClient.delete(
DeleteParam.newBuilder()
.withCollectionName(collectionName + kid)
.withExpr("fid == " + fid)
.build()
);
}
@Override
public void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList) {
String fullCollectionName = collectionName + kid;
// 检查集合是否存在
HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder()
.withCollectionName(fullCollectionName)
.build();
R<Boolean> booleanR = milvusServiceClient.hasCollection(hasCollectionParam);
if (booleanR.getStatus() == R.Status.Success.getCode()) {
boolean collectionExists = booleanR.getData().booleanValue();
if (!collectionExists) {
// 集合不存在,创建集合
List<FieldType> fieldTypes = new ArrayList<>();
// 假设这里定义 id 字段,根据实际情况修改
FieldType idField = FieldType.newBuilder()
.withName("id")
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(true)
.build();
fieldTypes.add(idField);
// 定义向量字段
FieldType vectorField = FieldType.newBuilder()
.withName("fv")
.withDataType(DataType.FloatVector)
.withDimension(vectorList.get(0).size())
.build();
fieldTypes.add(vectorField);
// 定义其他字段
FieldType contentField = FieldType.newBuilder()
.withName("content")
.withDataType(DataType.VarChar)
.withMaxLength(chunkList.size() * 1024) // 根据实际情况修改
.build();
fieldTypes.add(contentField);
FieldType kidField = FieldType.newBuilder()
.withName("kid")
.withDataType(DataType.VarChar)
.withMaxLength(256) // 根据实际情况修改
.build();
fieldTypes.add(kidField);
FieldType docIdField = FieldType.newBuilder()
.withName("docId")
.withDataType(DataType.VarChar)
.withMaxLength(256) // 根据实际情况修改
.build();
fieldTypes.add(docIdField);
FieldType fidField = FieldType.newBuilder()
.withName("fid")
.withDataType(DataType.VarChar)
.withMaxLength(256) // 根据实际情况修改
.build();
fieldTypes.add(fidField);
CreateCollectionParam createCollectionParam = CreateCollectionParam.newBuilder()
.withCollectionName(fullCollectionName)
.withFieldTypes(fieldTypes)
.build();
R<RpcStatus> collection = milvusServiceClient.createCollection(createCollectionParam);
if (collection.getStatus() == R.Status.Success.getCode()) {
System.out.println("集合 " + fullCollectionName + " 创建成功");
// 创建索引
CreateIndexParam createIndexParam = CreateIndexParam.newBuilder()
.withCollectionName(fullCollectionName)
.withFieldName("fv") // 向量字段名
.withIndexType(IndexType.IVF_FLAT) // 索引类型
.withMetricType(MetricType.IP)
.withExtraParam("{\"nlist\":1024}") // 索引参数
.build();
R<RpcStatus> indexResponse = milvusServiceClient.createIndex(createIndexParam);
if (indexResponse.getStatus() == R.Status.Success.getCode()) {
System.out.println("索引创建成功");
} else {
System.err.println("索引创建失败: " + indexResponse.getMessage());
return;
}
} else {
System.err.println("集合创建失败: " + collection.getMessage());
return;
}
}
} else {
System.err.println("检查集合是否存在时出错: " + booleanR.getMessage());
return;
}
if (StringUtils.isNotBlank(docId)) {
milvusServiceClient.createPartition(
CreatePartitionParam.newBuilder()
.withCollectionName(fullCollectionName)
.withPartitionName(docId)
.build()
);
}
List<List<Float>> vectorFloatList = new ArrayList<>();
List<String> kidList = new ArrayList<>();
List<String> docIdList = new ArrayList<>();
for (int i = 0; i < Math.min(chunkList.size(), vectorList.size()); i++) {
List<Double> vector = vectorList.get(i);
List<Float> vfList = new ArrayList<>();
for (int j = 0; j < vector.size(); j++) {
Double value = vector.get(j);
vfList.add(value.floatValue());
}
vectorFloatList.add(vfList);
kidList.add(kid);
docIdList.add(docId);
}
List<InsertParam.Field> fields = new ArrayList<>();
fields.add(new InsertParam.Field("content", chunkList));
fields.add(new InsertParam.Field("kid", kidList));
fields.add(new InsertParam.Field("docId", docIdList));
fields.add(new InsertParam.Field("fid", fidList));
fields.add(new InsertParam.Field("fv", vectorFloatList));
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(fullCollectionName)
.withPartitionName(docId)
.withFields(fields)
.build();
System.out.println("=========================");
R<MutationResult> insert = milvusServiceClient.insert(insertParam);
if (insert.getStatus() == R.Status.Success.getCode()) {
System.out.println("插入成功,插入的行数: " + insert.getData().getInsertCnt());
} else {
System.err.println("插入失败: " + insert.getMessage());
}
System.out.println("=========================");
// milvus在将数据装载到内存后才能进行向量计算.
LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
.withCollectionName(fullCollectionName)
.build();
R<RpcStatus> loadResponse = milvusServiceClient.loadCollection(loadCollectionParam);
if (loadResponse.getStatus() != R.Status.Success.getCode()) {
System.err.println("加载集合 " + fullCollectionName + " 到内存时出错:" + loadResponse.getMessage());
}
// milvusServiceClient.loadCollection(LoadCollectionParam.newBuilder().withCollectionName(fullCollectionName).build());
}
@Override
public void removeByDocId(String kid, String docId) {
milvusServiceClient.delete(
DeleteParam.newBuilder()
.withCollectionName(collectionName + kid)
.withExpr("1 == 1")
.withPartitionName(docId)
.build()
);
}
@Override
public void removeByKid(String kid) {
milvusServiceClient.dropCollection(
DropCollectionParam.newBuilder()
.withCollectionName(collectionName + kid)
.build()
);
}
@Override
public List<String> nearest(List<Double> queryVector, String kid) {
String fullCollectionName = collectionName + kid;
HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder()
.withCollectionName(fullCollectionName)
.build();
R<Boolean> booleanR = milvusServiceClient.hasCollection(hasCollectionParam);
if (booleanR.getStatus() != R.Status.Success.getCode() || !booleanR.getData().booleanValue()) {
System.err.println("集合 " + fullCollectionName + " 不存在或检查集合存在性时出错。");
return new ArrayList<>();
}
DescribeIndexParam describeIndexParam = DescribeIndexParam.newBuilder().withCollectionName(fullCollectionName).build();
R<DescribeIndexResponse> describeIndexResponseR = milvusServiceClient.describeIndex(describeIndexParam);
if (describeIndexResponseR.getStatus() == R.Status.Success.getCode()) {
System.out.println("索引信息: " + describeIndexResponseR.getData().getIndexDescriptionsCount());
} else {
System.err.println("获取索引失败: " + describeIndexResponseR.getMessage());
}
// // 加载集合到内存
// LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
// .withCollectionName(fullCollectionName)
// .build();
// R<RpcStatus> loadResponse = milvusServiceClient.loadCollection(loadCollectionParam);
// if (loadResponse.getStatus() != R.Status.Success.getCode()) {
// System.err.println("加载集合 " + fullCollectionName + " 到内存时出错:" + loadResponse.getMessage());
// return new ArrayList<>();
// }
List<String> search_output_fields = Arrays.asList("content", "fv");
List<Float> fv = new ArrayList<>();
for (int i = 0; i < queryVector.size(); i++) {
fv.add(queryVector.get(i).floatValue());
}
List<List<Float>> vectors = new ArrayList<>();
vectors.add(fv);
String search_param = "{\"nprobe\":10, \"offset\":0}";
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(collectionName + kid)
.withMetricType(MetricType.IP)
.withOutFields(search_output_fields)
.withTopK(10)
.withVectors(vectors)
.withVectorFieldName("fv")
.withParams(search_param)
.build();
System.out.println("SearchParam: " + searchParam.toString());
R<SearchResults> respSearch = milvusServiceClient.search(searchParam);
if (respSearch.getStatus() == R.Status.Success.getCode()) {
SearchResults searchResults = respSearch.getData();
if (searchResults != null) {
System.out.println(searchResults.getResults());
SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(searchResults.getResults());
List<QueryResultsWrapper.RowRecord> rowRecords = wrapperSearch.getRowRecords();
List<String> resultList = new ArrayList<>();
if (rowRecords != null && !rowRecords.isEmpty()) {
for (QueryResultsWrapper.RowRecord rowRecord : rowRecords) {
String content = rowRecord.get("content").toString();
resultList.add(content);
}
}
return resultList;
} else {
System.err.println("搜索结果为空");
}
} else {
System.err.println("搜索操作失败: " + respSearch.getMessage());
}
return new ArrayList<>();
}
/**
* milvus 不支持通过文本检索相似性
*
* @param query
* @param kid
* @return
*/
@Override
public List<String> nearest(String query, String kid) {
return null;
}
}

View File

@@ -0,0 +1,42 @@
package org.ruoyi.chat.service.knowledge.vectorstore;
import cn.hutool.core.util.StrUtil;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.ruoyi.domain.vo.KnowledgeInfoVo;
import org.ruoyi.mapper.KnowledgeInfoMapper;
import org.ruoyi.service.VectorStoreService;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class VectorStoreFactory {
private final WeaviateVectorStore weaviateVectorStore;
private final MilvusVectorStore milvusVectorStore;
@Resource
private KnowledgeInfoMapper knowledgeInfoMapper;
public VectorStoreFactory(WeaviateVectorStore weaviateVectorStore, MilvusVectorStore milvusVectorStore) {
this.weaviateVectorStore = weaviateVectorStore;
this.milvusVectorStore = milvusVectorStore;
}
public VectorStoreService getVectorStore(String kid){
String vectorModel = "weaviate";
if (StrUtil.isNotEmpty(kid)) {
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoMapper.selectVoById(Long.valueOf(kid));
if (knowledgeInfoVo != null && StrUtil.isNotEmpty(knowledgeInfoVo.getVector())) {
vectorModel = knowledgeInfoVo.getVector();
}
}
if ("weaviate".equals(vectorModel)){
return weaviateVectorStore;
}else if ("milvus".equals(vectorModel)){
return milvusVectorStore;
}
return null;
}
}

View File

@@ -0,0 +1,402 @@
package org.ruoyi.chat.service.knowledge.vectorstore;
import cn.hutool.core.lang.UUID;
import cn.hutool.json.JSONObject;
import com.google.gson.internal.LinkedTreeMap;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
import io.weaviate.client.v1.filters.Operator;
import io.weaviate.client.v1.filters.WhereFilter;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.graphql.query.argument.NearTextArgument;
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
import io.weaviate.client.v1.graphql.query.fields.Field;
import io.weaviate.client.v1.misc.model.Meta;
import io.weaviate.client.v1.misc.model.ReplicationConfig;
import io.weaviate.client.v1.misc.model.ShardingConfig;
import io.weaviate.client.v1.misc.model.VectorIndexConfig;
import io.weaviate.client.v1.schema.model.DataType;
import io.weaviate.client.v1.schema.model.Property;
import io.weaviate.client.v1.schema.model.Schema;
import io.weaviate.client.v1.schema.model.WeaviateClass;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.ruoyi.common.core.service.ConfigService;
import org.ruoyi.domain.vo.KnowledgeInfoVo;
import org.ruoyi.service.IKnowledgeInfoService;
import org.ruoyi.service.VectorStoreService;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
@Slf4j
public class WeaviateVectorStore implements VectorStoreService {
private volatile String protocol;
private volatile String host;
private volatile String className;
@Lazy
@Resource
private IKnowledgeInfoService knowledgeInfoService;
@Lazy
@Resource
private ConfigService configService;
@PostConstruct
public void loadConfig() {
this.protocol = configService.getConfigValue("weaviate", "protocol");
this.host = configService.getConfigValue("weaviate", "host");
this.className = configService.getConfigValue("weaviate", "classname");
}
public WeaviateClient getClient() {
Config config = new Config(protocol, host);
WeaviateClient client = new WeaviateClient(config);
return client;
}
public Result<Meta> getMeta() {
WeaviateClient client = getClient();
Result<Meta> meta = client.misc().metaGetter().run();
if (meta.getError() == null) {
System.out.printf("meta.hostname: %s\n", meta.getResult().getHostname());
System.out.printf("meta.version: %s\n", meta.getResult().getVersion());
System.out.printf("meta.modules: %s\n", meta.getResult().getModules());
} else {
System.out.printf("Error: %s\n", meta.getError().getMessages());
}
return meta;
}
public Result<Schema> getSchemas() {
WeaviateClient client = getClient();
Result<Schema> result = client.schema().getter().run();
if (result.hasErrors()) {
System.out.println(result.getError());
} else {
System.out.println(result.getResult());
}
return result;
}
public Result<Boolean> createSchema(String kid) {
WeaviateClient client = getClient();
VectorIndexConfig vectorIndexConfig = VectorIndexConfig.builder()
.distance("cosine")
.cleanupIntervalSeconds(300)
.efConstruction(128)
.maxConnections(64)
.vectorCacheMaxObjects(500000L)
.ef(-1)
.skip(false)
.dynamicEfFactor(8)
.dynamicEfMax(500)
.dynamicEfMin(100)
.flatSearchCutoff(40000)
.build();
ShardingConfig shardingConfig = ShardingConfig.builder()
.desiredCount(3)
.desiredVirtualCount(128)
.function("murmur3")
.key("_id")
.strategy("hash")
.virtualPerPhysical(128)
.build();
ReplicationConfig replicationConfig = ReplicationConfig.builder()
.factor(1)
.build();
JSONObject classModuleConfigValue = new JSONObject();
classModuleConfigValue.put("vectorizeClassName", false);
JSONObject classModuleConfig = new JSONObject();
classModuleConfig.put("text2vec-transformers", classModuleConfigValue);
JSONObject propertyModuleConfigValueSkipTrue = new JSONObject();
propertyModuleConfigValueSkipTrue.put("vectorizePropertyName", false);
propertyModuleConfigValueSkipTrue.put("skip", true);
JSONObject propertyModuleConfigSkipTrue = new JSONObject();
propertyModuleConfigSkipTrue.put("text2vec-transformers", propertyModuleConfigValueSkipTrue);
JSONObject propertyModuleConfigValueSkipFalse = new JSONObject();
propertyModuleConfigValueSkipFalse.put("vectorizePropertyName", false);
propertyModuleConfigValueSkipFalse.put("skip", false);
JSONObject propertyModuleConfigSkipFalse = new JSONObject();
propertyModuleConfigSkipFalse.put("text2vec-transformers", propertyModuleConfigValueSkipFalse);
WeaviateClass clazz = WeaviateClass.builder()
.className(className + kid)
.description("local knowledge")
.vectorIndexType("hnsw")
.vectorizer("text2vec-transformers")
.shardingConfig(shardingConfig)
.vectorIndexConfig(vectorIndexConfig)
.replicationConfig(replicationConfig)
.moduleConfig(classModuleConfig)
.properties(new ArrayList() {
{
add(Property.builder()
.dataType(new ArrayList() {
{
add(DataType.TEXT);
}
})
.name("content")
.description("The content of the local knowledge,for search")
.moduleConfig(propertyModuleConfigSkipFalse)
.build());
add(Property.builder()
.dataType(new ArrayList() {
{
add(DataType.TEXT);
}
})
.name("kid")
.description("The knowledge id of the local knowledge,for search")
.moduleConfig(propertyModuleConfigSkipTrue)
.build());
add(Property.builder()
.dataType(new ArrayList() {
{
add(DataType.TEXT);
}
})
.name("docId")
.description("The doc id of the local knowledge,for search")
.moduleConfig(propertyModuleConfigSkipTrue)
.build());
add(Property.builder()
.dataType(new ArrayList() {
{
add(DataType.TEXT);
}
})
.name("fid")
.description("The fragment id of the local knowledge,for search")
.moduleConfig(propertyModuleConfigSkipTrue)
.build());
add(Property.builder()
.dataType(new ArrayList() {
{
add(DataType.TEXT);
}
})
.name("uuid")
.description("The uuid id of the local knowledge fragment(same with id properties),for search")
.moduleConfig(propertyModuleConfigSkipTrue)
.build());
} })
.build();
Result<Boolean> result = client.schema().classCreator().withClass(clazz).run();
if (result.hasErrors()) {
System.out.println(result.getError());
}
System.out.println(result.getResult());
return result;
}
@Override
public void newSchema(String kid) {
createSchema(kid);
}
@Override
public void removeByKidAndFid(String kid, String fid) {
List<String> resultList = new ArrayList<>();
WeaviateClient client = getClient();
Field fieldId = Field.builder().name("uuid").build();
WhereFilter where = WhereFilter.builder()
.path(new String[]{"fid"})
.operator(Operator.Equal)
.valueString(fid)
.build();
Result<GraphQLResponse> result = client.graphQL().get()
.withClassName(className + kid)
.withFields(fieldId)
.withWhere(where)
.run();
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
ArrayList<LinkedTreeMap> m = l.get(className + kid);
for (LinkedTreeMap linkedTreeMap : m) {
String uuid = linkedTreeMap.get("uuid").toString();
resultList.add(uuid);
}
for (String uuid : resultList) {
Result<Boolean> deleteResult = client.data().deleter()
.withID(uuid)
.withClassName(className + kid)
.withConsistencyLevel(ConsistencyLevel.ALL) // default QUORUM
.run();
}
}
@Override
public void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList) {
WeaviateClient client = getClient();
for (int i = 0; i < Math.min(chunkList.size(), vectorList.size()); i++) {
List<Double> vector = vectorList.get(i);
Float[] vf = vector.stream().map(Double::floatValue).toArray(Float[]::new);
Map<String, Object> dataSchema = new HashMap<>();
dataSchema.put("content", chunkList.get(i));
dataSchema.put("kid", kid);
dataSchema.put("docId", docId);
dataSchema.put("fid", fidList.get(i));
String uuid = UUID.randomUUID().toString();
dataSchema.put("uuid", uuid);
Result<WeaviateObject> result = client.data().creator()
.withClassName(className + kid)
.withID(uuid)
.withVector(vf)
.withProperties(dataSchema)
.run();
}
}
@Override
public void removeByDocId(String kid, String docId) {
List<String> resultList = new ArrayList<>();
WeaviateClient client = getClient();
Field fieldId = Field.builder().name("uuid").build();
WhereFilter where = WhereFilter.builder()
.path(new String[]{"docId"})
.operator(Operator.Equal)
.valueString(docId)
.build();
Result<GraphQLResponse> result = client.graphQL().get()
.withClassName(className + kid)
.withFields(fieldId)
.withWhere(where)
.run();
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
ArrayList<LinkedTreeMap> m = l.get(className + kid);
for (LinkedTreeMap linkedTreeMap : m) {
String uuid = linkedTreeMap.get("uuid").toString();
resultList.add(uuid);
}
for (String uuid : resultList) {
Result<Boolean> deleteResult = client.data().deleter()
.withID(uuid)
.withClassName(className + kid)
.withConsistencyLevel(ConsistencyLevel.ALL) // default QUORUM
.run();
}
}
@Override
public void removeByKid(String kid) {
WeaviateClient client = getClient();
Result<Boolean> result = client.schema().classDeleter().withClassName(className + kid).run();
if (result.hasErrors()) {
System.out.println("删除schema失败" + result.getError());
} else {
System.out.println("删除schema成功" + result.getResult());
}
log.info("drop schema by kid, result = {}", result);
}
@Override
public List<String> nearest(List<Double> queryVector, String kid) {
if (StringUtils.isBlank(kid)) {
return new ArrayList<String>();
}
List<String> resultList = new ArrayList<>();
Float[] vf = new Float[queryVector.size()];
for (int j = 0; j < queryVector.size(); j++) {
Double value = queryVector.get(j);
vf[j] = value.floatValue();
}
WeaviateClient client = getClient();
Field contentField = Field.builder().name("content").build();
Field _additional = Field.builder()
.name("_additional")
.fields(new Field[]{
Field.builder().name("distance").build()
}).build();
NearVectorArgument nearVector = NearVectorArgument.builder()
.vector(vf)
.distance(1.6f) // certainty = 1f - distance /2f
.build();
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
Result<GraphQLResponse> result = client.graphQL().get()
.withClassName(className + kid)
.withFields(contentField, _additional)
.withNearVector(nearVector)
.withLimit(knowledgeInfoVo.getRetrieveLimit())
.run();
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
ArrayList<LinkedTreeMap> m = l.get(className + kid);
for (LinkedTreeMap linkedTreeMap : m) {
String content = linkedTreeMap.get("content").toString();
resultList.add(content);
}
return resultList;
}
@Override
public List<String> nearest(String query, String kid) {
if (StringUtils.isBlank(kid)) {
return new ArrayList<String>();
}
List<String> resultList = new ArrayList<>();
WeaviateClient client = getClient();
Field contentField = Field.builder().name("content").build();
Field _additional = Field.builder()
.name("_additional")
.fields(new Field[]{
Field.builder().name("distance").build()
}).build();
NearTextArgument nearText = client.graphQL().arguments().nearTextArgBuilder()
.concepts(new String[]{query})
.distance(1.6f) // certainty = 1f - distance /2f
.build();
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
Result<GraphQLResponse> result = client.graphQL().get()
.withClassName(className + kid)
.withFields(contentField, _additional)
.withNearText(nearText)
.withLimit(knowledgeInfoVo.getRetrieveLimit())
.run();
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
ArrayList<LinkedTreeMap> m = l.get(className + kid);
for (LinkedTreeMap linkedTreeMap : m) {
String content = linkedTreeMap.get("content").toString();
resultList.add(content);
}
return resultList;
}
public Result<Boolean> deleteSchema(String kid) {
WeaviateClient client = getClient();
Result<Boolean> result = client.schema().classDeleter().withClassName(className + kid).run();
if (result.hasErrors()) {
System.out.println(result.getError());
} else {
System.out.println(result.getResult());
}
return result;
}
}