feat: 支持插件功能

This commit is contained in:
ageerle
2025-03-11 17:32:47 +08:00
parent c98a6deaf6
commit 6a1b544545
47 changed files with 2865 additions and 230 deletions

View File

@@ -26,6 +26,12 @@
<artifactId>ruoyi-common-core</artifactId>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.33</version>
</dependency>
<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-ai-openai</artifactId>
@@ -92,5 +98,25 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
</dependency>
<dependency>
<groupId>cn.bigmodel.openapi</groupId>
<artifactId>oapi-java-sdk</artifactId>
<version>release-V4-2.3.0</version>
</dependency>
<dependency>
<groupId>com.squareup.okhttp</groupId>
<artifactId>okhttp</artifactId>
<version>2.7.5</version>
<scope>compile</scope>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,73 @@
package org.ruoyi.common.chat.demo;
import cn.hutool.json.JSONUtil;
import lombok.Getter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
/**
* 描述: sse
*
* @author https:www.unfbx.com
* 2023-06-15
*/
@Slf4j
public class ConsoleEventSourceListenerV2 extends EventSourceListener {
@Getter
String args = "";
final CountDownLatch countDownLatch;
public ConsoleEventSourceListenerV2(CountDownLatch countDownLatch) {
this.countDownLatch = countDownLatch;
}
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("OpenAI建立sse连接...");
}
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("OpenAI返回数据{}", data);
if (data.equals("[DONE]")) {
log.info("OpenAI返回数据结束了");
countDownLatch.countDown();
return;
}
ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class);
if(Objects.nonNull(chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall())){
args += chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall().getArguments();
}
}
@Override
public void onClosed(EventSource eventSource) {
log.info("OpenAI关闭sse连接...");
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if(Objects.isNull(response)){
log.error("OpenAI sse连接异常:{}", t);
eventSource.cancel();
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
log.error("OpenAI sse连接异常data{},异常:{}", body.string(), t);
} else {
log.error("OpenAI sse连接异常data{},异常:{}", response, t);
}
eventSource.cancel();
}
}

View File

@@ -0,0 +1,92 @@
package org.ruoyi.common.chat.demo;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.json.JSONUtil;
import lombok.Getter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction;
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
/**
* 描述: demo测试实现类仅供思路参考
*
* @author https:www.unfbx.com
* 2023-11-12
*/
@Slf4j
public class ConsoleEventSourceListenerV3 extends EventSourceListener {
@Getter
List<ToolCalls> choices = new ArrayList<>();
@Getter
ToolCalls toolCalls = new ToolCalls();
@Getter
ToolCallFunction toolCallFunction = ToolCallFunction.builder().name("").arguments("").build();
final CountDownLatch countDownLatch;
public ConsoleEventSourceListenerV3(CountDownLatch countDownLatch) {
this.countDownLatch = countDownLatch;
}
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("OpenAI建立sse连接...");
}
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("OpenAI返回数据{}", data);
if (data.equals("[DONE]")) {
log.info("OpenAI返回数据结束了");
return;
}
ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class);
Message delta = chatCompletionResponse.getChoices().get(0).getDelta();
if (CollectionUtil.isNotEmpty(delta.getToolCalls())) {
choices.addAll(delta.getToolCalls());
}
}
@Override
public void onClosed(EventSource eventSource) {
if(CollectionUtil.isNotEmpty(choices)){
toolCalls.setId(choices.get(0).getId());
toolCalls.setType(choices.get(0).getType());
choices.forEach(e -> {
toolCallFunction.setName(e.getFunction().getName());
toolCallFunction.setArguments(toolCallFunction.getArguments() + e.getFunction().getArguments());
toolCalls.setFunction(toolCallFunction);
});
}
log.info("OpenAI关闭sse连接...");
countDownLatch.countDown();
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if(Objects.isNull(response)){
log.error("OpenAI sse连接异常:{}", t);
eventSource.cancel();
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
log.error("OpenAI sse连接异常data{},异常:{}", body.string(), t);
} else {
log.error("OpenAI sse连接异常data{},异常:{}", response, t);
}
eventSource.cancel();
}
}

View File

@@ -0,0 +1,417 @@
package org.ruoyi.common.chat.demo;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson.JSONObject;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import org.junit.Before;
import org.junit.Test;
import org.ruoyi.common.chat.entity.chat.*;
import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction;
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
import org.ruoyi.common.chat.entity.chat.tool.Tools;
import org.ruoyi.common.chat.entity.chat.tool.ToolsFunction;
import org.ruoyi.common.chat.openai.OpenAiClient;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
import org.ruoyi.common.chat.openai.interceptor.OpenAILogger;
import org.ruoyi.common.chat.openai.interceptor.OpenAiResponseInterceptor;
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
import org.ruoyi.common.chat.plugin.CmdPlugin;
import org.ruoyi.common.chat.plugin.CmdReq;
import org.ruoyi.common.chat.sse.ConsoleEventSourceListener;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
/**
* 描述:
*
* @author ageerle@163.com
* date 2025/3/8
*/
@Slf4j
public class PluginTest {
private OpenAiClient openAiClient;
private OpenAiStreamClient openAiStreamClient;
@Before
public void before() {
//可以为null
// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
//千万别再生产或者测试环境打开BODY级别日志
//生产或者测试环境建议设置为这三种级别NONE,BASIC,HEADERS,
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
OkHttpClient okHttpClient = new OkHttpClient
.Builder()
// .proxy(proxy)
.addInterceptor(httpLoggingInterceptor)
.addInterceptor(new OpenAiResponseInterceptor())
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(30, TimeUnit.SECONDS)
.readTimeout(30, TimeUnit.SECONDS)
.build();
openAiClient = OpenAiClient.builder()
//支持多key传入请求时候随机选择
.apiKey(Arrays.asList("sk-xx"))
//自定义key的获取策略默认KeyRandomStrategy
//.keyStrategy(new KeyRandomStrategy())
.keyStrategy(new KeyRandomStrategy())
.okHttpClient(okHttpClient)
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复openai ,获取免费的测试代理地址)
.apiHost("https://api.pandarobot.chat/")
.build();
openAiStreamClient = OpenAiStreamClient.builder()
//支持多key传入请求时候随机选择
.apiKey(Arrays.asList("sk-xx"))
//自定义key的获取策略默认KeyRandomStrategy
.keyStrategy(new KeyRandomStrategy())
.authInterceptor(new DynamicKeyOpenAiAuthInterceptor())
.okHttpClient(okHttpClient)
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复openai ,获取免费的测试代理地址)
.apiHost("https://api.pandarobot.chat/")
.build();
}
@Test
public void chatFunction() {
//模型GPT_3_5_TURBO_16K_0613
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语并解释下词语对应物品的用途").build();
//属性一
JSONObject wordLength = new JSONObject();
wordLength.put("type", "number");
wordLength.put("description", "词语的长度");
//属性二
JSONObject language = new JSONObject();
language.put("type", "string");
language.put("enum", Arrays.asList("zh", "en"));
language.put("description", "语言类型例如zh代表中文、en代表英语");
//参数
JSONObject properties = new JSONObject();
properties.put("wordLength", wordLength);
properties.put("language", language);
Parameters parameters = Parameters.builder()
.type("object")
.properties(properties)
.required(Collections.singletonList("wordLength")).build();
Functions functions = Functions.builder()
.name("getOneWord")
.description("获取一个指定长度和语言类型的词语")
.parameters(parameters)
.build();
ChatCompletion chatCompletion = ChatCompletion
.builder()
.messages(Collections.singletonList(message))
.functions(Collections.singletonList(functions))
.functionCall("auto")
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
.build();
ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion);
ChatChoice chatChoice = chatCompletionResponse.getChoices().get(0);
log.info("构造的方法值:{}", chatChoice.getMessage().getFunctionCall());
log.info("构造的方法名称:{}", chatChoice.getMessage().getFunctionCall().getName());
log.info("构造的方法参数:{}", chatChoice.getMessage().getFunctionCall().getArguments());
WordParam wordParam = JSONUtil.toBean(chatChoice.getMessage().getFunctionCall().getArguments(), WordParam.class);
String oneWord = getOneWord(wordParam);
FunctionCall functionCall = FunctionCall.builder()
.arguments(chatChoice.getMessage().getFunctionCall().getArguments())
.name("getOneWord")
.build();
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").functionCall(functionCall).build();
String content
= "{ " +
"\"wordLength\": \"3\", " +
"\"language\": \"zh\", " +
"\"word\": \"" + oneWord + "\"," +
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
"}";
Message message3 = Message.builder().role(Message.Role.FUNCTION).name("getOneWord").content(content).build();
List<Message> messageList = Arrays.asList(message, message2, message3);
ChatCompletion chatCompletionV2 = ChatCompletion
.builder()
.messages(messageList)
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
.build();
ChatCompletionResponse chatCompletionResponseV2 = openAiClient.chatCompletion(chatCompletionV2);
log.info("自定义的方法返回值:{}",chatCompletionResponseV2.getChoices().get(0).getMessage().getContent());
}
@Test
public void plugin() {
CmdPlugin plugin = new CmdPlugin(CmdReq.class);
// 插件名称
plugin.setName("命令行工具");
// 方法名称
plugin.setFunction("openCmd");
// 方法说明
plugin.setDescription("提供一个命令行指令,比如<记事本>,指令使用中文,以function返回结果为准");
PluginAbstract.Arg arg = new PluginAbstract.Arg();
// 参数名称
arg.setName("cmd");
// 参数说明
arg.setDescription("命令行指令");
// 参数类型
arg.setType("string");
arg.setRequired(true);
plugin.setArgs(Collections.singletonList(arg));
Message message2 = Message.builder().role(Message.Role.USER).content("帮我打开计算器,结合上下文判断指令是否执行成功,只用回复成功或者失败").build();
List<Message> messages = new ArrayList<>();
messages.add(message2);
//有四个重载方法,都可以使用
ChatCompletionResponse response = openAiClient.chatCompletionWithPlugin(messages,"gpt-4o-mini",plugin);
log.info("自定义的方法返回值:{}", response.getChoices().get(0).getMessage().getContent());
}
/**
* 自定义返回数据格式
*/
@Test
public void diyReturnDataModelChat() {
Message message = Message.builder().role(Message.Role.USER).content("随机输出10个单词使用json输出").build();
ChatCompletion chatCompletion = ChatCompletion
.builder()
.messages(Collections.singletonList(message))
.responseFormat(ResponseFormat.builder().type(ResponseFormat.Type.JSON_OBJECT.getName()).build())
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
.build();
ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion);
chatCompletionResponse.getChoices().forEach(e -> System.out.println(e.getMessage()));
}
@Test
public void streamPlugin() {
WeatherPlugin plugin = new WeatherPlugin(WeatherReq.class);
plugin.setName("知心天气");
plugin.setFunction("getLocationWeather");
plugin.setDescription("提供一个地址,方法将会获取该地址的天气的实时温度信息。");
PluginAbstract.Arg arg = new PluginAbstract.Arg();
arg.setName("location");
arg.setDescription("地名");
arg.setType("string");
arg.setRequired(true);
plugin.setArgs(Collections.singletonList(arg));
// Message message1 = Message.builder().role(Message.Role.USER).content("秦始皇统一了哪六国。").build();
Message message2 = Message.builder().role(Message.Role.USER).content("获取上海市的天气现在多少度然后再给出3个推荐的户外运动。").build();
List<Message> messages = new ArrayList<>();
// messages.add(message1);
messages.add(message2);
//默认模型GPT_3_5_TURBO_16K_0613
//有四个重载方法,都可以使用
openAiStreamClient.streamChatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_4_1106_PREVIEW.getName(), new ConsoleEventSourceListener(), plugin);
CountDownLatch countDownLatch = new CountDownLatch(1);
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
/**
* tools使用示例
*/
@Test
public void toolsChat() {
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语并解释下词语对应物品的用途").build();
//属性一
JSONObject wordLength = new JSONObject();
wordLength.put("type", "number");
wordLength.put("description", "词语的长度");
//属性二
JSONObject language = new JSONObject();
language.put("type", "string");
language.put("enum", Arrays.asList("zh", "en"));
language.put("description", "语言类型例如zh代表中文、en代表英语");
//参数
JSONObject properties = new JSONObject();
properties.put("wordLength", wordLength);
properties.put("language", language);
Parameters parameters = Parameters.builder()
.type("object")
.properties(properties)
.required(Collections.singletonList("wordLength")).build();
Tools tools = Tools.builder()
.type(Tools.Type.FUNCTION.getName())
.function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build())
.build();
ChatCompletion chatCompletion = ChatCompletion
.builder()
.messages(Collections.singletonList(message))
.tools(Collections.singletonList(tools))
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
.build();
ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion);
ChatChoice chatChoice = chatCompletionResponse.getChoices().get(0);
log.info("构造的方法值:{}", chatChoice.getMessage().getToolCalls());
ToolCalls openAiReturnToolCalls = chatChoice.getMessage().getToolCalls().get(0);
WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class);
String oneWord = getOneWord(wordParam);
ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build();
ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build();
//构造tool call
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build();
String content
= "{ " +
"\"wordLength\": \"3\", " +
"\"language\": \"zh\", " +
"\"word\": \"" + oneWord + "\"," +
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
"}";
Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build();
List<Message> messageList = Arrays.asList(message, message2, message3);
ChatCompletion chatCompletionV2 = ChatCompletion
.builder()
.messages(messageList)
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
.build();
ChatCompletionResponse chatCompletionResponseV2 = openAiClient.chatCompletion(chatCompletionV2);
log.info("自定义的方法返回值:{}", chatCompletionResponseV2.getChoices().get(0).getMessage().getContent());
}
/**
* tools流式输出使用示例
*/
@Test
public void streamToolsChat() {
CountDownLatch countDownLatch = new CountDownLatch(1);
ConsoleEventSourceListenerV3 eventSourceListener = new ConsoleEventSourceListenerV3(countDownLatch);
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语并解释下词语对应物品的用途").build();
//属性一
JSONObject wordLength = new JSONObject();
wordLength.put("type", "number");
wordLength.put("description", "词语的长度");
//属性二
JSONObject language = new JSONObject();
language.put("type", "string");
language.put("enum", Arrays.asList("zh", "en"));
language.put("description", "语言类型例如zh代表中文、en代表英语");
//参数
JSONObject properties = new JSONObject();
properties.put("wordLength", wordLength);
properties.put("language", language);
Parameters parameters = Parameters.builder()
.type("object")
.properties(properties)
.required(Collections.singletonList("wordLength")).build();
Tools tools = Tools.builder()
.type(Tools.Type.FUNCTION.getName())
.function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build())
.build();
ChatCompletion chatCompletion = ChatCompletion
.builder()
.messages(Collections.singletonList(message))
.tools(Collections.singletonList(tools))
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
.build();
openAiStreamClient.streamChatCompletion(chatCompletion, eventSourceListener);
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
ToolCalls openAiReturnToolCalls = eventSourceListener.getToolCalls();
WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class);
String oneWord = getOneWord(wordParam);
ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build();
ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build();
//构造tool call
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build();
String content
= "{ " +
"\"wordLength\": \"3\", " +
"\"language\": \"zh\", " +
"\"word\": \"" + oneWord + "\"," +
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
"}";
Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build();
List<Message> messageList = Arrays.asList(message, message2, message3);
ChatCompletion chatCompletionV2 = ChatCompletion
.builder()
.messages(messageList)
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
.build();
CountDownLatch countDownLatch1 = new CountDownLatch(1);
openAiStreamClient.streamChatCompletion(chatCompletionV2, new ConsoleEventSourceListenerV3(countDownLatch));
try {
countDownLatch1.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
try {
countDownLatch1.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
@Data
@Builder
static class WordParam {
private int wordLength;
@Builder.Default
private String language = "zh";
}
/**
* 获取一个词语(根据语言和字符长度查询)
* @param wordParam
* @return
*/
public String getOneWord(WordParam wordParam) {
List<String> zh = Arrays.asList("大香蕉", "哈密瓜", "苹果");
List<String> en = Arrays.asList("apple", "banana", "cantaloupe");
if (wordParam.getLanguage().equals("zh")) {
for (String e : zh) {
if (e.length() == wordParam.getWordLength()) {
return e;
}
}
}
if (wordParam.getLanguage().equals("en")) {
for (String e : en) {
if (e.length() == wordParam.getWordLength()) {
return e;
}
}
}
return "西瓜";
}
}

View File

@@ -0,0 +1,24 @@
package org.ruoyi.common.chat.demo;
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
public class WeatherPlugin extends PluginAbstract<WeatherReq, WeatherResp> {
public WeatherPlugin(Class<?> r) {
super(r);
}
@Override
public WeatherResp func(WeatherReq args) {
WeatherResp weatherResp = new WeatherResp();
weatherResp.setTemp("25到28摄氏度");
weatherResp.setLevel(3);
return weatherResp;
}
@Override
public String content(WeatherResp weatherResp) {
return "当前天气温度:" + weatherResp.getTemp() + ",风力等级:" + weatherResp.getLevel();
}
}

View File

@@ -0,0 +1,13 @@
package org.ruoyi.common.chat.demo;
import lombok.Data;
import org.ruoyi.common.chat.openai.plugin.PluginParam;
@Data
public class WeatherReq extends PluginParam {
/**
* 城市
*/
private String location;
}

View File

@@ -0,0 +1,15 @@
package org.ruoyi.common.chat.demo;
import lombok.Data;
@Data
public class WeatherResp {
/**
* 温度
*/
private String temp;
/**
* 风力等级
*/
private Integer level;
}

View File

@@ -0,0 +1,122 @@
package org.ruoyi.common.chat.demo.zhipu;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.Constants;
import com.zhipu.oapi.service.v4.deserialize.MessageDeserializeFactory;
import com.zhipu.oapi.service.v4.model.*;
import io.reactivex.Flowable;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class AllToolsTest {
private final static Logger logger = LoggerFactory.getLogger(AllToolsTest.class);
private static final String API_SECRET_KEY = "28550a39d4cfaabbbf38df04dd3931f5.IUvfTThUf0xBF5l0";
private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY)
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
.build();
private static final ObjectMapper mapper = MessageDeserializeFactory.defaultObjectMapper();
// 请自定义自己的业务id
private static final String requestIdTemplate = "mycompany-%d";
@Test
public void test1() throws JsonProcessingException {
List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "帮我查询北京天气");
messages.add(chatMessage);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
// 函数调用参数构建部分
List<ChatTool> chatToolList = new ArrayList<>();
ChatTool chatTool = new ChatTool();
chatTool.setType("code_interpreter");
ObjectNode objectNode = mapper.createObjectNode();
objectNode.put("code", "北京天气");
// chatTool.set(chatFunction);
chatToolList.add(chatTool);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model("glm-4-alltools")
.stream(Boolean.TRUE)
.invokeMethod(Constants.invokeMethod)
.messages(messages)
.tools(chatToolList)
.toolChoice("auto")
.requestId(requestId)
.build();
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
if (sseModelApiResp.isSuccess()) {
AtomicBoolean isFirst = new AtomicBoolean(true);
List<Choice> choices = new ArrayList<>();
AtomicReference<ChatMessageAccumulator> lastAccumulator = new AtomicReference<>();
mapStreamToAccumulator(sseModelApiResp.getFlowable())
.doOnNext(accumulator -> {
{
if (isFirst.getAndSet(false)) {
logger.info("Response: ");
}
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
logger.info("tool_calls: {}", jsonString);
}
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
logger.info(accumulator.getDelta().getContent());
}
choices.add(accumulator.getChoice());
lastAccumulator.set(accumulator);
}
})
.doOnComplete(() -> System.out.println("Stream completed."))
.doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors
.blockingSubscribe();// Use blockingSubscribe instead of blockingGet()
ChatMessageAccumulator chatMessageAccumulator = lastAccumulator.get();
ModelData data = new ModelData();
data.setChoices(choices);
if (chatMessageAccumulator != null) {
data.setUsage(chatMessageAccumulator.getUsage());
data.setId(chatMessageAccumulator.getId());
data.setCreated(chatMessageAccumulator.getCreated());
}
data.setRequestId(chatCompletionRequest.getRequestId());
sseModelApiResp.setFlowable(null);// 打印前置空
sseModelApiResp.setData(data);
}
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
client.getConfig().getHttpClient().dispatcher().executorService().shutdown();
client.getConfig().getHttpClient().connectionPool().evictAll();
// List all active threads
for (Thread t : Thread.getAllStackTraces().keySet()) {
logger.info("Thread: " + t.getName() + " State: " + t.getState());
}
}
public static Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ModelData> flowable) {
return flowable.map(chunk -> {
return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
});
}
}

View File

@@ -0,0 +1,246 @@
package org.ruoyi.common.chat.demo.zhipu;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.Constants;
import com.zhipu.oapi.service.v4.tools.*;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.zhipu.oapi.core.response.HttpxBinaryResponseContent;
import com.zhipu.oapi.service.v4.batchs.BatchCreateParams;
import com.zhipu.oapi.service.v4.batchs.BatchResponse;
import com.zhipu.oapi.service.v4.batchs.QueryBatchResponse;
import com.zhipu.oapi.service.v4.embedding.EmbeddingApiResponse;
import com.zhipu.oapi.service.v4.embedding.EmbeddingRequest;
import com.zhipu.oapi.service.v4.file.*;
import com.zhipu.oapi.service.v4.fine_turning.*;
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
import com.zhipu.oapi.service.v4.model.*;
import io.reactivex.Flowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
public class WebSearchToolsTest {
private final static Logger logger = LoggerFactory.getLogger(WebSearchToolsTest.class);
private static final String API_SECRET_KEY = "xx";
private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY)
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
.build();
private static final ObjectMapper mapper = new ObjectMapper();
// 请自定义自己的业务id
private static final String requestIdTemplate = "mycompany-%d";
@Test
public void test1() throws JsonProcessingException {
// json 转换 ArrayList<SearchChatMessage>
String jsonString = "[\n" +
" {\n" +
" \"content\": \"今天武汉天气怎么样\",\n" +
" \"role\": \"user\"\n" +
" }\n" +
" ]";
ArrayList<SearchChatMessage> messages = new ObjectMapper().readValue(jsonString, new TypeReference<ArrayList<SearchChatMessage>>() {
});
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
.model("web-search-pro")
.stream(Boolean.TRUE)
.messages(messages)
.requestId(requestId)
.build();
WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest);
if (webSearchApiResponse.isSuccess()) {
AtomicBoolean isFirst = new AtomicBoolean(true);
List<ChoiceDelta> choices = new ArrayList<>();
AtomicReference<WebSearchPro> lastAccumulator = new AtomicReference<>();
webSearchApiResponse.getFlowable().map(result -> result)
.doOnNext(accumulator -> {
{
if (isFirst.getAndSet(false)) {
logger.info("Response: ");
}
ChoiceDelta delta = accumulator.getChoices().get(0).getDelta();
if (delta != null && delta.getToolCalls() != null) {
logger.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls()));
}
choices.add(delta);
lastAccumulator.set(accumulator);
}
})
.doOnComplete(() -> System.out.println("Stream completed."))
.doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors
.blockingSubscribe();// Use blockingSubscribe instead of blockingGet()
WebSearchPro chatMessageAccumulator = lastAccumulator.get();
webSearchApiResponse.setFlowable(null);// 打印前置空
webSearchApiResponse.setData(chatMessageAccumulator);
}
logger.info("model output: {}", mapper.writeValueAsString(webSearchApiResponse));
client.getConfig().getHttpClient().dispatcher().executorService().shutdown();
client.getConfig().getHttpClient().connectionPool().evictAll();
// List all active threads
for (Thread t : Thread.getAllStackTraces().keySet()) {
logger.info("Thread: " + t.getName() + " State: " + t.getState());
}
}
@Test
public void test2() throws JsonProcessingException {
// json 转换 ArrayList<SearchChatMessage>
String jsonString = "[\n" +
" {\n" +
" \"content\": \"今天天气怎么样\",\n" +
" \"role\": \"user\"\n" +
" }\n" +
" ]";
ArrayList<SearchChatMessage> messages = new ObjectMapper().readValue(jsonString, new TypeReference<ArrayList<SearchChatMessage>>() {
});
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
.model("web-search-pro")
.stream(Boolean.FALSE)
.messages(messages)
.requestId(requestId)
.build();
WebSearchApiResponse webSearchApiResponse = client.invokeWebSearchPro(chatCompletionRequest);
logger.info("model output: {}", mapper.writeValueAsString(webSearchApiResponse));
}
@Test
public void testFunctionSSE() throws JsonProcessingException {
List<ChatMessage> messages = new ArrayList<>();
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "成都到北京要多久,天气如何");
messages.add(chatMessage);
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
// 函数调用参数构建部分
List<ChatTool> chatToolList = new ArrayList<>();
ChatTool chatTool = new ChatTool();
chatTool.setType(ChatToolType.FUNCTION.value());
ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters();
chatFunctionParameters.setType("object");
Map<String, Object> properties = new HashMap<>();
properties.put("location", new HashMap<String, Object>() {{
put("type", "string");
put("description", "城市,如:北京");
}});
properties.put("unit", new HashMap<String, Object>() {{
put("type", "string");
put("enum", new ArrayList<String>() {{
add("celsius");
add("fahrenheit");
}});
}});
chatFunctionParameters.setProperties(properties);
ChatFunction chatFunction = ChatFunction.builder()
.name("get_weather")
.description("Get the current weather of a location")
.parameters(chatFunctionParameters)
.build();
chatTool.setFunction(chatFunction);
chatToolList.add(chatTool);
HashMap<String, Object> extraJson = new HashMap<>();
extraJson.put("temperature", 0.5);
extraJson.put("max_tokens", 50);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(Constants.ModelChatGLM4)
.stream(Boolean.TRUE)
.messages(messages)
.requestId(requestId)
.tools(chatToolList)
.toolChoice("auto")
.extraJson(extraJson)
.build();
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
if (sseModelApiResp.isSuccess()) {
AtomicBoolean isFirst = new AtomicBoolean(true);
List<Choice> choices = new ArrayList<>();
ChatMessageAccumulator chatMessageAccumulator = mapStreamToAccumulator(sseModelApiResp.getFlowable())
.doOnNext(accumulator -> {
{
if (isFirst.getAndSet(false)) {
logger.info("Response: ");
}
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
logger.info("tool_calls: {}", jsonString);
}
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
logger.info(accumulator.getDelta().getContent());
}
choices.add(accumulator.getChoice());
}
})
.doOnComplete(System.out::println)
.lastElement()
.blockingGet();
ModelData data = new ModelData();
data.setChoices(choices);
data.setUsage(chatMessageAccumulator.getUsage());
data.setId(chatMessageAccumulator.getId());
data.setCreated(chatMessageAccumulator.getCreated());
data.setRequestId(chatCompletionRequest.getRequestId());
sseModelApiResp.setFlowable(null);// 打印前置空
sseModelApiResp.setData(data);
}
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
}
public static Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ModelData> flowable) {
return flowable.map(chunk -> {
return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
});
}
}

View File

@@ -2,12 +2,11 @@ package org.ruoyi.common.chat.entity.chat;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
import java.io.Serializable;
import java.util.List;
/**
* 描述:
@@ -18,21 +17,10 @@ import java.io.Serializable;
@Data
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonIgnoreProperties(ignoreUnknown = true)
public class Message implements Serializable {
/**
* 目前支持四个中角色参考官网,进行情景输入:
* https://platform.openai.com/docs/guides/chat/introduction
*/
private String role;
public class Message extends BaseMessage implements Serializable {
private Object content;
private String name;
@JsonProperty("function_call")
private FunctionCall functionCall;
public static Builder builder() {
return new Builder();
}
@@ -41,44 +29,37 @@ public class Message implements Serializable {
* 构造函数
*
* @param role 角色
* @param content 描述主题信息
* @param name name
* @param content content
* @param functionCall functionCall
*/
public Message(String role, String content, String name, FunctionCall functionCall) {
this.role = role;
public Message(String role, String name, String content, List<ToolCalls> toolCalls, String toolCallId, FunctionCall functionCall) {
this.content = content;
this.name = name;
this.functionCall = functionCall;
super.setRole(role);
super.setName(name);
super.setToolCalls(toolCalls);
super.setToolCallId(toolCallId);
super.setFunctionCall(functionCall);
}
public Message() {
}
private Message(Builder builder) {
setRole(builder.role);
setContent(builder.content);
setName(builder.name);
setFunctionCall(builder.functionCall);
}
@Getter
@AllArgsConstructor
public enum Role {
SYSTEM("system"),
USER("user"),
ASSISTANT("assistant"),
FUNCTION("function"),
;
private String name;
super.setRole(builder.role);
super.setName(builder.name);
super.setFunctionCall(builder.functionCall);
super.setToolCalls(builder.toolCalls);
super.setToolCallId(builder.toolCallId);
}
public static final class Builder {
private String role;
private String content;
private String name;
private String toolCallId;
private List<ToolCalls> toolCalls;
private FunctionCall functionCall;
public Builder() {
@@ -109,6 +90,16 @@ public class Message implements Serializable {
return this;
}
public Builder toolCalls(List<ToolCalls> toolCalls) {
this.toolCalls = toolCalls;
return this;
}
public Builder toolCallId(String toolCallId) {
this.toolCallId = toolCallId;
return this;
}
public Message build() {
return new Message(this);
}

View File

@@ -24,5 +24,6 @@ public class ResponseFormat {
TEXT("text"),
;
private final String name;
}
}

View File

@@ -75,6 +75,8 @@ public class WebSocketEventListener extends EventSourceListener {
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
// 返回非流式回复内容
if(response.code() == OpenAIConst.SUCCEED_CODE){

View File

@@ -2,6 +2,7 @@ package org.ruoyi.common.chat.openai;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import io.reactivex.Single;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
@@ -12,9 +13,7 @@ import okhttp3.RequestBody;
import org.ruoyi.common.chat.constant.OpenAIConst;
import org.ruoyi.common.chat.entity.billing.BillingUsage;
import org.ruoyi.common.chat.entity.billing.Subscription;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.entity.chat.*;
import org.ruoyi.common.chat.entity.common.DeleteResponse;
import org.ruoyi.common.chat.entity.common.OpenAiResponse;
import org.ruoyi.common.chat.entity.completions.Completion;
@@ -43,6 +42,8 @@ import org.ruoyi.common.chat.openai.function.KeyStrategyFunction;
import org.ruoyi.common.chat.openai.interceptor.DefaultOpenAiAuthInterceptor;
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
import org.ruoyi.common.chat.openai.interceptor.OpenAiAuthInterceptor;
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
import org.ruoyi.common.chat.openai.plugin.PluginParam;
import org.ruoyi.common.core.exception.base.BaseException;
import org.jetbrains.annotations.NotNull;
import retrofit2.Retrofit;
@@ -696,6 +697,90 @@ public class OpenAiClient {
return whisperResponse.blockingGet();
}
/**
* 插件问答简易版
* 默认取messages最后一个元素构建插件对话
* 默认模型ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
*
* @param chatCompletion 参数
* @param plugin 插件
* @param <R> 插件自定义函数的请求值
* @param <T> 插件自定义函数的返回值
* @return ChatCompletionResponse
*/
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(ChatCompletion chatCompletion, PluginAbstract<R, T> plugin) {
if (Objects.isNull(plugin)) {
return this.chatCompletion(chatCompletion);
}
if (CollectionUtil.isEmpty(chatCompletion.getMessages())) {
throw new BaseException(CommonError.MESSAGE_NOT_NUL.msg());
}
List<Message> messages = chatCompletion.getMessages();
Functions functions = Functions.builder()
.name(plugin.getFunction())
.description(plugin.getDescription())
.parameters(plugin.getParameters())
.build();
//没有值,设置默认值
if (Objects.isNull(chatCompletion.getFunctionCall())) {
chatCompletion.setFunctionCall("auto");
}
//tip: 覆盖自己设置的functions参数使用plugin构造的functions
chatCompletion.setFunctions(Collections.singletonList(functions));
//调用OpenAi
ChatCompletionResponse functionCallChatCompletionResponse = this.chatCompletion(chatCompletion);
ChatChoice chatChoice = functionCallChatCompletionResponse.getChoices().get(0);
log.debug("构造的方法值:{}", chatChoice.getMessage().getFunctionCall());
R realFunctionParam = (R) JSONUtil.toBean(chatChoice.getMessage().getFunctionCall().getArguments(), plugin.getR());
T tq = plugin.func(realFunctionParam);
FunctionCall functionCall = FunctionCall.builder()
.arguments(chatChoice.getMessage().getFunctionCall().getArguments())
.name(plugin.getFunction())
.build();
messages.add(Message.builder().role(Message.Role.ASSISTANT).content("function_call").functionCall(functionCall).build());
messages.add(Message.builder().role(Message.Role.FUNCTION).name(plugin.getFunction()).content(plugin.content(tq)).build());
//设置第二次,请求的参数
chatCompletion.setFunctionCall(null);
chatCompletion.setFunctions(null);
ChatCompletionResponse chatCompletionResponse = this.chatCompletion(chatCompletion);
log.debug("自定义的方法返回值:{}", chatCompletionResponse.getChoices());
return chatCompletionResponse;
}
/**
* 插件问答简易版
* 默认取messages最后一个元素构建插件对话
* 默认模型ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
*
* @param messages 问答参数
* @param plugin 插件
* @param <R> 插件自定义函数的请求值
* @param <T> 插件自定义函数的返回值
* @return ChatCompletionResponse
*/
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(List<Message> messages, PluginAbstract<R, T> plugin) {
return chatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), plugin);
}
/**
* 插件问答简易版
* 默认取messages最后一个元素构建插件对话
*
* @param messages 问答参数
* @param model 模型
* @param plugin 插件
* @param <R> 插件自定义函数的请求值
* @param <T> 插件自定义函数的返回值
* @return ChatCompletionResponse
*/
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(List<Message> messages, String model, PluginAbstract<R, T> plugin) {
ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).model(model).build();
return this.chatCompletionWithPlugin(chatCompletion, plugin);
}
/**
* 简易版 语音翻译:目前仅支持翻译为英文
*

View File

@@ -3,6 +3,7 @@ package org.ruoyi.common.chat.openai;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.ContentType;
import cn.hutool.json.JSONUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.reactivex.Single;
import lombok.Getter;
@@ -17,10 +18,7 @@ import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
import org.ruoyi.common.chat.entity.billing.BillingUsage;
import org.ruoyi.common.chat.entity.billing.KeyInfo;
import org.ruoyi.common.chat.entity.billing.Subscription;
import org.ruoyi.common.chat.entity.chat.BaseChatCompletion;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.chat.entity.chat.ChatCompletionWithPicture;
import org.ruoyi.common.chat.entity.chat.*;
import org.ruoyi.common.chat.entity.embeddings.Embedding;
import org.ruoyi.common.chat.entity.embeddings.EmbeddingResponse;
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
@@ -29,6 +27,7 @@ import org.ruoyi.common.chat.entity.images.ImageResponse;
import org.ruoyi.common.chat.entity.models.Model;
import org.ruoyi.common.chat.entity.models.ModelResponse;
import org.ruoyi.common.chat.entity.whisper.Transcriptions;
import org.ruoyi.common.chat.entity.whisper.Translations;
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
import org.ruoyi.common.chat.openai.exception.CommonError;
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
@@ -36,6 +35,10 @@ import org.ruoyi.common.chat.openai.function.KeyStrategyFunction;
import org.ruoyi.common.chat.openai.interceptor.DefaultOpenAiAuthInterceptor;
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
import org.ruoyi.common.chat.openai.interceptor.OpenAiAuthInterceptor;
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
import org.ruoyi.common.chat.openai.plugin.PluginParam;
import org.ruoyi.common.chat.sse.DefaultPluginListener;
import org.ruoyi.common.chat.sse.PluginListener;
import org.ruoyi.common.core.exception.base.BaseException;
import org.jetbrains.annotations.NotNull;
import retrofit2.Call;
@@ -186,6 +189,93 @@ public class OpenAiStreamClient {
}
}
/**
* 插件问答简易版
* 默认取messages最后一个元素构建插件对话
* 默认模型ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
*
* @param chatCompletion 参数
* @param eventSourceListener sse监听器
* @param pluginEventSourceListener 插件sse监听器收集function call返回信息
* @param plugin 插件
* @param <R> 插件自定义函数的请求值
* @param <T> 插件自定义函数的返回值
*/
public <R extends PluginParam, T> void streamChatCompletionWithPlugin(ChatCompletion chatCompletion, EventSourceListener eventSourceListener, PluginListener pluginEventSourceListener, PluginAbstract<R, T> plugin) {
if (Objects.isNull(plugin)) {
this.streamChatCompletion(chatCompletion, eventSourceListener);
return;
}
if (CollectionUtil.isEmpty(chatCompletion.getMessages())) {
throw new BaseException(CommonError.MESSAGE_NOT_NUL.msg());
}
Functions functions = Functions.builder()
.name(plugin.getFunction())
.description(plugin.getDescription())
.parameters(plugin.getParameters())
.build();
//没有值,设置默认值
if (Objects.isNull(chatCompletion.getFunctionCall())) {
chatCompletion.setFunctionCall("auto");
}
//tip: 覆盖自己设置的functions参数使用plugin构造的functions
chatCompletion.setFunctions(Collections.singletonList(functions));
//调用OpenAi
if (Objects.isNull(pluginEventSourceListener)) {
pluginEventSourceListener = new DefaultPluginListener(this, eventSourceListener, plugin, chatCompletion);
}
this.streamChatCompletion(chatCompletion, pluginEventSourceListener);
}
/**
* 插件问答简易版
* 默认取messages最后一个元素构建插件对话
* 默认模型ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
*
* @param chatCompletion 参数
* @param eventSourceListener sse监听器
* @param plugin 插件
* @param <R> 插件自定义函数的请求值
* @param <T> 插件自定义函数的返回值
*/
public <R extends PluginParam, T> void streamChatCompletionWithPlugin(ChatCompletion chatCompletion, EventSourceListener eventSourceListener, PluginAbstract<R, T> plugin) {
PluginListener pluginEventSourceListener = new DefaultPluginListener(this, eventSourceListener, plugin, chatCompletion);
this.streamChatCompletionWithPlugin(chatCompletion, eventSourceListener, pluginEventSourceListener, plugin);
}
/**
* 插件问答简易版
* 默认取messages最后一个元素构建插件对话
* 默认模型ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
*
* @param messages 问答参数
* @param eventSourceListener sse监听器
* @param plugin 插件
* @param <R> 插件自定义函数的请求值
* @param <T> 插件自定义函数的返回值
*/
public <R extends PluginParam, T> void streamChatCompletionWithPlugin(List<Message> messages, EventSourceListener eventSourceListener, PluginAbstract<R, T> plugin) {
this.streamChatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), eventSourceListener, plugin);
}
/**
* 插件问答简易版
* 默认取messages最后一个元素构建插件对话
*
* @param messages 问答参数
* @param model 模型
* @param eventSourceListener eventSourceListener
* @param plugin 插件
* @param <R> 插件自定义函数的请求值
* @param <T> 插件自定义函数的返回值
*/
public <R extends PluginParam, T> void streamChatCompletionWithPlugin(List<Message> messages, String model, EventSourceListener eventSourceListener, PluginAbstract<R, T> plugin) {
ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).model(model).build();
this.streamChatCompletionWithPlugin(chatCompletion, eventSourceListener, plugin);
}
/**
* 根据描述生成图片
@@ -418,6 +508,95 @@ public class OpenAiStreamClient {
}
}
/**
* 插件问答简易版
* 默认取messages最后一个元素构建插件对话
* 默认模型ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
*
* @param chatCompletion 参数
* @param plugin 插件
* @param <R> 插件自定义函数的请求值
* @param <T> 插件自定义函数的返回值
* @return ChatCompletionResponse
*/
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(ChatCompletion chatCompletion, PluginAbstract<R, T> plugin) {
if (Objects.isNull(plugin)) {
return this.chatCompletion(chatCompletion);
}
if (CollectionUtil.isEmpty(chatCompletion.getMessages())) {
throw new BaseException(CommonError.MESSAGE_NOT_NUL.msg());
}
List<Message> messages = chatCompletion.getMessages();
Functions functions = Functions.builder()
.name(plugin.getFunction())
.description(plugin.getDescription())
.parameters(plugin.getParameters())
.build();
//没有值,设置默认值
if (Objects.isNull(chatCompletion.getFunctionCall())) {
chatCompletion.setFunctionCall("auto");
}
//tip: 覆盖自己设置的functions参数使用plugin构造的functions
chatCompletion.setFunctions(Collections.singletonList(functions));
//调用OpenAi
ChatCompletionResponse functionCallChatCompletionResponse = this.chatCompletion(chatCompletion);
ChatChoice chatChoice = functionCallChatCompletionResponse.getChoices().get(0);
log.debug("构造的方法值:{}", chatChoice.getMessage().getFunctionCall());
R realFunctionParam = (R) JSONUtil.toBean(chatChoice.getMessage().getFunctionCall().getArguments(), plugin.getR());
T tq = plugin.func(realFunctionParam);
FunctionCall functionCall = FunctionCall.builder()
.arguments(chatChoice.getMessage().getFunctionCall().getArguments())
.name(plugin.getFunction())
.build();
messages.add(Message.builder().role(Message.Role.ASSISTANT).content("function_call").functionCall(functionCall).build());
messages.add(Message.builder().role(Message.Role.FUNCTION).name(plugin.getFunction()).content(plugin.content(tq)).build());
//设置第二次,请求的参数
chatCompletion.setFunctionCall(null);
chatCompletion.setFunctions(null);
ChatCompletionResponse chatCompletionResponse = this.chatCompletion(chatCompletion);
log.debug("自定义的方法返回值:{}", chatCompletionResponse.getChoices());
return chatCompletionResponse;
}
/**
* 插件问答简易版
* 默认取messages最后一个元素构建插件对话
* 默认模型ChatCompletion.Model.GPT_3_5_TURBO_16K_0613
*
* @param messages 问答参数
* @param plugin 插件
* @param <R> 插件自定义函数的请求值
* @param <T> 插件自定义函数的返回值
* @return ChatCompletionResponse
*/
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(List<Message> messages, PluginAbstract<R, T> plugin) {
return chatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), plugin);
}
/**
* 插件问答简易版
* 默认取messages最后一个元素构建插件对话
*
* @param messages 问答参数
* @param model 模型
* @param plugin 插件
* @param <R> 插件自定义函数的请求值
* @param <T> 插件自定义函数的返回值
* @return ChatCompletionResponse
*/
public <R extends PluginParam, T> ChatCompletionResponse chatCompletionWithPlugin(List<Message> messages, String model, PluginAbstract<R, T> plugin) {
ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).model(model).build();
return this.chatCompletionWithPlugin(chatCompletion, plugin);
}
/**
* 构造

View File

@@ -7,6 +7,7 @@ package org.ruoyi.common.chat.openai.exception;
* 2023-02-11
*/
public enum CommonError implements IError {
MESSAGE_NOT_NUL(500, "Message 不能为空"),
API_KEYS_NOT_NUL(500, "API KEYS 不能为空"),
NO_ACTIVE_API_KEYS(500, "没有可用的API KEYS"),
SYS_ERROR(500, "系统繁忙"),
@@ -19,8 +20,8 @@ public enum CommonError implements IError {
;
private int code;
private String msg;
private final int code;
private final String msg;
CommonError(int code, String msg) {
this.code = code;

View File

@@ -0,0 +1,88 @@
package org.ruoyi.common.chat.openai.plugin;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.json.JSONObject;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.ruoyi.common.chat.entity.chat.Parameters;
import java.util.List;
import java.util.stream.Collectors;
@Data
@AllArgsConstructor
public abstract class PluginAbstract<R extends PluginParam, T> {
private Class<?> R;
private String name;
private String function;
private String description;
private List<Arg> args;
private List<String> required;
private Parameters parameters;
public PluginAbstract(Class<?> r) {
R = r;
}
public void setRequired(List<String> required) {
if (CollectionUtil.isEmpty(required)) {
this.required = this.getArgs().stream().filter(e -> e.isRequired()).map(Arg::getName).collect(Collectors.toList());
return;
}
this.required = required;
}
private void setRequired() {
if (CollectionUtil.isEmpty(required)) {
this.required = this.getArgs().stream().filter(e -> e.isRequired()).map(Arg::getName).collect(Collectors.toList());
}
}
private void setParameters() {
JSONObject properties = new JSONObject();
args.forEach(e -> {
JSONObject param = new JSONObject();
param.putOpt("type", e.getType());
param.putOpt("enum", e.getEnumDictValue());
param.putOpt("description", e.getDescription());
properties.putOpt(e.getName(), param);
});
this.parameters = Parameters.builder()
.type("object")
.properties(properties)
.required(this.getRequired())
.build();
}
public void setArgs(List<Arg> args) {
this.args = args;
setRequired();
setParameters();
}
@Data
public static class Arg {
private String name;
private String type;
private String description;
@JsonIgnore
private boolean enumDict;
@JsonProperty("enum")
private List<String> enumDictValue;
@JsonIgnore
private boolean required;
}
public abstract T func(R args);
public abstract String content(T t);
}

View File

@@ -0,0 +1,7 @@
package org.ruoyi.common.chat.openai.plugin;
import lombok.Data;
@Data
public class PluginParam {
}

View File

@@ -0,0 +1,36 @@
package org.ruoyi.common.chat.plugin;
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
import java.io.IOException;
public class CmdPlugin extends PluginAbstract<CmdReq, CmdResp> {
public CmdPlugin(Class<?> r) {
super(r);
}
@Override
public CmdResp func(CmdReq args) {
try {
if("计算器".equals(args.getCmd())){
Runtime.getRuntime().exec("calc");
}else if("记事本".equals(args.getCmd())){
Runtime.getRuntime().exec("notepad");
}else if("命令行".equals(args.getCmd())){
String [] cmd={"cmd","/C","start copy exel exe2"};
Runtime.getRuntime().exec(cmd);
}
} catch (IOException e) {
throw new RuntimeException("指令执行失败");
}
CmdResp resp = new CmdResp();
resp.setResult(args.getCmd()+"指令执行成功!");
return resp;
}
@Override
public String content(CmdResp resp) {
return resp.getResult();
}
}

View File

@@ -0,0 +1,13 @@
package org.ruoyi.common.chat.plugin;
import lombok.Data;
import org.ruoyi.common.chat.openai.plugin.PluginParam;
@Data
public class CmdReq extends PluginParam {
/**
* 指令
*/
private String cmd;
}

View File

@@ -0,0 +1,12 @@
package org.ruoyi.common.chat.plugin;
import lombok.Data;
@Data
public class CmdResp {
/**
* 返回结果
*/
private String result;
}

View File

@@ -0,0 +1,88 @@
package org.ruoyi.common.chat.plugin;
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
import java.sql.*;
/**
* @author ageer
*/
public class SqlPlugin extends PluginAbstract<SqlReq, SqlResp> {
public SqlPlugin(Class<?> r) {
super(r);
}
@Override
public SqlResp func(SqlReq args) {
SqlResp resp = new SqlResp();
resp.setUserBalance(getBalance(args.getUsername()));
return resp;
}
@Override
public String content(SqlResp resp) {
return "用户余额:"+resp.getUserBalance();
}
public String getBalance(String userName) {
// MySQL 8.0 以下版本 - JDBC 驱动名及数据库 URL
String JDBC_DRIVER = "com.mysql.cj.jdbc.Driver";
String DB_URL = "jdbc:mysql://43.139.70.230:3306/ry-vue";
// 数据库的用户名与密码,需要根据自己的设置
String USER = "ry-vue";
String PASS = "BXZiGsY35K523Xfx";
Connection conn = null;
Statement stmt = null;
String balance = "0.1";
try{
// 注册 JDBC 驱动
Class.forName(JDBC_DRIVER);
// 打开链接
System.out.println("连接数据库...");
conn = DriverManager.getConnection(DB_URL,USER,PASS);
// 执行查询
System.out.println(" 实例化Statement对象...");
stmt = conn.createStatement();
String sql;
sql = "SELECT user_balance FROM sys_user where user_name ='" + userName + "'";
ResultSet rs = stmt.executeQuery(sql);
// 展开结果集数据库
while(rs.next()){
// 通过字段检索
balance = rs.getString("user_balance");
// 输出数据
System.out.print("余额: " + balance);
System.out.print("\n");
}
// 完成后关闭
rs.close();
stmt.close();
conn.close();
}catch(SQLException se){
// 处理 JDBC 错误
se.printStackTrace();
}catch(Exception e){
// 处理 Class.forName 错误
e.printStackTrace();
}finally{
// 关闭资源
try{
if(stmt!=null) stmt.close();
}catch(SQLException se2){
}// 什么都不做
try{
if(conn!=null) conn.close();
}catch(SQLException se){
se.printStackTrace();
}
}
return balance;
}
}

View File

@@ -0,0 +1,13 @@
package org.ruoyi.common.chat.plugin;
import lombok.Data;
import org.ruoyi.common.chat.openai.plugin.PluginParam;
@Data
public class SqlReq extends PluginParam {
/**
* 用户名称
*/
private String username;
}

View File

@@ -0,0 +1,12 @@
package org.ruoyi.common.chat.plugin;
import lombok.Data;
@Data
public class SqlResp {
/**
* 用户余额
*/
private String userBalance;
}

View File

@@ -0,0 +1,56 @@
package org.ruoyi.common.chat.sse;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import java.util.Objects;
/**
* 描述: sse
*
* @author https:www.unfbx.com
* 2023-02-28
*/
@Slf4j
public class ConsoleEventSourceListener extends EventSourceListener {
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("OpenAI建立sse连接...");
}
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("OpenAI返回数据{}", data);
if ("[DONE]".equals(data)) {
log.info("OpenAI返回数据结束了");
return;
}
}
@Override
public void onClosed(EventSource eventSource) {
log.info("OpenAI关闭sse连接...");
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if(Objects.isNull(response)){
log.error("OpenAI sse连接异常:{}", t);
eventSource.cancel();
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
log.error("OpenAI sse连接异常data{},异常:{}", body.string(), t);
} else {
log.error("OpenAI sse连接异常data{},异常:{}", response, t);
}
eventSource.cancel();
}
}

View File

@@ -0,0 +1,22 @@
package org.ruoyi.common.chat.sse;
import lombok.extern.slf4j.Slf4j;
import okhttp3.sse.EventSourceListener;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
/**
* 描述: 插件开发返回信息收集sse监听器
*
* @author https:www.unfbx.com
* 2023-08-18
*/
@Slf4j
public class DefaultPluginListener extends PluginListener {
public DefaultPluginListener(OpenAiStreamClient client, EventSourceListener eventSourceListener, PluginAbstract plugin, ChatCompletion chatCompletion) {
super(client, eventSourceListener, plugin, chatCompletion);
}
}

View File

@@ -0,0 +1,126 @@
package org.ruoyi.common.chat.sse;
import cn.hutool.json.JSONUtil;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;
import org.ruoyi.common.chat.constant.OpenAIConst;
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
import org.ruoyi.common.chat.entity.chat.FunctionCall;
import org.ruoyi.common.chat.entity.chat.Message;
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
import org.ruoyi.common.chat.openai.plugin.PluginParam;
import java.util.Objects;
/**
* 描述: 插件开发返回信息收集sse监听器
*
* @author https:www.unfbx.com
* 2023-08-18
*/
@Slf4j
public abstract class PluginListener<R extends PluginParam, T> extends EventSourceListener {
/**
* openAi插件构建的参数
*/
private String arguments = "";
/**
* 获取openAi插件构建的参数
*
* @return arguments
*/
private String getArguments() {
return this.arguments;
}
private OpenAiStreamClient client;
private EventSourceListener eventSourceListener;
private PluginAbstract<R, T> plugin;
private ChatCompletion chatCompletion;
/**
* 构造方法必备四个元素
*
* @param client OpenAiStreamClient
* @param eventSourceListener 处理真实第二次sse请求的自定义监听
* @param plugin 插件信息
* @param chatCompletion 请求参数
*/
public PluginListener(OpenAiStreamClient client, EventSourceListener eventSourceListener, PluginAbstract<R, T> plugin, ChatCompletion chatCompletion) {
this.client = client;
this.eventSourceListener = eventSourceListener;
this.plugin = plugin;
this.chatCompletion = chatCompletion;
}
/**
* sse关闭后处理第二次请求方法
*/
public void onClosedAfter() {
log.debug("构造的方法值:{}", getArguments());
R realFunctionParam = (R) JSONUtil.toBean(getArguments(), plugin.getR());
T tq = plugin.func(realFunctionParam);
FunctionCall functionCall = FunctionCall.builder()
.arguments(getArguments())
.name(plugin.getFunction())
.build();
chatCompletion.getMessages().add(Message.builder().role(Message.Role.ASSISTANT).content("function_call").functionCall(functionCall).build());
chatCompletion.getMessages().add(Message.builder().role(Message.Role.FUNCTION).name(plugin.getFunction()).content(plugin.content(tq)).build());
//设置第二次,请求的参数
chatCompletion.setFunctionCall(null);
chatCompletion.setFunctions(null);
client.streamChatCompletion(chatCompletion, eventSourceListener);
}
@SneakyThrows
@Override
public final void onEvent(@NotNull EventSource eventSource, String id, String type, String data) {
log.debug("插件开发返回信息收集sse监听器返回数据{}", data);
if ("[DONE]".equals(data)) {
log.debug("插件开发返回信息收集sse监听器返回数据结束了");
return;
}
ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class);
if (Objects.nonNull(chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall())) {
this.arguments += chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall().getArguments();
}
}
@Override
public final void onClosed(EventSource eventSource) {
log.debug("插件开发返回信息收集sse监听器关闭连接...");
this.onClosedAfter();
}
@Override
public void onOpen(EventSource eventSource, Response response) {
log.debug("插件开发返回信息收集sse监听器建立连接...");
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if (Objects.isNull(response)) {
log.error("插件开发返回信息收集sse监听器,连接异常:{}", t);
eventSource.cancel();
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
log.error("插件开发返回信息收集sse监听器,连接异常data{},异常:{}", body.string(), t);
} else {
log.error("插件开发返回信息收集sse监听器,连接异常data{},异常:{}", response, t);
}
eventSource.cancel();
}
}

View File

@@ -15,4 +15,11 @@ public interface UserService {
*/
String selectUserNameById(Long userId);
/**
* 通过用户名称查询余额
*
* @param userName
* @return
*/
String selectUserByName(String userName);
}