mirror of
https://gitcode.com/ageerle/ruoyi-ai.git
synced 2026-04-16 21:33:40 +00:00
feat: 重构模块
This commit is contained in:
@@ -1,56 +0,0 @@
|
||||
package org.ruoyi.common.chat.config;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.logging.HttpLoggingInterceptor;
|
||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
|
||||
import org.ruoyi.common.chat.openai.interceptor.OpenAILogger;
|
||||
import org.ruoyi.common.core.service.ConfigService;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Chat配置类
|
||||
*
|
||||
* @date: 2023/5/16
|
||||
*/
|
||||
@Configuration
|
||||
@RequiredArgsConstructor
|
||||
public class ChatConfig {
|
||||
|
||||
@Getter
|
||||
private OpenAiStreamClient openAiStreamClient;
|
||||
|
||||
private final ConfigService configService;
|
||||
|
||||
// 重启才会生效
|
||||
@Bean
|
||||
public OpenAiStreamClient openAiStreamClient() {
|
||||
String apiHost = configService.getConfigValue("chat", "apiHost");
|
||||
String apiKey = configService.getConfigValue("chat", "apiKey");
|
||||
openAiStreamClient = createOpenAiStreamClient(apiHost,apiKey);
|
||||
return openAiStreamClient;
|
||||
}
|
||||
|
||||
public OpenAiStreamClient createOpenAiStreamClient(String apiHost, String apiKey) {
|
||||
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
|
||||
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
|
||||
OkHttpClient okHttpClient = new OkHttpClient.Builder()
|
||||
.addInterceptor(httpLoggingInterceptor)
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.writeTimeout(600, TimeUnit.SECONDS)
|
||||
.readTimeout(600, TimeUnit.SECONDS)
|
||||
.build();
|
||||
return OpenAiStreamClient.builder()
|
||||
.apiHost(apiHost)
|
||||
.apiKey(Collections.singletonList(apiKey))
|
||||
.keyStrategy(new KeyRandomStrategy())
|
||||
.okHttpClient(okHttpClient)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,9 @@ import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
@Data
|
||||
public class WebSocketProperties {
|
||||
|
||||
/**
|
||||
* 是否开启
|
||||
*/
|
||||
private Boolean enabled;
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
package org.ruoyi.common.chat.demo;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Response;
|
||||
import okhttp3.ResponseBody;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
/**
|
||||
* 描述: sse
|
||||
*
|
||||
* @author https:www.unfbx.com
|
||||
* 2023-06-15
|
||||
*/
|
||||
@Slf4j
|
||||
public class ConsoleEventSourceListenerV2 extends EventSourceListener {
|
||||
@Getter
|
||||
String args = "";
|
||||
final CountDownLatch countDownLatch;
|
||||
|
||||
public ConsoleEventSourceListenerV2(CountDownLatch countDownLatch) {
|
||||
this.countDownLatch = countDownLatch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onOpen(EventSource eventSource, Response response) {
|
||||
log.info("OpenAI建立sse连接...");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEvent(EventSource eventSource, String id, String type, String data) {
|
||||
log.info("OpenAI返回数据:{}", data);
|
||||
if (data.equals("[DONE]")) {
|
||||
log.info("OpenAI返回数据结束了");
|
||||
countDownLatch.countDown();
|
||||
return;
|
||||
}
|
||||
ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class);
|
||||
if(Objects.nonNull(chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall())){
|
||||
args += chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall().getArguments();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClosed(EventSource eventSource) {
|
||||
log.info("OpenAI关闭sse连接...");
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public void onFailure(EventSource eventSource, Throwable t, Response response) {
|
||||
if(Objects.isNull(response)){
|
||||
log.error("OpenAI sse连接异常:{}", t);
|
||||
eventSource.cancel();
|
||||
return;
|
||||
}
|
||||
ResponseBody body = response.body();
|
||||
if (Objects.nonNull(body)) {
|
||||
log.error("OpenAI sse连接异常data:{},异常:{}", body.string(), t);
|
||||
} else {
|
||||
log.error("OpenAI sse连接异常data:{},异常:{}", response, t);
|
||||
}
|
||||
eventSource.cancel();
|
||||
}
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
package org.ruoyi.common.chat.demo;
|
||||
|
||||
import cn.hutool.core.collection.CollectionUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import lombok.Getter;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Response;
|
||||
import okhttp3.ResponseBody;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction;
|
||||
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
/**
|
||||
* 描述: demo测试实现类,仅供思路参考
|
||||
*
|
||||
* @author https:www.unfbx.com
|
||||
* 2023-11-12
|
||||
*/
|
||||
@Slf4j
|
||||
public class ConsoleEventSourceListenerV3 extends EventSourceListener {
|
||||
@Getter
|
||||
List<ToolCalls> choices = new ArrayList<>();
|
||||
@Getter
|
||||
ToolCalls toolCalls = new ToolCalls();
|
||||
@Getter
|
||||
ToolCallFunction toolCallFunction = ToolCallFunction.builder().name("").arguments("").build();
|
||||
final CountDownLatch countDownLatch;
|
||||
|
||||
public ConsoleEventSourceListenerV3(CountDownLatch countDownLatch) {
|
||||
this.countDownLatch = countDownLatch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onOpen(EventSource eventSource, Response response) {
|
||||
log.info("OpenAI建立sse连接...");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEvent(EventSource eventSource, String id, String type, String data) {
|
||||
log.info("OpenAI返回数据:{}", data);
|
||||
if (data.equals("[DONE]")) {
|
||||
log.info("OpenAI返回数据结束了");
|
||||
return;
|
||||
}
|
||||
ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class);
|
||||
Message delta = chatCompletionResponse.getChoices().get(0).getDelta();
|
||||
if (CollectionUtil.isNotEmpty(delta.getToolCalls())) {
|
||||
choices.addAll(delta.getToolCalls());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClosed(EventSource eventSource) {
|
||||
if(CollectionUtil.isNotEmpty(choices)){
|
||||
toolCalls.setId(choices.get(0).getId());
|
||||
toolCalls.setType(choices.get(0).getType());
|
||||
choices.forEach(e -> {
|
||||
toolCallFunction.setName(e.getFunction().getName());
|
||||
toolCallFunction.setArguments(toolCallFunction.getArguments() + e.getFunction().getArguments());
|
||||
toolCalls.setFunction(toolCallFunction);
|
||||
});
|
||||
}
|
||||
log.info("OpenAI关闭sse连接...");
|
||||
countDownLatch.countDown();
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public void onFailure(EventSource eventSource, Throwable t, Response response) {
|
||||
if(Objects.isNull(response)){
|
||||
log.error("OpenAI sse连接异常:{}", t);
|
||||
eventSource.cancel();
|
||||
return;
|
||||
}
|
||||
ResponseBody body = response.body();
|
||||
if (Objects.nonNull(body)) {
|
||||
log.error("OpenAI sse连接异常data:{},异常:{}", body.string(), t);
|
||||
} else {
|
||||
log.error("OpenAI sse连接异常data:{},异常:{}", response, t);
|
||||
}
|
||||
eventSource.cancel();
|
||||
}
|
||||
}
|
||||
@@ -1,417 +0,0 @@
|
||||
package org.ruoyi.common.chat.demo;
|
||||
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.logging.HttpLoggingInterceptor;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.ruoyi.common.chat.entity.chat.*;
|
||||
import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction;
|
||||
import org.ruoyi.common.chat.entity.chat.tool.ToolCalls;
|
||||
import org.ruoyi.common.chat.entity.chat.tool.Tools;
|
||||
import org.ruoyi.common.chat.entity.chat.tool.ToolsFunction;
|
||||
import org.ruoyi.common.chat.openai.OpenAiClient;
|
||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||
import org.ruoyi.common.chat.openai.function.KeyRandomStrategy;
|
||||
import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor;
|
||||
import org.ruoyi.common.chat.openai.interceptor.OpenAILogger;
|
||||
import org.ruoyi.common.chat.openai.interceptor.OpenAiResponseInterceptor;
|
||||
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||
import org.ruoyi.common.chat.plugin.CmdPlugin;
|
||||
import org.ruoyi.common.chat.plugin.CmdReq;
|
||||
import org.ruoyi.common.chat.sse.ConsoleEventSourceListener;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* 描述:
|
||||
*
|
||||
* @author ageerle@163.com
|
||||
* date 2025/3/8
|
||||
*/
|
||||
@Slf4j
|
||||
public class PluginTest {
|
||||
|
||||
private OpenAiClient openAiClient;
|
||||
private OpenAiStreamClient openAiStreamClient;
|
||||
|
||||
@Before
|
||||
public void before() {
|
||||
//可以为null
|
||||
// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
|
||||
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
|
||||
//!!!!千万别再生产或者测试环境打开BODY级别日志!!!!
|
||||
//!!!生产或者测试环境建议设置为这三种级别:NONE,BASIC,HEADERS,!!!
|
||||
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
|
||||
OkHttpClient okHttpClient = new OkHttpClient
|
||||
.Builder()
|
||||
// .proxy(proxy)
|
||||
.addInterceptor(httpLoggingInterceptor)
|
||||
.addInterceptor(new OpenAiResponseInterceptor())
|
||||
.connectTimeout(10, TimeUnit.SECONDS)
|
||||
.writeTimeout(30, TimeUnit.SECONDS)
|
||||
.readTimeout(30, TimeUnit.SECONDS)
|
||||
.build();
|
||||
openAiClient = OpenAiClient.builder()
|
||||
//支持多key传入,请求时候随机选择
|
||||
.apiKey(Arrays.asList("sk-xx"))
|
||||
//自定义key的获取策略:默认KeyRandomStrategy
|
||||
//.keyStrategy(new KeyRandomStrategy())
|
||||
.keyStrategy(new KeyRandomStrategy())
|
||||
.okHttpClient(okHttpClient)
|
||||
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址)
|
||||
.apiHost("https://api.pandarobot.chat/")
|
||||
.build();
|
||||
|
||||
openAiStreamClient = OpenAiStreamClient.builder()
|
||||
//支持多key传入,请求时候随机选择
|
||||
.apiKey(Arrays.asList("sk-xx"))
|
||||
//自定义key的获取策略:默认KeyRandomStrategy
|
||||
.keyStrategy(new KeyRandomStrategy())
|
||||
.authInterceptor(new DynamicKeyOpenAiAuthInterceptor())
|
||||
.okHttpClient(okHttpClient)
|
||||
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址)
|
||||
.apiHost("https://api.pandarobot.chat/")
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void chatFunction() {
|
||||
//模型:GPT_3_5_TURBO_16K_0613
|
||||
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build();
|
||||
//属性一
|
||||
JSONObject wordLength = new JSONObject();
|
||||
wordLength.put("type", "number");
|
||||
wordLength.put("description", "词语的长度");
|
||||
//属性二
|
||||
JSONObject language = new JSONObject();
|
||||
language.put("type", "string");
|
||||
language.put("enum", Arrays.asList("zh", "en"));
|
||||
language.put("description", "语言类型,例如:zh代表中文、en代表英语");
|
||||
//参数
|
||||
JSONObject properties = new JSONObject();
|
||||
properties.put("wordLength", wordLength);
|
||||
properties.put("language", language);
|
||||
|
||||
Parameters parameters = Parameters.builder()
|
||||
.type("object")
|
||||
.properties(properties)
|
||||
.required(Collections.singletonList("wordLength")).build();
|
||||
Functions functions = Functions.builder()
|
||||
.name("getOneWord")
|
||||
.description("获取一个指定长度和语言类型的词语")
|
||||
.parameters(parameters)
|
||||
.build();
|
||||
|
||||
ChatCompletion chatCompletion = ChatCompletion
|
||||
.builder()
|
||||
.messages(Collections.singletonList(message))
|
||||
.functions(Collections.singletonList(functions))
|
||||
.functionCall("auto")
|
||||
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||
.build();
|
||||
ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion);
|
||||
|
||||
ChatChoice chatChoice = chatCompletionResponse.getChoices().get(0);
|
||||
log.info("构造的方法值:{}", chatChoice.getMessage().getFunctionCall());
|
||||
log.info("构造的方法名称:{}", chatChoice.getMessage().getFunctionCall().getName());
|
||||
log.info("构造的方法参数:{}", chatChoice.getMessage().getFunctionCall().getArguments());
|
||||
WordParam wordParam = JSONUtil.toBean(chatChoice.getMessage().getFunctionCall().getArguments(), WordParam.class);
|
||||
String oneWord = getOneWord(wordParam);
|
||||
|
||||
FunctionCall functionCall = FunctionCall.builder()
|
||||
.arguments(chatChoice.getMessage().getFunctionCall().getArguments())
|
||||
.name("getOneWord")
|
||||
.build();
|
||||
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").functionCall(functionCall).build();
|
||||
String content
|
||||
= "{ " +
|
||||
"\"wordLength\": \"3\", " +
|
||||
"\"language\": \"zh\", " +
|
||||
"\"word\": \"" + oneWord + "\"," +
|
||||
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
|
||||
"}";
|
||||
Message message3 = Message.builder().role(Message.Role.FUNCTION).name("getOneWord").content(content).build();
|
||||
List<Message> messageList = Arrays.asList(message, message2, message3);
|
||||
ChatCompletion chatCompletionV2 = ChatCompletion
|
||||
.builder()
|
||||
.messages(messageList)
|
||||
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||
.build();
|
||||
ChatCompletionResponse chatCompletionResponseV2 = openAiClient.chatCompletion(chatCompletionV2);
|
||||
log.info("自定义的方法返回值:{}",chatCompletionResponseV2.getChoices().get(0).getMessage().getContent());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void plugin() {
|
||||
CmdPlugin plugin = new CmdPlugin(CmdReq.class);
|
||||
// 插件名称
|
||||
plugin.setName("命令行工具");
|
||||
// 方法名称
|
||||
plugin.setFunction("openCmd");
|
||||
// 方法说明
|
||||
plugin.setDescription("提供一个命令行指令,比如<记事本>,指令使用中文,以function返回结果为准");
|
||||
|
||||
PluginAbstract.Arg arg = new PluginAbstract.Arg();
|
||||
// 参数名称
|
||||
arg.setName("cmd");
|
||||
// 参数说明
|
||||
arg.setDescription("命令行指令");
|
||||
// 参数类型
|
||||
arg.setType("string");
|
||||
arg.setRequired(true);
|
||||
plugin.setArgs(Collections.singletonList(arg));
|
||||
|
||||
Message message2 = Message.builder().role(Message.Role.USER).content("帮我打开计算器,结合上下文判断指令是否执行成功,只用回复成功或者失败").build();
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(message2);
|
||||
//有四个重载方法,都可以使用
|
||||
ChatCompletionResponse response = openAiClient.chatCompletionWithPlugin(messages,"gpt-4o-mini",plugin);
|
||||
log.info("自定义的方法返回值:{}", response.getChoices().get(0).getMessage().getContent());
|
||||
}
|
||||
|
||||
/**
|
||||
* 自定义返回数据格式
|
||||
*/
|
||||
@Test
|
||||
public void diyReturnDataModelChat() {
|
||||
Message message = Message.builder().role(Message.Role.USER).content("随机输出10个单词,使用json输出").build();
|
||||
ChatCompletion chatCompletion = ChatCompletion
|
||||
.builder()
|
||||
.messages(Collections.singletonList(message))
|
||||
.responseFormat(ResponseFormat.builder().type(ResponseFormat.Type.JSON_OBJECT.getName()).build())
|
||||
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||
.build();
|
||||
ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion);
|
||||
chatCompletionResponse.getChoices().forEach(e -> System.out.println(e.getMessage()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void streamPlugin() {
|
||||
WeatherPlugin plugin = new WeatherPlugin(WeatherReq.class);
|
||||
plugin.setName("知心天气");
|
||||
plugin.setFunction("getLocationWeather");
|
||||
plugin.setDescription("提供一个地址,方法将会获取该地址的天气的实时温度信息。");
|
||||
PluginAbstract.Arg arg = new PluginAbstract.Arg();
|
||||
arg.setName("location");
|
||||
arg.setDescription("地名");
|
||||
arg.setType("string");
|
||||
arg.setRequired(true);
|
||||
plugin.setArgs(Collections.singletonList(arg));
|
||||
|
||||
// Message message1 = Message.builder().role(Message.Role.USER).content("秦始皇统一了哪六国。").build();
|
||||
Message message2 = Message.builder().role(Message.Role.USER).content("获取上海市的天气现在多少度,然后再给出3个推荐的户外运动。").build();
|
||||
List<Message> messages = new ArrayList<>();
|
||||
// messages.add(message1);
|
||||
messages.add(message2);
|
||||
//默认模型:GPT_3_5_TURBO_16K_0613
|
||||
//有四个重载方法,都可以使用
|
||||
openAiStreamClient.streamChatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_4_1106_PREVIEW.getName(), new ConsoleEventSourceListener(), plugin);
|
||||
CountDownLatch countDownLatch = new CountDownLatch(1);
|
||||
try {
|
||||
countDownLatch.await();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* tools使用示例
|
||||
*/
|
||||
@Test
|
||||
public void toolsChat() {
|
||||
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build();
|
||||
//属性一
|
||||
JSONObject wordLength = new JSONObject();
|
||||
wordLength.put("type", "number");
|
||||
wordLength.put("description", "词语的长度");
|
||||
//属性二
|
||||
JSONObject language = new JSONObject();
|
||||
language.put("type", "string");
|
||||
language.put("enum", Arrays.asList("zh", "en"));
|
||||
language.put("description", "语言类型,例如:zh代表中文、en代表英语");
|
||||
//参数
|
||||
JSONObject properties = new JSONObject();
|
||||
properties.put("wordLength", wordLength);
|
||||
properties.put("language", language);
|
||||
Parameters parameters = Parameters.builder()
|
||||
.type("object")
|
||||
.properties(properties)
|
||||
.required(Collections.singletonList("wordLength")).build();
|
||||
Tools tools = Tools.builder()
|
||||
.type(Tools.Type.FUNCTION.getName())
|
||||
.function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build())
|
||||
.build();
|
||||
|
||||
ChatCompletion chatCompletion = ChatCompletion
|
||||
.builder()
|
||||
.messages(Collections.singletonList(message))
|
||||
.tools(Collections.singletonList(tools))
|
||||
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||
.build();
|
||||
ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion);
|
||||
|
||||
ChatChoice chatChoice = chatCompletionResponse.getChoices().get(0);
|
||||
log.info("构造的方法值:{}", chatChoice.getMessage().getToolCalls());
|
||||
|
||||
ToolCalls openAiReturnToolCalls = chatChoice.getMessage().getToolCalls().get(0);
|
||||
WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class);
|
||||
String oneWord = getOneWord(wordParam);
|
||||
|
||||
|
||||
ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build();
|
||||
ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build();
|
||||
//构造tool call
|
||||
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build();
|
||||
String content
|
||||
= "{ " +
|
||||
"\"wordLength\": \"3\", " +
|
||||
"\"language\": \"zh\", " +
|
||||
"\"word\": \"" + oneWord + "\"," +
|
||||
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
|
||||
"}";
|
||||
Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build();
|
||||
List<Message> messageList = Arrays.asList(message, message2, message3);
|
||||
ChatCompletion chatCompletionV2 = ChatCompletion
|
||||
.builder()
|
||||
.messages(messageList)
|
||||
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||
.build();
|
||||
ChatCompletionResponse chatCompletionResponseV2 = openAiClient.chatCompletion(chatCompletionV2);
|
||||
log.info("自定义的方法返回值:{}", chatCompletionResponseV2.getChoices().get(0).getMessage().getContent());
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* tools流式输出使用示例
|
||||
*/
|
||||
@Test
|
||||
public void streamToolsChat() {
|
||||
|
||||
CountDownLatch countDownLatch = new CountDownLatch(1);
|
||||
ConsoleEventSourceListenerV3 eventSourceListener = new ConsoleEventSourceListenerV3(countDownLatch);
|
||||
|
||||
Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build();
|
||||
//属性一
|
||||
JSONObject wordLength = new JSONObject();
|
||||
wordLength.put("type", "number");
|
||||
wordLength.put("description", "词语的长度");
|
||||
//属性二
|
||||
JSONObject language = new JSONObject();
|
||||
language.put("type", "string");
|
||||
language.put("enum", Arrays.asList("zh", "en"));
|
||||
language.put("description", "语言类型,例如:zh代表中文、en代表英语");
|
||||
//参数
|
||||
JSONObject properties = new JSONObject();
|
||||
properties.put("wordLength", wordLength);
|
||||
properties.put("language", language);
|
||||
Parameters parameters = Parameters.builder()
|
||||
.type("object")
|
||||
.properties(properties)
|
||||
.required(Collections.singletonList("wordLength")).build();
|
||||
Tools tools = Tools.builder()
|
||||
.type(Tools.Type.FUNCTION.getName())
|
||||
.function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build())
|
||||
.build();
|
||||
|
||||
ChatCompletion chatCompletion = ChatCompletion
|
||||
.builder()
|
||||
.messages(Collections.singletonList(message))
|
||||
.tools(Collections.singletonList(tools))
|
||||
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||
.build();
|
||||
openAiStreamClient.streamChatCompletion(chatCompletion, eventSourceListener);
|
||||
|
||||
try {
|
||||
countDownLatch.await();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
ToolCalls openAiReturnToolCalls = eventSourceListener.getToolCalls();
|
||||
WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class);
|
||||
String oneWord = getOneWord(wordParam);
|
||||
|
||||
|
||||
ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build();
|
||||
ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build();
|
||||
//构造tool call
|
||||
Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build();
|
||||
String content
|
||||
= "{ " +
|
||||
"\"wordLength\": \"3\", " +
|
||||
"\"language\": \"zh\", " +
|
||||
"\"word\": \"" + oneWord + "\"," +
|
||||
"\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" +
|
||||
"}";
|
||||
Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build();
|
||||
List<Message> messageList = Arrays.asList(message, message2, message3);
|
||||
ChatCompletion chatCompletionV2 = ChatCompletion
|
||||
.builder()
|
||||
.messages(messageList)
|
||||
.model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName())
|
||||
.build();
|
||||
|
||||
|
||||
CountDownLatch countDownLatch1 = new CountDownLatch(1);
|
||||
openAiStreamClient.streamChatCompletion(chatCompletionV2, new ConsoleEventSourceListenerV3(countDownLatch));
|
||||
try {
|
||||
countDownLatch1.await();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
try {
|
||||
countDownLatch1.await();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
static class WordParam {
|
||||
private int wordLength;
|
||||
@Builder.Default
|
||||
private String language = "zh";
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 获取一个词语(根据语言和字符长度查询)
|
||||
* @param wordParam
|
||||
* @return
|
||||
*/
|
||||
public String getOneWord(WordParam wordParam) {
|
||||
|
||||
List<String> zh = Arrays.asList("大香蕉", "哈密瓜", "苹果");
|
||||
List<String> en = Arrays.asList("apple", "banana", "cantaloupe");
|
||||
if (wordParam.getLanguage().equals("zh")) {
|
||||
for (String e : zh) {
|
||||
if (e.length() == wordParam.getWordLength()) {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (wordParam.getLanguage().equals("en")) {
|
||||
for (String e : en) {
|
||||
if (e.length() == wordParam.getWordLength()) {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
}
|
||||
return "西瓜";
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package org.ruoyi.common.chat.demo;
|
||||
|
||||
|
||||
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||
|
||||
public class WeatherPlugin extends PluginAbstract<WeatherReq, WeatherResp> {
|
||||
|
||||
public WeatherPlugin(Class<?> r) {
|
||||
super(r);
|
||||
}
|
||||
|
||||
@Override
|
||||
public WeatherResp func(WeatherReq args) {
|
||||
WeatherResp weatherResp = new WeatherResp();
|
||||
weatherResp.setTemp("25到28摄氏度");
|
||||
weatherResp.setLevel(3);
|
||||
return weatherResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String content(WeatherResp weatherResp) {
|
||||
return "当前天气温度:" + weatherResp.getTemp() + ",风力等级:" + weatherResp.getLevel();
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package org.ruoyi.common.chat.demo;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
import org.ruoyi.common.chat.openai.plugin.PluginParam;
|
||||
|
||||
@Data
|
||||
public class WeatherReq extends PluginParam {
|
||||
/**
|
||||
* 城市
|
||||
*/
|
||||
private String location;
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package org.ruoyi.common.chat.demo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class WeatherResp {
|
||||
/**
|
||||
* 温度
|
||||
*/
|
||||
private String temp;
|
||||
/**
|
||||
* 风力等级
|
||||
*/
|
||||
private Integer level;
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
package org.ruoyi.common.chat.demo;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.zhipu.oapi.ClientV4;
|
||||
import com.zhipu.oapi.Constants;
|
||||
import com.zhipu.oapi.service.v4.tools.*;
|
||||
import org.junit.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
|
||||
import com.zhipu.oapi.service.v4.model.*;
|
||||
import io.reactivex.Flowable;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
public class WebSearchToolsTest {
|
||||
|
||||
private final static Logger logger = LoggerFactory.getLogger(WebSearchToolsTest.class);
|
||||
private static final String API_SECRET_KEY = "xx";
|
||||
|
||||
private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY)
|
||||
.networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
|
||||
.connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
|
||||
.build();
|
||||
private static final ObjectMapper mapper = new ObjectMapper();
|
||||
// 请自定义自己的业务id
|
||||
private static final String requestIdTemplate = "mycompany-%d";
|
||||
|
||||
|
||||
@Test
|
||||
public void test1() throws JsonProcessingException {
|
||||
|
||||
// json 转换 ArrayList<SearchChatMessage>
|
||||
String jsonString = "[\n" +
|
||||
" {\n" +
|
||||
" \"content\": \"今天武汉天气怎么样\",\n" +
|
||||
" \"role\": \"user\"\n" +
|
||||
" }\n" +
|
||||
" ]";
|
||||
|
||||
ArrayList<SearchChatMessage> messages = new ObjectMapper().readValue(jsonString, new TypeReference<ArrayList<SearchChatMessage>>() {
|
||||
});
|
||||
|
||||
|
||||
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
||||
WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
|
||||
.model("web-search-pro")
|
||||
.stream(Boolean.TRUE)
|
||||
.messages(messages)
|
||||
.requestId(requestId)
|
||||
.build();
|
||||
WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest);
|
||||
if (webSearchApiResponse.isSuccess()) {
|
||||
AtomicBoolean isFirst = new AtomicBoolean(true);
|
||||
List<ChoiceDelta> choices = new ArrayList<>();
|
||||
AtomicReference<WebSearchPro> lastAccumulator = new AtomicReference<>();
|
||||
|
||||
webSearchApiResponse.getFlowable().map(result -> result)
|
||||
.doOnNext(accumulator -> {
|
||||
{
|
||||
if (isFirst.getAndSet(false)) {
|
||||
logger.info("Response: ");
|
||||
}
|
||||
ChoiceDelta delta = accumulator.getChoices().get(0).getDelta();
|
||||
if (delta != null && delta.getToolCalls() != null) {
|
||||
logger.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls()));
|
||||
}
|
||||
choices.add(delta);
|
||||
lastAccumulator.set(accumulator);
|
||||
|
||||
}
|
||||
})
|
||||
.doOnComplete(() -> System.out.println("Stream completed."))
|
||||
.doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors
|
||||
.blockingSubscribe();// Use blockingSubscribe instead of blockingGet()
|
||||
|
||||
WebSearchPro chatMessageAccumulator = lastAccumulator.get();
|
||||
|
||||
webSearchApiResponse.setFlowable(null);// 打印前置空
|
||||
webSearchApiResponse.setData(chatMessageAccumulator);
|
||||
}
|
||||
logger.info("model output: {}", mapper.writeValueAsString(webSearchApiResponse));
|
||||
client.getConfig().getHttpClient().dispatcher().executorService().shutdown();
|
||||
|
||||
client.getConfig().getHttpClient().connectionPool().evictAll();
|
||||
// List all active threads
|
||||
for (Thread t : Thread.getAllStackTraces().keySet()) {
|
||||
logger.info("Thread: " + t.getName() + " State: " + t.getState());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void test2() throws JsonProcessingException {
|
||||
|
||||
// json 转换 ArrayList<SearchChatMessage>
|
||||
String jsonString = "[\n" +
|
||||
" {\n" +
|
||||
" \"content\": \"今天天气怎么样\",\n" +
|
||||
" \"role\": \"user\"\n" +
|
||||
" }\n" +
|
||||
" ]";
|
||||
|
||||
ArrayList<SearchChatMessage> messages = new ObjectMapper().readValue(jsonString, new TypeReference<ArrayList<SearchChatMessage>>() {
|
||||
});
|
||||
|
||||
|
||||
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
||||
WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
|
||||
.model("web-search-pro")
|
||||
.stream(Boolean.FALSE)
|
||||
.messages(messages)
|
||||
.requestId(requestId)
|
||||
.build();
|
||||
WebSearchApiResponse webSearchApiResponse = client.invokeWebSearchPro(chatCompletionRequest);
|
||||
|
||||
logger.info("model output: {}", mapper.writeValueAsString(webSearchApiResponse));
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testFunctionSSE() throws JsonProcessingException {
|
||||
List<ChatMessage> messages = new ArrayList<>();
|
||||
ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "成都到北京要多久,天气如何");
|
||||
messages.add(chatMessage);
|
||||
String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
||||
// 函数调用参数构建部分
|
||||
List<ChatTool> chatToolList = new ArrayList<>();
|
||||
ChatTool chatTool = new ChatTool();
|
||||
|
||||
chatTool.setType(ChatToolType.FUNCTION.value());
|
||||
ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters();
|
||||
chatFunctionParameters.setType("object");
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put("location", new HashMap<String, Object>() {{
|
||||
put("type", "string");
|
||||
put("description", "城市,如:北京");
|
||||
}});
|
||||
properties.put("unit", new HashMap<String, Object>() {{
|
||||
put("type", "string");
|
||||
put("enum", new ArrayList<String>() {{
|
||||
add("celsius");
|
||||
add("fahrenheit");
|
||||
}});
|
||||
}});
|
||||
chatFunctionParameters.setProperties(properties);
|
||||
ChatFunction chatFunction = ChatFunction.builder()
|
||||
.name("get_weather")
|
||||
.description("Get the current weather of a location")
|
||||
.parameters(chatFunctionParameters)
|
||||
.build();
|
||||
chatTool.setFunction(chatFunction);
|
||||
chatToolList.add(chatTool);
|
||||
HashMap<String, Object> extraJson = new HashMap<>();
|
||||
extraJson.put("temperature", 0.5);
|
||||
extraJson.put("max_tokens", 50);
|
||||
|
||||
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
|
||||
.model(Constants.ModelChatGLM4)
|
||||
.stream(Boolean.TRUE)
|
||||
.messages(messages)
|
||||
.requestId(requestId)
|
||||
.tools(chatToolList)
|
||||
.toolChoice("auto")
|
||||
.extraJson(extraJson)
|
||||
.build();
|
||||
ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest);
|
||||
if (sseModelApiResp.isSuccess()) {
|
||||
AtomicBoolean isFirst = new AtomicBoolean(true);
|
||||
List<Choice> choices = new ArrayList<>();
|
||||
ChatMessageAccumulator chatMessageAccumulator = mapStreamToAccumulator(sseModelApiResp.getFlowable())
|
||||
.doOnNext(accumulator -> {
|
||||
{
|
||||
if (isFirst.getAndSet(false)) {
|
||||
logger.info("Response: ");
|
||||
}
|
||||
if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) {
|
||||
String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls());
|
||||
logger.info("tool_calls: {}", jsonString);
|
||||
}
|
||||
if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) {
|
||||
logger.info(accumulator.getDelta().getContent());
|
||||
}
|
||||
choices.add(accumulator.getChoice());
|
||||
}
|
||||
})
|
||||
.doOnComplete(System.out::println)
|
||||
.lastElement()
|
||||
.blockingGet();
|
||||
|
||||
|
||||
ModelData data = new ModelData();
|
||||
data.setChoices(choices);
|
||||
data.setUsage(chatMessageAccumulator.getUsage());
|
||||
data.setId(chatMessageAccumulator.getId());
|
||||
data.setCreated(chatMessageAccumulator.getCreated());
|
||||
data.setRequestId(chatCompletionRequest.getRequestId());
|
||||
sseModelApiResp.setFlowable(null);// 打印前置空
|
||||
sseModelApiResp.setData(data);
|
||||
}
|
||||
logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp));
|
||||
}
|
||||
|
||||
public static Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ModelData> flowable) {
|
||||
return flowable.map(chunk -> {
|
||||
return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId());
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package org.ruoyi.common.chat.domain.request;
|
||||
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 描述:
|
||||
*
|
||||
* @author https:www.unfbx.com
|
||||
* @sine 2023-04-08
|
||||
*/
|
||||
@Data
|
||||
public class ChatRequest {
|
||||
|
||||
|
||||
private String frequency_penalty;
|
||||
|
||||
private String max_tokens;
|
||||
|
||||
@NotEmpty(message = "对话消息不能为空")
|
||||
List<Message> messages;
|
||||
|
||||
@NotEmpty(message = "传入的模型不能为空")
|
||||
private String model;
|
||||
|
||||
private String presence_penalty;
|
||||
|
||||
private String stream;
|
||||
|
||||
private double temperature;
|
||||
|
||||
private double top_p = 1;
|
||||
|
||||
/**
|
||||
* 知识库id
|
||||
*/
|
||||
private String kid;
|
||||
|
||||
private String userId;
|
||||
|
||||
/**
|
||||
* 1 联网搜索
|
||||
*/
|
||||
private int chat_type;
|
||||
|
||||
/**
|
||||
* 应用ID
|
||||
*/
|
||||
private String appId;
|
||||
//
|
||||
|
||||
//
|
||||
// /**
|
||||
// * gpt的默认设置
|
||||
// */
|
||||
// private String systemMessage = "";
|
||||
//
|
||||
//
|
||||
//
|
||||
// private double temperature = 0.2;
|
||||
//
|
||||
// /**
|
||||
// * 上下文的条数
|
||||
// */
|
||||
// private Integer contentNumber = 10;
|
||||
//
|
||||
// /**
|
||||
// * 是否携带上下文
|
||||
// */
|
||||
// private Boolean usingContext = Boolean.TRUE;
|
||||
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package org.ruoyi.common.chat.domain.request;
|
||||
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 描述:
|
||||
*
|
||||
* @author https:www.unfbx.com
|
||||
* @sine 2023-04-08
|
||||
*/
|
||||
@Data
|
||||
public class Dall3Request {
|
||||
|
||||
@NotEmpty(message = "传入的模型不能为空")
|
||||
private String model;
|
||||
|
||||
@NotEmpty(message = "提示词不能为空")
|
||||
private String prompt;
|
||||
|
||||
/** 图片大小 */
|
||||
@NotEmpty(message = "图片大小不能为空")
|
||||
private String size ;
|
||||
|
||||
/** 图片质量 */
|
||||
@NotEmpty(message = "图片质量不能为空")
|
||||
private String quality;
|
||||
|
||||
/** 图片风格 */
|
||||
@NotEmpty(message = "图片风格不能为空")
|
||||
private String style;
|
||||
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package org.ruoyi.common.chat.handler;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.ruoyi.common.chat.config.LocalCache;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
||||
@@ -12,7 +11,6 @@ import org.ruoyi.common.chat.listener.WebSocketEventListener;
|
||||
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
||||
import org.ruoyi.common.chat.utils.WebSocketUtils;
|
||||
import org.ruoyi.common.core.utils.SpringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.web.socket.*;
|
||||
import org.springframework.web.socket.handler.AbstractWebSocketHandler;
|
||||
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
package org.ruoyi.common.chat.localModels;
|
||||
|
||||
import io.micrometer.common.util.StringUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.OkHttpClient;
|
||||
import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
|
||||
import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
|
||||
import org.springframework.stereotype.Service;
|
||||
import retrofit2.Call;
|
||||
import retrofit2.Callback;
|
||||
import retrofit2.Response;
|
||||
import retrofit2.Retrofit;
|
||||
import retrofit2.converter.jackson.JacksonConverterFactory;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LocalModelsofitClient {
|
||||
private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask 服务的 URL
|
||||
private static Retrofit retrofit = null;
|
||||
|
||||
// 获取 Retrofit 实例
|
||||
public static Retrofit getRetrofitInstance() {
|
||||
if (retrofit == null) {
|
||||
OkHttpClient client = new OkHttpClient.Builder()
|
||||
.build();
|
||||
|
||||
retrofit = new Retrofit.Builder()
|
||||
.baseUrl(BASE_URL)
|
||||
.client(client)
|
||||
.addConverterFactory(JacksonConverterFactory.create()) // 使用 Jackson 处理 JSON 转换
|
||||
.build();
|
||||
}
|
||||
return retrofit;
|
||||
}
|
||||
|
||||
/**
|
||||
* 向 Flask 服务发送文本向量化请求
|
||||
*
|
||||
* @param queries 查询文本列表
|
||||
* @param modelName 模型名称
|
||||
* @param delimiter 文本分隔符
|
||||
* @param topK 返回的结果数
|
||||
* @param blockSize 文本块大小
|
||||
* @param overlapChars 重叠字符数
|
||||
* @return 返回计算得到的 Top K 嵌入向量列表
|
||||
*/
|
||||
|
||||
public static List<List<Double>> getTopKEmbeddings(
|
||||
List<String> queries,
|
||||
String modelName,
|
||||
String delimiter,
|
||||
int topK,
|
||||
int blockSize,
|
||||
int overlapChars) {
|
||||
|
||||
modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // 默认模型名称
|
||||
delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : "."; // 默认分隔符
|
||||
topK = (topK > 0) ? topK : 3; // 默认返回 3 个结果
|
||||
blockSize = (blockSize > 0) ? blockSize : 500; // 默认文本块大小为 500
|
||||
overlapChars = (overlapChars > 0) ? overlapChars : 50; // 默认重叠字符数为 50
|
||||
|
||||
// 创建 Retrofit 实例
|
||||
Retrofit retrofit = getRetrofitInstance();
|
||||
|
||||
// 创建 SearchService 接口
|
||||
SearchService service = retrofit.create(SearchService.class);
|
||||
|
||||
// 创建请求对象 LocalModelsSearchRequest
|
||||
LocalModelsSearchRequest request = new LocalModelsSearchRequest(
|
||||
queries, // 查询文本列表
|
||||
modelName, // 模型名称
|
||||
delimiter, // 文本分隔符
|
||||
topK, // 返回的结果数
|
||||
blockSize, // 文本块大小
|
||||
overlapChars // 重叠字符数
|
||||
);
|
||||
|
||||
final CountDownLatch latch = new CountDownLatch(1); // 创建一个 CountDownLatch
|
||||
final List<List<Double>>[] topKEmbeddings = new List[]{null}; // 使用数组来存储结果(因为 Java 不支持直接修改 List)
|
||||
|
||||
// 发起异步请求
|
||||
service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
|
||||
@Override
|
||||
public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
|
||||
if (response.isSuccessful()) {
|
||||
LocalModelsSearchResponse searchResponse = response.body();
|
||||
if (searchResponse != null) {
|
||||
topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0); // 获取结果
|
||||
log.info("Successfully retrieved embeddings");
|
||||
} else {
|
||||
log.error("Response body is null");
|
||||
}
|
||||
} else {
|
||||
log.error("Request failed. HTTP error code: " + response.code());
|
||||
}
|
||||
latch.countDown(); // 请求完成,减少计数
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
|
||||
t.printStackTrace();
|
||||
log.error("Request failed: ", t);
|
||||
latch.countDown(); // 请求失败,减少计数
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
latch.await(); // 等待请求完成
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
return topKEmbeddings[0]; // 返回结果
|
||||
}
|
||||
|
||||
// public static void main(String[] args) {
|
||||
// // 示例调用
|
||||
// List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries.");
|
||||
// String modelName = "msmarco-distilbert-base-tas-b";
|
||||
// String delimiter = ".";
|
||||
// int topK = 3;
|
||||
// int blockSize = 500;
|
||||
// int overlapChars = 50;
|
||||
//
|
||||
// List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars);
|
||||
//
|
||||
// // 打印结果
|
||||
// if (topKEmbeddings != null) {
|
||||
// System.out.println("Top K embeddings: ");
|
||||
// for (List<Double> embedding : topKEmbeddings) {
|
||||
// System.out.println(embedding);
|
||||
// }
|
||||
// } else {
|
||||
// System.out.println("No embeddings returned.");
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
// public static void main(String[] args) {
|
||||
// // 创建 Retrofit 实例
|
||||
// Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance();
|
||||
//
|
||||
// // 创建 SearchService 接口
|
||||
// SearchService service = retrofit.create(SearchService.class);
|
||||
//
|
||||
// // 创建请求对象 LocalModelsSearchRequest
|
||||
// LocalModelsSearchRequest request = new LocalModelsSearchRequest(
|
||||
// Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // 查询文本列表
|
||||
// "msmarco-distilbert-base-tas-b", // 模型名称
|
||||
// ".", // 分隔符
|
||||
// 3, // 返回的结果数
|
||||
// 500, // 文本块大小
|
||||
// 50 // 重叠字符数
|
||||
// );
|
||||
//
|
||||
// // 发起请求
|
||||
// service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
|
||||
// @Override
|
||||
// public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
|
||||
// if (response.isSuccessful()) {
|
||||
// LocalModelsSearchResponse searchResponse = response.body();
|
||||
// System.out.println("Response Body: " + response.body()); // Print the whole response body for debugging
|
||||
//
|
||||
// if (searchResponse != null) {
|
||||
// // If the response is not null, process it.
|
||||
// // Example: Extract the embeddings and print them
|
||||
// List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings();
|
||||
// if (topKEmbeddings != null) {
|
||||
// // Print the Top K embeddings
|
||||
//
|
||||
// } else {
|
||||
// System.err.println("Top K embeddings are null");
|
||||
// }
|
||||
//
|
||||
// // If there is more information you want to process, handle it here
|
||||
//
|
||||
// } else {
|
||||
// System.err.println("Response body is null");
|
||||
// }
|
||||
// } else {
|
||||
// System.err.println("Request failed. HTTP error code: " + response.code());
|
||||
// log.error("Failed to retrieve data. HTTP error code: " + response.code());
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// @Override
|
||||
// public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
|
||||
// // 请求失败,打印错误
|
||||
// t.printStackTrace();
|
||||
// log.error("Request failed: ", t);
|
||||
// }
|
||||
// });
|
||||
// }
|
||||
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package org.ruoyi.common.chat.localModels;
|
||||
|
||||
|
||||
|
||||
import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
|
||||
import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
|
||||
import retrofit2.Call;
|
||||
import retrofit2.http.Body;
|
||||
import retrofit2.http.POST;
|
||||
/**
|
||||
* @program: RUOYIAI
|
||||
* @ClassName SearchService
|
||||
* @description: 请求模型
|
||||
* @author: hejh
|
||||
* @create: 2025-03-15 17:27
|
||||
* @Version 1.0
|
||||
**/
|
||||
|
||||
|
||||
public interface SearchService {
|
||||
@POST("/vectorize") // 与 Flask 服务中的路由匹配
|
||||
Call<LocalModelsSearchResponse> vectorize(@Body LocalModelsSearchRequest request);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
package org.ruoyi.common.chat.plugin;
|
||||
|
||||
import org.ruoyi.common.chat.openai.plugin.PluginAbstract;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class CmdPlugin extends PluginAbstract<CmdReq, CmdResp> {
|
||||
|
||||
public CmdPlugin(Class<?> r) {
|
||||
super(r);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CmdResp func(CmdReq args) {
|
||||
try {
|
||||
if("计算器".equals(args.getCmd())){
|
||||
Runtime.getRuntime().exec("calc");
|
||||
}else if("记事本".equals(args.getCmd())){
|
||||
Runtime.getRuntime().exec("notepad");
|
||||
}else if("命令行".equals(args.getCmd())){
|
||||
String [] cmd={"cmd","/C","start copy exel exe2"};
|
||||
Runtime.getRuntime().exec(cmd);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("指令执行失败");
|
||||
}
|
||||
CmdResp resp = new CmdResp();
|
||||
resp.setResult(args.getCmd()+"指令执行成功!");
|
||||
return resp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String content(CmdResp resp) {
|
||||
return resp.getResult();
|
||||
}
|
||||
}
|
||||
@@ -2,31 +2,39 @@ package org.ruoyi.common.chat.request;
|
||||
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import lombok.Data;
|
||||
import org.ruoyi.common.chat.entity.chat.Content;
|
||||
import org.ruoyi.common.chat.entity.chat.Message;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 描述:
|
||||
* 描述:对话请求对象
|
||||
*
|
||||
* @author https:www.unfbx.com
|
||||
* @author ageerle
|
||||
* @sine 2023-04-08
|
||||
*/
|
||||
@Data
|
||||
public class ChatRequest {
|
||||
|
||||
@NotEmpty(message = "传入的模型不能为空")
|
||||
private String model;
|
||||
|
||||
@NotEmpty(message = "对话消息不能为空")
|
||||
List<Message> messages;
|
||||
|
||||
List<Content> imageContent;
|
||||
@NotEmpty(message = "传入的模型不能为空")
|
||||
private String model;
|
||||
|
||||
/**
|
||||
* 提示词
|
||||
*/
|
||||
private String prompt;
|
||||
|
||||
private String userId;
|
||||
/**
|
||||
* 是否开启流式对话
|
||||
*/
|
||||
private Boolean stream = Boolean.TRUE;
|
||||
|
||||
/**
|
||||
* 是否开启联网搜索(0关闭 1开启)
|
||||
*/
|
||||
private Boolean search = Boolean.FALSE;
|
||||
|
||||
/**
|
||||
* 知识库id
|
||||
@@ -34,13 +42,14 @@ public class ChatRequest {
|
||||
private String kid;
|
||||
|
||||
/**
|
||||
* gpt的默认设置
|
||||
* 用户id
|
||||
*/
|
||||
private String systemMessage = "";
|
||||
private String userId;
|
||||
|
||||
private double top_p = 1;
|
||||
|
||||
private double temperature = 0.2;
|
||||
/**
|
||||
* 应用ID
|
||||
*/
|
||||
private String appId;
|
||||
|
||||
/**
|
||||
* 上下文的条数
|
||||
@@ -52,4 +61,5 @@ public class ChatRequest {
|
||||
*/
|
||||
private Boolean usingContext = Boolean.TRUE;
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ public class ConsoleEventSourceListener extends EventSourceListener {
|
||||
log.info("OpenAI返回数据:{}", data);
|
||||
if ("[DONE]".equals(data)) {
|
||||
log.info("OpenAI返回数据结束了");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import okhttp3.ResponseBody;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.ruoyi.common.chat.constant.OpenAIConst;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
||||
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
||||
import org.ruoyi.common.chat.entity.chat.FunctionCall;
|
||||
|
||||
Reference in New Issue
Block a user