mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-04 23:37:32 +00:00
feat: 调整知识库模块
This commit is contained in:
@@ -5,11 +5,12 @@ import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.validation.Valid;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.chat.service.chat.ISseService;
|
||||
import org.ruoyi.common.chat.domain.request.ChatRequest;
|
||||
import org.ruoyi.common.chat.domain.request.Dall3Request;
|
||||
|
||||
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
|
||||
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
|
||||
import org.ruoyi.common.chat.entity.images.Item;
|
||||
|
||||
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
|
||||
import org.ruoyi.common.core.domain.R;
|
||||
import org.ruoyi.common.core.domain.model.LoginUser;
|
||||
@@ -17,8 +18,10 @@ import org.ruoyi.common.core.exception.base.BaseException;
|
||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
||||
import org.ruoyi.common.satoken.utils.LoginHelper;
|
||||
import org.ruoyi.system.domain.request.translation.TranslationRequest;
|
||||
import org.ruoyi.system.service.ISseService;
|
||||
import org.ruoyi.domain.bo.ChatMessageBo;
|
||||
|
||||
import org.ruoyi.domain.vo.ChatMessageVo;
|
||||
import org.ruoyi.service.IChatMessageService;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Controller;
|
||||
@@ -26,7 +29,6 @@ import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 描述:聊天管理
|
||||
@@ -56,7 +58,6 @@ public class ChatController {
|
||||
return sseService.sseChat(chatRequest,request);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 上传文件
|
||||
*/
|
||||
@@ -90,22 +91,6 @@ public class ChatController {
|
||||
return sseService.textToSpeed(textToSpeech);
|
||||
}
|
||||
|
||||
/**
|
||||
* 文本翻译
|
||||
*
|
||||
* @param
|
||||
*/
|
||||
@PostMapping("/translation")
|
||||
@ResponseBody
|
||||
public String translation(@RequestBody TranslationRequest translationRequest) {
|
||||
return sseService.translation(translationRequest);
|
||||
}
|
||||
|
||||
@PostMapping("/dall3")
|
||||
@ResponseBody
|
||||
public R<List<Item>> dall3(@RequestBody @Valid Dall3Request request) {
|
||||
return R.ok(sseService.dall3(request));
|
||||
}
|
||||
|
||||
/**
|
||||
* 聊天记录
|
||||
|
||||
@@ -7,7 +7,7 @@ import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Request;
|
||||
import org.apache.commons.lang3.math.NumberUtils;
|
||||
import org.ruoyi.chat.dto.*;
|
||||
import org.ruoyi.chat.domain.dto.*;
|
||||
import org.ruoyi.chat.enums.ActionType;
|
||||
import org.ruoyi.chat.util.MjOkHttpUtil;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
|
||||
@@ -7,7 +7,7 @@ import io.swagger.annotations.ApiParam;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Request;
|
||||
import org.ruoyi.chat.dto.TaskConditionDTO;
|
||||
import org.ruoyi.chat.domain.dto.TaskConditionDTO;
|
||||
import org.ruoyi.chat.util.MjOkHttpUtil;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.chat.dto;
|
||||
package org.ruoyi.chat.domain.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Getter;
|
||||
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.chat.dto;
|
||||
package org.ruoyi.chat.domain.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.chat.dto;
|
||||
package org.ruoyi.chat.domain.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.chat.dto;
|
||||
package org.ruoyi.chat.domain.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.chat.dto;
|
||||
package org.ruoyi.chat.domain.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.chat.dto;
|
||||
package org.ruoyi.chat.domain.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.chat.dto;
|
||||
package org.ruoyi.chat.domain.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.chat.dto;
|
||||
package org.ruoyi.chat.domain.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.chat.dto;
|
||||
package org.ruoyi.chat.domain.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
@@ -1,4 +1,4 @@
|
||||
package org.ruoyi.chat.dto;
|
||||
package org.ruoyi.chat.domain.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import lombok.Data;
|
||||
@@ -0,0 +1,107 @@
|
||||
package org.ruoyi.chat.listener;
|
||||
|
||||
|
||||
import cn.hutool.core.collection.CollectionUtil;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Response;
|
||||
import okhttp3.ResponseBody;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
||||
import org.ruoyi.common.core.utils.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* 描述:OpenAIEventSourceListener
|
||||
*
|
||||
* @author https:www.unfbx.com
|
||||
* @date 2023-02-22
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
@Component
|
||||
public class SSEEventSourceListener extends EventSourceListener {
|
||||
|
||||
@Autowired(required = false)
|
||||
public SSEEventSourceListener(ResponseBodyEmitter emitter) {
|
||||
this.emitter = emitter;
|
||||
}
|
||||
|
||||
private ResponseBodyEmitter emitter;
|
||||
|
||||
private StringBuilder stringBuffer;
|
||||
|
||||
private String modelName;
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void onOpen(EventSource eventSource, Response response) {
|
||||
log.info("OpenAI建立sse连接...");
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public void onEvent(@NotNull EventSource eventSource, String id, String type, String data) {
|
||||
try {
|
||||
if ("[DONE]".equals(data)) {
|
||||
//成功响应
|
||||
emitter.complete();
|
||||
|
||||
// 扣除费用 (消耗字符 模型名称)
|
||||
return;
|
||||
}
|
||||
// 解析返回内容
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class);
|
||||
if(completionResponse == null || CollectionUtil.isEmpty(completionResponse.getChoices())){
|
||||
return;
|
||||
}
|
||||
Object content = completionResponse.getChoices().get(0).getDelta().getContent();
|
||||
if(content == null){
|
||||
content = completionResponse.getChoices().get(0).getDelta().getReasoningContent();
|
||||
if(content == null) return;
|
||||
}
|
||||
if(StringUtils.isEmpty(modelName)){
|
||||
modelName = completionResponse.getModel();
|
||||
}
|
||||
stringBuffer.append(content);
|
||||
emitter.send(data);
|
||||
} catch (Exception e) {
|
||||
log.error("sse信息推送失败{}内容:{}",e.getMessage(),data);
|
||||
eventSource.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClosed(EventSource eventSource) {
|
||||
log.info("OpenAI关闭sse连接...");
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public void onFailure(EventSource eventSource, Throwable t, Response response) {
|
||||
if (Objects.isNull(response)) {
|
||||
return;
|
||||
}
|
||||
ResponseBody body = response.body();
|
||||
if (Objects.nonNull(body)) {
|
||||
log.error("OpenAI sse连接异常data:{},异常:{}", body.string(), t);
|
||||
} else {
|
||||
log.error("OpenAI sse连接异常data:{},异常:{}", response, t);
|
||||
}
|
||||
eventSource.cancel();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package org.ruoyi.chat.service.chat;
|
||||
|
||||
import org.ruoyi.domain.bo.ChatMessageBo;
|
||||
|
||||
/**
|
||||
* 计费管理Service接口
|
||||
*
|
||||
* @author ageerle
|
||||
* @date 2025-04-08
|
||||
*/
|
||||
public interface IChatCostService {
|
||||
|
||||
/**
|
||||
* 根据消耗的tokens扣除余额
|
||||
*
|
||||
* @param chatMessageBo
|
||||
* @return 结果
|
||||
*/
|
||||
|
||||
void deductToken(ChatMessageBo chatMessageBo);
|
||||
|
||||
/**
|
||||
* 扣除用户的余额
|
||||
*
|
||||
*/
|
||||
void deductUserBalance(Long userId, Double numberCost);
|
||||
|
||||
|
||||
/**
|
||||
* 扣除任务费用并且保存记录
|
||||
*
|
||||
* @param type 任务类型
|
||||
* @param prompt 任务描述
|
||||
* @param cost 扣除费用
|
||||
*/
|
||||
void taskDeduct(String type,String prompt, double cost);
|
||||
|
||||
|
||||
/**
|
||||
* 判断用户是否付费
|
||||
*/
|
||||
void checkUserGrade();
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package org.ruoyi.chat.service.chat;
|
||||
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import org.ruoyi.common.chat.domain.request.ChatRequest;
|
||||
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
|
||||
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
|
||||
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
|
||||
/**
|
||||
* 用户聊天管理Service接口
|
||||
*
|
||||
* @author ageerle
|
||||
* @date 2025-04-08
|
||||
*/
|
||||
public interface ISseService {
|
||||
|
||||
/**
|
||||
* 客户端发送消息到服务端
|
||||
* @param chatRequest 请求对象
|
||||
*/
|
||||
SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request);
|
||||
|
||||
/**
|
||||
* 语音转文字
|
||||
* @param file 语音文件
|
||||
*/
|
||||
WhisperResponse speechToTextTranscriptionsV2(MultipartFile file);
|
||||
|
||||
/**
|
||||
* 文字转语音
|
||||
*
|
||||
* @param textToSpeech 文本信息
|
||||
* @return 流式语音
|
||||
*/
|
||||
ResponseEntity<Resource> textToSpeed(TextToSpeech textToSpeech);
|
||||
|
||||
|
||||
/**
|
||||
* 上传文件到api服务器
|
||||
*
|
||||
* @param file 文件信息
|
||||
* @return 返回文件信息
|
||||
*/
|
||||
UploadFileResponse upload(MultipartFile file);
|
||||
|
||||
|
||||
/**
|
||||
* 使用ollama调用本地模型
|
||||
* @param chatRequest 对话信息
|
||||
* @return 流式输出返回内容
|
||||
*/
|
||||
SseEmitter ollamaChat(ChatRequest chatRequest);
|
||||
|
||||
/**
|
||||
* 企业应用回复
|
||||
* @param prompt 提示词
|
||||
* @return 回复内容
|
||||
*/
|
||||
String wxCpChat(String prompt);
|
||||
}
|
||||
@@ -0,0 +1,292 @@
|
||||
package org.ruoyi.chat.service.chat.impl;
|
||||
|
||||
import cn.dev33.satoken.stp.StpUtil;
|
||||
import cn.hutool.extra.servlet.ServletUtil;
|
||||
import com.google.protobuf.ServiceException;
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessage;
|
||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
||||
import io.github.ollama4j.models.chat.OllamaChatRequestModel;
|
||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import org.ruoyi.chat.listener.SSEEventSourceListener;
|
||||
|
||||
import org.ruoyi.chat.service.chat.ISseService;
|
||||
import org.ruoyi.common.chat.config.ChatConfig;
|
||||
import org.ruoyi.common.chat.domain.request.ChatRequest;
|
||||
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
||||
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
|
||||
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
|
||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||
import org.ruoyi.common.core.utils.file.FileUtils;
|
||||
import org.ruoyi.common.core.utils.file.MimeTypeUtils;
|
||||
|
||||
import org.ruoyi.common.redis.utils.RedisUtils;
|
||||
import org.ruoyi.domain.vo.ChatModelVo;
|
||||
import org.ruoyi.service.IChatModelService;
|
||||
import org.springframework.core.io.InputStreamResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class SseServiceImpl implements ISseService {
|
||||
|
||||
private OpenAiStreamClient openAiStreamClient;
|
||||
|
||||
private final ChatConfig chatConfig;
|
||||
|
||||
private final IChatModelService chatModelService;
|
||||
|
||||
@Override
|
||||
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
||||
SseEmitter sseEmitter = new SseEmitter(0L);
|
||||
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
|
||||
// 获取对话消息列表
|
||||
List<Message> messages = chatRequest.getMessages();
|
||||
try {
|
||||
if (StpUtil.isLogin()) {
|
||||
// 通过模型名称查询模型信息
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||
// 构建api请求客户端
|
||||
openAiStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
|
||||
|
||||
// 模型设置默认提示词
|
||||
|
||||
// 是否开启联网查询
|
||||
}else {
|
||||
// 未登录用户限制对话次数,默认5次
|
||||
String clientIp = ServletUtil.getClientIP((javax.servlet.http.HttpServletRequest) request,"X-Forwarded-For");
|
||||
|
||||
int timeWindowInSeconds = 5;
|
||||
|
||||
String redisKey = "visitor:" + clientIp;
|
||||
int count = 0;
|
||||
|
||||
if (RedisUtils.getCacheObject(redisKey) == null) {
|
||||
// 当前访问次数
|
||||
RedisUtils.setCacheObject(redisKey, count, Duration.ofSeconds(86400));
|
||||
}else {
|
||||
count = RedisUtils.getCacheObject(redisKey);
|
||||
if (count >= timeWindowInSeconds) {
|
||||
throw new ServiceException("当日免费次数已用完");
|
||||
}
|
||||
count++;
|
||||
RedisUtils.setCacheObject(redisKey, count);
|
||||
}
|
||||
}
|
||||
|
||||
ChatCompletion completion = ChatCompletion
|
||||
.builder()
|
||||
.messages(messages)
|
||||
.model(chatRequest.getModel())
|
||||
.temperature(chatRequest.getTemperature())
|
||||
.topP(chatRequest.getTop_p())
|
||||
.stream(true)
|
||||
.build();
|
||||
openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
|
||||
// 保存消息记录 并扣除费用
|
||||
|
||||
} catch (Exception e) {
|
||||
String message = e.getMessage();
|
||||
sendErrorEvent(sseEmitter, message);
|
||||
return sseEmitter;
|
||||
}
|
||||
return sseEmitter;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 发送SSE错误事件的封装方法
|
||||
*
|
||||
* @param sseEmitter
|
||||
* @param errorMessage
|
||||
*/
|
||||
private void sendErrorEvent(SseEmitter sseEmitter, String errorMessage) {
|
||||
SseEmitter.SseEventBuilder event = SseEmitter.event()
|
||||
.name("error")
|
||||
.data(errorMessage);
|
||||
try {
|
||||
sseEmitter.send(event);
|
||||
} catch (IOException e) {
|
||||
log.error("发送事件失败: {}", e.getMessage());
|
||||
}
|
||||
sseEmitter.complete();
|
||||
}
|
||||
|
||||
/**
|
||||
* 文字转语音
|
||||
*/
|
||||
@Override
|
||||
public ResponseEntity<Resource> textToSpeed(TextToSpeech textToSpeech) {
|
||||
ResponseBody body = openAiStreamClient.textToSpeech(textToSpeech);
|
||||
if (body != null) {
|
||||
// 将ResponseBody转换为InputStreamResource
|
||||
InputStreamResource resource = new InputStreamResource(body.byteStream());
|
||||
|
||||
// 创建并返回ResponseEntity
|
||||
return ResponseEntity.ok()
|
||||
.contentType(MediaType.parseMediaType("audio/mpeg"))
|
||||
.body(resource);
|
||||
} else {
|
||||
// 如果ResponseBody为空,返回404状态码
|
||||
return ResponseEntity.notFound().build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 语音转文字
|
||||
*/
|
||||
@Override
|
||||
public WhisperResponse speechToTextTranscriptionsV2(MultipartFile file) {
|
||||
// 确保文件不为空
|
||||
if (file.isEmpty()) {
|
||||
throw new IllegalStateException("Cannot convert an empty MultipartFile");
|
||||
}
|
||||
if (!FileUtils.isValidFileExtention(file, MimeTypeUtils.AUDIO__EXTENSION)) {
|
||||
throw new IllegalStateException("File Extention not supported");
|
||||
}
|
||||
// 创建一个文件对象
|
||||
File fileA = new File(System.getProperty("java.io.tmpdir") + File.separator + file.getOriginalFilename());
|
||||
try {
|
||||
// 将 MultipartFile 的内容写入文件
|
||||
file.transferTo(fileA);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to convert MultipartFile to File", e);
|
||||
}
|
||||
return openAiStreamClient.speechToTextTranscriptions(fileA);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public UploadFileResponse upload(MultipartFile file) {
|
||||
if (file.isEmpty()) {
|
||||
throw new IllegalStateException("Cannot upload an empty MultipartFile");
|
||||
}
|
||||
if (!FileUtils.isValidFileExtention(file, MimeTypeUtils.DEFAULT_ALLOWED_EXTENSION)) {
|
||||
throw new IllegalStateException("File Extention not supported");
|
||||
}
|
||||
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
||||
return openAiStreamClient.uploadFile("fine-tune", convertMultiPartToFile(file));
|
||||
}
|
||||
|
||||
private File convertMultiPartToFile(MultipartFile multipartFile) {
|
||||
File file = null;
|
||||
try {
|
||||
// 获取原始文件名
|
||||
String originalFileName = multipartFile.getOriginalFilename();
|
||||
// 默认扩展名
|
||||
String extension = ".tmp";
|
||||
// 尝试从原始文件名中获取扩展名
|
||||
if (originalFileName != null && originalFileName.contains(".")) {
|
||||
extension = originalFileName.substring(originalFileName.lastIndexOf("."));
|
||||
}
|
||||
|
||||
// 使用原始文件的扩展名创建临时文件
|
||||
Path tempFile = Files.createTempFile(null, extension);
|
||||
file = tempFile.toFile();
|
||||
|
||||
// 将MultipartFile的内容写入文件
|
||||
try (InputStream inputStream = multipartFile.getInputStream();
|
||||
FileOutputStream outputStream = new FileOutputStream(file)) {
|
||||
int read;
|
||||
byte[] bytes = new byte[1024];
|
||||
while ((read = inputStream.read(bytes)) != -1) {
|
||||
outputStream.write(bytes, 0, read);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
// 处理文件写入异常
|
||||
e.printStackTrace();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
// 处理临时文件创建异常
|
||||
e.printStackTrace();
|
||||
}
|
||||
return file;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SseEmitter ollamaChat(ChatRequest chatRequest) {
|
||||
String[] parts = chatRequest.getModel().split("ollama-");
|
||||
ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
||||
final SseEmitter emitter = new SseEmitter();
|
||||
String host = chatModelVo.getApiHost();
|
||||
List<Message> msgList = chatRequest.getMessages();
|
||||
List<OllamaChatMessage> messages = new ArrayList<>();
|
||||
|
||||
for (Message message : msgList) {
|
||||
OllamaChatMessage ollamaChatMessage = new OllamaChatMessage();
|
||||
ollamaChatMessage.setRole(OllamaChatMessageRole.USER);
|
||||
ollamaChatMessage.setContent(message.getContent().toString());
|
||||
messages.add(ollamaChatMessage);
|
||||
}
|
||||
OllamaAPI api = new OllamaAPI(host);
|
||||
api.setRequestTimeoutSeconds(100);
|
||||
OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(parts[1]);
|
||||
|
||||
OllamaChatRequestModel requestModel = builder
|
||||
.withMessages(messages)
|
||||
.build();
|
||||
|
||||
// 异步执行 OllAma API 调用
|
||||
CompletableFuture.runAsync(() -> {
|
||||
try {
|
||||
StringBuilder response = new StringBuilder();
|
||||
OllamaStreamHandler streamHandler = (s) -> {
|
||||
String substr = s.substring(response.length());
|
||||
response.append(substr);
|
||||
System.out.println(substr);
|
||||
try {
|
||||
emitter.send(substr);
|
||||
} catch (IOException e) {
|
||||
sendErrorEvent(emitter, e.getMessage());
|
||||
}
|
||||
};
|
||||
api.chat(requestModel, streamHandler);
|
||||
emitter.complete();
|
||||
} catch (Exception e) {
|
||||
sendErrorEvent(emitter, e.getMessage());
|
||||
}
|
||||
});
|
||||
return emitter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String wxCpChat(String prompt) {
|
||||
List<Message> messageList = new ArrayList<>();
|
||||
Message message = Message.builder().role(Message.Role.USER).content(prompt).build();
|
||||
messageList.add(message);
|
||||
ChatCompletion chatCompletion = ChatCompletion
|
||||
.builder()
|
||||
.messages(messageList)
|
||||
.model("gpt-4o-mini")
|
||||
.stream(false)
|
||||
.build();
|
||||
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
|
||||
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package org.ruoyi.chat.service.knowledge.vectorizer;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import io.github.ollama4j.OllamaAPI;
|
||||
import io.github.ollama4j.models.embeddings.OllamaEmbeddingsRequestModel;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.VectorizationService;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class BgeLargeVectorization implements VectorizationService {
|
||||
|
||||
String host = "http://localhost:11434/";
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
@Override
|
||||
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
|
||||
OllamaAPI ollamaAPI = new OllamaAPI(host);
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
List<Double> doubleVector;
|
||||
try {
|
||||
doubleVector = ollamaAPI.generateEmbeddings(new OllamaEmbeddingsRequestModel(knowledgeInfoVo.getVectorModel(), new Gson().toJson(chunkList)));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
List<List<Double>> vectorList = new ArrayList<>();
|
||||
vectorList.add(doubleVector);
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> singleVectorization(String chunk, String kid) {
|
||||
List<String> chunkList = new ArrayList<>();
|
||||
chunkList.add(chunk);
|
||||
List<List<Double>> vectorList = batchVectorization(chunkList, kid);
|
||||
return vectorList.get(0);
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
OllamaAPI ollamaAPI = new OllamaAPI("http://localhost:11434/");
|
||||
List<String> chunkList = Arrays.asList("天很蓝", "海很深");
|
||||
List<Double> doubleVector;
|
||||
try {
|
||||
doubleVector = ollamaAPI.generateEmbeddings(new OllamaEmbeddingsRequestModel("quentinz/bge-large-zh-v1.5", new Gson().toJson(chunkList)));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
System.out.println("=== " + doubleVector + " 1===");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package org.ruoyi.chat.service.knowledge.vectorizer;
|
||||
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.chat.config.ChatConfig;
|
||||
import org.ruoyi.common.chat.entity.embeddings.Embedding;
|
||||
import org.ruoyi.common.chat.entity.embeddings.EmbeddingResponse;
|
||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.VectorizationService;
|
||||
import org.ruoyi.system.domain.SysModel;
|
||||
import org.ruoyi.system.service.ISysModelService;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class OpenAiVectorization implements VectorizationService {
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private ISysModelService sysModelService;
|
||||
|
||||
@Getter
|
||||
private OpenAiStreamClient openAiStreamClient;
|
||||
|
||||
private final ChatConfig chatConfig;
|
||||
|
||||
@Override
|
||||
public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
|
||||
List<List<Double>> vectorList;
|
||||
// 获取知识库信息
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
if(knowledgeInfoVo == null){
|
||||
log.warn("知识库不存在:请查检ID {}",kid);
|
||||
vectorList=new ArrayList<>();
|
||||
vectorList.add(new ArrayList<>());
|
||||
return vectorList;
|
||||
}
|
||||
SysModel sysModel = sysModelService.selectModelByName(knowledgeInfoVo.getVectorModel());
|
||||
String apiHost= sysModel.getApiHost();
|
||||
String apiKey= sysModel.getApiKey();
|
||||
openAiStreamClient = chatConfig.createOpenAiStreamClient(apiHost,apiKey);
|
||||
|
||||
Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo);
|
||||
EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
|
||||
|
||||
// 处理 OpenAI 返回的嵌入数据
|
||||
vectorList = processOpenAiEmbeddings(embeddings);
|
||||
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Embedding 对象
|
||||
*/
|
||||
private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) {
|
||||
return Embedding.builder()
|
||||
.input(chunkList)
|
||||
.model(knowledgeInfoVo.getVectorModel())
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 OpenAI 返回的嵌入数据
|
||||
*/
|
||||
private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) {
|
||||
List<List<Double>> vectorList = new ArrayList<>();
|
||||
|
||||
embeddings.getData().forEach(data -> {
|
||||
List<BigDecimal> vector = data.getEmbedding();
|
||||
List<Double> doubleVector = convertToDoubleList(vector);
|
||||
vectorList.add(doubleVector);
|
||||
});
|
||||
|
||||
return vectorList;
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 BigDecimal 转换为 Double 列表
|
||||
*/
|
||||
private List<Double> convertToDoubleList(List<BigDecimal> vector) {
|
||||
return vector.stream()
|
||||
.map(BigDecimal::doubleValue)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<Double> singleVectorization(String chunk, String kid) {
|
||||
List<String> chunkList = new ArrayList<>();
|
||||
chunkList.add(chunk);
|
||||
List<List<Double>> vectorList = batchVectorization(chunkList, kid);
|
||||
return vectorList.get(0);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package org.ruoyi.chat.service.knowledge.vectorizer;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.VectorizationService;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* 文本向量化
|
||||
* @author huangkh
|
||||
*/
|
||||
@Component
|
||||
@Slf4j
|
||||
public class VectorizationFactory {
|
||||
|
||||
private final OpenAiVectorization openAiVectorization;
|
||||
|
||||
private final BgeLargeVectorization bgeLargeVectorization;
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
public VectorizationFactory(OpenAiVectorization openAiVectorization, BgeLargeVectorization bgeLargeVectorization) {
|
||||
this.openAiVectorization = openAiVectorization;
|
||||
this.bgeLargeVectorization = bgeLargeVectorization;
|
||||
}
|
||||
|
||||
public VectorizationService getEmbedding(String kid){
|
||||
String vectorModel = "text-embedding-3-small";
|
||||
if (StrUtil.isNotEmpty(kid)) {
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
if (knowledgeInfoVo != null && StrUtil.isNotEmpty(knowledgeInfoVo.getVectorModel())) {
|
||||
vectorModel = knowledgeInfoVo.getVectorModel();
|
||||
}
|
||||
}
|
||||
return switch (vectorModel) {
|
||||
case "quentinz/bge-large-zh-v1.5" -> bgeLargeVectorization;
|
||||
default -> openAiVectorization;
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,397 @@
|
||||
package org.ruoyi.chat.service.knowledge.vectorstore;
|
||||
|
||||
import io.milvus.client.MilvusServiceClient;
|
||||
import io.milvus.grpc.DataType;
|
||||
import io.milvus.grpc.DescribeIndexResponse;
|
||||
import io.milvus.grpc.MutationResult;
|
||||
import io.milvus.grpc.SearchResults;
|
||||
import io.milvus.param.*;
|
||||
import io.milvus.param.collection.*;
|
||||
import io.milvus.param.dml.DeleteParam;
|
||||
import io.milvus.param.dml.InsertParam;
|
||||
import io.milvus.param.dml.SearchParam;
|
||||
import io.milvus.param.index.CreateIndexParam;
|
||||
import io.milvus.param.index.DescribeIndexParam;
|
||||
import io.milvus.param.partition.CreatePartitionParam;
|
||||
import io.milvus.response.QueryResultsWrapper;
|
||||
import io.milvus.response.SearchResultsWrapper;
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.ruoyi.common.core.service.ConfigService;
|
||||
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class MilvusVectorStore implements VectorStoreService {
|
||||
|
||||
private volatile Integer dimension;
|
||||
private volatile String collectionName;
|
||||
private MilvusServiceClient milvusServiceClient;
|
||||
|
||||
@Resource
|
||||
private ConfigService configService;
|
||||
|
||||
@PostConstruct
|
||||
public void loadConfig() {
|
||||
this.dimension = Integer.parseInt(configService.getConfigValue("milvus", "dimension"));
|
||||
this.collectionName = configService.getConfigValue("milvus", "collection");
|
||||
}
|
||||
|
||||
@PostConstruct
|
||||
public void init() {
|
||||
String milvusHost = configService.getConfigValue("milvus", "host");
|
||||
String milvausPort = configService.getConfigValue("milvus", "port");
|
||||
milvusServiceClient = new MilvusServiceClient(
|
||||
ConnectParam.newBuilder()
|
||||
.withHost(milvusHost)
|
||||
.withPort(Integer.parseInt(milvausPort))
|
||||
.withDatabaseName("default")
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
private void createSchema(String kid) {
|
||||
FieldType primaryField = FieldType.newBuilder()
|
||||
.withName("row_id")
|
||||
.withDataType(DataType.Int64)
|
||||
.withPrimaryKey(true)
|
||||
.withAutoID(true)
|
||||
.build();
|
||||
FieldType contentField = FieldType.newBuilder()
|
||||
.withName("content")
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(1000)
|
||||
.build();
|
||||
FieldType kidField = FieldType.newBuilder()
|
||||
.withName("kid")
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(20)
|
||||
.build();
|
||||
FieldType docIdField = FieldType.newBuilder()
|
||||
.withName("docId")
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(20)
|
||||
.build();
|
||||
FieldType fidField = FieldType.newBuilder()
|
||||
.withName("fid")
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(20)
|
||||
.build();
|
||||
FieldType vectorField = FieldType.newBuilder()
|
||||
.withName("fv")
|
||||
.withDataType(DataType.FloatVector)
|
||||
.withDimension(dimension)
|
||||
.build();
|
||||
CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
|
||||
.withCollectionName(collectionName + kid)
|
||||
.withDescription("local knowledge")
|
||||
.addFieldType(primaryField)
|
||||
.addFieldType(contentField)
|
||||
.addFieldType(kidField)
|
||||
.addFieldType(docIdField)
|
||||
.addFieldType(fidField)
|
||||
.addFieldType(vectorField)
|
||||
.build();
|
||||
milvusServiceClient.createCollection(createCollectionReq);
|
||||
|
||||
// 创建向量的索引
|
||||
IndexType INDEX_TYPE = IndexType.IVF_FLAT;
|
||||
String INDEX_PARAM = "{\"nlist\":1024}";
|
||||
milvusServiceClient.createIndex(
|
||||
CreateIndexParam.newBuilder()
|
||||
.withCollectionName(collectionName + kid)
|
||||
.withFieldName("fv")
|
||||
.withIndexType(INDEX_TYPE)
|
||||
.withMetricType(MetricType.IP)
|
||||
.withExtraParam(INDEX_PARAM)
|
||||
.withSyncMode(Boolean.FALSE)
|
||||
.build()
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void newSchema(String kid) {
|
||||
createSchema(kid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByKidAndFid(String kid, String fid) {
|
||||
milvusServiceClient.delete(
|
||||
DeleteParam.newBuilder()
|
||||
.withCollectionName(collectionName + kid)
|
||||
.withExpr("fid == " + fid)
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList) {
|
||||
String fullCollectionName = collectionName + kid;
|
||||
|
||||
// 检查集合是否存在
|
||||
HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder()
|
||||
.withCollectionName(fullCollectionName)
|
||||
.build();
|
||||
R<Boolean> booleanR = milvusServiceClient.hasCollection(hasCollectionParam);
|
||||
|
||||
if (booleanR.getStatus() == R.Status.Success.getCode()) {
|
||||
boolean collectionExists = booleanR.getData().booleanValue();
|
||||
if (!collectionExists) {
|
||||
// 集合不存在,创建集合
|
||||
List<FieldType> fieldTypes = new ArrayList<>();
|
||||
// 假设这里定义 id 字段,根据实际情况修改
|
||||
FieldType idField = FieldType.newBuilder()
|
||||
.withName("id")
|
||||
.withDataType(DataType.Int64)
|
||||
.withPrimaryKey(true)
|
||||
.withAutoID(true)
|
||||
.build();
|
||||
fieldTypes.add(idField);
|
||||
|
||||
// 定义向量字段
|
||||
FieldType vectorField = FieldType.newBuilder()
|
||||
.withName("fv")
|
||||
.withDataType(DataType.FloatVector)
|
||||
.withDimension(vectorList.get(0).size())
|
||||
.build();
|
||||
fieldTypes.add(vectorField);
|
||||
|
||||
// 定义其他字段
|
||||
FieldType contentField = FieldType.newBuilder()
|
||||
.withName("content")
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(chunkList.size() * 1024) // 根据实际情况修改
|
||||
.build();
|
||||
fieldTypes.add(contentField);
|
||||
|
||||
FieldType kidField = FieldType.newBuilder()
|
||||
.withName("kid")
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(256) // 根据实际情况修改
|
||||
.build();
|
||||
fieldTypes.add(kidField);
|
||||
|
||||
FieldType docIdField = FieldType.newBuilder()
|
||||
.withName("docId")
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(256) // 根据实际情况修改
|
||||
.build();
|
||||
fieldTypes.add(docIdField);
|
||||
|
||||
FieldType fidField = FieldType.newBuilder()
|
||||
.withName("fid")
|
||||
.withDataType(DataType.VarChar)
|
||||
.withMaxLength(256) // 根据实际情况修改
|
||||
.build();
|
||||
fieldTypes.add(fidField);
|
||||
|
||||
CreateCollectionParam createCollectionParam = CreateCollectionParam.newBuilder()
|
||||
.withCollectionName(fullCollectionName)
|
||||
.withFieldTypes(fieldTypes)
|
||||
.build();
|
||||
|
||||
R<RpcStatus> collection = milvusServiceClient.createCollection(createCollectionParam);
|
||||
if (collection.getStatus() == R.Status.Success.getCode()) {
|
||||
System.out.println("集合 " + fullCollectionName + " 创建成功");
|
||||
|
||||
// 创建索引
|
||||
CreateIndexParam createIndexParam = CreateIndexParam.newBuilder()
|
||||
.withCollectionName(fullCollectionName)
|
||||
.withFieldName("fv") // 向量字段名
|
||||
.withIndexType(IndexType.IVF_FLAT) // 索引类型
|
||||
.withMetricType(MetricType.IP)
|
||||
.withExtraParam("{\"nlist\":1024}") // 索引参数
|
||||
.build();
|
||||
R<RpcStatus> indexResponse = milvusServiceClient.createIndex(createIndexParam);
|
||||
if (indexResponse.getStatus() == R.Status.Success.getCode()) {
|
||||
System.out.println("索引创建成功");
|
||||
} else {
|
||||
System.err.println("索引创建失败: " + indexResponse.getMessage());
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
System.err.println("集合创建失败: " + collection.getMessage());
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
System.err.println("检查集合是否存在时出错: " + booleanR.getMessage());
|
||||
return;
|
||||
}
|
||||
|
||||
if (StringUtils.isNotBlank(docId)) {
|
||||
milvusServiceClient.createPartition(
|
||||
CreatePartitionParam.newBuilder()
|
||||
.withCollectionName(fullCollectionName)
|
||||
.withPartitionName(docId)
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
List<List<Float>> vectorFloatList = new ArrayList<>();
|
||||
List<String> kidList = new ArrayList<>();
|
||||
List<String> docIdList = new ArrayList<>();
|
||||
for (int i = 0; i < Math.min(chunkList.size(), vectorList.size()); i++) {
|
||||
List<Double> vector = vectorList.get(i);
|
||||
List<Float> vfList = new ArrayList<>();
|
||||
for (int j = 0; j < vector.size(); j++) {
|
||||
Double value = vector.get(j);
|
||||
vfList.add(value.floatValue());
|
||||
}
|
||||
vectorFloatList.add(vfList);
|
||||
kidList.add(kid);
|
||||
docIdList.add(docId);
|
||||
}
|
||||
List<InsertParam.Field> fields = new ArrayList<>();
|
||||
fields.add(new InsertParam.Field("content", chunkList));
|
||||
fields.add(new InsertParam.Field("kid", kidList));
|
||||
fields.add(new InsertParam.Field("docId", docIdList));
|
||||
fields.add(new InsertParam.Field("fid", fidList));
|
||||
fields.add(new InsertParam.Field("fv", vectorFloatList));
|
||||
|
||||
InsertParam insertParam = InsertParam.newBuilder()
|
||||
.withCollectionName(fullCollectionName)
|
||||
.withPartitionName(docId)
|
||||
.withFields(fields)
|
||||
.build();
|
||||
System.out.println("=========================");
|
||||
|
||||
R<MutationResult> insert = milvusServiceClient.insert(insertParam);
|
||||
if (insert.getStatus() == R.Status.Success.getCode()) {
|
||||
System.out.println("插入成功,插入的行数: " + insert.getData().getInsertCnt());
|
||||
} else {
|
||||
System.err.println("插入失败: " + insert.getMessage());
|
||||
}
|
||||
System.out.println("=========================");
|
||||
// milvus在将数据装载到内存后才能进行向量计算.
|
||||
LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
|
||||
.withCollectionName(fullCollectionName)
|
||||
.build();
|
||||
R<RpcStatus> loadResponse = milvusServiceClient.loadCollection(loadCollectionParam);
|
||||
if (loadResponse.getStatus() != R.Status.Success.getCode()) {
|
||||
System.err.println("加载集合 " + fullCollectionName + " 到内存时出错:" + loadResponse.getMessage());
|
||||
}
|
||||
// milvusServiceClient.loadCollection(LoadCollectionParam.newBuilder().withCollectionName(fullCollectionName).build());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void removeByDocId(String kid, String docId) {
|
||||
milvusServiceClient.delete(
|
||||
DeleteParam.newBuilder()
|
||||
.withCollectionName(collectionName + kid)
|
||||
.withExpr("1 == 1")
|
||||
.withPartitionName(docId)
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByKid(String kid) {
|
||||
milvusServiceClient.dropCollection(
|
||||
DropCollectionParam.newBuilder()
|
||||
.withCollectionName(collectionName + kid)
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> nearest(List<Double> queryVector, String kid) {
|
||||
String fullCollectionName = collectionName + kid;
|
||||
|
||||
HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder()
|
||||
.withCollectionName(fullCollectionName)
|
||||
.build();
|
||||
|
||||
R<Boolean> booleanR = milvusServiceClient.hasCollection(hasCollectionParam);
|
||||
if (booleanR.getStatus() != R.Status.Success.getCode() || !booleanR.getData().booleanValue()) {
|
||||
System.err.println("集合 " + fullCollectionName + " 不存在或检查集合存在性时出错。");
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
DescribeIndexParam describeIndexParam = DescribeIndexParam.newBuilder().withCollectionName(fullCollectionName).build();
|
||||
|
||||
R<DescribeIndexResponse> describeIndexResponseR = milvusServiceClient.describeIndex(describeIndexParam);
|
||||
|
||||
if (describeIndexResponseR.getStatus() == R.Status.Success.getCode()) {
|
||||
System.out.println("索引信息: " + describeIndexResponseR.getData().getIndexDescriptionsCount());
|
||||
} else {
|
||||
System.err.println("获取索引失败: " + describeIndexResponseR.getMessage());
|
||||
}
|
||||
|
||||
// // 加载集合到内存
|
||||
// LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
|
||||
// .withCollectionName(fullCollectionName)
|
||||
// .build();
|
||||
// R<RpcStatus> loadResponse = milvusServiceClient.loadCollection(loadCollectionParam);
|
||||
// if (loadResponse.getStatus() != R.Status.Success.getCode()) {
|
||||
// System.err.println("加载集合 " + fullCollectionName + " 到内存时出错:" + loadResponse.getMessage());
|
||||
// return new ArrayList<>();
|
||||
// }
|
||||
|
||||
List<String> search_output_fields = Arrays.asList("content", "fv");
|
||||
List<Float> fv = new ArrayList<>();
|
||||
for (int i = 0; i < queryVector.size(); i++) {
|
||||
fv.add(queryVector.get(i).floatValue());
|
||||
}
|
||||
List<List<Float>> vectors = new ArrayList<>();
|
||||
vectors.add(fv);
|
||||
String search_param = "{\"nprobe\":10, \"offset\":0}";
|
||||
SearchParam searchParam = SearchParam.newBuilder()
|
||||
.withCollectionName(collectionName + kid)
|
||||
.withMetricType(MetricType.IP)
|
||||
.withOutFields(search_output_fields)
|
||||
.withTopK(10)
|
||||
.withVectors(vectors)
|
||||
.withVectorFieldName("fv")
|
||||
.withParams(search_param)
|
||||
.build();
|
||||
System.out.println("SearchParam: " + searchParam.toString());
|
||||
R<SearchResults> respSearch = milvusServiceClient.search(searchParam);
|
||||
if (respSearch.getStatus() == R.Status.Success.getCode()) {
|
||||
SearchResults searchResults = respSearch.getData();
|
||||
if (searchResults != null) {
|
||||
System.out.println(searchResults.getResults());
|
||||
SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(searchResults.getResults());
|
||||
List<QueryResultsWrapper.RowRecord> rowRecords = wrapperSearch.getRowRecords();
|
||||
|
||||
List<String> resultList = new ArrayList<>();
|
||||
if (rowRecords != null && !rowRecords.isEmpty()) {
|
||||
for (QueryResultsWrapper.RowRecord rowRecord : rowRecords) {
|
||||
String content = rowRecord.get("content").toString();
|
||||
resultList.add(content);
|
||||
}
|
||||
}
|
||||
return resultList;
|
||||
} else {
|
||||
System.err.println("搜索结果为空");
|
||||
}
|
||||
} else {
|
||||
System.err.println("搜索操作失败: " + respSearch.getMessage());
|
||||
}
|
||||
return new ArrayList<>();
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* milvus 不支持通过文本检索相似性
|
||||
*
|
||||
* @param query
|
||||
* @param kid
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public List<String> nearest(String query, String kid) {
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package org.ruoyi.chat.service.knowledge.vectorstore;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.mapper.KnowledgeInfoMapper;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class VectorStoreFactory {
|
||||
|
||||
private final WeaviateVectorStore weaviateVectorStore;
|
||||
|
||||
private final MilvusVectorStore milvusVectorStore;
|
||||
|
||||
@Resource
|
||||
private KnowledgeInfoMapper knowledgeInfoMapper;
|
||||
|
||||
public VectorStoreFactory(WeaviateVectorStore weaviateVectorStore, MilvusVectorStore milvusVectorStore) {
|
||||
this.weaviateVectorStore = weaviateVectorStore;
|
||||
this.milvusVectorStore = milvusVectorStore;
|
||||
}
|
||||
|
||||
public VectorStoreService getVectorStore(String kid){
|
||||
String vectorModel = "weaviate";
|
||||
if (StrUtil.isNotEmpty(kid)) {
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoMapper.selectVoById(Long.valueOf(kid));
|
||||
if (knowledgeInfoVo != null && StrUtil.isNotEmpty(knowledgeInfoVo.getVector())) {
|
||||
vectorModel = knowledgeInfoVo.getVector();
|
||||
}
|
||||
}
|
||||
if ("weaviate".equals(vectorModel)){
|
||||
return weaviateVectorStore;
|
||||
}else if ("milvus".equals(vectorModel)){
|
||||
return milvusVectorStore;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,402 @@
|
||||
package org.ruoyi.chat.service.knowledge.vectorstore;
|
||||
|
||||
import cn.hutool.core.lang.UUID;
|
||||
import cn.hutool.json.JSONObject;
|
||||
import com.google.gson.internal.LinkedTreeMap;
|
||||
import io.weaviate.client.Config;
|
||||
import io.weaviate.client.WeaviateClient;
|
||||
import io.weaviate.client.base.Result;
|
||||
import io.weaviate.client.v1.data.model.WeaviateObject;
|
||||
import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
|
||||
import io.weaviate.client.v1.filters.Operator;
|
||||
import io.weaviate.client.v1.filters.WhereFilter;
|
||||
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
|
||||
import io.weaviate.client.v1.graphql.query.argument.NearTextArgument;
|
||||
import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
|
||||
import io.weaviate.client.v1.graphql.query.fields.Field;
|
||||
import io.weaviate.client.v1.misc.model.Meta;
|
||||
import io.weaviate.client.v1.misc.model.ReplicationConfig;
|
||||
import io.weaviate.client.v1.misc.model.ShardingConfig;
|
||||
import io.weaviate.client.v1.misc.model.VectorIndexConfig;
|
||||
import io.weaviate.client.v1.schema.model.DataType;
|
||||
import io.weaviate.client.v1.schema.model.Property;
|
||||
import io.weaviate.client.v1.schema.model.Schema;
|
||||
import io.weaviate.client.v1.schema.model.WeaviateClass;
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.ruoyi.common.core.service.ConfigService;
|
||||
import org.ruoyi.domain.vo.KnowledgeInfoVo;
|
||||
import org.ruoyi.service.IKnowledgeInfoService;
|
||||
import org.ruoyi.service.VectorStoreService;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class WeaviateVectorStore implements VectorStoreService {
|
||||
|
||||
private volatile String protocol;
|
||||
private volatile String host;
|
||||
private volatile String className;
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private IKnowledgeInfoService knowledgeInfoService;
|
||||
|
||||
@Lazy
|
||||
@Resource
|
||||
private ConfigService configService;
|
||||
|
||||
@PostConstruct
|
||||
public void loadConfig() {
|
||||
this.protocol = configService.getConfigValue("weaviate", "protocol");
|
||||
this.host = configService.getConfigValue("weaviate", "host");
|
||||
this.className = configService.getConfigValue("weaviate", "classname");
|
||||
}
|
||||
|
||||
public WeaviateClient getClient() {
|
||||
Config config = new Config(protocol, host);
|
||||
WeaviateClient client = new WeaviateClient(config);
|
||||
return client;
|
||||
}
|
||||
|
||||
public Result<Meta> getMeta() {
|
||||
WeaviateClient client = getClient();
|
||||
Result<Meta> meta = client.misc().metaGetter().run();
|
||||
if (meta.getError() == null) {
|
||||
System.out.printf("meta.hostname: %s\n", meta.getResult().getHostname());
|
||||
System.out.printf("meta.version: %s\n", meta.getResult().getVersion());
|
||||
System.out.printf("meta.modules: %s\n", meta.getResult().getModules());
|
||||
} else {
|
||||
System.out.printf("Error: %s\n", meta.getError().getMessages());
|
||||
}
|
||||
return meta;
|
||||
}
|
||||
|
||||
public Result<Schema> getSchemas() {
|
||||
WeaviateClient client = getClient();
|
||||
Result<Schema> result = client.schema().getter().run();
|
||||
if (result.hasErrors()) {
|
||||
System.out.println(result.getError());
|
||||
} else {
|
||||
System.out.println(result.getResult());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
public Result<Boolean> createSchema(String kid) {
|
||||
WeaviateClient client = getClient();
|
||||
|
||||
VectorIndexConfig vectorIndexConfig = VectorIndexConfig.builder()
|
||||
.distance("cosine")
|
||||
.cleanupIntervalSeconds(300)
|
||||
.efConstruction(128)
|
||||
.maxConnections(64)
|
||||
.vectorCacheMaxObjects(500000L)
|
||||
.ef(-1)
|
||||
.skip(false)
|
||||
.dynamicEfFactor(8)
|
||||
.dynamicEfMax(500)
|
||||
.dynamicEfMin(100)
|
||||
.flatSearchCutoff(40000)
|
||||
.build();
|
||||
|
||||
ShardingConfig shardingConfig = ShardingConfig.builder()
|
||||
.desiredCount(3)
|
||||
.desiredVirtualCount(128)
|
||||
.function("murmur3")
|
||||
.key("_id")
|
||||
.strategy("hash")
|
||||
.virtualPerPhysical(128)
|
||||
.build();
|
||||
|
||||
ReplicationConfig replicationConfig = ReplicationConfig.builder()
|
||||
.factor(1)
|
||||
.build();
|
||||
|
||||
JSONObject classModuleConfigValue = new JSONObject();
|
||||
classModuleConfigValue.put("vectorizeClassName", false);
|
||||
JSONObject classModuleConfig = new JSONObject();
|
||||
classModuleConfig.put("text2vec-transformers", classModuleConfigValue);
|
||||
|
||||
JSONObject propertyModuleConfigValueSkipTrue = new JSONObject();
|
||||
propertyModuleConfigValueSkipTrue.put("vectorizePropertyName", false);
|
||||
propertyModuleConfigValueSkipTrue.put("skip", true);
|
||||
JSONObject propertyModuleConfigSkipTrue = new JSONObject();
|
||||
propertyModuleConfigSkipTrue.put("text2vec-transformers", propertyModuleConfigValueSkipTrue);
|
||||
|
||||
JSONObject propertyModuleConfigValueSkipFalse = new JSONObject();
|
||||
propertyModuleConfigValueSkipFalse.put("vectorizePropertyName", false);
|
||||
propertyModuleConfigValueSkipFalse.put("skip", false);
|
||||
JSONObject propertyModuleConfigSkipFalse = new JSONObject();
|
||||
propertyModuleConfigSkipFalse.put("text2vec-transformers", propertyModuleConfigValueSkipFalse);
|
||||
|
||||
WeaviateClass clazz = WeaviateClass.builder()
|
||||
.className(className + kid)
|
||||
.description("local knowledge")
|
||||
.vectorIndexType("hnsw")
|
||||
.vectorizer("text2vec-transformers")
|
||||
.shardingConfig(shardingConfig)
|
||||
.vectorIndexConfig(vectorIndexConfig)
|
||||
.replicationConfig(replicationConfig)
|
||||
.moduleConfig(classModuleConfig)
|
||||
.properties(new ArrayList() {
|
||||
{
|
||||
add(Property.builder()
|
||||
.dataType(new ArrayList() {
|
||||
{
|
||||
add(DataType.TEXT);
|
||||
}
|
||||
})
|
||||
.name("content")
|
||||
.description("The content of the local knowledge,for search")
|
||||
.moduleConfig(propertyModuleConfigSkipFalse)
|
||||
.build());
|
||||
add(Property.builder()
|
||||
.dataType(new ArrayList() {
|
||||
{
|
||||
add(DataType.TEXT);
|
||||
}
|
||||
})
|
||||
.name("kid")
|
||||
.description("The knowledge id of the local knowledge,for search")
|
||||
.moduleConfig(propertyModuleConfigSkipTrue)
|
||||
.build());
|
||||
add(Property.builder()
|
||||
.dataType(new ArrayList() {
|
||||
{
|
||||
add(DataType.TEXT);
|
||||
}
|
||||
})
|
||||
.name("docId")
|
||||
.description("The doc id of the local knowledge,for search")
|
||||
.moduleConfig(propertyModuleConfigSkipTrue)
|
||||
.build());
|
||||
add(Property.builder()
|
||||
.dataType(new ArrayList() {
|
||||
{
|
||||
add(DataType.TEXT);
|
||||
}
|
||||
})
|
||||
.name("fid")
|
||||
.description("The fragment id of the local knowledge,for search")
|
||||
.moduleConfig(propertyModuleConfigSkipTrue)
|
||||
.build());
|
||||
add(Property.builder()
|
||||
.dataType(new ArrayList() {
|
||||
{
|
||||
add(DataType.TEXT);
|
||||
}
|
||||
})
|
||||
.name("uuid")
|
||||
.description("The uuid id of the local knowledge fragment(same with id properties),for search")
|
||||
.moduleConfig(propertyModuleConfigSkipTrue)
|
||||
.build());
|
||||
} })
|
||||
.build();
|
||||
|
||||
Result<Boolean> result = client.schema().classCreator().withClass(clazz).run();
|
||||
if (result.hasErrors()) {
|
||||
System.out.println(result.getError());
|
||||
}
|
||||
System.out.println(result.getResult());
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void newSchema(String kid) {
|
||||
createSchema(kid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByKidAndFid(String kid, String fid) {
|
||||
List<String> resultList = new ArrayList<>();
|
||||
WeaviateClient client = getClient();
|
||||
Field fieldId = Field.builder().name("uuid").build();
|
||||
WhereFilter where = WhereFilter.builder()
|
||||
.path(new String[]{"fid"})
|
||||
.operator(Operator.Equal)
|
||||
.valueString(fid)
|
||||
.build();
|
||||
Result<GraphQLResponse> result = client.graphQL().get()
|
||||
.withClassName(className + kid)
|
||||
.withFields(fieldId)
|
||||
.withWhere(where)
|
||||
.run();
|
||||
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
|
||||
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
|
||||
ArrayList<LinkedTreeMap> m = l.get(className + kid);
|
||||
for (LinkedTreeMap linkedTreeMap : m) {
|
||||
String uuid = linkedTreeMap.get("uuid").toString();
|
||||
resultList.add(uuid);
|
||||
}
|
||||
for (String uuid : resultList) {
|
||||
Result<Boolean> deleteResult = client.data().deleter()
|
||||
.withID(uuid)
|
||||
.withClassName(className + kid)
|
||||
.withConsistencyLevel(ConsistencyLevel.ALL) // default QUORUM
|
||||
.run();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList) {
|
||||
WeaviateClient client = getClient();
|
||||
|
||||
for (int i = 0; i < Math.min(chunkList.size(), vectorList.size()); i++) {
|
||||
List<Double> vector = vectorList.get(i);
|
||||
Float[] vf = vector.stream().map(Double::floatValue).toArray(Float[]::new);
|
||||
|
||||
Map<String, Object> dataSchema = new HashMap<>();
|
||||
dataSchema.put("content", chunkList.get(i));
|
||||
dataSchema.put("kid", kid);
|
||||
dataSchema.put("docId", docId);
|
||||
dataSchema.put("fid", fidList.get(i));
|
||||
String uuid = UUID.randomUUID().toString();
|
||||
dataSchema.put("uuid", uuid);
|
||||
|
||||
Result<WeaviateObject> result = client.data().creator()
|
||||
.withClassName(className + kid)
|
||||
.withID(uuid)
|
||||
.withVector(vf)
|
||||
.withProperties(dataSchema)
|
||||
.run();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByDocId(String kid, String docId) {
|
||||
List<String> resultList = new ArrayList<>();
|
||||
WeaviateClient client = getClient();
|
||||
Field fieldId = Field.builder().name("uuid").build();
|
||||
WhereFilter where = WhereFilter.builder()
|
||||
.path(new String[]{"docId"})
|
||||
.operator(Operator.Equal)
|
||||
.valueString(docId)
|
||||
.build();
|
||||
Result<GraphQLResponse> result = client.graphQL().get()
|
||||
.withClassName(className + kid)
|
||||
.withFields(fieldId)
|
||||
.withWhere(where)
|
||||
.run();
|
||||
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
|
||||
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
|
||||
ArrayList<LinkedTreeMap> m = l.get(className + kid);
|
||||
for (LinkedTreeMap linkedTreeMap : m) {
|
||||
String uuid = linkedTreeMap.get("uuid").toString();
|
||||
resultList.add(uuid);
|
||||
}
|
||||
for (String uuid : resultList) {
|
||||
Result<Boolean> deleteResult = client.data().deleter()
|
||||
.withID(uuid)
|
||||
.withClassName(className + kid)
|
||||
.withConsistencyLevel(ConsistencyLevel.ALL) // default QUORUM
|
||||
.run();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeByKid(String kid) {
|
||||
WeaviateClient client = getClient();
|
||||
Result<Boolean> result = client.schema().classDeleter().withClassName(className + kid).run();
|
||||
if (result.hasErrors()) {
|
||||
System.out.println("删除schema失败" + result.getError());
|
||||
} else {
|
||||
System.out.println("删除schema成功" + result.getResult());
|
||||
}
|
||||
log.info("drop schema by kid, result = {}", result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> nearest(List<Double> queryVector, String kid) {
|
||||
if (StringUtils.isBlank(kid)) {
|
||||
return new ArrayList<String>();
|
||||
}
|
||||
List<String> resultList = new ArrayList<>();
|
||||
Float[] vf = new Float[queryVector.size()];
|
||||
for (int j = 0; j < queryVector.size(); j++) {
|
||||
Double value = queryVector.get(j);
|
||||
vf[j] = value.floatValue();
|
||||
}
|
||||
WeaviateClient client = getClient();
|
||||
Field contentField = Field.builder().name("content").build();
|
||||
Field _additional = Field.builder()
|
||||
.name("_additional")
|
||||
.fields(new Field[]{
|
||||
Field.builder().name("distance").build()
|
||||
}).build();
|
||||
NearVectorArgument nearVector = NearVectorArgument.builder()
|
||||
.vector(vf)
|
||||
.distance(1.6f) // certainty = 1f - distance /2f
|
||||
.build();
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
Result<GraphQLResponse> result = client.graphQL().get()
|
||||
.withClassName(className + kid)
|
||||
.withFields(contentField, _additional)
|
||||
.withNearVector(nearVector)
|
||||
.withLimit(knowledgeInfoVo.getRetrieveLimit())
|
||||
.run();
|
||||
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
|
||||
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
|
||||
ArrayList<LinkedTreeMap> m = l.get(className + kid);
|
||||
for (LinkedTreeMap linkedTreeMap : m) {
|
||||
String content = linkedTreeMap.get("content").toString();
|
||||
resultList.add(content);
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> nearest(String query, String kid) {
|
||||
if (StringUtils.isBlank(kid)) {
|
||||
return new ArrayList<String>();
|
||||
}
|
||||
List<String> resultList = new ArrayList<>();
|
||||
WeaviateClient client = getClient();
|
||||
Field contentField = Field.builder().name("content").build();
|
||||
Field _additional = Field.builder()
|
||||
.name("_additional")
|
||||
.fields(new Field[]{
|
||||
Field.builder().name("distance").build()
|
||||
}).build();
|
||||
NearTextArgument nearText = client.graphQL().arguments().nearTextArgBuilder()
|
||||
.concepts(new String[]{query})
|
||||
.distance(1.6f) // certainty = 1f - distance /2f
|
||||
.build();
|
||||
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
|
||||
Result<GraphQLResponse> result = client.graphQL().get()
|
||||
.withClassName(className + kid)
|
||||
.withFields(contentField, _additional)
|
||||
.withNearText(nearText)
|
||||
.withLimit(knowledgeInfoVo.getRetrieveLimit())
|
||||
.run();
|
||||
LinkedTreeMap<String, Object> t = (LinkedTreeMap<String, Object>) result.getResult().getData();
|
||||
LinkedTreeMap<String, ArrayList<LinkedTreeMap>> l = (LinkedTreeMap<String, ArrayList<LinkedTreeMap>>) t.get("Get");
|
||||
ArrayList<LinkedTreeMap> m = l.get(className + kid);
|
||||
for (LinkedTreeMap linkedTreeMap : m) {
|
||||
String content = linkedTreeMap.get("content").toString();
|
||||
resultList.add(content);
|
||||
}
|
||||
return resultList;
|
||||
}
|
||||
|
||||
public Result<Boolean> deleteSchema(String kid) {
|
||||
WeaviateClient client = getClient();
|
||||
Result<Boolean> result = client.schema().classDeleter().withClassName(className + kid).run();
|
||||
if (result.hasErrors()) {
|
||||
System.out.println(result.getError());
|
||||
} else {
|
||||
System.out.println(result.getResult());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user