mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-11 10:37:20 +00:00
feat: 支持插件功能
This commit is contained in:
@@ -73,11 +73,9 @@ public class KnowledgeController extends BaseController {
|
|||||||
*/
|
*/
|
||||||
@PostMapping("/send")
|
@PostMapping("/send")
|
||||||
public SseEmitter send(@RequestBody @Valid ChatRequest chatRequest) {
|
public SseEmitter send(@RequestBody @Valid ChatRequest chatRequest) {
|
||||||
|
|
||||||
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
||||||
SseEmitter sseEmitter = new SseEmitter(0L);
|
SseEmitter sseEmitter = new SseEmitter(0L);
|
||||||
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
|
SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter);
|
||||||
|
|
||||||
List<Message> messages = chatRequest.getMessages();
|
List<Message> messages = chatRequest.getMessages();
|
||||||
String content = messages.get(messages.size() - 1).getContent().toString();
|
String content = messages.get(messages.size() - 1).getContent().toString();
|
||||||
List<String> nearestList;
|
List<String> nearestList;
|
||||||
@@ -89,8 +87,6 @@ public class KnowledgeController extends BaseController {
|
|||||||
}
|
}
|
||||||
Message userMessage = Message.builder().content(content + (nearestList.size() > 0 ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "") ).role(Message.Role.USER).build();
|
Message userMessage = Message.builder().content(content + (nearestList.size() > 0 ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "") ).role(Message.Role.USER).build();
|
||||||
messages.add(userMessage);
|
messages.add(userMessage);
|
||||||
|
|
||||||
|
|
||||||
ChatCompletion completion = ChatCompletion
|
ChatCompletion completion = ChatCompletion
|
||||||
.builder()
|
.builder()
|
||||||
.messages(messages)
|
.messages(messages)
|
||||||
@@ -104,7 +100,6 @@ public class KnowledgeController extends BaseController {
|
|||||||
return sseEmitter;
|
return sseEmitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 根据用户信息查询本地知识库
|
* 根据用户信息查询本地知识库
|
||||||
*/
|
*/
|
||||||
@@ -117,8 +112,6 @@ public class KnowledgeController extends BaseController {
|
|||||||
return knowledgeInfoService.queryPageList(bo, pageQuery);
|
return knowledgeInfoService.queryPageList(bo, pageQuery);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 新增知识库
|
* 新增知识库
|
||||||
*/
|
*/
|
||||||
@@ -190,10 +183,9 @@ public class KnowledgeController extends BaseController {
|
|||||||
* 删除知识库附件
|
* 删除知识库附件
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@PostMapping("attach/remove/{kid}")
|
@PostMapping("attach/remove/{docId}")
|
||||||
public R<Void> removeAttach(@NotEmpty(message = "主键不能为空")
|
public R<Void> removeAttach(@NotEmpty(message = "主键不能为空") @PathVariable String docId) {
|
||||||
@PathVariable String kid) {
|
attachService.removeKnowledgeAttach(docId);
|
||||||
attachService.removeKnowledgeAttach(kid);
|
|
||||||
return R.ok();
|
return R.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,12 @@
|
|||||||
<artifactId>ruoyi-common-core</artifactId>
|
<artifactId>ruoyi-common-core</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>mysql</groupId>
|
||||||
|
<artifactId>mysql-connector-java</artifactId>
|
||||||
|
<version>8.0.33</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.azure</groupId>
|
<groupId>com.azure</groupId>
|
||||||
<artifactId>azure-ai-openai</artifactId>
|
<artifactId>azure-ai-openai</artifactId>
|
||||||
@@ -92,5 +98,25 @@
|
|||||||
</exclusion>
|
</exclusion>
|
||||||
</exclusions>
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>cn.bigmodel.openapi</groupId>
|
||||||
|
<artifactId>oapi-java-sdk</artifactId>
|
||||||
|
<version>release-V4-2.3.0</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.squareup.okhttp</groupId>
|
||||||
|
<artifactId>okhttp</artifactId>
|
||||||
|
<version>2.7.5</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package org.ruoyi.common.chat.demo;
|
||||||
|
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import okhttp3.Response;
|
||||||
|
import okhttp3.ResponseBody;
|
||||||
|
import okhttp3.sse.EventSource;
|
||||||
|
import okhttp3.sse.EventSourceListener;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 描述: sse
|
||||||
|
*
|
||||||
|
* @author https:www.unfbx.com
|
||||||
|
* 2023-06-15
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class ConsoleEventSourceListenerV2 extends EventSourceListener {
|
||||||
|
@Getter
|
||||||
|
String args = "";
|
||||||
|
final CountDownLatch countDownLatch;
|
||||||
|
|
||||||
|
public ConsoleEventSourceListenerV2(CountDownLatch countDownLatch) {
|
||||||
|
this.countDownLatch = countDownLatch;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onOpen(EventSource eventSource, Response response) {
|
||||||
|
log.info("OpenAI建立sse连接...");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onEvent(EventSource eventSource, String id, String type, String data) {
|
||||||
|
log.info("OpenAI返回数据:{}", data);
|
||||||
|
if (data.equals("[DONE]")) {
|
||||||
|
log.info("OpenAI返回数据结束了");
|
||||||
|
countDownLatch.countDown();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class);
|
||||||
|
if(Objects.nonNull(chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall())){
|
||||||
|
args += chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall().getArguments();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onClosed(EventSource eventSource) {
|
||||||
|
log.info("OpenAI关闭sse连接...");
|
||||||
|
}
|
||||||
|
|
||||||
|
@SneakyThrows
|
||||||
|
@Override
|
||||||
|
public void onFailure(EventSource eventSource, Throwable t, Response response) {
|
||||||
|
if(Objects.isNull(response)){
|
||||||
|
log.error("OpenAI sse连接异常:{}", t);
|
||||||
|
eventSource.cancel();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ResponseBody body = response.body();
|
||||||
|
if (Objects.nonNull(body)) {
|
||||||
|
log.error("OpenAI sse连接异常data:{},异常:{}", body.string(), t);
|
||||||
|
} else {
|
||||||
|
log.error("OpenAI sse连接异常data:{},异常:{}", response, t);
|
||||||
|
}
|
||||||
|
eventSource.cancel();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
package org.ruoyi.common.chat.demo;
|
||||||
|
|
||||||
|
import cn.hutool.core.collection.CollectionUtil;
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import okhttp3.Response;
|
||||||
|
import okhttp3.ResponseBody;
|
||||||
|
import okhttp3.sse.EventSource;
|
||||||
|
import okhttp3.sse.EventSourceListener;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.Message;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 描述: demo测试实现类,仅供思路参考
|
||||||
|
*
|
||||||
|
* @author https:www.unfbx.com
|
||||||
|
* 2023-11-12
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class ConsoleEventSourceListenerV3 extends EventSourceListener {
|
||||||
|
@Getter
|
||||||
|
List<ToolCalls> choices = new ArrayList<>();
|
||||||
|
@Getter
|
||||||
|
ToolCalls toolCalls = new ToolCalls();
|
||||||
|
@Getter
|
||||||
|
ToolCallFunction toolCallFunction = ToolCallFunction.builder().name("").arguments("").build();
|
||||||
|
final CountDownLatch countDownLatch;
|
||||||
|
|
||||||
|
public ConsoleEventSourceListenerV3(CountDownLatch countDownLatch) {
|
||||||
|
this.countDownLatch = countDownLatch;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onOpen(EventSource eventSource, Response response) {
|
||||||
|
log.info("OpenAI建立sse连接...");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onEvent(EventSource eventSource, String id, String type, String data) {
|
||||||
|
log.info("OpenAI返回数据:{}", data);
|
||||||
|
if (data.equals("[DONE]")) {
|
||||||
|
log.info("OpenAI返回数据结束了");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class);
|
||||||
|
Message delta = chatCompletionResponse.getChoices().get(0).getDelta();
|
||||||
|
if (CollectionUtil.isNotEmpty(delta.getToolCalls())) {
|
||||||
|
choices.addAll(delta.getToolCalls());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onClosed(EventSource eventSource) {
|
||||||
|
if(CollectionUtil.isNotEmpty(choices)){
|
||||||
|
toolCalls.setId(choices.get(0).getId());
|
||||||
|
toolCalls.setType(choices.get(0).getType());
|
||||||
|
choices.forEach(e -> {
|
||||||
|
toolCallFunction.setName(e.getFunction().getName());
|
||||||
|
toolCallFunction.setArguments(toolCallFunction.getArguments() + e.getFunction().getArguments());
|
||||||
|
toolCalls.setFunction(toolCallFunction);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
log.info("OpenAI关闭sse连接...");
|
||||||
|
countDownLatch.countDown();
|
||||||
|
}
|
||||||
|
|
||||||
|
@SneakyThrows
|
||||||
|
@Override
|
||||||
|
public void onFailure(EventSource eventSource, Throwable t, Response response) {
|
||||||
|
if(Objects.isNull(response)){
|
||||||
|
log.error("OpenAI sse连接异常:{}", t);
|
||||||
|
eventSource.cancel();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ResponseBody body = response.body();
|
||||||
|
if (Objects.nonNull(body)) {
|
||||||
|
log.error("OpenAI sse连接异常data:{},异常:{}", body.string(), t);
|
||||||
|
} else {
|
||||||
|
log.error("OpenAI sse连接异常data:{},异常:{}", response, t);
|
||||||
|
}
|
||||||
|
eventSource.cancel();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,417 @@
|
|||||||
|
package org.ruoyi.common.chat.demo;
|
||||||
|
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import okhttp3.OkHttpClient;
|
||||||
|
import okhttp3.logging.HttpLoggingInterceptor;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.*;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.tool.Tools;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.tool.ToolsFunction;
|
||||||
|
import org.ruoyi.common.chat.openai.OpenAiClient;
|
||||||
|
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||||
|
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
|
||||||
|
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
|
||||||
|
import org.ruoyi.common.chat.openai.interceptor.OpenAILogger;
|
||||||
|
import org.ruoyi.common.chat.openai.interceptor.OpenAiResponseInterceptor;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||||
|
import org.ruoyi.common.chat.plugin.CmdPlugin;
|
||||||
|
import org.ruoyi.common.chat.plugin.CmdReq;
|
||||||
|
import org.ruoyi.common.chat.sse.ConsoleEventSourceListener;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 描述:
|
||||||
|
*
|
||||||
|
* @author ageerle@163.com
|
||||||
|
* date 2025/3/8
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class PluginTest {
|
||||||
|
|
||||||
|
private OpenAiClient openAiClient;
|
||||||
|
private OpenAiStreamClient openAiStreamClient;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void before() {
|
||||||
|
//可以为null
|
||||||
|
// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
|
||||||
|
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
|
||||||
|
//!!!!千万别再生产或者测试环境打开BODY级别日志!!!!
|
||||||
|
//!!!生产或者测试环境建议设置为这三种级别:NONE,BASIC,HEADERS,!!!
|
||||||
|
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
|
||||||
|
OkHttpClient okHttpClient = new OkHttpClient
|
||||||
|
.Builder()
|
||||||
|
// .proxy(proxy)
|
||||||
|
.addInterceptor(httpLoggingInterceptor)
|
||||||
|
.addInterceptor(new OpenAiResponseInterceptor())
|
||||||
|
.connectTimeout(10, TimeUnit.SECONDS)
|
||||||
|
.writeTimeout(30, TimeUnit.SECONDS)
|
||||||
|
.readTimeout(30, TimeUnit.SECONDS)
|
||||||
|
.build();
|
||||||
|
openAiClient = OpenAiClient.builder()
|
||||||
|
//支持多key传入,请求时候随机选择
|
||||||
|
.apiKey(Arrays.asList("sk-xx"))
|
||||||
|
//自定义key的获取策略:默认KeyRandomStrategy
|
||||||
|
//.keyStrategy(new KeyRandomStrategy())
|
||||||
|
.keyStrategy(new KeyRandomStrategy())
|
||||||
|
.okHttpClient(okHttpClient)
|
||||||
|
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址)
|
||||||
|
.apiHost("https://api.pandarobot.chat/")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
openAiStreamClient = OpenAiStreamClient.builder()
|
||||||
|
//支持多key传入,请求时候随机选择
|
||||||
|
.apiKey(Arrays.asList("sk-xx"))
|
||||||
|
//自定义key的获取策略:默认KeyRandomStrategy
|
||||||
|
.keyStrategy(new KeyRandomStrategy())
|
||||||
|
.authInterceptor(new DynamicKeyOpenAiAuthInterceptor())
|
||||||
|
.okHttpClient(okHttpClient)
|
||||||
|
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址)
|
||||||
|
.apiHost("https://api.pandarobot.chat/")
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void chatFunction() {
|
||||||
|
//模型:GPT_3_5_TURBO_16K_0613
|
||||||
|
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build();
|
||||||
|
//属性一
|
||||||
|
JSONObject wordLength = new JSONObject();
|
||||||
|
wordLength.put("type", "number");
|
||||||
|
wordLength.put("description", "词语的长度");
|
||||||
|
//属性二
|
||||||
|
JSONObject language = new JSONObject();
|
||||||
|
language.put("type", "string");
|
||||||
|
language.put("enum", Arrays.asList("zh", "en"));
|
||||||
|
language.put("description", "语言类型,例如:zh代表中文、en代表英语");
|
||||||
|
//参数
|
||||||
|
JSONObject properties = new JSONObject();
|
||||||
|
properties.put("wordLength", wordLength);
|
||||||
|
properties.put("language", language);
|
||||||
|
|
||||||
|
Parameters parameters = Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(properties)
|
||||||
|
.required(Collections.singletonList("wordLength")).build();
|
||||||
|
Functions functions = Functions.builder()
|
||||||
|
.name("getOneWord")
|
||||||
|
.description("获取一个指定长度和语言类型的词语")
|
||||||
|
.parameters(parameters)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ChatCompletion chatCompletion = ChatCompletion
|
||||||
|
.builder()
|
||||||
|
.messages(Collections.singletonList(message))
|
||||||
|
.functions(Collections.singletonList(functions))
|
||||||
|
.functionCall("auto")
|
||||||
|
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||||
|
.build();
|
||||||
|
ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion);
|
||||||
|
|
||||||
|
ChatChoice chatChoice = chatCompletionResponse.getChoices().get(0);
|
||||||
|
log.info("构造的方法值:{}", chatChoice.getMessage().getFunctionCall());
|
||||||
|
log.info("构造的方法名称:{}", chatChoice.getMessage().getFunctionCall().getName());
|
||||||
|
log.info("构造的方法参数:{}", chatChoice.getMessage().getFunctionCall().getArguments());
|
||||||
|
WordParam wordParam = JSONUtil.toBean(chatChoice.getMessage().getFunctionCall().getArguments(), WordParam.class);
|
||||||
|
String oneWord = getOneWord(wordParam);
|
||||||
|
|
||||||
|
FunctionCall functionCall = FunctionCall.builder()
|
||||||
|
.arguments(chatChoice.getMessage().getFunctionCall().getArguments())
|
||||||
|
.name("getOneWord")
|
||||||
|
.build();
|
||||||
|
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").functionCall(functionCall).build();
|
||||||
|
String content
|
||||||
|
= "{ " +
|
||||||
|
"\"wordLength\": \"3\", " +
|
||||||
|
"\"language\": \"zh\", " +
|
||||||
|
"\"word\": \"" + oneWord + "\"," +
|
||||||
|
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
|
||||||
|
"}";
|
||||||
|
Message message3 = Message.builder().role(Message.Role.FUNCTION).name("getOneWord").content(content).build();
|
||||||
|
List<Message> messageList = Arrays.asList(message, message2, message3);
|
||||||
|
ChatCompletion chatCompletionV2 = ChatCompletion
|
||||||
|
.builder()
|
||||||
|
.messages(messageList)
|
||||||
|
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||||
|
.build();
|
||||||
|
ChatCompletionResponse chatCompletionResponseV2 = openAiClient.chatCompletion(chatCompletionV2);
|
||||||
|
log.info("自定义的方法返回值:{}",chatCompletionResponseV2.getChoices().get(0).getMessage().getContent());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void plugin() {
|
||||||
|
CmdPlugin plugin = new CmdPlugin(CmdReq.class);
|
||||||
|
// 插件名称
|
||||||
|
plugin.setName("命令行工具");
|
||||||
|
// 方法名称
|
||||||
|
plugin.setFunction("openCmd");
|
||||||
|
// 方法说明
|
||||||
|
plugin.setDescription("提供一个命令行指令,比如<记事本>,指令使用中文,以function返回结果为准");
|
||||||
|
|
||||||
|
PluginAbstract.Arg arg = new PluginAbstract.Arg();
|
||||||
|
// 参数名称
|
||||||
|
arg.setName("cmd");
|
||||||
|
// 参数说明
|
||||||
|
arg.setDescription("命令行指令");
|
||||||
|
// 参数类型
|
||||||
|
arg.setType("string");
|
||||||
|
arg.setRequired(true);
|
||||||
|
plugin.setArgs(Collections.singletonList(arg));
|
||||||
|
|
||||||
|
Message message2 = Message.builder().role(Message.Role.USER).content("帮我打开计算器,结合上下文判断指令是否执行成功,只用回复成功或者失败").build();
|
||||||
|
List<Message> messages = new ArrayList<>();
|
||||||
|
messages.add(message2);
|
||||||
|
//有四个重载方法,都可以使用
|
||||||
|
ChatCompletionResponse response = openAiClient.chatCompletionWithPlugin(messages,"gpt-4o-mini",plugin);
|
||||||
|
log.info("自定义的方法返回值:{}", response.getChoices().get(0).getMessage().getContent());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 自定义返回数据格式
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void diyReturnDataModelChat() {
|
||||||
|
Message message = Message.builder().role(Message.Role.USER).content("随机输出10个单词,使用json输出").build();
|
||||||
|
ChatCompletion chatCompletion = ChatCompletion
|
||||||
|
.builder()
|
||||||
|
.messages(Collections.singletonList(message))
|
||||||
|
.responseFormat(ResponseFormat.builder().type(ResponseFormat.Type.JSON_OBJECT.getName()).build())
|
||||||
|
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||||
|
.build();
|
||||||
|
ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion);
|
||||||
|
chatCompletionResponse.getChoices().forEach(e -> System.out.println(e.getMessage()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void streamPlugin() {
|
||||||
|
WeatherPlugin plugin = new WeatherPlugin(WeatherReq.class);
|
||||||
|
plugin.setName("知心天气");
|
||||||
|
plugin.setFunction("getLocationWeather");
|
||||||
|
plugin.setDescription("提供一个地址,方法将会获取该地址的天气的实时温度信息。");
|
||||||
|
PluginAbstract.Arg arg = new PluginAbstract.Arg();
|
||||||
|
arg.setName("location");
|
||||||
|
arg.setDescription("地名");
|
||||||
|
arg.setType("string");
|
||||||
|
arg.setRequired(true);
|
||||||
|
plugin.setArgs(Collections.singletonList(arg));
|
||||||
|
|
||||||
|
// Message message1 = Message.builder().role(Message.Role.USER).content("秦始皇统一了哪六国。").build();
|
||||||
|
Message message2 = Message.builder().role(Message.Role.USER).content("获取上海市的天气现在多少度,然后再给出3个推荐的户外运动。").build();
|
||||||
|
List<Message> messages = new ArrayList<>();
|
||||||
|
// messages.add(message1);
|
||||||
|
messages.add(message2);
|
||||||
|
//默认模型:GPT_3_5_TURBO_16K_0613
|
||||||
|
//有四个重载方法,都可以使用
|
||||||
|
openAiStreamClient.streamChatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_4_1106_PREVIEW.getName(), new ConsoleEventSourceListener(), plugin);
|
||||||
|
CountDownLatch countDownLatch = new CountDownLatch(1);
|
||||||
|
try {
|
||||||
|
countDownLatch.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* tools使用示例
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void toolsChat() {
|
||||||
|
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build();
|
||||||
|
//属性一
|
||||||
|
JSONObject wordLength = new JSONObject();
|
||||||
|
wordLength.put("type", "number");
|
||||||
|
wordLength.put("description", "词语的长度");
|
||||||
|
//属性二
|
||||||
|
JSONObject language = new JSONObject();
|
||||||
|
language.put("type", "string");
|
||||||
|
language.put("enum", Arrays.asList("zh", "en"));
|
||||||
|
language.put("description", "语言类型,例如:zh代表中文、en代表英语");
|
||||||
|
//参数
|
||||||
|
JSONObject properties = new JSONObject();
|
||||||
|
properties.put("wordLength", wordLength);
|
||||||
|
properties.put("language", language);
|
||||||
|
Parameters parameters = Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(properties)
|
||||||
|
.required(Collections.singletonList("wordLength")).build();
|
||||||
|
Tools tools = Tools.builder()
|
||||||
|
.type(Tools.Type.FUNCTION.getName())
|
||||||
|
.function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ChatCompletion chatCompletion = ChatCompletion
|
||||||
|
.builder()
|
||||||
|
.messages(Collections.singletonList(message))
|
||||||
|
.tools(Collections.singletonList(tools))
|
||||||
|
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||||
|
.build();
|
||||||
|
ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion);
|
||||||
|
|
||||||
|
ChatChoice chatChoice = chatCompletionResponse.getChoices().get(0);
|
||||||
|
log.info("构造的方法值:{}", chatChoice.getMessage().getToolCalls());
|
||||||
|
|
||||||
|
ToolCalls openAiReturnToolCalls = chatChoice.getMessage().getToolCalls().get(0);
|
||||||
|
WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class);
|
||||||
|
String oneWord = getOneWord(wordParam);
|
||||||
|
|
||||||
|
|
||||||
|
ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build();
|
||||||
|
ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build();
|
||||||
|
//构造tool call
|
||||||
|
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build();
|
||||||
|
String content
|
||||||
|
= "{ " +
|
||||||
|
"\"wordLength\": \"3\", " +
|
||||||
|
"\"language\": \"zh\", " +
|
||||||
|
"\"word\": \"" + oneWord + "\"," +
|
||||||
|
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
|
||||||
|
"}";
|
||||||
|
Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build();
|
||||||
|
List<Message> messageList = Arrays.asList(message, message2, message3);
|
||||||
|
ChatCompletion chatCompletionV2 = ChatCompletion
|
||||||
|
.builder()
|
||||||
|
.messages(messageList)
|
||||||
|
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||||
|
.build();
|
||||||
|
ChatCompletionResponse chatCompletionResponseV2 = openAiClient.chatCompletion(chatCompletionV2);
|
||||||
|
log.info("自定义的方法返回值:{}", chatCompletionResponseV2.getChoices().get(0).getMessage().getContent());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* tools流式输出使用示例
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void streamToolsChat() {
|
||||||
|
|
||||||
|
CountDownLatch countDownLatch = new CountDownLatch(1);
|
||||||
|
ConsoleEventSourceListenerV3 eventSourceListener = new ConsoleEventSourceListenerV3(countDownLatch);
|
||||||
|
|
||||||
|
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build();
|
||||||
|
//属性一
|
||||||
|
JSONObject wordLength = new JSONObject();
|
||||||
|
wordLength.put("type", "number");
|
||||||
|
wordLength.put("description", "词语的长度");
|
||||||
|
//属性二
|
||||||
|
JSONObject language = new JSONObject();
|
||||||
|
language.put("type", "string");
|
||||||
|
language.put("enum", Arrays.asList("zh", "en"));
|
||||||
|
language.put("description", "语言类型,例如:zh代表中文、en代表英语");
|
||||||
|
//参数
|
||||||
|
JSONObject properties = new JSONObject();
|
||||||
|
properties.put("wordLength", wordLength);
|
||||||
|
properties.put("language", language);
|
||||||
|
Parameters parameters = Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(properties)
|
||||||
|
.required(Collections.singletonList("wordLength")).build();
|
||||||
|
Tools tools = Tools.builder()
|
||||||
|
.type(Tools.Type.FUNCTION.getName())
|
||||||
|
.function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ChatCompletion chatCompletion = ChatCompletion
|
||||||
|
.builder()
|
||||||
|
.messages(Collections.singletonList(message))
|
||||||
|
.tools(Collections.singletonList(tools))
|
||||||
|
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||||
|
.build();
|
||||||
|
openAiStreamClient.streamChatCompletion(chatCompletion, eventSourceListener);
|
||||||
|
|
||||||
|
try {
|
||||||
|
countDownLatch.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
|
||||||
|
ToolCalls openAiReturnToolCalls = eventSourceListener.getToolCalls();
|
||||||
|
WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class);
|
||||||
|
String oneWord = getOneWord(wordParam);
|
||||||
|
|
||||||
|
|
||||||
|
ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build();
|
||||||
|
ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build();
|
||||||
|
//构造tool call
|
||||||
|
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build();
|
||||||
|
String content
|
||||||
|
= "{ " +
|
||||||
|
"\"wordLength\": \"3\", " +
|
||||||
|
"\"language\": \"zh\", " +
|
||||||
|
"\"word\": \"" + oneWord + "\"," +
|
||||||
|
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
|
||||||
|
"}";
|
||||||
|
Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build();
|
||||||
|
List<Message> messageList = Arrays.asList(message, message2, message3);
|
||||||
|
ChatCompletion chatCompletionV2 = ChatCompletion
|
||||||
|
.builder()
|
||||||
|
.messages(messageList)
|
||||||
|
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
CountDownLatch countDownLatch1 = new CountDownLatch(1);
|
||||||
|
openAiStreamClient.streamChatCompletion(chatCompletionV2, new ConsoleEventSourceListenerV3(countDownLatch));
|
||||||
|
try {
|
||||||
|
countDownLatch1.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
countDownLatch1.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
static class WordParam {
|
||||||
|
private int wordLength;
|
||||||
|
@Builder.Default
|
||||||
|
private String language = "zh";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取一个词语(根据语言和字符长度查询)
|
||||||
|
* @param wordParam
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public String getOneWord(WordParam wordParam) {
|
||||||
|
|
||||||
|
List<String> zh = Arrays.asList("大香蕉", "哈密瓜", "苹果");
|
||||||
|
List<String> en = Arrays.asList("apple", "banana", "cantaloupe");
|
||||||
|
if (wordParam.getLanguage().equals("zh")) {
|
||||||
|
for (String e : zh) {
|
||||||
|
if (e.length() == wordParam.getWordLength()) {
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (wordParam.getLanguage().equals("en")) {
|
||||||
|
for (String e : en) {
|
||||||
|
if (e.length() == wordParam.getWordLength()) {
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "西瓜";
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package org.ruoyi.common.chat.demo;
|
||||||
|
|
||||||
|
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||||
|
|
||||||
|
public class WeatherPlugin extends PluginAbstract<WeatherReq, WeatherResp> {
|
||||||
|
|
||||||
|
public WeatherPlugin(Class<?> r) {
|
||||||
|
super(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public WeatherResp func(WeatherReq args) {
|
||||||
|
WeatherResp weatherResp = new WeatherResp();
|
||||||
|
weatherResp.setTemp("25到28摄氏度");
|
||||||
|
weatherResp.setLevel(3);
|
||||||
|
return weatherResp;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String content(WeatherResp weatherResp) {
|
||||||
|
return "当前天气温度:" + weatherResp.getTemp() + ",风力等级:" + weatherResp.getLevel();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package org.ruoyi.common.chat.demo;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginParam;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class WeatherReq extends PluginParam {
|
||||||
|
/**
|
||||||
|
* 城市
|
||||||
|
*/
|
||||||
|
private String location;
|
||||||
|
}
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package org.ruoyi.common.chat.demo;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class WeatherResp {
|
||||||
|
/**
|
||||||
|
* 温度
|
||||||
|
*/
|
||||||
|
private String temp;
|
||||||
|
/**
|
||||||
|
* 风力等级
|
||||||
|
*/
|
||||||
|
private Integer level;
|
||||||
|
}
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package org.ruoyi.common.chat.demo.zhipu;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||||
|
import com.zhipu.oapi.ClientV4;
|
||||||
|
import com.zhipu.oapi.Constants;
|
||||||
|
import com.zhipu.oapi.service.v4.deserialize.MessageDeserializeFactory;
|
||||||
|
import com.zhipu.oapi.service.v4.model.*;
|
||||||
|
import io.reactivex.Flowable;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
|
||||||
|
public class AllToolsTest {
|
||||||
|
|
||||||
|
private final static Logger logger = LoggerFactory.getLogger(AllToolsTest.class);
|
||||||
|
private static final String API_SECRET_KEY = "28550a39d4cfaabbbf38df04dd3931f5.IUvfTThUf0xBF5l0";
|
||||||
|
|
||||||
|
private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY)
|
||||||
|
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
|
||||||
|
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
|
||||||
|
.build();
|
||||||
|
private static final ObjectMapper mapper = MessageDeserializeFactory.defaultObjectMapper();
|
||||||
|
// 请自定义自己的业务id
|
||||||
|
private static final String requestIdTemplate = "mycompany-%d";
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test1() throws JsonProcessingException {
|
||||||
|
|
||||||
|
|
||||||
|
List<ChatMessage> messages = new ArrayList<>();
|
||||||
|
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "帮我查询北京天气");
|
||||||
|
messages.add(chatMessage);
|
||||||
|
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
||||||
|
// 函数调用参数构建部分
|
||||||
|
List<ChatTool> chatToolList = new ArrayList<>();
|
||||||
|
ChatTool chatTool = new ChatTool();
|
||||||
|
|
||||||
|
chatTool.setType("code_interpreter");
|
||||||
|
ObjectNode objectNode = mapper.createObjectNode();
|
||||||
|
objectNode.put("code", "北京天气");
|
||||||
|
// chatTool.set(chatFunction);
|
||||||
|
chatToolList.add(chatTool);
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
|
||||||
|
.model("glm-4-alltools")
|
||||||
|
.stream(Boolean.TRUE)
|
||||||
|
.invokeMethod(Constants.invokeMethod)
|
||||||
|
.messages(messages)
|
||||||
|
.tools(chatToolList)
|
||||||
|
.toolChoice("auto")
|
||||||
|
.requestId(requestId)
|
||||||
|
.build();
|
||||||
|
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
|
||||||
|
if (sseModelApiResp.isSuccess()) {
|
||||||
|
AtomicBoolean isFirst = new AtomicBoolean(true);
|
||||||
|
List<Choice> choices = new ArrayList<>();
|
||||||
|
AtomicReference<ChatMessageAccumulator> lastAccumulator = new AtomicReference<>();
|
||||||
|
|
||||||
|
mapStreamToAccumulator(sseModelApiResp.getFlowable())
|
||||||
|
.doOnNext(accumulator -> {
|
||||||
|
{
|
||||||
|
if (isFirst.getAndSet(false)) {
|
||||||
|
logger.info("Response: ");
|
||||||
|
}
|
||||||
|
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
|
||||||
|
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
|
||||||
|
logger.info("tool_calls: {}", jsonString);
|
||||||
|
}
|
||||||
|
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
|
||||||
|
logger.info(accumulator.getDelta().getContent());
|
||||||
|
}
|
||||||
|
choices.add(accumulator.getChoice());
|
||||||
|
lastAccumulator.set(accumulator);
|
||||||
|
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.doOnComplete(() -> System.out.println("Stream completed."))
|
||||||
|
.doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors
|
||||||
|
.blockingSubscribe();// Use blockingSubscribe instead of blockingGet()
|
||||||
|
|
||||||
|
ChatMessageAccumulator chatMessageAccumulator = lastAccumulator.get();
|
||||||
|
ModelData data = new ModelData();
|
||||||
|
data.setChoices(choices);
|
||||||
|
if (chatMessageAccumulator != null) {
|
||||||
|
data.setUsage(chatMessageAccumulator.getUsage());
|
||||||
|
data.setId(chatMessageAccumulator.getId());
|
||||||
|
data.setCreated(chatMessageAccumulator.getCreated());
|
||||||
|
}
|
||||||
|
data.setRequestId(chatCompletionRequest.getRequestId());
|
||||||
|
sseModelApiResp.setFlowable(null);// 打印前置空
|
||||||
|
sseModelApiResp.setData(data);
|
||||||
|
}
|
||||||
|
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
|
||||||
|
client.getConfig().getHttpClient().dispatcher().executorService().shutdown();
|
||||||
|
|
||||||
|
client.getConfig().getHttpClient().connectionPool().evictAll();
|
||||||
|
// List all active threads
|
||||||
|
for (Thread t : Thread.getAllStackTraces().keySet()) {
|
||||||
|
logger.info("Thread: " + t.getName() + " State: " + t.getState());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
public static Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ModelData> flowable) {
|
||||||
|
return flowable.map(chunk -> {
|
||||||
|
return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,246 @@
|
|||||||
|
package org.ruoyi.common.chat.demo.zhipu;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.zhipu.oapi.ClientV4;
|
||||||
|
import com.zhipu.oapi.Constants;
|
||||||
|
import com.zhipu.oapi.service.v4.tools.*;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
|
||||||
|
import com.zhipu.oapi.core.response.HttpxBinaryResponseContent;
|
||||||
|
import com.zhipu.oapi.service.v4.batchs.BatchCreateParams;
|
||||||
|
import com.zhipu.oapi.service.v4.batchs.BatchResponse;
|
||||||
|
import com.zhipu.oapi.service.v4.batchs.QueryBatchResponse;
|
||||||
|
import com.zhipu.oapi.service.v4.embedding.EmbeddingApiResponse;
|
||||||
|
import com.zhipu.oapi.service.v4.embedding.EmbeddingRequest;
|
||||||
|
import com.zhipu.oapi.service.v4.file.*;
|
||||||
|
import com.zhipu.oapi.service.v4.fine_turning.*;
|
||||||
|
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
|
||||||
|
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
|
||||||
|
import com.zhipu.oapi.service.v4.model.*;
|
||||||
|
import io.reactivex.Flowable;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
|
|
||||||
|
public class WebSearchToolsTest {
|
||||||
|
|
||||||
|
private final static Logger logger = LoggerFactory.getLogger(WebSearchToolsTest.class);
|
||||||
|
private static final String API_SECRET_KEY = "xx";
|
||||||
|
|
||||||
|
private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY)
|
||||||
|
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
|
||||||
|
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
|
||||||
|
.build();
|
||||||
|
private static final ObjectMapper mapper = new ObjectMapper();
|
||||||
|
// 请自定义自己的业务id
|
||||||
|
private static final String requestIdTemplate = "mycompany-%d";
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test1() throws JsonProcessingException {
|
||||||
|
|
||||||
|
// json 转换 ArrayList<SearchChatMessage>
|
||||||
|
String jsonString = "[\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"content\": \"今天武汉天气怎么样\",\n" +
|
||||||
|
" \"role\": \"user\"\n" +
|
||||||
|
" }\n" +
|
||||||
|
" ]";
|
||||||
|
|
||||||
|
ArrayList<SearchChatMessage> messages = new ObjectMapper().readValue(jsonString, new TypeReference<ArrayList<SearchChatMessage>>() {
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
||||||
|
WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
|
||||||
|
.model("web-search-pro")
|
||||||
|
.stream(Boolean.TRUE)
|
||||||
|
.messages(messages)
|
||||||
|
.requestId(requestId)
|
||||||
|
.build();
|
||||||
|
WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest);
|
||||||
|
if (webSearchApiResponse.isSuccess()) {
|
||||||
|
AtomicBoolean isFirst = new AtomicBoolean(true);
|
||||||
|
List<ChoiceDelta> choices = new ArrayList<>();
|
||||||
|
AtomicReference<WebSearchPro> lastAccumulator = new AtomicReference<>();
|
||||||
|
|
||||||
|
webSearchApiResponse.getFlowable().map(result -> result)
|
||||||
|
.doOnNext(accumulator -> {
|
||||||
|
{
|
||||||
|
if (isFirst.getAndSet(false)) {
|
||||||
|
logger.info("Response: ");
|
||||||
|
}
|
||||||
|
ChoiceDelta delta = accumulator.getChoices().get(0).getDelta();
|
||||||
|
if (delta != null && delta.getToolCalls() != null) {
|
||||||
|
logger.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls()));
|
||||||
|
}
|
||||||
|
choices.add(delta);
|
||||||
|
lastAccumulator.set(accumulator);
|
||||||
|
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.doOnComplete(() -> System.out.println("Stream completed."))
|
||||||
|
.doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors
|
||||||
|
.blockingSubscribe();// Use blockingSubscribe instead of blockingGet()
|
||||||
|
|
||||||
|
WebSearchPro chatMessageAccumulator = lastAccumulator.get();
|
||||||
|
|
||||||
|
webSearchApiResponse.setFlowable(null);// 打印前置空
|
||||||
|
webSearchApiResponse.setData(chatMessageAccumulator);
|
||||||
|
}
|
||||||
|
logger.info("model output: {}", mapper.writeValueAsString(webSearchApiResponse));
|
||||||
|
client.getConfig().getHttpClient().dispatcher().executorService().shutdown();
|
||||||
|
|
||||||
|
client.getConfig().getHttpClient().connectionPool().evictAll();
|
||||||
|
// List all active threads
|
||||||
|
for (Thread t : Thread.getAllStackTraces().keySet()) {
|
||||||
|
logger.info("Thread: " + t.getName() + " State: " + t.getState());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test2() throws JsonProcessingException {
|
||||||
|
|
||||||
|
// json 转换 ArrayList<SearchChatMessage>
|
||||||
|
String jsonString = "[\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"content\": \"今天天气怎么样\",\n" +
|
||||||
|
" \"role\": \"user\"\n" +
|
||||||
|
" }\n" +
|
||||||
|
" ]";
|
||||||
|
|
||||||
|
ArrayList<SearchChatMessage> messages = new ObjectMapper().readValue(jsonString, new TypeReference<ArrayList<SearchChatMessage>>() {
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
||||||
|
WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
|
||||||
|
.model("web-search-pro")
|
||||||
|
.stream(Boolean.FALSE)
|
||||||
|
.messages(messages)
|
||||||
|
.requestId(requestId)
|
||||||
|
.build();
|
||||||
|
WebSearchApiResponse webSearchApiResponse = client.invokeWebSearchPro(chatCompletionRequest);
|
||||||
|
|
||||||
|
logger.info("model output: {}", mapper.writeValueAsString(webSearchApiResponse));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFunctionSSE() throws JsonProcessingException {
|
||||||
|
List<ChatMessage> messages = new ArrayList<>();
|
||||||
|
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "成都到北京要多久,天气如何");
|
||||||
|
messages.add(chatMessage);
|
||||||
|
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
||||||
|
// 函数调用参数构建部分
|
||||||
|
List<ChatTool> chatToolList = new ArrayList<>();
|
||||||
|
ChatTool chatTool = new ChatTool();
|
||||||
|
|
||||||
|
chatTool.setType(ChatToolType.FUNCTION.value());
|
||||||
|
ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters();
|
||||||
|
chatFunctionParameters.setType("object");
|
||||||
|
Map<String, Object> properties = new HashMap<>();
|
||||||
|
properties.put("location", new HashMap<String, Object>() {{
|
||||||
|
put("type", "string");
|
||||||
|
put("description", "城市,如:北京");
|
||||||
|
}});
|
||||||
|
properties.put("unit", new HashMap<String, Object>() {{
|
||||||
|
put("type", "string");
|
||||||
|
put("enum", new ArrayList<String>() {{
|
||||||
|
add("celsius");
|
||||||
|
add("fahrenheit");
|
||||||
|
}});
|
||||||
|
}});
|
||||||
|
chatFunctionParameters.setProperties(properties);
|
||||||
|
ChatFunction chatFunction = ChatFunction.builder()
|
||||||
|
.name("get_weather")
|
||||||
|
.description("Get the current weather of a location")
|
||||||
|
.parameters(chatFunctionParameters)
|
||||||
|
.build();
|
||||||
|
chatTool.setFunction(chatFunction);
|
||||||
|
chatToolList.add(chatTool);
|
||||||
|
HashMap<String, Object> extraJson = new HashMap<>();
|
||||||
|
extraJson.put("temperature", 0.5);
|
||||||
|
extraJson.put("max_tokens", 50);
|
||||||
|
|
||||||
|
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
|
||||||
|
.model(Constants.ModelChatGLM4)
|
||||||
|
.stream(Boolean.TRUE)
|
||||||
|
.messages(messages)
|
||||||
|
.requestId(requestId)
|
||||||
|
.tools(chatToolList)
|
||||||
|
.toolChoice("auto")
|
||||||
|
.extraJson(extraJson)
|
||||||
|
.build();
|
||||||
|
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
|
||||||
|
if (sseModelApiResp.isSuccess()) {
|
||||||
|
AtomicBoolean isFirst = new AtomicBoolean(true);
|
||||||
|
List<Choice> choices = new ArrayList<>();
|
||||||
|
ChatMessageAccumulator chatMessageAccumulator = mapStreamToAccumulator(sseModelApiResp.getFlowable())
|
||||||
|
.doOnNext(accumulator -> {
|
||||||
|
{
|
||||||
|
if (isFirst.getAndSet(false)) {
|
||||||
|
logger.info("Response: ");
|
||||||
|
}
|
||||||
|
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
|
||||||
|
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
|
||||||
|
logger.info("tool_calls: {}", jsonString);
|
||||||
|
}
|
||||||
|
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
|
||||||
|
logger.info(accumulator.getDelta().getContent());
|
||||||
|
}
|
||||||
|
choices.add(accumulator.getChoice());
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.doOnComplete(System.out::println)
|
||||||
|
.lastElement()
|
||||||
|
.blockingGet();
|
||||||
|
|
||||||
|
|
||||||
|
ModelData data = new ModelData();
|
||||||
|
data.setChoices(choices);
|
||||||
|
data.setUsage(chatMessageAccumulator.getUsage());
|
||||||
|
data.setId(chatMessageAccumulator.getId());
|
||||||
|
data.setCreated(chatMessageAccumulator.getCreated());
|
||||||
|
data.setRequestId(chatCompletionRequest.getRequestId());
|
||||||
|
sseModelApiResp.setFlowable(null);// 打印前置空
|
||||||
|
sseModelApiResp.setData(data);
|
||||||
|
}
|
||||||
|
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ModelData> flowable) {
|
||||||
|
return flowable.map(chunk -> {
|
||||||
|
return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -2,12 +2,11 @@ package org.ruoyi.common.chat.entity.chat;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Getter;
|
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 描述:
|
* 描述:
|
||||||
@@ -18,21 +17,10 @@ import java.io.Serializable;
|
|||||||
@Data
|
@Data
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||||
public class Message implements Serializable {
|
public class Message extends BaseMessage implements Serializable {
|
||||||
|
|
||||||
/**
|
|
||||||
* 目前支持四个中角色参考官网,进行情景输入:
|
|
||||||
* https://platform.openai.com/docs/guides/chat/introduction
|
|
||||||
*/
|
|
||||||
private String role;
|
|
||||||
|
|
||||||
private Object content;
|
private Object content;
|
||||||
|
|
||||||
private String name;
|
|
||||||
|
|
||||||
@JsonProperty("function_call")
|
|
||||||
private FunctionCall functionCall;
|
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new Builder();
|
return new Builder();
|
||||||
}
|
}
|
||||||
@@ -41,44 +29,37 @@ public class Message implements Serializable {
|
|||||||
* 构造函数
|
* 构造函数
|
||||||
*
|
*
|
||||||
* @param role 角色
|
* @param role 角色
|
||||||
* @param content 描述主题信息
|
|
||||||
* @param name name
|
* @param name name
|
||||||
|
* @param content content
|
||||||
* @param functionCall functionCall
|
* @param functionCall functionCall
|
||||||
*/
|
*/
|
||||||
public Message(String role, String content, String name, FunctionCall functionCall) {
|
public Message(String role, String name, String content, List<ToolCalls> toolCalls, String toolCallId, FunctionCall functionCall) {
|
||||||
this.role = role;
|
|
||||||
this.content = content;
|
this.content = content;
|
||||||
this.name = name;
|
super.setRole(role);
|
||||||
this.functionCall = functionCall;
|
super.setName(name);
|
||||||
|
super.setToolCalls(toolCalls);
|
||||||
|
super.setToolCallId(toolCallId);
|
||||||
|
super.setFunctionCall(functionCall);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Message() {
|
public Message() {
|
||||||
}
|
}
|
||||||
|
|
||||||
private Message(Builder builder) {
|
private Message(Builder builder) {
|
||||||
setRole(builder.role);
|
|
||||||
setContent(builder.content);
|
setContent(builder.content);
|
||||||
setName(builder.name);
|
super.setRole(builder.role);
|
||||||
setFunctionCall(builder.functionCall);
|
super.setName(builder.name);
|
||||||
}
|
super.setFunctionCall(builder.functionCall);
|
||||||
|
super.setToolCalls(builder.toolCalls);
|
||||||
|
super.setToolCallId(builder.toolCallId);
|
||||||
@Getter
|
|
||||||
@AllArgsConstructor
|
|
||||||
public enum Role {
|
|
||||||
|
|
||||||
SYSTEM("system"),
|
|
||||||
USER("user"),
|
|
||||||
ASSISTANT("assistant"),
|
|
||||||
FUNCTION("function"),
|
|
||||||
;
|
|
||||||
private String name;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static final class Builder {
|
public static final class Builder {
|
||||||
private String role;
|
private String role;
|
||||||
private String content;
|
private String content;
|
||||||
private String name;
|
private String name;
|
||||||
|
private String toolCallId;
|
||||||
|
private List<ToolCalls> toolCalls;
|
||||||
private FunctionCall functionCall;
|
private FunctionCall functionCall;
|
||||||
|
|
||||||
public Builder() {
|
public Builder() {
|
||||||
@@ -109,6 +90,16 @@ public class Message implements Serializable {
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder toolCalls(List<ToolCalls> toolCalls) {
|
||||||
|
this.toolCalls = toolCalls;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder toolCallId(String toolCallId) {
|
||||||
|
this.toolCallId = toolCallId;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public Message build() {
|
public Message build() {
|
||||||
return new Message(this);
|
return new Message(this);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,5 +24,6 @@ public class ResponseFormat {
|
|||||||
TEXT("text"),
|
TEXT("text"),
|
||||||
;
|
;
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,6 +75,8 @@ public class WebSocketEventListener extends EventSourceListener {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ResponseBody body = response.body();
|
ResponseBody body = response.body();
|
||||||
|
|
||||||
|
|
||||||
if (Objects.nonNull(body)) {
|
if (Objects.nonNull(body)) {
|
||||||
// 返回非流式回复内容
|
// 返回非流式回复内容
|
||||||
if(response.code() == OpenAIConst.SUCCEED_CODE){
|
if(response.code() == OpenAIConst.SUCCEED_CODE){
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package org.ruoyi.common.chat.openai;
|
|||||||
|
|
||||||
import cn.hutool.core.collection.CollectionUtil;
|
import cn.hutool.core.collection.CollectionUtil;
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
import io.reactivex.Single;
|
import io.reactivex.Single;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -12,9 +13,7 @@ import okhttp3.RequestBody;
|
|||||||
import org.ruoyi.common.chat.constant.OpenAIConst;
|
import org.ruoyi.common.chat.constant.OpenAIConst;
|
||||||
import org.ruoyi.common.chat.entity.billing.BillingUsage;
|
import org.ruoyi.common.chat.entity.billing.BillingUsage;
|
||||||
import org.ruoyi.common.chat.entity.billing.Subscription;
|
import org.ruoyi.common.chat.entity.billing.Subscription;
|
||||||
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
import org.ruoyi.common.chat.entity.chat.*;
|
||||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.Message;
|
|
||||||
import org.ruoyi.common.chat.entity.common.DeleteResponse;
|
import org.ruoyi.common.chat.entity.common.DeleteResponse;
|
||||||
import org.ruoyi.common.chat.entity.common.OpenAiResponse;
|
import org.ruoyi.common.chat.entity.common.OpenAiResponse;
|
||||||
import org.ruoyi.common.chat.entity.completions.Completion;
|
import org.ruoyi.common.chat.entity.completions.Completion;
|
||||||
@@ -43,6 +42,8 @@ import org.ruoyi.common.chat.openai.function.KeyStrategyFunction;
|
|||||||
import org.ruoyi.common.chat.openai.interceptor.DefaultOpenAiAuthInterceptor;
|
import org.ruoyi.common.chat.openai.interceptor.DefaultOpenAiAuthInterceptor;
|
||||||
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
|
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
|
||||||
import org.ruoyi.common.chat.openai.interceptor.OpenAiAuthInterceptor;
|
import org.ruoyi.common.chat.openai.interceptor.OpenAiAuthInterceptor;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginParam;
|
||||||
import org.ruoyi.common.core.exception.base.BaseException;
|
import org.ruoyi.common.core.exception.base.BaseException;
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
import retrofit2.Retrofit;
|
import retrofit2.Retrofit;
|
||||||
@@ -696,6 +697,90 @@ public class OpenAiClient {
|
|||||||
return whisperResponse.blockingGet();
|
return whisperResponse.blockingGet();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件问答简易版
|
||||||
|
* 默认取messages最后一个元素构建插件对话
|
||||||
|
* 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
|
||||||
|
*
|
||||||
|
* @param chatCompletion 参数
|
||||||
|
* @param plugin 插件
|
||||||
|
* @param <R> 插件自定义函数的请求值
|
||||||
|
* @param <T> 插件自定义函数的返回值
|
||||||
|
* @return ChatCompletionResponse
|
||||||
|
*/
|
||||||
|
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(ChatCompletion chatCompletion, PluginAbstract<R, T> plugin) {
|
||||||
|
if (Objects.isNull(plugin)) {
|
||||||
|
return this.chatCompletion(chatCompletion);
|
||||||
|
}
|
||||||
|
if (CollectionUtil.isEmpty(chatCompletion.getMessages())) {
|
||||||
|
throw new BaseException(CommonError.MESSAGE_NOT_NUL.msg());
|
||||||
|
}
|
||||||
|
List<Message> messages = chatCompletion.getMessages();
|
||||||
|
Functions functions = Functions.builder()
|
||||||
|
.name(plugin.getFunction())
|
||||||
|
.description(plugin.getDescription())
|
||||||
|
.parameters(plugin.getParameters())
|
||||||
|
.build();
|
||||||
|
//没有值,设置默认值
|
||||||
|
if (Objects.isNull(chatCompletion.getFunctionCall())) {
|
||||||
|
chatCompletion.setFunctionCall("auto");
|
||||||
|
}
|
||||||
|
//tip: 覆盖自己设置的functions参数,使用plugin构造的functions
|
||||||
|
chatCompletion.setFunctions(Collections.singletonList(functions));
|
||||||
|
//调用OpenAi
|
||||||
|
ChatCompletionResponse functionCallChatCompletionResponse = this.chatCompletion(chatCompletion);
|
||||||
|
ChatChoice chatChoice = functionCallChatCompletionResponse.getChoices().get(0);
|
||||||
|
log.debug("构造的方法值:{}", chatChoice.getMessage().getFunctionCall());
|
||||||
|
|
||||||
|
R realFunctionParam = (R) JSONUtil.toBean(chatChoice.getMessage().getFunctionCall().getArguments(), plugin.getR());
|
||||||
|
T tq = plugin.func(realFunctionParam);
|
||||||
|
|
||||||
|
FunctionCall functionCall = FunctionCall.builder()
|
||||||
|
.arguments(chatChoice.getMessage().getFunctionCall().getArguments())
|
||||||
|
.name(plugin.getFunction())
|
||||||
|
.build();
|
||||||
|
messages.add(Message.builder().role(Message.Role.ASSISTANT).content("function_call").functionCall(functionCall).build());
|
||||||
|
messages.add(Message.builder().role(Message.Role.FUNCTION).name(plugin.getFunction()).content(plugin.content(tq)).build());
|
||||||
|
//设置第二次,请求的参数
|
||||||
|
chatCompletion.setFunctionCall(null);
|
||||||
|
chatCompletion.setFunctions(null);
|
||||||
|
|
||||||
|
ChatCompletionResponse chatCompletionResponse = this.chatCompletion(chatCompletion);
|
||||||
|
log.debug("自定义的方法返回值:{}", chatCompletionResponse.getChoices());
|
||||||
|
return chatCompletionResponse;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件问答简易版
|
||||||
|
* 默认取messages最后一个元素构建插件对话
|
||||||
|
* 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
|
||||||
|
*
|
||||||
|
* @param messages 问答参数
|
||||||
|
* @param plugin 插件
|
||||||
|
* @param <R> 插件自定义函数的请求值
|
||||||
|
* @param <T> 插件自定义函数的返回值
|
||||||
|
* @return ChatCompletionResponse
|
||||||
|
*/
|
||||||
|
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(List<Message> messages, PluginAbstract<R, T> plugin) {
|
||||||
|
return chatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), plugin);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件问答简易版
|
||||||
|
* 默认取messages最后一个元素构建插件对话
|
||||||
|
*
|
||||||
|
* @param messages 问答参数
|
||||||
|
* @param model 模型
|
||||||
|
* @param plugin 插件
|
||||||
|
* @param <R> 插件自定义函数的请求值
|
||||||
|
* @param <T> 插件自定义函数的返回值
|
||||||
|
* @return ChatCompletionResponse
|
||||||
|
*/
|
||||||
|
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(List<Message> messages, String model, PluginAbstract<R, T> plugin) {
|
||||||
|
ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).model(model).build();
|
||||||
|
return this.chatCompletionWithPlugin(chatCompletion, plugin);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 简易版 语音翻译:目前仅支持翻译为英文
|
* 简易版 语音翻译:目前仅支持翻译为英文
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package org.ruoyi.common.chat.openai;
|
|||||||
import cn.hutool.core.collection.CollectionUtil;
|
import cn.hutool.core.collection.CollectionUtil;
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.hutool.http.ContentType;
|
import cn.hutool.http.ContentType;
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import io.reactivex.Single;
|
import io.reactivex.Single;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
@@ -17,10 +18,7 @@ import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
|
|||||||
import org.ruoyi.common.chat.entity.billing.BillingUsage;
|
import org.ruoyi.common.chat.entity.billing.BillingUsage;
|
||||||
import org.ruoyi.common.chat.entity.billing.KeyInfo;
|
import org.ruoyi.common.chat.entity.billing.KeyInfo;
|
||||||
import org.ruoyi.common.chat.entity.billing.Subscription;
|
import org.ruoyi.common.chat.entity.billing.Subscription;
|
||||||
import org.ruoyi.common.chat.entity.chat.BaseChatCompletion;
|
import org.ruoyi.common.chat.entity.chat.*;
|
||||||
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
|
||||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionWithPicture;
|
|
||||||
import org.ruoyi.common.chat.entity.embeddings.Embedding;
|
import org.ruoyi.common.chat.entity.embeddings.Embedding;
|
||||||
import org.ruoyi.common.chat.entity.embeddings.EmbeddingResponse;
|
import org.ruoyi.common.chat.entity.embeddings.EmbeddingResponse;
|
||||||
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
|
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
|
||||||
@@ -29,6 +27,7 @@ import org.ruoyi.common.chat.entity.images.ImageResponse;
|
|||||||
import org.ruoyi.common.chat.entity.models.Model;
|
import org.ruoyi.common.chat.entity.models.Model;
|
||||||
import org.ruoyi.common.chat.entity.models.ModelResponse;
|
import org.ruoyi.common.chat.entity.models.ModelResponse;
|
||||||
import org.ruoyi.common.chat.entity.whisper.Transcriptions;
|
import org.ruoyi.common.chat.entity.whisper.Transcriptions;
|
||||||
|
import org.ruoyi.common.chat.entity.whisper.Translations;
|
||||||
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
|
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
|
||||||
import org.ruoyi.common.chat.openai.exception.CommonError;
|
import org.ruoyi.common.chat.openai.exception.CommonError;
|
||||||
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
|
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
|
||||||
@@ -36,6 +35,10 @@ import org.ruoyi.common.chat.openai.function.KeyStrategyFunction;
|
|||||||
import org.ruoyi.common.chat.openai.interceptor.DefaultOpenAiAuthInterceptor;
|
import org.ruoyi.common.chat.openai.interceptor.DefaultOpenAiAuthInterceptor;
|
||||||
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
|
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
|
||||||
import org.ruoyi.common.chat.openai.interceptor.OpenAiAuthInterceptor;
|
import org.ruoyi.common.chat.openai.interceptor.OpenAiAuthInterceptor;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginParam;
|
||||||
|
import org.ruoyi.common.chat.sse.DefaultPluginListener;
|
||||||
|
import org.ruoyi.common.chat.sse.PluginListener;
|
||||||
import org.ruoyi.common.core.exception.base.BaseException;
|
import org.ruoyi.common.core.exception.base.BaseException;
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
import retrofit2.Call;
|
import retrofit2.Call;
|
||||||
@@ -186,6 +189,93 @@ public class OpenAiStreamClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件问答简易版
|
||||||
|
* 默认取messages最后一个元素构建插件对话
|
||||||
|
* 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
|
||||||
|
*
|
||||||
|
* @param chatCompletion 参数
|
||||||
|
* @param eventSourceListener sse监听器
|
||||||
|
* @param pluginEventSourceListener 插件sse监听器,收集function call返回信息
|
||||||
|
* @param plugin 插件
|
||||||
|
* @param <R> 插件自定义函数的请求值
|
||||||
|
* @param <T> 插件自定义函数的返回值
|
||||||
|
*/
|
||||||
|
public <R extends PluginParam, T> void streamChatCompletionWithPlugin(ChatCompletion chatCompletion, EventSourceListener eventSourceListener, PluginListener pluginEventSourceListener, PluginAbstract<R, T> plugin) {
|
||||||
|
if (Objects.isNull(plugin)) {
|
||||||
|
this.streamChatCompletion(chatCompletion, eventSourceListener);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (CollectionUtil.isEmpty(chatCompletion.getMessages())) {
|
||||||
|
throw new BaseException(CommonError.MESSAGE_NOT_NUL.msg());
|
||||||
|
}
|
||||||
|
Functions functions = Functions.builder()
|
||||||
|
.name(plugin.getFunction())
|
||||||
|
.description(plugin.getDescription())
|
||||||
|
.parameters(plugin.getParameters())
|
||||||
|
.build();
|
||||||
|
//没有值,设置默认值
|
||||||
|
if (Objects.isNull(chatCompletion.getFunctionCall())) {
|
||||||
|
chatCompletion.setFunctionCall("auto");
|
||||||
|
}
|
||||||
|
//tip: 覆盖自己设置的functions参数,使用plugin构造的functions
|
||||||
|
chatCompletion.setFunctions(Collections.singletonList(functions));
|
||||||
|
//调用OpenAi
|
||||||
|
if (Objects.isNull(pluginEventSourceListener)) {
|
||||||
|
pluginEventSourceListener = new DefaultPluginListener(this, eventSourceListener, plugin, chatCompletion);
|
||||||
|
}
|
||||||
|
this.streamChatCompletion(chatCompletion, pluginEventSourceListener);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件问答简易版
|
||||||
|
* 默认取messages最后一个元素构建插件对话
|
||||||
|
* 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
|
||||||
|
*
|
||||||
|
* @param chatCompletion 参数
|
||||||
|
* @param eventSourceListener sse监听器
|
||||||
|
* @param plugin 插件
|
||||||
|
* @param <R> 插件自定义函数的请求值
|
||||||
|
* @param <T> 插件自定义函数的返回值
|
||||||
|
*/
|
||||||
|
public <R extends PluginParam, T> void streamChatCompletionWithPlugin(ChatCompletion chatCompletion, EventSourceListener eventSourceListener, PluginAbstract<R, T> plugin) {
|
||||||
|
PluginListener pluginEventSourceListener = new DefaultPluginListener(this, eventSourceListener, plugin, chatCompletion);
|
||||||
|
this.streamChatCompletionWithPlugin(chatCompletion, eventSourceListener, pluginEventSourceListener, plugin);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件问答简易版
|
||||||
|
* 默认取messages最后一个元素构建插件对话
|
||||||
|
* 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
|
||||||
|
*
|
||||||
|
* @param messages 问答参数
|
||||||
|
* @param eventSourceListener sse监听器
|
||||||
|
* @param plugin 插件
|
||||||
|
* @param <R> 插件自定义函数的请求值
|
||||||
|
* @param <T> 插件自定义函数的返回值
|
||||||
|
*/
|
||||||
|
public <R extends PluginParam, T> void streamChatCompletionWithPlugin(List<Message> messages, EventSourceListener eventSourceListener, PluginAbstract<R, T> plugin) {
|
||||||
|
this.streamChatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), eventSourceListener, plugin);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件问答简易版
|
||||||
|
* 默认取messages最后一个元素构建插件对话
|
||||||
|
*
|
||||||
|
* @param messages 问答参数
|
||||||
|
* @param model 模型
|
||||||
|
* @param eventSourceListener eventSourceListener
|
||||||
|
* @param plugin 插件
|
||||||
|
* @param <R> 插件自定义函数的请求值
|
||||||
|
* @param <T> 插件自定义函数的返回值
|
||||||
|
*/
|
||||||
|
public <R extends PluginParam, T> void streamChatCompletionWithPlugin(List<Message> messages, String model, EventSourceListener eventSourceListener, PluginAbstract<R, T> plugin) {
|
||||||
|
ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).model(model).build();
|
||||||
|
this.streamChatCompletionWithPlugin(chatCompletion, eventSourceListener, plugin);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 根据描述生成图片
|
* 根据描述生成图片
|
||||||
@@ -418,6 +508,95 @@ public class OpenAiStreamClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件问答简易版
|
||||||
|
* 默认取messages最后一个元素构建插件对话
|
||||||
|
* 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
|
||||||
|
*
|
||||||
|
* @param chatCompletion 参数
|
||||||
|
* @param plugin 插件
|
||||||
|
* @param <R> 插件自定义函数的请求值
|
||||||
|
* @param <T> 插件自定义函数的返回值
|
||||||
|
* @return ChatCompletionResponse
|
||||||
|
*/
|
||||||
|
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(ChatCompletion chatCompletion, PluginAbstract<R, T> plugin) {
|
||||||
|
if (Objects.isNull(plugin)) {
|
||||||
|
return this.chatCompletion(chatCompletion);
|
||||||
|
}
|
||||||
|
if (CollectionUtil.isEmpty(chatCompletion.getMessages())) {
|
||||||
|
throw new BaseException(CommonError.MESSAGE_NOT_NUL.msg());
|
||||||
|
}
|
||||||
|
List<Message> messages = chatCompletion.getMessages();
|
||||||
|
Functions functions = Functions.builder()
|
||||||
|
.name(plugin.getFunction())
|
||||||
|
.description(plugin.getDescription())
|
||||||
|
.parameters(plugin.getParameters())
|
||||||
|
.build();
|
||||||
|
//没有值,设置默认值
|
||||||
|
if (Objects.isNull(chatCompletion.getFunctionCall())) {
|
||||||
|
chatCompletion.setFunctionCall("auto");
|
||||||
|
}
|
||||||
|
//tip: 覆盖自己设置的functions参数,使用plugin构造的functions
|
||||||
|
chatCompletion.setFunctions(Collections.singletonList(functions));
|
||||||
|
//调用OpenAi
|
||||||
|
ChatCompletionResponse functionCallChatCompletionResponse = this.chatCompletion(chatCompletion);
|
||||||
|
ChatChoice chatChoice = functionCallChatCompletionResponse.getChoices().get(0);
|
||||||
|
log.debug("构造的方法值:{}", chatChoice.getMessage().getFunctionCall());
|
||||||
|
|
||||||
|
R realFunctionParam = (R) JSONUtil.toBean(chatChoice.getMessage().getFunctionCall().getArguments(), plugin.getR());
|
||||||
|
T tq = plugin.func(realFunctionParam);
|
||||||
|
|
||||||
|
FunctionCall functionCall = FunctionCall.builder()
|
||||||
|
.arguments(chatChoice.getMessage().getFunctionCall().getArguments())
|
||||||
|
.name(plugin.getFunction())
|
||||||
|
.build();
|
||||||
|
messages.add(Message.builder().role(Message.Role.ASSISTANT).content("function_call").functionCall(functionCall).build());
|
||||||
|
messages.add(Message.builder().role(Message.Role.FUNCTION).name(plugin.getFunction()).content(plugin.content(tq)).build());
|
||||||
|
//设置第二次,请求的参数
|
||||||
|
chatCompletion.setFunctionCall(null);
|
||||||
|
chatCompletion.setFunctions(null);
|
||||||
|
|
||||||
|
ChatCompletionResponse chatCompletionResponse = this.chatCompletion(chatCompletion);
|
||||||
|
log.debug("自定义的方法返回值:{}", chatCompletionResponse.getChoices());
|
||||||
|
return chatCompletionResponse;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件问答简易版
|
||||||
|
* 默认取messages最后一个元素构建插件对话
|
||||||
|
* 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
|
||||||
|
*
|
||||||
|
* @param messages 问答参数
|
||||||
|
* @param plugin 插件
|
||||||
|
* @param <R> 插件自定义函数的请求值
|
||||||
|
* @param <T> 插件自定义函数的返回值
|
||||||
|
* @return ChatCompletionResponse
|
||||||
|
*/
|
||||||
|
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(List<Message> messages, PluginAbstract<R, T> plugin) {
|
||||||
|
return chatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), plugin);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件问答简易版
|
||||||
|
* 默认取messages最后一个元素构建插件对话
|
||||||
|
*
|
||||||
|
* @param messages 问答参数
|
||||||
|
* @param model 模型
|
||||||
|
* @param plugin 插件
|
||||||
|
* @param <R> 插件自定义函数的请求值
|
||||||
|
* @param <T> 插件自定义函数的返回值
|
||||||
|
* @return ChatCompletionResponse
|
||||||
|
*/
|
||||||
|
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(List<Message> messages, String model, PluginAbstract<R, T> plugin) {
|
||||||
|
ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).model(model).build();
|
||||||
|
return this.chatCompletionWithPlugin(chatCompletion, plugin);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 构造
|
* 构造
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ package org.ruoyi.common.chat.openai.exception;
|
|||||||
* 2023-02-11
|
* 2023-02-11
|
||||||
*/
|
*/
|
||||||
public enum CommonError implements IError {
|
public enum CommonError implements IError {
|
||||||
|
MESSAGE_NOT_NUL(500, "Message 不能为空"),
|
||||||
API_KEYS_NOT_NUL(500, "API KEYS 不能为空"),
|
API_KEYS_NOT_NUL(500, "API KEYS 不能为空"),
|
||||||
NO_ACTIVE_API_KEYS(500, "没有可用的API KEYS"),
|
NO_ACTIVE_API_KEYS(500, "没有可用的API KEYS"),
|
||||||
SYS_ERROR(500, "系统繁忙"),
|
SYS_ERROR(500, "系统繁忙"),
|
||||||
@@ -19,8 +20,8 @@ public enum CommonError implements IError {
|
|||||||
;
|
;
|
||||||
|
|
||||||
|
|
||||||
private int code;
|
private final int code;
|
||||||
private String msg;
|
private final String msg;
|
||||||
|
|
||||||
CommonError(int code, String msg) {
|
CommonError(int code, String msg) {
|
||||||
this.code = code;
|
this.code = code;
|
||||||
|
|||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package org.ruoyi.common.chat.openai.plugin;
|
||||||
|
|
||||||
|
import cn.hutool.core.collection.CollectionUtil;
|
||||||
|
import cn.hutool.json.JSONObject;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.Parameters;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
public abstract class PluginAbstract<R extends PluginParam, T> {
|
||||||
|
|
||||||
|
private Class<?> R;
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
private String function;
|
||||||
|
|
||||||
|
private String description;
|
||||||
|
|
||||||
|
private List<Arg> args;
|
||||||
|
|
||||||
|
private List<String> required;
|
||||||
|
|
||||||
|
private Parameters parameters;
|
||||||
|
|
||||||
|
public PluginAbstract(Class<?> r) {
|
||||||
|
R = r;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setRequired(List<String> required) {
|
||||||
|
if (CollectionUtil.isEmpty(required)) {
|
||||||
|
this.required = this.getArgs().stream().filter(e -> e.isRequired()).map(Arg::getName).collect(Collectors.toList());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.required = required;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setRequired() {
|
||||||
|
if (CollectionUtil.isEmpty(required)) {
|
||||||
|
this.required = this.getArgs().stream().filter(e -> e.isRequired()).map(Arg::getName).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setParameters() {
|
||||||
|
JSONObject properties = new JSONObject();
|
||||||
|
args.forEach(e -> {
|
||||||
|
JSONObject param = new JSONObject();
|
||||||
|
param.putOpt("type", e.getType());
|
||||||
|
param.putOpt("enum", e.getEnumDictValue());
|
||||||
|
param.putOpt("description", e.getDescription());
|
||||||
|
properties.putOpt(e.getName(), param);
|
||||||
|
});
|
||||||
|
this.parameters = Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(properties)
|
||||||
|
.required(this.getRequired())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setArgs(List<Arg> args) {
|
||||||
|
this.args = args;
|
||||||
|
setRequired();
|
||||||
|
setParameters();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class Arg {
|
||||||
|
private String name;
|
||||||
|
private String type;
|
||||||
|
private String description;
|
||||||
|
@JsonIgnore
|
||||||
|
private boolean enumDict;
|
||||||
|
@JsonProperty("enum")
|
||||||
|
private List<String> enumDictValue;
|
||||||
|
@JsonIgnore
|
||||||
|
private boolean required;
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract T func(R args);
|
||||||
|
|
||||||
|
public abstract String content(T t);
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package org.ruoyi.common.chat.openai.plugin;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class PluginParam {
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package org.ruoyi.common.chat.plugin;
|
||||||
|
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class CmdPlugin extends PluginAbstract<CmdReq, CmdResp> {
|
||||||
|
|
||||||
|
public CmdPlugin(Class<?> r) {
|
||||||
|
super(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CmdResp func(CmdReq args) {
|
||||||
|
try {
|
||||||
|
if("计算器".equals(args.getCmd())){
|
||||||
|
Runtime.getRuntime().exec("calc");
|
||||||
|
}else if("记事本".equals(args.getCmd())){
|
||||||
|
Runtime.getRuntime().exec("notepad");
|
||||||
|
}else if("命令行".equals(args.getCmd())){
|
||||||
|
String [] cmd={"cmd","/C","start copy exel exe2"};
|
||||||
|
Runtime.getRuntime().exec(cmd);
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("指令执行失败");
|
||||||
|
}
|
||||||
|
CmdResp resp = new CmdResp();
|
||||||
|
resp.setResult(args.getCmd()+"指令执行成功!");
|
||||||
|
return resp;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String content(CmdResp resp) {
|
||||||
|
return resp.getResult();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package org.ruoyi.common.chat.plugin;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginParam;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class CmdReq extends PluginParam {
|
||||||
|
/**
|
||||||
|
* 指令
|
||||||
|
*/
|
||||||
|
private String cmd;
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package org.ruoyi.common.chat.plugin;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class CmdResp {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 返回结果
|
||||||
|
*/
|
||||||
|
private String result;
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package org.ruoyi.common.chat.plugin;
|
||||||
|
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||||
|
|
||||||
|
import java.sql.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author ageer
|
||||||
|
*/
|
||||||
|
public class SqlPlugin extends PluginAbstract<SqlReq, SqlResp> {
|
||||||
|
|
||||||
|
public SqlPlugin(Class<?> r) {
|
||||||
|
super(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SqlResp func(SqlReq args) {
|
||||||
|
SqlResp resp = new SqlResp();
|
||||||
|
resp.setUserBalance(getBalance(args.getUsername()));
|
||||||
|
return resp;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String content(SqlResp resp) {
|
||||||
|
return "用户余额:"+resp.getUserBalance();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public String getBalance(String userName) {
|
||||||
|
// MySQL 8.0 以下版本 - JDBC 驱动名及数据库 URL
|
||||||
|
String JDBC_DRIVER = "com.mysql.cj.jdbc.Driver";
|
||||||
|
String DB_URL = "jdbc:mysql://43.139.70.230:3306/ry-vue";
|
||||||
|
// 数据库的用户名与密码,需要根据自己的设置
|
||||||
|
String USER = "ry-vue";
|
||||||
|
String PASS = "BXZiGsY35K523Xfx";
|
||||||
|
Connection conn = null;
|
||||||
|
Statement stmt = null;
|
||||||
|
String balance = "0.1";
|
||||||
|
|
||||||
|
try{
|
||||||
|
// 注册 JDBC 驱动
|
||||||
|
Class.forName(JDBC_DRIVER);
|
||||||
|
|
||||||
|
// 打开链接
|
||||||
|
System.out.println("连接数据库...");
|
||||||
|
conn = DriverManager.getConnection(DB_URL,USER,PASS);
|
||||||
|
|
||||||
|
// 执行查询
|
||||||
|
System.out.println(" 实例化Statement对象...");
|
||||||
|
stmt = conn.createStatement();
|
||||||
|
String sql;
|
||||||
|
sql = "SELECT user_balance FROM sys_user where user_name ='" + userName + "'";
|
||||||
|
ResultSet rs = stmt.executeQuery(sql);
|
||||||
|
// 展开结果集数据库
|
||||||
|
while(rs.next()){
|
||||||
|
// 通过字段检索
|
||||||
|
balance = rs.getString("user_balance");
|
||||||
|
// 输出数据
|
||||||
|
System.out.print("余额: " + balance);
|
||||||
|
System.out.print("\n");
|
||||||
|
}
|
||||||
|
// 完成后关闭
|
||||||
|
rs.close();
|
||||||
|
stmt.close();
|
||||||
|
conn.close();
|
||||||
|
}catch(SQLException se){
|
||||||
|
// 处理 JDBC 错误
|
||||||
|
se.printStackTrace();
|
||||||
|
}catch(Exception e){
|
||||||
|
// 处理 Class.forName 错误
|
||||||
|
e.printStackTrace();
|
||||||
|
}finally{
|
||||||
|
// 关闭资源
|
||||||
|
try{
|
||||||
|
if(stmt!=null) stmt.close();
|
||||||
|
}catch(SQLException se2){
|
||||||
|
}// 什么都不做
|
||||||
|
try{
|
||||||
|
if(conn!=null) conn.close();
|
||||||
|
}catch(SQLException se){
|
||||||
|
se.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return balance;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package org.ruoyi.common.chat.plugin;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginParam;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SqlReq extends PluginParam {
|
||||||
|
/**
|
||||||
|
* 用户名称
|
||||||
|
*/
|
||||||
|
private String username;
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package org.ruoyi.common.chat.plugin;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SqlResp {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 用户余额
|
||||||
|
*/
|
||||||
|
private String userBalance;
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package org.ruoyi.common.chat.sse;
|
||||||
|
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import okhttp3.Response;
|
||||||
|
import okhttp3.ResponseBody;
|
||||||
|
import okhttp3.sse.EventSource;
|
||||||
|
import okhttp3.sse.EventSourceListener;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 描述: sse
|
||||||
|
*
|
||||||
|
* @author https:www.unfbx.com
|
||||||
|
* 2023-02-28
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class ConsoleEventSourceListener extends EventSourceListener {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onOpen(EventSource eventSource, Response response) {
|
||||||
|
log.info("OpenAI建立sse连接...");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onEvent(EventSource eventSource, String id, String type, String data) {
|
||||||
|
log.info("OpenAI返回数据:{}", data);
|
||||||
|
if ("[DONE]".equals(data)) {
|
||||||
|
log.info("OpenAI返回数据结束了");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onClosed(EventSource eventSource) {
|
||||||
|
log.info("OpenAI关闭sse连接...");
|
||||||
|
}
|
||||||
|
|
||||||
|
@SneakyThrows
|
||||||
|
@Override
|
||||||
|
public void onFailure(EventSource eventSource, Throwable t, Response response) {
|
||||||
|
if(Objects.isNull(response)){
|
||||||
|
log.error("OpenAI sse连接异常:{}", t);
|
||||||
|
eventSource.cancel();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ResponseBody body = response.body();
|
||||||
|
if (Objects.nonNull(body)) {
|
||||||
|
log.error("OpenAI sse连接异常data:{},异常:{}", body.string(), t);
|
||||||
|
} else {
|
||||||
|
log.error("OpenAI sse连接异常data:{},异常:{}", response, t);
|
||||||
|
}
|
||||||
|
eventSource.cancel();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package org.ruoyi.common.chat.sse;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import okhttp3.sse.EventSourceListener;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
||||||
|
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 描述: 插件开发返回信息收集sse监听器
|
||||||
|
*
|
||||||
|
* @author https:www.unfbx.com
|
||||||
|
* 2023-08-18
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class DefaultPluginListener extends PluginListener {
|
||||||
|
|
||||||
|
public DefaultPluginListener(OpenAiStreamClient client, EventSourceListener eventSourceListener, PluginAbstract plugin, ChatCompletion chatCompletion) {
|
||||||
|
super(client, eventSourceListener, plugin, chatCompletion);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
package org.ruoyi.common.chat.sse;
|
||||||
|
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import okhttp3.Response;
|
||||||
|
import okhttp3.ResponseBody;
|
||||||
|
import okhttp3.sse.EventSource;
|
||||||
|
import okhttp3.sse.EventSourceListener;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.ruoyi.common.chat.constant.OpenAIConst;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.FunctionCall;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.Message;
|
||||||
|
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginParam;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 描述: 插件开发返回信息收集sse监听器
|
||||||
|
*
|
||||||
|
* @author https:www.unfbx.com
|
||||||
|
* 2023-08-18
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public abstract class PluginListener<R extends PluginParam, T> extends EventSourceListener {
|
||||||
|
/**
|
||||||
|
* openAi插件构建的参数
|
||||||
|
*/
|
||||||
|
private String arguments = "";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取openAi插件构建的参数
|
||||||
|
*
|
||||||
|
* @return arguments
|
||||||
|
*/
|
||||||
|
private String getArguments() {
|
||||||
|
return this.arguments;
|
||||||
|
}
|
||||||
|
|
||||||
|
private OpenAiStreamClient client;
|
||||||
|
private EventSourceListener eventSourceListener;
|
||||||
|
private PluginAbstract<R, T> plugin;
|
||||||
|
private ChatCompletion chatCompletion;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构造方法必备四个元素
|
||||||
|
*
|
||||||
|
* @param client OpenAiStreamClient
|
||||||
|
* @param eventSourceListener 处理真实第二次sse请求的自定义监听
|
||||||
|
* @param plugin 插件信息
|
||||||
|
* @param chatCompletion 请求参数
|
||||||
|
*/
|
||||||
|
public PluginListener(OpenAiStreamClient client, EventSourceListener eventSourceListener, PluginAbstract<R, T> plugin, ChatCompletion chatCompletion) {
|
||||||
|
this.client = client;
|
||||||
|
this.eventSourceListener = eventSourceListener;
|
||||||
|
this.plugin = plugin;
|
||||||
|
this.chatCompletion = chatCompletion;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* sse关闭后处理,第二次请求方法
|
||||||
|
*/
|
||||||
|
public void onClosedAfter() {
|
||||||
|
log.debug("构造的方法值:{}", getArguments());
|
||||||
|
|
||||||
|
R realFunctionParam = (R) JSONUtil.toBean(getArguments(), plugin.getR());
|
||||||
|
T tq = plugin.func(realFunctionParam);
|
||||||
|
|
||||||
|
FunctionCall functionCall = FunctionCall.builder()
|
||||||
|
.arguments(getArguments())
|
||||||
|
.name(plugin.getFunction())
|
||||||
|
.build();
|
||||||
|
chatCompletion.getMessages().add(Message.builder().role(Message.Role.ASSISTANT).content("function_call").functionCall(functionCall).build());
|
||||||
|
chatCompletion.getMessages().add(Message.builder().role(Message.Role.FUNCTION).name(plugin.getFunction()).content(plugin.content(tq)).build());
|
||||||
|
//设置第二次,请求的参数
|
||||||
|
chatCompletion.setFunctionCall(null);
|
||||||
|
chatCompletion.setFunctions(null);
|
||||||
|
client.streamChatCompletion(chatCompletion, eventSourceListener);
|
||||||
|
}
|
||||||
|
|
||||||
|
@SneakyThrows
|
||||||
|
@Override
|
||||||
|
public final void onEvent(@NotNull EventSource eventSource, String id, String type, String data) {
|
||||||
|
log.debug("插件开发返回信息收集sse监听器返回数据:{}", data);
|
||||||
|
if ("[DONE]".equals(data)) {
|
||||||
|
log.debug("插件开发返回信息收集sse监听器返回数据结束了");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class);
|
||||||
|
if (Objects.nonNull(chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall())) {
|
||||||
|
this.arguments += chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall().getArguments();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final void onClosed(EventSource eventSource) {
|
||||||
|
log.debug("插件开发返回信息收集sse监听器关闭连接...");
|
||||||
|
this.onClosedAfter();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onOpen(EventSource eventSource, Response response) {
|
||||||
|
log.debug("插件开发返回信息收集sse监听器建立连接...");
|
||||||
|
}
|
||||||
|
|
||||||
|
@SneakyThrows
|
||||||
|
@Override
|
||||||
|
public void onFailure(EventSource eventSource, Throwable t, Response response) {
|
||||||
|
if (Objects.isNull(response)) {
|
||||||
|
log.error("插件开发返回信息收集sse监听器,连接异常:{}", t);
|
||||||
|
eventSource.cancel();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ResponseBody body = response.body();
|
||||||
|
if (Objects.nonNull(body)) {
|
||||||
|
log.error("插件开发返回信息收集sse监听器,连接异常data:{},异常:{}", body.string(), t);
|
||||||
|
} else {
|
||||||
|
log.error("插件开发返回信息收集sse监听器,连接异常data:{},异常:{}", response, t);
|
||||||
|
}
|
||||||
|
eventSource.cancel();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,4 +15,11 @@ public interface UserService {
|
|||||||
*/
|
*/
|
||||||
String selectUserNameById(Long userId);
|
String selectUserNameById(Long userId);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通过用户名称查询余额
|
||||||
|
*
|
||||||
|
* @param userName
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
String selectUserByName(String userName);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,6 +83,9 @@
|
|||||||
<scope>runtime</scope>
|
<scope>runtime</scope>
|
||||||
<optional>true</optional>
|
<optional>true</optional>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.mysql</groupId>
|
<groupId>com.mysql</groupId>
|
||||||
<artifactId>mysql-connector-j</artifactId>
|
<artifactId>mysql-connector-j</artifactId>
|
||||||
|
|||||||
@@ -41,13 +41,13 @@ public class MilvusVectorStore implements VectorStore{
|
|||||||
@Resource
|
@Resource
|
||||||
private ConfigService configService;
|
private ConfigService configService;
|
||||||
|
|
||||||
// @PostConstruct
|
@PostConstruct
|
||||||
public void loadConfig() {
|
public void loadConfig() {
|
||||||
this.dimension = Integer.parseInt(configService.getConfigValue("milvus", "dimension"));
|
this.dimension = Integer.parseInt(configService.getConfigValue("milvus", "dimension"));
|
||||||
this.collectionName = configService.getConfigValue("milvus", "collection");
|
this.collectionName = configService.getConfigValue("milvus", "collection");
|
||||||
}
|
}
|
||||||
|
|
||||||
//@PostConstruct
|
@PostConstruct
|
||||||
public void init(){
|
public void init(){
|
||||||
String milvusHost = configService.getConfigValue("milvus", "host");
|
String milvusHost = configService.getConfigValue("milvus", "host");
|
||||||
String milvausPort = configService.getConfigValue("milvus", "port");
|
String milvausPort = configService.getConfigValue("milvus", "port");
|
||||||
|
|||||||
@@ -6,11 +6,16 @@ import java.util.List;
|
|||||||
* 向量存储
|
* 向量存储
|
||||||
*/
|
*/
|
||||||
public interface VectorStore {
|
public interface VectorStore {
|
||||||
void storeEmbeddings(List<String> chunkList,List<List<Double>> vectorList, String kid, String docId,List<String> fidList);
|
|
||||||
void removeByDocId(String kid,String docId);
|
void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList);
|
||||||
|
|
||||||
|
void removeByDocId(String kid, String docId);
|
||||||
|
|
||||||
void removeByKid(String kid);
|
void removeByKid(String kid);
|
||||||
List<String> nearest(List<Double> queryVector,String kid);
|
|
||||||
List<String> nearest(String query,String kid);
|
List<String> nearest(List<Double> queryVector, String kid);
|
||||||
|
|
||||||
|
List<String> nearest(String query, String kid);
|
||||||
|
|
||||||
void newSchema(String kid);
|
void newSchema(String kid);
|
||||||
|
|
||||||
|
|||||||
@@ -1,32 +1,35 @@
|
|||||||
package org.ruoyi.knowledge.chain.vectorstore;
|
package org.ruoyi.knowledge.chain.vectorstore;
|
||||||
|
|
||||||
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
|
||||||
|
import org.ruoyi.knowledge.mapper.KnowledgeInfoMapper;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class VectorStoreFactory {
|
public class VectorStoreFactory {
|
||||||
|
|
||||||
private final String type = "weaviate";
|
|
||||||
|
|
||||||
private final WeaviateVectorStore weaviateVectorStore;
|
private final WeaviateVectorStore weaviateVectorStore;
|
||||||
|
|
||||||
private final MilvusVectorStore milvusVectorStore;
|
private final MilvusVectorStore milvusVectorStore;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private KnowledgeInfoMapper knowledgeInfoMapper;
|
||||||
|
|
||||||
public VectorStoreFactory(WeaviateVectorStore weaviateVectorStore, MilvusVectorStore milvusVectorStore) {
|
public VectorStoreFactory(WeaviateVectorStore weaviateVectorStore, MilvusVectorStore milvusVectorStore) {
|
||||||
this.weaviateVectorStore = weaviateVectorStore;
|
this.weaviateVectorStore = weaviateVectorStore;
|
||||||
this.milvusVectorStore = milvusVectorStore;
|
this.milvusVectorStore = milvusVectorStore;
|
||||||
}
|
}
|
||||||
|
|
||||||
public VectorStore getVectorStore(String kid){
|
public VectorStore getVectorStore(String kid){
|
||||||
// if ("weaviate".equals(type)){
|
KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoMapper.selectVoById(Long.valueOf(kid));
|
||||||
// return weaviateVectorStore;
|
String vectorModel = knowledgeInfoVo.getVector();
|
||||||
// }else if ("milvus".equals(type)){
|
if ("weaviate".equals(vectorModel)){
|
||||||
// return milvusVectorStore;
|
return weaviateVectorStore;
|
||||||
// }
|
}else if ("milvus".equals(vectorModel)){
|
||||||
//
|
return milvusVectorStore;
|
||||||
// return null;
|
}
|
||||||
return weaviateVectorStore;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,19 +11,20 @@ import java.util.List;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
@Primary
|
@Primary
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class VectorStoreWrapper implements VectorStore{
|
public class VectorStoreWrapper implements VectorStore {
|
||||||
|
|
||||||
private final VectorStoreFactory vectorStoreFactory;
|
private final VectorStoreFactory vectorStoreFactory;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList) {
|
public void storeEmbeddings(List<String> chunkList, List<List<Double>> vectorList, String kid, String docId, List<String> fidList) {
|
||||||
VectorStore vectorStore = vectorStoreFactory.getVectorStore(kid);
|
VectorStore vectorStore = vectorStoreFactory.getVectorStore(kid);
|
||||||
vectorStore.storeEmbeddings(chunkList, vectorList, kid, docId, fidList);
|
vectorStore.storeEmbeddings(chunkList, vectorList, kid, docId, fidList);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void removeByDocId(String kid, String docId) {
|
public void removeByDocId(String kid, String docId) {
|
||||||
VectorStore vectorStore = vectorStoreFactory.getVectorStore(kid);
|
VectorStore vectorStore = vectorStoreFactory.getVectorStore(kid);
|
||||||
vectorStore.removeByDocId(kid,docId);
|
vectorStore.removeByDocId(kid, docId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -35,7 +36,7 @@ public class VectorStoreWrapper implements VectorStore{
|
|||||||
@Override
|
@Override
|
||||||
public List<String> nearest(List<Double> queryVector, String kid) {
|
public List<String> nearest(List<Double> queryVector, String kid) {
|
||||||
VectorStore vectorStore = vectorStoreFactory.getVectorStore(kid);
|
VectorStore vectorStore = vectorStoreFactory.getVectorStore(kid);
|
||||||
return vectorStore.nearest(queryVector,kid);
|
return vectorStore.nearest(queryVector, kid);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ public class KnowledgeInfo implements Serializable {
|
|||||||
private String kname;
|
private String kname;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识库名称
|
* 是否公开知识库(0 否 1是)
|
||||||
*/
|
*/
|
||||||
private String share;
|
private String share;
|
||||||
|
|
||||||
|
|||||||
@@ -49,8 +49,6 @@ public interface IKnowledgeAttachService {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 删除知识附件
|
* 删除知识附件
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
void removeKnowledgeAttach(String kid);
|
void removeKnowledgeAttach(String docId);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,13 +2,10 @@ package org.ruoyi.knowledge.service;
|
|||||||
|
|
||||||
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
import org.ruoyi.common.mybatis.core.page.PageQuery;
|
||||||
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
import org.ruoyi.common.mybatis.core.page.TableDataInfo;
|
||||||
import org.ruoyi.knowledge.domain.KnowledgeAttach;
|
|
||||||
import org.ruoyi.knowledge.domain.bo.KnowledgeAttachBo;
|
|
||||||
import org.ruoyi.knowledge.domain.bo.KnowledgeInfoBo;
|
import org.ruoyi.knowledge.domain.bo.KnowledgeInfoBo;
|
||||||
import org.ruoyi.knowledge.domain.req.KnowledgeInfoUploadRequest;
|
import org.ruoyi.knowledge.domain.req.KnowledgeInfoUploadRequest;
|
||||||
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
|
import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
|
||||||
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -40,7 +37,6 @@ public interface IKnowledgeInfoService {
|
|||||||
*/
|
*/
|
||||||
Boolean updateByBo(KnowledgeInfoBo bo);
|
Boolean updateByBo(KnowledgeInfoBo bo);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 新增知识库
|
* 新增知识库
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Double> getQueryVector(String query, String kid) {
|
public List<Double> getQueryVector(String query, String kid) {
|
||||||
List<Double> queryVector = vectorization.singleVectorization(query,kid);
|
return vectorization.singleVectorization(query,kid);
|
||||||
return queryVector;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -126,13 +126,9 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void removeKnowledgeAttach(String kid) {
|
public void removeKnowledgeAttach(String docId) {
|
||||||
LoginUser loginUser = LoginHelper.getLoginUser();
|
|
||||||
Map<String,Object> map = new HashMap<>();
|
Map<String,Object> map = new HashMap<>();
|
||||||
map.put("kid",kid);
|
map.put("doc_id",docId);
|
||||||
List<KnowledgeInfoVo> knowledgeInfoList = knowledgeInfoMapper.selectVoByMap(map);
|
|
||||||
knowledgeInfoService.check(knowledgeInfoList);
|
|
||||||
|
|
||||||
baseMapper.deleteByMap(map);
|
baseMapper.deleteByMap(map);
|
||||||
fragmentMapper.deleteByMap(map);
|
fragmentMapper.deleteByMap(map);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
package org.ruoyi.knowledge.service.impl;
|
package org.ruoyi.knowledge.service.impl;
|
||||||
|
|
||||||
|
import cn.hutool.core.collection.CollUtil;
|
||||||
import cn.hutool.core.util.RandomUtil;
|
import cn.hutool.core.util.RandomUtil;
|
||||||
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
|
||||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
import io.github.ollama4j.OllamaAPI;
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequestModel;
|
|
||||||
import io.github.ollama4j.models.chat.OllamaChatResult;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import org.ruoyi.common.core.domain.model.LoginUser;
|
import org.ruoyi.common.core.domain.model.LoginUser;
|
||||||
import org.ruoyi.common.core.utils.MapstructUtils;
|
import org.ruoyi.common.core.utils.MapstructUtils;
|
||||||
@@ -30,6 +26,7 @@ import org.ruoyi.knowledge.mapper.KnowledgeInfoMapper;
|
|||||||
import org.ruoyi.knowledge.service.EmbeddingService;
|
import org.ruoyi.knowledge.service.EmbeddingService;
|
||||||
import org.ruoyi.knowledge.service.IKnowledgeInfoService;
|
import org.ruoyi.knowledge.service.IKnowledgeInfoService;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
import org.springframework.web.multipart.MultipartFile;
|
import org.springframework.web.multipart.MultipartFile;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
@@ -41,8 +38,8 @@ import java.util.*;
|
|||||||
* @author Lion Li
|
* @author Lion Li
|
||||||
* @date 2024-10-21
|
* @date 2024-10-21
|
||||||
*/
|
*/
|
||||||
@RequiredArgsConstructor
|
|
||||||
@Service
|
@Service
|
||||||
|
@RequiredArgsConstructor
|
||||||
public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
||||||
|
|
||||||
private final KnowledgeInfoMapper baseMapper;
|
private final KnowledgeInfoMapper baseMapper;
|
||||||
@@ -110,9 +107,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
|||||||
//TODO 做一些数据校验,如唯一约束
|
//TODO 做一些数据校验,如唯一约束
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public void saveOne(KnowledgeInfoBo bo) {
|
public void saveOne(KnowledgeInfoBo bo) {
|
||||||
KnowledgeInfo knowledgeInfo = MapstructUtils.convert(bo, KnowledgeInfo.class);
|
KnowledgeInfo knowledgeInfo = MapstructUtils.convert(bo, KnowledgeInfo.class);
|
||||||
if (StringUtils.isBlank(bo.getKid())){
|
if (StringUtils.isBlank(bo.getKid())){
|
||||||
@@ -122,7 +118,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
|||||||
knowledgeInfo.setUid(LoginHelper.getLoginUser().getUserId());
|
knowledgeInfo.setUid(LoginHelper.getLoginUser().getUserId());
|
||||||
}
|
}
|
||||||
baseMapper.insert(knowledgeInfo);
|
baseMapper.insert(knowledgeInfo);
|
||||||
embeddingService.createSchema(kid);
|
embeddingService.createSchema(String.valueOf(knowledgeInfo.getId()));
|
||||||
}else {
|
}else {
|
||||||
baseMapper.updateById(knowledgeInfo);
|
baseMapper.updateById(knowledgeInfo);
|
||||||
}
|
}
|
||||||
@@ -148,19 +144,23 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
|||||||
try {
|
try {
|
||||||
content = resourceLoader.getContent(file.getInputStream());
|
content = resourceLoader.getContent(file.getInputStream());
|
||||||
chunkList = resourceLoader.getChunkList(content, kid);
|
chunkList = resourceLoader.getChunkList(content, kid);
|
||||||
for (int i = 0; i < chunkList.size(); i++) {
|
List<KnowledgeFragment> knowledgeFragmentList = new ArrayList<>();
|
||||||
String fid = RandomUtil.randomString(16);
|
if (CollUtil.isNotEmpty(chunkList)) {
|
||||||
fids.add(fid);
|
for (int i = 0; i < chunkList.size(); i++) {
|
||||||
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
|
String fid = RandomUtil.randomString(16);
|
||||||
knowledgeFragment.setKid(kid);
|
fids.add(fid);
|
||||||
knowledgeFragment.setDocId(docId);
|
KnowledgeFragment knowledgeFragment = new KnowledgeFragment();
|
||||||
knowledgeFragment.setFid(fid);
|
knowledgeFragment.setKid(kid);
|
||||||
knowledgeFragment.setIdx(i);
|
knowledgeFragment.setDocId(docId);
|
||||||
// String text = convertTextBlockToPretrainData(chunkList.get(i));
|
knowledgeFragment.setFid(fid);
|
||||||
knowledgeFragment.setContent(chunkList.get(i));
|
knowledgeFragment.setIdx(i);
|
||||||
knowledgeFragment.setCreateTime(new Date());
|
// String text = convertTextBlockToPretrainData(chunkList.get(i));
|
||||||
fragmentMapper.insert(knowledgeFragment);
|
knowledgeFragment.setContent(chunkList.get(i));
|
||||||
|
knowledgeFragment.setCreateTime(new Date());
|
||||||
|
knowledgeFragmentList.add(knowledgeFragment);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
fragmentMapper.insertBatch(knowledgeFragmentList);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
@@ -171,19 +171,21 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public void removeKnowledge(String id) {
|
public void removeKnowledge(String id) {
|
||||||
|
|
||||||
Map<String,Object> map = new HashMap<>();
|
Map<String,Object> map = new HashMap<>();
|
||||||
map.put("kid",id);
|
map.put("kid",id);
|
||||||
List<KnowledgeInfoVo> knowledgeInfoList = baseMapper.selectVoByMap(map);
|
List<KnowledgeInfoVo> knowledgeInfoList = baseMapper.selectVoByMap(map);
|
||||||
check(knowledgeInfoList);
|
check(knowledgeInfoList);
|
||||||
// 删除知识库
|
// 删除向量库信息
|
||||||
baseMapper.deleteByMap(map);
|
knowledgeInfoList.forEach(knowledgeInfoVo -> {
|
||||||
|
embeddingService.removeByKid(String.valueOf(knowledgeInfoVo.getId()));
|
||||||
|
});
|
||||||
// 删除附件和知识片段
|
// 删除附件和知识片段
|
||||||
fragmentMapper.deleteByMap(map);
|
fragmentMapper.deleteByMap(map);
|
||||||
attachMapper.deleteByMap(map);
|
attachMapper.deleteByMap(map);
|
||||||
// 删除向量库信息
|
// 删除知识库
|
||||||
embeddingService.removeByKid(id);
|
baseMapper.deleteByMap(map);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import org.ruoyi.common.log.enums.BusinessType;
|
|||||||
import org.ruoyi.common.satoken.utils.LoginHelper;
|
import org.ruoyi.common.satoken.utils.LoginHelper;
|
||||||
import org.ruoyi.common.web.core.BaseController;
|
import org.ruoyi.common.web.core.BaseController;
|
||||||
import org.ruoyi.system.domain.bo.SysUserBo;
|
import org.ruoyi.system.domain.bo.SysUserBo;
|
||||||
|
import org.ruoyi.system.domain.bo.SysUserPasswordBo;
|
||||||
import org.ruoyi.system.domain.bo.SysUserProfileBo;
|
import org.ruoyi.system.domain.bo.SysUserProfileBo;
|
||||||
import org.ruoyi.system.domain.vo.AvatarVo;
|
import org.ruoyi.system.domain.vo.AvatarVo;
|
||||||
import org.ruoyi.system.domain.vo.ProfileVo;
|
import org.ruoyi.system.domain.vo.ProfileVo;
|
||||||
@@ -75,23 +76,20 @@ public class SysProfileController extends BaseController {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 重置密码
|
* 重置密码
|
||||||
*
|
|
||||||
* @param newPassword 旧密码
|
|
||||||
* @param oldPassword 新密码
|
|
||||||
*/
|
*/
|
||||||
@Log(title = "个人信息", businessType = BusinessType.UPDATE)
|
@Log(title = "个人信息", businessType = BusinessType.UPDATE)
|
||||||
@PutMapping("/updatePwd")
|
@PutMapping("/updatePwd")
|
||||||
public R<Void> updatePwd(String oldPassword, String newPassword) {
|
public R<Void> updatePwd(@Validated @RequestBody SysUserPasswordBo bo) {
|
||||||
SysUserVo user = userService.selectUserById(LoginHelper.getUserId());
|
SysUserVo user = userService.selectUserById(LoginHelper.getUserId());
|
||||||
String password = user.getPassword();
|
String password = user.getPassword();
|
||||||
if (!BCrypt.checkpw(oldPassword, password)) {
|
if (!BCrypt.checkpw(bo.getOldPassword(), password)) {
|
||||||
return R.fail("修改密码失败,旧密码错误");
|
return R.fail("修改密码失败,旧密码错误");
|
||||||
}
|
}
|
||||||
if (BCrypt.checkpw(newPassword, password)) {
|
if (BCrypt.checkpw(bo.getNewPassword(), password)) {
|
||||||
return R.fail("新密码不能与旧密码相同");
|
return R.fail("新密码不能与旧密码相同");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (userService.resetUserPwd(user.getUserId(), BCrypt.hashpw(newPassword)) > 0) {
|
if (userService.resetUserPwd(user.getUserId(), BCrypt.hashpw(bo.getNewPassword())) > 0) {
|
||||||
return R.ok();
|
return R.ok();
|
||||||
}
|
}
|
||||||
return R.fail("修改密码异常,请联系管理员");
|
return R.fail("修改密码异常,请联系管理员");
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
package org.ruoyi.system.domain.bo;
|
||||||
|
import jakarta.validation.constraints.NotBlank;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serial;
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 描述:用户密码修改bo
|
||||||
|
*
|
||||||
|
* @author ageerle@163.com
|
||||||
|
* date 2025/3/9
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class SysUserPasswordBo implements Serializable {
|
||||||
|
|
||||||
|
@Serial
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 旧密码
|
||||||
|
*/
|
||||||
|
@NotBlank(message = "旧密码不能为空")
|
||||||
|
private String oldPassword;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 新密码
|
||||||
|
*/
|
||||||
|
@NotBlank(message = "新密码不能为空")
|
||||||
|
private String newPassword;
|
||||||
|
}
|
||||||
@@ -65,7 +65,7 @@ public class SSEEventSourceListener extends EventSourceListener {
|
|||||||
@Override
|
@Override
|
||||||
public void onEvent(@NotNull EventSource eventSource, String id, String type, String data) {
|
public void onEvent(@NotNull EventSource eventSource, String id, String type, String data) {
|
||||||
try {
|
try {
|
||||||
if (data.equals("[DONE]")) {
|
if ("[DONE]".equals(data)) {
|
||||||
//成功响应
|
//成功响应
|
||||||
emitter.complete();
|
emitter.complete();
|
||||||
if(StringUtils.isNotEmpty(modelName)){
|
if(StringUtils.isNotEmpty(modelName)){
|
||||||
|
|||||||
@@ -0,0 +1,212 @@
|
|||||||
|
package org.ruoyi.system.plugin;
|
||||||
|
|
||||||
|
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import okhttp3.OkHttpClient;
|
||||||
|
import okhttp3.logging.HttpLoggingInterceptor;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.ruoyi.common.chat.demo.ConsoleEventSourceListenerV3;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.Message;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.Parameters;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.tool.Tools;
|
||||||
|
import org.ruoyi.common.chat.entity.chat.tool.ToolsFunction;
|
||||||
|
import org.ruoyi.common.chat.openai.OpenAiClient;
|
||||||
|
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||||
|
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
|
||||||
|
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
|
||||||
|
import org.ruoyi.common.chat.openai.interceptor.OpenAILogger;
|
||||||
|
import org.ruoyi.common.chat.openai.interceptor.OpenAiResponseInterceptor;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
public class WebSearchPlugin {
|
||||||
|
|
||||||
|
private OpenAiClient openAiClient;
|
||||||
|
private OpenAiStreamClient openAiStreamClient;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void before() {
|
||||||
|
//可以为null
|
||||||
|
// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
|
||||||
|
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
|
||||||
|
//!!!!千万别再生产或者测试环境打开BODY级别日志!!!!
|
||||||
|
//!!!生产或者测试环境建议设置为这三种级别:NONE,BASIC,HEADERS,!!!
|
||||||
|
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
|
||||||
|
OkHttpClient okHttpClient = new OkHttpClient
|
||||||
|
.Builder()
|
||||||
|
// .proxy(proxy)
|
||||||
|
.addInterceptor(httpLoggingInterceptor)
|
||||||
|
.addInterceptor(new OpenAiResponseInterceptor())
|
||||||
|
.connectTimeout(10, TimeUnit.SECONDS)
|
||||||
|
.writeTimeout(30, TimeUnit.SECONDS)
|
||||||
|
.readTimeout(30, TimeUnit.SECONDS)
|
||||||
|
.build();
|
||||||
|
openAiClient = OpenAiClient.builder()
|
||||||
|
//支持多key传入,请求时候随机选择
|
||||||
|
.apiKey(Arrays.asList("xx"))
|
||||||
|
//自定义key的获取策略:默认KeyRandomStrategy
|
||||||
|
//.keyStrategy(new KeyRandomStrategy())
|
||||||
|
.keyStrategy(new KeyRandomStrategy())
|
||||||
|
.okHttpClient(okHttpClient)
|
||||||
|
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址)
|
||||||
|
.apiHost("https://open.bigmodel.cn/")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
openAiStreamClient = OpenAiStreamClient.builder()
|
||||||
|
//支持多key传入,请求时候随机选择
|
||||||
|
.apiKey(Arrays.asList("xx"))
|
||||||
|
//自定义key的获取策略:默认KeyRandomStrategy
|
||||||
|
.keyStrategy(new KeyRandomStrategy())
|
||||||
|
.authInterceptor(new DynamicKeyOpenAiAuthInterceptor())
|
||||||
|
.okHttpClient(okHttpClient)
|
||||||
|
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址)
|
||||||
|
.apiHost("https://open.bigmodel.cn/")
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test() {
|
||||||
|
Message message = Message.builder().role(Message.Role.USER).content("今天武汉天气怎么样").build();
|
||||||
|
ChatCompletion chatCompletion = ChatCompletion
|
||||||
|
.builder()
|
||||||
|
.messages(Collections.singletonList(message))
|
||||||
|
// .tools(Collections.singletonList(tools))
|
||||||
|
.model("web-search-pro")
|
||||||
|
.build();
|
||||||
|
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
|
||||||
|
|
||||||
|
System.out.printf("chatCompletionResponse=%s\n", JSONUtil.toJsonStr(chatCompletionResponse));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void streamToolsChat() {
|
||||||
|
|
||||||
|
CountDownLatch countDownLatch = new CountDownLatch(1);
|
||||||
|
ConsoleEventSourceListenerV3 eventSourceListener = new ConsoleEventSourceListenerV3(countDownLatch);
|
||||||
|
|
||||||
|
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build();
|
||||||
|
//属性一
|
||||||
|
JSONObject wordLength = new JSONObject();
|
||||||
|
wordLength.put("type", "number");
|
||||||
|
wordLength.put("description", "词语的长度");
|
||||||
|
//属性二
|
||||||
|
JSONObject language = new JSONObject();
|
||||||
|
language.put("type", "string");
|
||||||
|
language.put("enum", Arrays.asList("zh", "en"));
|
||||||
|
language.put("description", "语言类型,例如:zh代表中文、en代表英语");
|
||||||
|
//参数
|
||||||
|
JSONObject properties = new JSONObject();
|
||||||
|
properties.put("wordLength", wordLength);
|
||||||
|
properties.put("language", language);
|
||||||
|
Parameters parameters = Parameters.builder()
|
||||||
|
.type("object")
|
||||||
|
.properties(properties)
|
||||||
|
.required(Collections.singletonList("wordLength")).build();
|
||||||
|
Tools tools = Tools.builder()
|
||||||
|
.type(Tools.Type.FUNCTION.getName())
|
||||||
|
.function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ChatCompletion chatCompletion = ChatCompletion
|
||||||
|
.builder()
|
||||||
|
.messages(Collections.singletonList(message))
|
||||||
|
.tools(Collections.singletonList(tools))
|
||||||
|
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||||
|
.build();
|
||||||
|
openAiStreamClient.streamChatCompletion(chatCompletion, eventSourceListener);
|
||||||
|
|
||||||
|
try {
|
||||||
|
countDownLatch.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
|
||||||
|
ToolCalls openAiReturnToolCalls = eventSourceListener.getToolCalls();
|
||||||
|
WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class);
|
||||||
|
String oneWord = getOneWord(wordParam);
|
||||||
|
|
||||||
|
|
||||||
|
ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build();
|
||||||
|
ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build();
|
||||||
|
//构造tool call
|
||||||
|
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build();
|
||||||
|
String content
|
||||||
|
= "{ " +
|
||||||
|
"\"wordLength\": \"3\", " +
|
||||||
|
"\"language\": \"zh\", " +
|
||||||
|
"\"word\": \"" + oneWord + "\"," +
|
||||||
|
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
|
||||||
|
"}";
|
||||||
|
Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build();
|
||||||
|
List<Message> messageList = Arrays.asList(message, message2, message3);
|
||||||
|
ChatCompletion chatCompletionV2 = ChatCompletion
|
||||||
|
.builder()
|
||||||
|
.messages(messageList)
|
||||||
|
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
CountDownLatch countDownLatch1 = new CountDownLatch(1);
|
||||||
|
openAiStreamClient.streamChatCompletion(chatCompletionV2, new ConsoleEventSourceListenerV3(countDownLatch));
|
||||||
|
try {
|
||||||
|
countDownLatch1.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
countDownLatch1.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
static class WordParam {
|
||||||
|
private int wordLength;
|
||||||
|
@Builder.Default
|
||||||
|
private String language = "zh";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取一个词语(根据语言和字符长度查询)
|
||||||
|
* @param wordParam
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public String getOneWord(WordParam wordParam) {
|
||||||
|
|
||||||
|
List<String> zh = Arrays.asList("大香蕉", "哈密瓜", "苹果");
|
||||||
|
List<String> en = Arrays.asList("apple", "banana", "cantaloupe");
|
||||||
|
if (wordParam.getLanguage().equals("zh")) {
|
||||||
|
for (String e : zh) {
|
||||||
|
if (e.length() == wordParam.getWordLength()) {
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (wordParam.getLanguage().equals("en")) {
|
||||||
|
for (String e : en) {
|
||||||
|
if (e.length() == wordParam.getWordLength()) {
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "西瓜";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
@@ -3,16 +3,11 @@ package org.ruoyi.system.service.impl;
|
|||||||
import cn.dev33.satoken.stp.StpUtil;
|
import cn.dev33.satoken.stp.StpUtil;
|
||||||
import cn.hutool.core.collection.CollectionUtil;
|
import cn.hutool.core.collection.CollectionUtil;
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.azure.ai.openai.OpenAIClient;
|
|
||||||
import com.azure.ai.openai.OpenAIClientBuilder;
|
|
||||||
import com.azure.ai.openai.models.*;
|
|
||||||
import com.azure.core.credential.AzureKeyCredential;
|
|
||||||
import io.github.ollama4j.OllamaAPI;
|
import io.github.ollama4j.OllamaAPI;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
||||||
import io.github.ollama4j.models.chat.OllamaChatRequestModel;
|
import io.github.ollama4j.models.chat.OllamaChatRequestModel;
|
||||||
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
import io.github.ollama4j.models.generate.OllamaStreamHandler;
|
||||||
import io.github.ollama4j.utils.Options;
|
|
||||||
import jakarta.servlet.http.HttpServletRequest;
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -33,6 +28,12 @@ import org.ruoyi.common.chat.entity.images.Item;
|
|||||||
import org.ruoyi.common.chat.entity.images.ResponseFormat;
|
import org.ruoyi.common.chat.entity.images.ResponseFormat;
|
||||||
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
|
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
|
||||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||||
|
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||||
|
import org.ruoyi.common.chat.plugin.CmdPlugin;
|
||||||
|
import org.ruoyi.common.chat.plugin.CmdReq;
|
||||||
|
import org.ruoyi.common.chat.plugin.SqlPlugin;
|
||||||
|
import org.ruoyi.common.chat.plugin.SqlReq;
|
||||||
|
import org.ruoyi.common.chat.sse.ConsoleEventSourceListener;
|
||||||
import org.ruoyi.common.chat.utils.TikTokensUtil;
|
import org.ruoyi.common.chat.utils.TikTokensUtil;
|
||||||
import org.ruoyi.common.core.domain.model.LoginUser;
|
import org.ruoyi.common.core.domain.model.LoginUser;
|
||||||
import org.ruoyi.common.core.exception.base.BaseException;
|
import org.ruoyi.common.core.exception.base.BaseException;
|
||||||
@@ -43,12 +44,10 @@ import org.ruoyi.system.domain.bo.ChatMessageBo;
|
|||||||
import org.ruoyi.system.domain.bo.SysModelBo;
|
import org.ruoyi.system.domain.bo.SysModelBo;
|
||||||
import org.ruoyi.system.domain.request.translation.TranslationRequest;
|
import org.ruoyi.system.domain.request.translation.TranslationRequest;
|
||||||
import org.ruoyi.system.domain.vo.SysModelVo;
|
import org.ruoyi.system.domain.vo.SysModelVo;
|
||||||
import org.ruoyi.system.domain.vo.SysUserVo;
|
|
||||||
import org.ruoyi.system.listener.SSEEventSourceListener;
|
import org.ruoyi.system.listener.SSEEventSourceListener;
|
||||||
import org.ruoyi.system.service.*;
|
import org.ruoyi.system.service.*;
|
||||||
import org.springframework.core.io.InputStreamResource;
|
import org.springframework.core.io.InputStreamResource;
|
||||||
import org.springframework.core.io.Resource;
|
import org.springframework.core.io.Resource;
|
||||||
import org.springframework.http.HttpStatus;
|
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.http.ResponseEntity;
|
import org.springframework.http.ResponseEntity;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -63,10 +62,10 @@ import java.net.URLEncoder;
|
|||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
|
||||||
import io.github.ollama4j.utils.OptionsBuilder;
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -89,9 +88,6 @@ public class SseServiceImpl implements ISseService {
|
|||||||
|
|
||||||
static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
|
static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
|
||||||
|
|
||||||
private final ISysPackagePlanService sysPackagePlanService;
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
||||||
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
||||||
@@ -101,12 +97,7 @@ public class SseServiceImpl implements ISseService {
|
|||||||
List<Message> messages = chatRequest.getMessages();
|
List<Message> messages = chatRequest.getMessages();
|
||||||
try {
|
try {
|
||||||
if (StpUtil.isLogin()) {
|
if (StpUtil.isLogin()) {
|
||||||
SysUserVo sysUserVo = userService.selectUserById(getUserId());
|
|
||||||
// if (!checkModel(sysUserVo.getUserPlan(), chatRequest.getModel())) {
|
|
||||||
// throw new BaseException("当前套餐不支持此模型!");
|
|
||||||
// }
|
|
||||||
LocalCache.CACHE.put("userId", getUserId());
|
LocalCache.CACHE.put("userId", getUserId());
|
||||||
|
|
||||||
Object content = messages.get(messages.size() - 1).getContent();
|
Object content = messages.get(messages.size() - 1).getContent();
|
||||||
|
|
||||||
String chatString = "";
|
String chatString = "";
|
||||||
@@ -161,36 +152,23 @@ public class SseServiceImpl implements ISseService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if("openCmd".equals(chatRequest.getModel())) {
|
||||||
// else {
|
sseEmitter.send(cmdPlugin(messages));
|
||||||
//
|
sseEmitter.complete();
|
||||||
// // 初始请求次数
|
}else if ("sqlPlugin".equals(chatRequest.getModel())){
|
||||||
// int number = 1;
|
sseEmitter.send(sqlPlugin(messages));
|
||||||
// // 获取请求IP
|
sseEmitter.complete();
|
||||||
// String realIp = getClientIpAddress(request);
|
} else {
|
||||||
// // 根据IP获取次数
|
ChatCompletion completion = ChatCompletion
|
||||||
// Integer requestNumber = RedisUtils.getCacheObject(realIp);
|
.builder()
|
||||||
// if (requestNumber == null) {
|
.messages(messages)
|
||||||
// // 记录ip使用次数
|
.model(chatRequest.getModel())
|
||||||
// RedisUtils.setCacheObject(realIp, number);
|
.temperature(chatRequest.getTemperature())
|
||||||
// } else {
|
.topP(chatRequest.getTop_p())
|
||||||
// String configValue = configService.getConfigValue("mail", "free");
|
.stream(true)
|
||||||
// if (requestNumber > Integer.parseInt(configValue)) {
|
.build();
|
||||||
// throw new BaseException("剩余次数不足,请充值后使用");
|
openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
|
||||||
// }
|
}
|
||||||
// RedisUtils.setCacheObject(realIp, requestNumber + 1);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// }
|
|
||||||
ChatCompletion completion = ChatCompletion
|
|
||||||
.builder()
|
|
||||||
.messages(messages)
|
|
||||||
.model(chatRequest.getModel())
|
|
||||||
.temperature(chatRequest.getTemperature())
|
|
||||||
.topP(chatRequest.getTop_p())
|
|
||||||
.stream(true)
|
|
||||||
.build();
|
|
||||||
openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
String message = e.getMessage();
|
String message = e.getMessage();
|
||||||
sendErrorEvent(sseEmitter, message);
|
sendErrorEvent(sseEmitter, message);
|
||||||
@@ -199,32 +177,51 @@ public class SseServiceImpl implements ISseService {
|
|||||||
return sseEmitter;
|
return sseEmitter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String cmdPlugin(List<Message> messages) {
|
||||||
|
CmdPlugin plugin = new CmdPlugin(CmdReq.class);
|
||||||
|
// 插件名称
|
||||||
|
plugin.setName("命令行工具");
|
||||||
|
// 方法名称
|
||||||
|
plugin.setFunction("openCmd");
|
||||||
|
// 方法说明
|
||||||
|
plugin.setDescription("提供一个命令行指令,比如<记事本>,指令使用中文");
|
||||||
|
|
||||||
// /**
|
PluginAbstract.Arg arg = new PluginAbstract.Arg();
|
||||||
// * 查当前用户是否可以调用此模型
|
// 参数名称
|
||||||
// *
|
arg.setName("cmd");
|
||||||
// * @param planId
|
// 参数说明
|
||||||
// * @return
|
arg.setDescription("命令行指令");
|
||||||
// */
|
// 参数类型
|
||||||
// public Boolean checkModel(String planId, String modelName) {
|
arg.setType("string");
|
||||||
// SysPackagePlanBo sysPackagePlanBo = new SysPackagePlanBo();
|
arg.setRequired(true);
|
||||||
// if (modelName.startsWith("gpt-4-gizmo")) {
|
plugin.setArgs(Collections.singletonList(arg));
|
||||||
// modelName = "gpt-4-gizmo";
|
//有四个重载方法,都可以使用
|
||||||
// }
|
ChatCompletionResponse response = openAiStreamClient.chatCompletionWithPlugin(messages,"gpt-4o-mini",plugin);
|
||||||
// if (StringUtils.isEmpty(planId)) {
|
return response.getChoices().get(0).getMessage().getContent().toString();
|
||||||
// sysPackagePlanBo.setName("Visitor");
|
}
|
||||||
// } else if ("Visitor".equals(planId) || "Free".equals(planId)) {
|
|
||||||
// sysPackagePlanBo.setName(planId);
|
public String sqlPlugin(List<Message> messages) {
|
||||||
// } else {
|
SqlPlugin plugin = new SqlPlugin(SqlReq.class);
|
||||||
// // sysPackagePlanBo.setId(Long.valueOf(planId));
|
// 插件名称
|
||||||
// return true;
|
plugin.setName("数据库查询插件");
|
||||||
// }
|
// 方法名称
|
||||||
//
|
plugin.setFunction("sqlPlugin");
|
||||||
// SysPackagePlanVo sysPackagePlanVo = sysPackagePlanService.queryList(sysPackagePlanBo).get(0);
|
// 方法说明
|
||||||
// // 将字符串转换为数组
|
plugin.setDescription("提供一个用户名称查询余额信息");
|
||||||
// String[] array = sysPackagePlanVo.getPlanDetail().split(",");
|
|
||||||
// return Arrays.asList(array).contains(modelName);
|
PluginAbstract.Arg arg = new PluginAbstract.Arg();
|
||||||
// }
|
// 参数名称
|
||||||
|
arg.setName("username");
|
||||||
|
// 参数说明
|
||||||
|
arg.setDescription("用户名称");
|
||||||
|
// 参数类型
|
||||||
|
arg.setType("string");
|
||||||
|
arg.setRequired(true);
|
||||||
|
plugin.setArgs(Collections.singletonList(arg));
|
||||||
|
//有四个重载方法,都可以使用
|
||||||
|
ChatCompletionResponse response = openAiStreamClient.chatCompletionWithPlugin(messages,"gpt-4o-mini",plugin);
|
||||||
|
return response.getChoices().get(0).getMessage().getContent().toString();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 根据次数扣除余额
|
* 根据次数扣除余额
|
||||||
@@ -295,25 +292,6 @@ public class SseServiceImpl implements ISseService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String chat(ChatRequest chatRequest, String userId) {
|
public String chat(ChatRequest chatRequest, String userId) {
|
||||||
// chatService.deductUserBalance(Long.valueOf(userId), 0.01);
|
|
||||||
// // 保存消息记录
|
|
||||||
// ChatMessageBo chatMessageBo = new ChatMessageBo();
|
|
||||||
// chatMessageBo.setUserId(Long.valueOf(userId));
|
|
||||||
// chatMessageBo.setModelName(ChatCompletion.Model.GPT_3_5_TURBO.getName());
|
|
||||||
// chatMessageBo.setContent(chatRequest.getPrompt());
|
|
||||||
// chatMessageBo.setDeductCost(0.01);
|
|
||||||
// chatMessageBo.setTotalTokens(0);
|
|
||||||
// chatMessageService.insertByBo(chatMessageBo);
|
|
||||||
//
|
|
||||||
// openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
|
||||||
// Message message = Message.builder().role(Message.Role.USER).content(chatRequest.getPrompt()).build();
|
|
||||||
// ChatCompletion chatCompletion = ChatCompletion
|
|
||||||
// .builder()
|
|
||||||
// .messages(Collections.singletonList(message))
|
|
||||||
// .model(chatRequest.getModel())
|
|
||||||
// .build();
|
|
||||||
// ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
|
|
||||||
// return chatCompletionResponse.getChoices().get(0).getMessage().getContent();
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -540,7 +518,8 @@ public class SseServiceImpl implements ISseService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String translation(TranslationRequest translationRequest) {
|
public String translation(TranslationRequest translationRequest) {
|
||||||
|
// 翻译模型固定为gpt-4o-mini
|
||||||
|
translationRequest.setModel("gpt-4o-mini");
|
||||||
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
||||||
chatMessageBo.setUserId(getUserId());
|
chatMessageBo.setUserId(getUserId());
|
||||||
chatMessageBo.setModelName(translationRequest.getModel());
|
chatMessageBo.setModelName(translationRequest.getModel());
|
||||||
@@ -557,17 +536,12 @@ public class SseServiceImpl implements ISseService {
|
|||||||
"\n" +
|
"\n" +
|
||||||
"请将用户输入词语翻译成{" + translationRequest.getTargetLanguage() + "}\n" +
|
"请将用户输入词语翻译成{" + translationRequest.getTargetLanguage() + "}\n" +
|
||||||
"\n" +
|
"\n" +
|
||||||
"让我们一步一步来思考\n" +
|
|
||||||
"==示例输出==\n" +
|
"==示例输出==\n" +
|
||||||
|
"**原文** : <这里显示要翻译的原文信息>\n" +
|
||||||
"**翻译** : <这里显示翻译成英语的结果>\n" +
|
"**翻译** : <这里显示翻译成英语的结果>\n" +
|
||||||
"\n" +
|
|
||||||
"**造句** : What's the weather like today? Use the 'Weather Query' plugin to find out instantly! <造一个英语句子>\n" +
|
|
||||||
"\n" +
|
|
||||||
"**同义词** : Add-on、Extension、Module <这里显示1-3个英文的同义词>\n" +
|
|
||||||
"\n" +
|
|
||||||
"==示例结束==\n" +
|
"==示例结束==\n" +
|
||||||
"\n" +
|
"\n" +
|
||||||
"注意:请严格按示例进行输出").build();
|
"注意:请严格按示例进行输出,返回markdown格式").build();
|
||||||
messageList.add(sysMessage);
|
messageList.add(sysMessage);
|
||||||
Message message = Message.builder().role(Message.Role.USER).content(translationRequest.getPrompt()).build();
|
Message message = Message.builder().role(Message.Role.USER).content(translationRequest.getPrompt()).build();
|
||||||
messageList.add(message);
|
messageList.add(message);
|
||||||
@@ -646,4 +620,6 @@ public class SseServiceImpl implements ISseService {
|
|||||||
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
|
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
|
||||||
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
|
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -557,4 +557,11 @@ public class SysUserServiceImpl implements ISysUserService, UserService {
|
|||||||
.select(SysUser::getUserName).eq(SysUser::getUserId, userId));
|
.select(SysUser::getUserName).eq(SysUser::getUserId, userId));
|
||||||
return ObjectUtil.isNull(sysUser) ? null : sysUser.getUserName();
|
return ObjectUtil.isNull(sysUser) ? null : sysUser.getUserName();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String selectUserByName(String userName) {
|
||||||
|
SysUser sysUser = baseMapper.selectOne(new LambdaQueryWrapper<SysUser>()
|
||||||
|
.eq(SysUser::getUserName, userName));
|
||||||
|
return ObjectUtil.isNull(sysUser) ? null : sysUser.getUserBalance().toString();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,8 +8,10 @@ ADD COLUMN `vector` varchar(50) NULL COMMENT '向量库' AFTER `text_block_size`
|
|||||||
ADD COLUMN `vector_model` varchar(50) NULL COMMENT '向量模型' AFTER `vector`;
|
ADD COLUMN `vector_model` varchar(50) NULL COMMENT '向量模型' AFTER `vector`;
|
||||||
|
|
||||||
|
|
||||||
|
INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412050, 'weaviate', 'protocol', 'http', '协议', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);
|
||||||
INSERT INTO `chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412050, 'weaviate', 'protocol', 'http', '协议', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);
|
INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412051, 'weaviate', 'host', '127.0.0.1:6038', '地址', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);
|
||||||
INSERT INTO `chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412051, 'weaviate', 'host', '127.0.0.1:6038', '地址', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);
|
INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412052, 'weaviate', 'classname', 'LocalKnowledge', '分类名称', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);
|
||||||
INSERT INTO `chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412052, 'weaviate', 'classname', 'LocalKnowledge', '分类名称', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);
|
INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412053, 'milvus', 'host', '127.0.0.1', '地址', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);
|
||||||
|
INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412054, 'milvus', 'port', '19530', '端口', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);
|
||||||
|
INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412055, 'milvus', 'dimension', '1536', '维度', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);
|
||||||
|
INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412056, 'milvus', 'collection', 'LocalKnowledge', '分类名称', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);
|
||||||
|
|||||||
Reference in New Issue
Block a user