diff --git a/ruoyi-admin/src/main/java/org/ruoyi/controller/KnowledgeController.java b/ruoyi-admin/src/main/java/org/ruoyi/controller/KnowledgeController.java index 1768a697..6f9f06b5 100644 --- a/ruoyi-admin/src/main/java/org/ruoyi/controller/KnowledgeController.java +++ b/ruoyi-admin/src/main/java/org/ruoyi/controller/KnowledgeController.java @@ -73,11 +73,9 @@ public class KnowledgeController extends BaseController { */ @PostMapping("/send") public SseEmitter send(@RequestBody @Valid ChatRequest chatRequest) { - openAiStreamClient = chatConfig.getOpenAiStreamClient(); SseEmitter sseEmitter = new SseEmitter(0L); SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(sseEmitter); - List messages = chatRequest.getMessages(); String content = messages.get(messages.size() - 1).getContent().toString(); List nearestList; @@ -89,8 +87,6 @@ public class KnowledgeController extends BaseController { } Message userMessage = Message.builder().content(content + (nearestList.size() > 0 ? "\n\n注意:回答问题时,须严格根据我给你的系统上下文内容原文进行回答,请不要自己发挥,回答时保持原来文本的段落层级" : "") ).role(Message.Role.USER).build(); messages.add(userMessage); - - ChatCompletion completion = ChatCompletion .builder() .messages(messages) @@ -104,7 +100,6 @@ public class KnowledgeController extends BaseController { return sseEmitter; } - /** * 根据用户信息查询本地知识库 */ @@ -117,8 +112,6 @@ public class KnowledgeController extends BaseController { return knowledgeInfoService.queryPageList(bo, pageQuery); } - - /** * 新增知识库 */ @@ -190,10 +183,9 @@ public class KnowledgeController extends BaseController { * 删除知识库附件 * */ - @PostMapping("attach/remove/{kid}") - public R removeAttach(@NotEmpty(message = "主键不能为空") - @PathVariable String kid) { - attachService.removeKnowledgeAttach(kid); + @PostMapping("attach/remove/{docId}") + public R removeAttach(@NotEmpty(message = "主键不能为空") @PathVariable String docId) { + attachService.removeKnowledgeAttach(docId); return R.ok(); } diff --git a/ruoyi-common/ruoyi-common-chat/pom.xml b/ruoyi-common/ruoyi-common-chat/pom.xml index 8b5ac816..aba4a0f6 100644 --- a/ruoyi-common/ruoyi-common-chat/pom.xml +++ b/ruoyi-common/ruoyi-common-chat/pom.xml @@ -26,6 +26,12 @@ ruoyi-common-core + + mysql + mysql-connector-java + 8.0.33 + + com.azure azure-ai-openai @@ -92,5 +98,25 @@ + + junit + junit + + + junit + junit + + + + cn.bigmodel.openapi + oapi-java-sdk + release-V4-2.3.0 + + + com.squareup.okhttp + okhttp + 2.7.5 + compile + diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/ConsoleEventSourceListenerV2.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/ConsoleEventSourceListenerV2.java new file mode 100644 index 00000000..21bb3d5b --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/ConsoleEventSourceListenerV2.java @@ -0,0 +1,73 @@ +package org.ruoyi.common.chat.demo; + +import cn.hutool.json.JSONUtil; + +import lombok.Getter; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse; + +import java.util.Objects; +import java.util.concurrent.CountDownLatch; + +/** + * 描述: sse + * + * @author https:www.unfbx.com + * 2023-06-15 + */ +@Slf4j +public class ConsoleEventSourceListenerV2 extends EventSourceListener { + @Getter + String args = ""; + final CountDownLatch countDownLatch; + + public ConsoleEventSourceListenerV2(CountDownLatch countDownLatch) { + this.countDownLatch = countDownLatch; + } + + @Override + public void onOpen(EventSource eventSource, Response response) { + log.info("OpenAI建立sse连接..."); + } + + @Override + public void onEvent(EventSource eventSource, String id, String type, String data) { + log.info("OpenAI返回数据:{}", data); + if (data.equals("[DONE]")) { + log.info("OpenAI返回数据结束了"); + countDownLatch.countDown(); + return; + } + ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class); + if(Objects.nonNull(chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall())){ + args += chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall().getArguments(); + } + } + + @Override + public void onClosed(EventSource eventSource) { + log.info("OpenAI关闭sse连接..."); + } + + @SneakyThrows + @Override + public void onFailure(EventSource eventSource, Throwable t, Response response) { + if(Objects.isNull(response)){ + log.error("OpenAI sse连接异常:{}", t); + eventSource.cancel(); + return; + } + ResponseBody body = response.body(); + if (Objects.nonNull(body)) { + log.error("OpenAI sse连接异常data:{},异常:{}", body.string(), t); + } else { + log.error("OpenAI sse连接异常data:{},异常:{}", response, t); + } + eventSource.cancel(); + } +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/ConsoleEventSourceListenerV3.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/ConsoleEventSourceListenerV3.java new file mode 100644 index 00000000..22661efe --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/ConsoleEventSourceListenerV3.java @@ -0,0 +1,92 @@ +package org.ruoyi.common.chat.demo; + +import cn.hutool.core.collection.CollectionUtil; +import cn.hutool.json.JSONUtil; +import lombok.Getter; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse; +import org.ruoyi.common.chat.entity.chat.Message; +import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction; +import org.ruoyi.common.chat.entity.chat.tool.ToolCalls; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; + +/** + * 描述: demo测试实现类,仅供思路参考 + * + * @author https:www.unfbx.com + * 2023-11-12 + */ +@Slf4j +public class ConsoleEventSourceListenerV3 extends EventSourceListener { + @Getter + List choices = new ArrayList<>(); + @Getter + ToolCalls toolCalls = new ToolCalls(); + @Getter + ToolCallFunction toolCallFunction = ToolCallFunction.builder().name("").arguments("").build(); + final CountDownLatch countDownLatch; + + public ConsoleEventSourceListenerV3(CountDownLatch countDownLatch) { + this.countDownLatch = countDownLatch; + } + + @Override + public void onOpen(EventSource eventSource, Response response) { + log.info("OpenAI建立sse连接..."); + } + + @Override + public void onEvent(EventSource eventSource, String id, String type, String data) { + log.info("OpenAI返回数据:{}", data); + if (data.equals("[DONE]")) { + log.info("OpenAI返回数据结束了"); + return; + } + ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class); + Message delta = chatCompletionResponse.getChoices().get(0).getDelta(); + if (CollectionUtil.isNotEmpty(delta.getToolCalls())) { + choices.addAll(delta.getToolCalls()); + } + } + + @Override + public void onClosed(EventSource eventSource) { + if(CollectionUtil.isNotEmpty(choices)){ + toolCalls.setId(choices.get(0).getId()); + toolCalls.setType(choices.get(0).getType()); + choices.forEach(e -> { + toolCallFunction.setName(e.getFunction().getName()); + toolCallFunction.setArguments(toolCallFunction.getArguments() + e.getFunction().getArguments()); + toolCalls.setFunction(toolCallFunction); + }); + } + log.info("OpenAI关闭sse连接..."); + countDownLatch.countDown(); + } + + @SneakyThrows + @Override + public void onFailure(EventSource eventSource, Throwable t, Response response) { + if(Objects.isNull(response)){ + log.error("OpenAI sse连接异常:{}", t); + eventSource.cancel(); + return; + } + ResponseBody body = response.body(); + if (Objects.nonNull(body)) { + log.error("OpenAI sse连接异常data:{},异常:{}", body.string(), t); + } else { + log.error("OpenAI sse连接异常data:{},异常:{}", response, t); + } + eventSource.cancel(); + } +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/PluginTest.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/PluginTest.java new file mode 100644 index 00000000..eeda6d4b --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/PluginTest.java @@ -0,0 +1,417 @@ +package org.ruoyi.common.chat.demo; + +import cn.hutool.json.JSONUtil; +import com.alibaba.fastjson.JSONObject; +import lombok.Builder; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; +import okhttp3.logging.HttpLoggingInterceptor; +import org.junit.Before; +import org.junit.Test; +import org.ruoyi.common.chat.entity.chat.*; +import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction; +import org.ruoyi.common.chat.entity.chat.tool.ToolCalls; +import org.ruoyi.common.chat.entity.chat.tool.Tools; +import org.ruoyi.common.chat.entity.chat.tool.ToolsFunction; +import org.ruoyi.common.chat.openai.OpenAiClient; +import org.ruoyi.common.chat.openai.OpenAiStreamClient; +import org.ruoyi.common.chat.openai.function.KeyRandomStrategy; +import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor; +import org.ruoyi.common.chat.openai.interceptor.OpenAILogger; +import org.ruoyi.common.chat.openai.interceptor.OpenAiResponseInterceptor; +import org.ruoyi.common.chat.openai.plugin.PluginAbstract; +import org.ruoyi.common.chat.plugin.CmdPlugin; +import org.ruoyi.common.chat.plugin.CmdReq; +import org.ruoyi.common.chat.sse.ConsoleEventSourceListener; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +/** + * 描述: + * + * @author ageerle@163.com + * date 2025/3/8 + */ +@Slf4j +public class PluginTest { + + private OpenAiClient openAiClient; + private OpenAiStreamClient openAiStreamClient; + + @Before + public void before() { + //可以为null +// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)); + HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger()); + //!!!!千万别再生产或者测试环境打开BODY级别日志!!!! + //!!!生产或者测试环境建议设置为这三种级别:NONE,BASIC,HEADERS,!!! + httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS); + OkHttpClient okHttpClient = new OkHttpClient + .Builder() +// .proxy(proxy) + .addInterceptor(httpLoggingInterceptor) + .addInterceptor(new OpenAiResponseInterceptor()) + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(30, TimeUnit.SECONDS) + .readTimeout(30, TimeUnit.SECONDS) + .build(); + openAiClient = OpenAiClient.builder() + //支持多key传入,请求时候随机选择 + .apiKey(Arrays.asList("sk-xx")) + //自定义key的获取策略:默认KeyRandomStrategy + //.keyStrategy(new KeyRandomStrategy()) + .keyStrategy(new KeyRandomStrategy()) + .okHttpClient(okHttpClient) + //自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址) + .apiHost("https://api.pandarobot.chat/") + .build(); + + openAiStreamClient = OpenAiStreamClient.builder() + //支持多key传入,请求时候随机选择 + .apiKey(Arrays.asList("sk-xx")) + //自定义key的获取策略:默认KeyRandomStrategy + .keyStrategy(new KeyRandomStrategy()) + .authInterceptor(new DynamicKeyOpenAiAuthInterceptor()) + .okHttpClient(okHttpClient) + //自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址) + .apiHost("https://api.pandarobot.chat/") + .build(); + } + + + @Test + public void chatFunction() { + //模型:GPT_3_5_TURBO_16K_0613 + Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build(); + //属性一 + JSONObject wordLength = new JSONObject(); + wordLength.put("type", "number"); + wordLength.put("description", "词语的长度"); + //属性二 + JSONObject language = new JSONObject(); + language.put("type", "string"); + language.put("enum", Arrays.asList("zh", "en")); + language.put("description", "语言类型,例如:zh代表中文、en代表英语"); + //参数 + JSONObject properties = new JSONObject(); + properties.put("wordLength", wordLength); + properties.put("language", language); + + Parameters parameters = Parameters.builder() + .type("object") + .properties(properties) + .required(Collections.singletonList("wordLength")).build(); + Functions functions = Functions.builder() + .name("getOneWord") + .description("获取一个指定长度和语言类型的词语") + .parameters(parameters) + .build(); + + ChatCompletion chatCompletion = ChatCompletion + .builder() + .messages(Collections.singletonList(message)) + .functions(Collections.singletonList(functions)) + .functionCall("auto") + .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) + .build(); + ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion); + + ChatChoice chatChoice = chatCompletionResponse.getChoices().get(0); + log.info("构造的方法值:{}", chatChoice.getMessage().getFunctionCall()); + log.info("构造的方法名称:{}", chatChoice.getMessage().getFunctionCall().getName()); + log.info("构造的方法参数:{}", chatChoice.getMessage().getFunctionCall().getArguments()); + WordParam wordParam = JSONUtil.toBean(chatChoice.getMessage().getFunctionCall().getArguments(), WordParam.class); + String oneWord = getOneWord(wordParam); + + FunctionCall functionCall = FunctionCall.builder() + .arguments(chatChoice.getMessage().getFunctionCall().getArguments()) + .name("getOneWord") + .build(); + Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").functionCall(functionCall).build(); + String content + = "{ " + + "\"wordLength\": \"3\", " + + "\"language\": \"zh\", " + + "\"word\": \"" + oneWord + "\"," + + "\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" + + "}"; + Message message3 = Message.builder().role(Message.Role.FUNCTION).name("getOneWord").content(content).build(); + List messageList = Arrays.asList(message, message2, message3); + ChatCompletion chatCompletionV2 = ChatCompletion + .builder() + .messages(messageList) + .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) + .build(); + ChatCompletionResponse chatCompletionResponseV2 = openAiClient.chatCompletion(chatCompletionV2); + log.info("自定义的方法返回值:{}",chatCompletionResponseV2.getChoices().get(0).getMessage().getContent()); + } + + + @Test + public void plugin() { + CmdPlugin plugin = new CmdPlugin(CmdReq.class); + // 插件名称 + plugin.setName("命令行工具"); + // 方法名称 + plugin.setFunction("openCmd"); + // 方法说明 + plugin.setDescription("提供一个命令行指令,比如<记事本>,指令使用中文,以function返回结果为准"); + + PluginAbstract.Arg arg = new PluginAbstract.Arg(); + // 参数名称 + arg.setName("cmd"); + // 参数说明 + arg.setDescription("命令行指令"); + // 参数类型 + arg.setType("string"); + arg.setRequired(true); + plugin.setArgs(Collections.singletonList(arg)); + + Message message2 = Message.builder().role(Message.Role.USER).content("帮我打开计算器,结合上下文判断指令是否执行成功,只用回复成功或者失败").build(); + List messages = new ArrayList<>(); + messages.add(message2); + //有四个重载方法,都可以使用 + ChatCompletionResponse response = openAiClient.chatCompletionWithPlugin(messages,"gpt-4o-mini",plugin); + log.info("自定义的方法返回值:{}", response.getChoices().get(0).getMessage().getContent()); + } + + /** + * 自定义返回数据格式 + */ + @Test + public void diyReturnDataModelChat() { + Message message = Message.builder().role(Message.Role.USER).content("随机输出10个单词,使用json输出").build(); + ChatCompletion chatCompletion = ChatCompletion + .builder() + .messages(Collections.singletonList(message)) + .responseFormat(ResponseFormat.builder().type(ResponseFormat.Type.JSON_OBJECT.getName()).build()) + .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) + .build(); + ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion); + chatCompletionResponse.getChoices().forEach(e -> System.out.println(e.getMessage())); + } + + @Test + public void streamPlugin() { + WeatherPlugin plugin = new WeatherPlugin(WeatherReq.class); + plugin.setName("知心天气"); + plugin.setFunction("getLocationWeather"); + plugin.setDescription("提供一个地址,方法将会获取该地址的天气的实时温度信息。"); + PluginAbstract.Arg arg = new PluginAbstract.Arg(); + arg.setName("location"); + arg.setDescription("地名"); + arg.setType("string"); + arg.setRequired(true); + plugin.setArgs(Collections.singletonList(arg)); + +// Message message1 = Message.builder().role(Message.Role.USER).content("秦始皇统一了哪六国。").build(); + Message message2 = Message.builder().role(Message.Role.USER).content("获取上海市的天气现在多少度,然后再给出3个推荐的户外运动。").build(); + List messages = new ArrayList<>(); +// messages.add(message1); + messages.add(message2); + //默认模型:GPT_3_5_TURBO_16K_0613 + //有四个重载方法,都可以使用 + openAiStreamClient.streamChatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_4_1106_PREVIEW.getName(), new ConsoleEventSourceListener(), plugin); + CountDownLatch countDownLatch = new CountDownLatch(1); + try { + countDownLatch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + /** + * tools使用示例 + */ + @Test + public void toolsChat() { + Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build(); + //属性一 + JSONObject wordLength = new JSONObject(); + wordLength.put("type", "number"); + wordLength.put("description", "词语的长度"); + //属性二 + JSONObject language = new JSONObject(); + language.put("type", "string"); + language.put("enum", Arrays.asList("zh", "en")); + language.put("description", "语言类型,例如:zh代表中文、en代表英语"); + //参数 + JSONObject properties = new JSONObject(); + properties.put("wordLength", wordLength); + properties.put("language", language); + Parameters parameters = Parameters.builder() + .type("object") + .properties(properties) + .required(Collections.singletonList("wordLength")).build(); + Tools tools = Tools.builder() + .type(Tools.Type.FUNCTION.getName()) + .function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build()) + .build(); + + ChatCompletion chatCompletion = ChatCompletion + .builder() + .messages(Collections.singletonList(message)) + .tools(Collections.singletonList(tools)) + .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) + .build(); + ChatCompletionResponse chatCompletionResponse = openAiClient.chatCompletion(chatCompletion); + + ChatChoice chatChoice = chatCompletionResponse.getChoices().get(0); + log.info("构造的方法值:{}", chatChoice.getMessage().getToolCalls()); + + ToolCalls openAiReturnToolCalls = chatChoice.getMessage().getToolCalls().get(0); + WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class); + String oneWord = getOneWord(wordParam); + + + ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build(); + ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build(); + //构造tool call + Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build(); + String content + = "{ " + + "\"wordLength\": \"3\", " + + "\"language\": \"zh\", " + + "\"word\": \"" + oneWord + "\"," + + "\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" + + "}"; + Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build(); + List messageList = Arrays.asList(message, message2, message3); + ChatCompletion chatCompletionV2 = ChatCompletion + .builder() + .messages(messageList) + .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) + .build(); + ChatCompletionResponse chatCompletionResponseV2 = openAiClient.chatCompletion(chatCompletionV2); + log.info("自定义的方法返回值:{}", chatCompletionResponseV2.getChoices().get(0).getMessage().getContent()); + + } + + /** + * tools流式输出使用示例 + */ + @Test + public void streamToolsChat() { + + CountDownLatch countDownLatch = new CountDownLatch(1); + ConsoleEventSourceListenerV3 eventSourceListener = new ConsoleEventSourceListenerV3(countDownLatch); + + Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build(); + //属性一 + JSONObject wordLength = new JSONObject(); + wordLength.put("type", "number"); + wordLength.put("description", "词语的长度"); + //属性二 + JSONObject language = new JSONObject(); + language.put("type", "string"); + language.put("enum", Arrays.asList("zh", "en")); + language.put("description", "语言类型,例如:zh代表中文、en代表英语"); + //参数 + JSONObject properties = new JSONObject(); + properties.put("wordLength", wordLength); + properties.put("language", language); + Parameters parameters = Parameters.builder() + .type("object") + .properties(properties) + .required(Collections.singletonList("wordLength")).build(); + Tools tools = Tools.builder() + .type(Tools.Type.FUNCTION.getName()) + .function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build()) + .build(); + + ChatCompletion chatCompletion = ChatCompletion + .builder() + .messages(Collections.singletonList(message)) + .tools(Collections.singletonList(tools)) + .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) + .build(); + openAiStreamClient.streamChatCompletion(chatCompletion, eventSourceListener); + + try { + countDownLatch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + ToolCalls openAiReturnToolCalls = eventSourceListener.getToolCalls(); + WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class); + String oneWord = getOneWord(wordParam); + + + ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build(); + ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build(); + //构造tool call + Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build(); + String content + = "{ " + + "\"wordLength\": \"3\", " + + "\"language\": \"zh\", " + + "\"word\": \"" + oneWord + "\"," + + "\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" + + "}"; + Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build(); + List messageList = Arrays.asList(message, message2, message3); + ChatCompletion chatCompletionV2 = ChatCompletion + .builder() + .messages(messageList) + .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) + .build(); + + + CountDownLatch countDownLatch1 = new CountDownLatch(1); + openAiStreamClient.streamChatCompletion(chatCompletionV2, new ConsoleEventSourceListenerV3(countDownLatch)); + try { + countDownLatch1.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + try { + countDownLatch1.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + } + + + @Data + @Builder + static class WordParam { + private int wordLength; + @Builder.Default + private String language = "zh"; + } + + + /** + * 获取一个词语(根据语言和字符长度查询) + * @param wordParam + * @return + */ + public String getOneWord(WordParam wordParam) { + + List zh = Arrays.asList("大香蕉", "哈密瓜", "苹果"); + List en = Arrays.asList("apple", "banana", "cantaloupe"); + if (wordParam.getLanguage().equals("zh")) { + for (String e : zh) { + if (e.length() == wordParam.getWordLength()) { + return e; + } + } + } + if (wordParam.getLanguage().equals("en")) { + for (String e : en) { + if (e.length() == wordParam.getWordLength()) { + return e; + } + } + } + return "西瓜"; + } +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WeatherPlugin.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WeatherPlugin.java new file mode 100644 index 00000000..787e5983 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WeatherPlugin.java @@ -0,0 +1,24 @@ +package org.ruoyi.common.chat.demo; + + +import org.ruoyi.common.chat.openai.plugin.PluginAbstract; + +public class WeatherPlugin extends PluginAbstract { + + public WeatherPlugin(Class r) { + super(r); + } + + @Override + public WeatherResp func(WeatherReq args) { + WeatherResp weatherResp = new WeatherResp(); + weatherResp.setTemp("25到28摄氏度"); + weatherResp.setLevel(3); + return weatherResp; + } + + @Override + public String content(WeatherResp weatherResp) { + return "当前天气温度:" + weatherResp.getTemp() + ",风力等级:" + weatherResp.getLevel(); + } +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WeatherReq.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WeatherReq.java new file mode 100644 index 00000000..b0670e5b --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WeatherReq.java @@ -0,0 +1,13 @@ +package org.ruoyi.common.chat.demo; + + +import lombok.Data; +import org.ruoyi.common.chat.openai.plugin.PluginParam; + +@Data +public class WeatherReq extends PluginParam { + /** + * 城市 + */ + private String location; +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WeatherResp.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WeatherResp.java new file mode 100644 index 00000000..360c56c8 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WeatherResp.java @@ -0,0 +1,15 @@ +package org.ruoyi.common.chat.demo; + +import lombok.Data; + +@Data +public class WeatherResp { + /** + * 温度 + */ + private String temp; + /** + * 风力等级 + */ + private Integer level; +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/AllToolsTest.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/AllToolsTest.java new file mode 100644 index 00000000..805402c5 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/AllToolsTest.java @@ -0,0 +1,122 @@ +package org.ruoyi.common.chat.demo.zhipu; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.zhipu.oapi.ClientV4; +import com.zhipu.oapi.Constants; +import com.zhipu.oapi.service.v4.deserialize.MessageDeserializeFactory; +import com.zhipu.oapi.service.v4.model.*; +import io.reactivex.Flowable; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +public class AllToolsTest { + + private final static Logger logger = LoggerFactory.getLogger(AllToolsTest.class); + private static final String API_SECRET_KEY = "28550a39d4cfaabbbf38df04dd3931f5.IUvfTThUf0xBF5l0"; + + private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY) + .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS) + .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS)) + .build(); + private static final ObjectMapper mapper = MessageDeserializeFactory.defaultObjectMapper(); + // 请自定义自己的业务id + private static final String requestIdTemplate = "mycompany-%d"; + + + @Test + public void test1() throws JsonProcessingException { + + + List messages = new ArrayList<>(); + ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "帮我查询北京天气"); + messages.add(chatMessage); + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + // 函数调用参数构建部分 + List chatToolList = new ArrayList<>(); + ChatTool chatTool = new ChatTool(); + + chatTool.setType("code_interpreter"); + ObjectNode objectNode = mapper.createObjectNode(); + objectNode.put("code", "北京天气"); +// chatTool.set(chatFunction); + chatToolList.add(chatTool); + + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model("glm-4-alltools") + .stream(Boolean.TRUE) + .invokeMethod(Constants.invokeMethod) + .messages(messages) + .tools(chatToolList) + .toolChoice("auto") + .requestId(requestId) + .build(); + ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest); + if (sseModelApiResp.isSuccess()) { + AtomicBoolean isFirst = new AtomicBoolean(true); + List choices = new ArrayList<>(); + AtomicReference lastAccumulator = new AtomicReference<>(); + + mapStreamToAccumulator(sseModelApiResp.getFlowable()) + .doOnNext(accumulator -> { + { + if (isFirst.getAndSet(false)) { + logger.info("Response: "); + } + if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) { + String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls()); + logger.info("tool_calls: {}", jsonString); + } + if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) { + logger.info(accumulator.getDelta().getContent()); + } + choices.add(accumulator.getChoice()); + lastAccumulator.set(accumulator); + + } + }) + .doOnComplete(() -> System.out.println("Stream completed.")) + .doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors + .blockingSubscribe();// Use blockingSubscribe instead of blockingGet() + + ChatMessageAccumulator chatMessageAccumulator = lastAccumulator.get(); + ModelData data = new ModelData(); + data.setChoices(choices); + if (chatMessageAccumulator != null) { + data.setUsage(chatMessageAccumulator.getUsage()); + data.setId(chatMessageAccumulator.getId()); + data.setCreated(chatMessageAccumulator.getCreated()); + } + data.setRequestId(chatCompletionRequest.getRequestId()); + sseModelApiResp.setFlowable(null);// 打印前置空 + sseModelApiResp.setData(data); + } + logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp)); + client.getConfig().getHttpClient().dispatcher().executorService().shutdown(); + + client.getConfig().getHttpClient().connectionPool().evictAll(); + // List all active threads + for (Thread t : Thread.getAllStackTraces().keySet()) { + logger.info("Thread: " + t.getName() + " State: " + t.getState()); + } + } + + + + public static Flowable mapStreamToAccumulator(Flowable flowable) { + return flowable.map(chunk -> { + return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId()); + }); + } +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/V4Test.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/V4Test.java new file mode 100644 index 00000000..afd664e8 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/V4Test.java @@ -0,0 +1,646 @@ +package org.ruoyi.common.chat.demo.zhipu; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.zhipu.oapi.ClientV4; +import com.zhipu.oapi.Constants; +import com.zhipu.oapi.core.response.HttpxBinaryResponseContent; +import com.zhipu.oapi.service.v4.batchs.BatchCreateParams; +import com.zhipu.oapi.service.v4.batchs.BatchResponse; +import com.zhipu.oapi.service.v4.batchs.QueryBatchResponse; +import com.zhipu.oapi.service.v4.embedding.EmbeddingApiResponse; +import com.zhipu.oapi.service.v4.embedding.EmbeddingRequest; +import com.zhipu.oapi.service.v4.file.*; +import com.zhipu.oapi.service.v4.fine_turning.*; +import com.zhipu.oapi.service.v4.image.CreateImageRequest; +import com.zhipu.oapi.service.v4.image.ImageApiResponse; +import com.zhipu.oapi.service.v4.model.*; +import io.reactivex.Flowable; + +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + + +public class V4Test { + + private final static Logger logger = LoggerFactory.getLogger(V4Test.class); + private static final String API_SECRET_KEY = "28550a39d4cfaabbbf38df04dd3931f5.IUvfTThUf0xBF5l0"; + + + private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY) + .enableTokenCache() + .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS) + .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS)) + .build(); + + // 请自定义自己的业务id + private static final String requestIdTemplate = "mycompany-%d"; + + private static final ObjectMapper mapper = new ObjectMapper(); + + + public static ObjectMapper defaultObjectMapper() { + ObjectMapper mapper = new ObjectMapper(); + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); + mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); + return mapper; + } + + @Test + public void test() { + + } + + /** + * sse-V4:function调用 + */ + @Test + public void testFunctionSSE() throws JsonProcessingException { + List messages = new ArrayList<>(); + ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "成都到北京要多久,天气如何"); + messages.add(chatMessage); + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + // 函数调用参数构建部分 + List chatToolList = new ArrayList<>(); + ChatTool chatTool = new ChatTool(); + + chatTool.setType(ChatToolType.FUNCTION.value()); + ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters(); + chatFunctionParameters.setType("object"); + Map properties = new HashMap<>(); + properties.put("location", new HashMap() {{ + put("type", "string"); + put("description", "城市,如:北京"); + }}); + properties.put("unit", new HashMap() {{ + put("type", "string"); + put("enum", new ArrayList() {{ + add("celsius"); + add("fahrenheit"); + }}); + }}); + chatFunctionParameters.setProperties(properties); + ChatFunction chatFunction = ChatFunction.builder() + .name("get_weather") + .description("Get the current weather of a location") + .parameters(chatFunctionParameters) + .build(); + chatTool.setFunction(chatFunction); + chatToolList.add(chatTool); + HashMap extraJson = new HashMap<>(); + extraJson.put("temperature", 0.5); + extraJson.put("max_tokens", 50); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(Constants.ModelChatGLM4) + .stream(Boolean.TRUE) + .messages(messages) + .requestId(requestId) + .tools(chatToolList) + .toolChoice("auto") + .extraJson(extraJson) + .build(); + ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest); + if (sseModelApiResp.isSuccess()) { + AtomicBoolean isFirst = new AtomicBoolean(true); + List choices = new ArrayList<>(); + ChatMessageAccumulator chatMessageAccumulator = mapStreamToAccumulator(sseModelApiResp.getFlowable()) + .doOnNext(accumulator -> { + { + if (isFirst.getAndSet(false)) { + logger.info("Response: "); + } + if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) { + String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls()); + logger.info("tool_calls: {}", jsonString); + } + if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) { + logger.info(accumulator.getDelta().getContent()); + } + choices.add(accumulator.getChoice()); + } + }) + .doOnComplete(System.out::println) + .lastElement() + .blockingGet(); + + + ModelData data = new ModelData(); + data.setChoices(choices); + data.setUsage(chatMessageAccumulator.getUsage()); + data.setId(chatMessageAccumulator.getId()); + data.setCreated(chatMessageAccumulator.getCreated()); + data.setRequestId(chatCompletionRequest.getRequestId()); + sseModelApiResp.setFlowable(null);// 打印前置空 + sseModelApiResp.setData(data); + } + logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp)); + } + + + /** + * sse-V4:非function调用 + */ + @Test + public void testNonFunctionSSE() throws JsonProcessingException { + List messages = new ArrayList<>(); + ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大"); + messages.add(chatMessage); + HashMap extraJson = new HashMap<>(); + extraJson.put("temperature", 0.5); + extraJson.put("max_tokens", 3); + + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(Constants.ModelChatGLM4) + .stream(Boolean.TRUE) + .messages(messages) + .requestId(requestId) + .extraJson(extraJson) + .build(); + ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest); + // stream 处理方法 + if (sseModelApiResp.isSuccess()) { + AtomicBoolean isFirst = new AtomicBoolean(true); + List choices = new ArrayList<>(); + ChatMessageAccumulator chatMessageAccumulator = mapStreamToAccumulator(sseModelApiResp.getFlowable()) + .doOnNext(accumulator -> { + { + if (isFirst.getAndSet(false)) { + logger.info("Response: "); + } + if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) { + String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls()); + logger.info("tool_calls: {}", jsonString); + } + if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) { + logger.info("accumulator.getDelta().getContent(): {}", accumulator.getDelta().getContent()); + } + choices.add(accumulator.getChoice()); + } + }) + .doOnComplete(System.out::println) + .lastElement() + .blockingGet(); + + + ModelData data = new ModelData(); + data.setChoices(choices); + data.setUsage(chatMessageAccumulator.getUsage()); + data.setId(chatMessageAccumulator.getId()); + data.setCreated(chatMessageAccumulator.getCreated()); + data.setRequestId(chatCompletionRequest.getRequestId()); + sseModelApiResp.setFlowable(null);// 打印前置空 + sseModelApiResp.setData(data); + } + logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp)); + } + + + /** + * V4-同步function调用 + */ + @Test + public void testFunctionInvoke() { + List messages = new ArrayList<>(); + ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "你可以做什么"); + messages.add(chatMessage); + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + // 函数调用参数构建部分 + List chatToolList = new ArrayList<>(); + ChatTool chatTool = new ChatTool(); + chatTool.setType(ChatToolType.FUNCTION.value()); + ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters(); + chatFunctionParameters.setType("object"); + Map properties = new HashMap<>(); + properties.put("location", new HashMap() {{ + put("type", "string"); + put("description", "城市,如:北京"); + }}); + properties.put("unit", new HashMap() {{ + put("type", "string"); + put("enum", new ArrayList() {{ + add("celsius"); + add("fahrenheit"); + }}); + }}); + chatFunctionParameters.setProperties(properties); + ChatFunction chatFunction = ChatFunction.builder() + .name("get_weather") + .description("Get the current weather of a location") + .parameters(chatFunctionParameters) + .build(); + chatTool.setFunction(chatFunction); + + + ChatTool chatTool1 = new ChatTool(); + chatTool1.setType(ChatToolType.WEB_SEARCH.value()); + WebSearch webSearch = new WebSearch(); + webSearch.setSearch_query("清华的升学率"); + webSearch.setSearch_result(true); + webSearch.setEnable(false); + chatTool1.setWeb_search(webSearch); + + chatToolList.add(chatTool); + chatToolList.add(chatTool1); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(Constants.ModelChatGLM4) + .stream(Boolean.FALSE) + .invokeMethod(Constants.invokeMethod) + .messages(messages) + .requestId(requestId) + .tools(chatToolList) + .toolChoice("auto") + .build(); + ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest); + try { + logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp)); + } catch (JsonProcessingException e) { + logger.error("model output error", e); + } + } + + + /** + * V4-同步非function调用 + */ + @Test + public void testNonFunctionInvoke() throws JsonProcessingException { + List messages = new ArrayList<>(); + ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大"); + messages.add(chatMessage); + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + + + HashMap extraJson = new HashMap<>(); + extraJson.put("temperature", 0.5); + extraJson.put("max_tokens", 3); + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(Constants.ModelChatGLM4) + .stream(Boolean.FALSE) + .invokeMethod(Constants.invokeMethod) + .messages(messages) + .requestId(requestId) + .extraJson(extraJson) + .build(); + ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest); + logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp)); + } + + + /** + * V4-同步非function调用 + */ + @Test + public void testCharGlmInvoke() throws JsonProcessingException { + List messages = new ArrayList<>(); + ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大"); + messages.add(chatMessage); + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + + + HashMap extraJson = new HashMap<>(); + extraJson.put("temperature", 0.5); + + ChatMeta meta = new ChatMeta(); + meta.setUser_info("我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。我擅长拍摄音乐题材的电影。苏梦远对我的态度是尊敬的,并视我为良师益友。"); + meta.setBot_info("苏梦远,本名苏远心,是一位当红的国内女歌手及演员。在参加选秀节目后,凭借独特的嗓音及出众的舞台魅力迅速成名,进入娱乐圈。她外表美丽动人,但真正的魅力在于她的才华和勤奋。苏梦远是音乐学院毕业的优秀生,善于创作,拥有多首热门原创歌曲。除了音乐方面的成就,她还热衷于慈善事业,积极参加公益活动,用实际行动传递正能量。在工作中,她对待工作非常敬业,拍戏时总是全身心投入角色,赢得了业内人士的赞誉和粉丝的喜爱。虽然在娱乐圈,但她始终保持低调、谦逊的态度,深得同行尊重。在表达时,苏梦远喜欢使用“我们”和“一起”,强调团队精神。"); + meta.setBot_name("苏梦远"); + meta.setUser_name("陆星辰"); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(Constants.ModelCharGLM3) + .stream(Boolean.FALSE) + .invokeMethod(Constants.invokeMethod) + .messages(messages) + .requestId(requestId) + .meta(meta) + .extraJson(extraJson) + .build(); + ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest); + logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp)); + } + + /** + * V4异步调用 + */ + @Test + public void testAsyncInvoke() throws JsonProcessingException { + String taskId = getAsyncTaskId(); + testQueryResult(taskId); + } + +// + + /** + * 文生图 + */ + @Test + public void testCreateImage() throws JsonProcessingException { + CreateImageRequest createImageRequest = new CreateImageRequest(); + createImageRequest.setModel(Constants.ModelCogView); + createImageRequest.setPrompt("Futuristic cloud data center, showcasing advanced technologgy and a high-tech atmosp\n" + + "here. The image should depict a spacious, well-lit interior with rows of server racks, glo\n" + + "wing lights, and digital displays. Include abstract representattions of data streams and\n" + + "onnectivity, symbolizing the essence of cloud computing. Thee style should be modern a\n" + + "nd sleek, with a focus on creating a sense of innovaticon and cutting-edge technology\n" + + "The overall ambiance should convey the power and effciency of cloud services in a visu\n" + + "ally engaging way."); + createImageRequest.setRequestId("test11111111111111"); + ImageApiResponse imageApiResponse = client.createImage(createImageRequest); + logger.info("imageApiResponse: {}", mapper.writeValueAsString(imageApiResponse)); + } + +// +// /** +// * 图生文 +// */ +// @Test +// public void testImageToWord() throws JsonProcessingException { +// List messages = new ArrayList<>(); +// List> contentList = new ArrayList<>(); +// Map textMap = new HashMap<>(); +// textMap.put("type", "text"); +// textMap.put("text", "图里有什么"); +// Map typeMap = new HashMap<>(); +// typeMap.put("type", "image_url"); +// Map urlMap = new HashMap<>(); +// urlMap.put("url", "https://sfile.chatglm.cn/testpath/275ae5b6-5390-51ca-a81a-60332d1a7cac_0.png"); +// typeMap.put("image_url", urlMap); +// contentList.add(textMap); +// contentList.add(typeMap); +// ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), contentList); +// messages.add(chatMessage); +// String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); +// +// +// ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() +// .model(Constants.ModelChatGLM4V) +// .stream(Boolean.FALSE) +// .invokeMethod(Constants.invokeMethod) +// .messages(messages) +// .requestId(requestId) +// .build(); +// ModelApiResponse modelApiResponse = client.invokeModelApi(chatCompletionRequest); +// logger.info("model output: {}", mapper.writeValueAsString(modelApiResponse)); +// } +// + + /** + * 向量模型V4 + */ + @Test + public void testEmbeddings() throws JsonProcessingException { + EmbeddingRequest embeddingRequest = new EmbeddingRequest(); + embeddingRequest.setInput("hello world"); + embeddingRequest.setModel(Constants.ModelEmbedding2); + EmbeddingApiResponse apiResponse = client.invokeEmbeddingsApi(embeddingRequest); + logger.info("model output: {}", mapper.writeValueAsString(apiResponse)); + } + + + /** + * V4微调上传数据集 + */ + @Test + public void testUploadFile() throws JsonProcessingException { + String filePath = "demo.jsonl"; + + String path = ClassLoader.getSystemResource(filePath).getPath(); + String purpose = "fine-tune"; + UploadFileRequest request = UploadFileRequest.builder() + .purpose(purpose) + .filePath(path) + .build(); + + FileApiResponse fileApiResponse = client.invokeUploadFileApi(request); + logger.info("model output: {}", mapper.writeValueAsString(fileApiResponse)); + } + + + /** + * 微调V4-查询上传文件列表 + */ + @Test + public void testQueryUploadFileList() throws JsonProcessingException { + QueryFilesRequest queryFilesRequest = new QueryFilesRequest(); + QueryFileApiResponse queryFileApiResponse = client.queryFilesApi(queryFilesRequest); + logger.info("model output: {}", mapper.writeValueAsString(queryFileApiResponse)); + } + + @Test + public void testFileContent() throws IOException { + try { + + HttpxBinaryResponseContent httpxBinaryResponseContent = client.fileContent("20240514_ea19d21b-d256-4586-b0df-e80a45e3c286"); + String filePath = "demo_output.jsonl"; + String resourcePath = V4Test.class.getClassLoader().getResource("").getPath(); + + httpxBinaryResponseContent.streamToFile(resourcePath + "1" + filePath, 1000); + + } catch (IOException e) { + logger.error("file content error", e); + } + } + +//// @Test +//// public void deletedFile() throws IOException { +//// FileDelResponse fileDelResponse = client.deletedFile("20240514_ea19d21b-d256-4586-b0df-e80a45e3c286"); +//// +//// logger.info("model output: {}", mapper.writeValueAsString(fileDelResponse)); +//// +//// } +// +// + + /** + * 微调V4-创建微调任务 + */ + @Test + public void testCreateFineTuningJob() throws JsonProcessingException { + FineTuningJobRequest request = new FineTuningJobRequest(); + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + request.setRequestId(requestId); + request.setModel("chatglm3-6b"); + request.setTraining_file("file-20240118082608327-kp8qr"); + CreateFineTuningJobApiResponse createFineTuningJobApiResponse = client.createFineTuningJob(request); + logger.info("model output: {}", mapper.writeValueAsString(createFineTuningJobApiResponse)); + } + + + /** + * 微调V4-查询微调任务 + */ + @Test + public void testRetrieveFineTuningJobs() throws JsonProcessingException { + QueryFineTuningJobRequest queryFineTuningJobRequest = new QueryFineTuningJobRequest(); + queryFineTuningJobRequest.setJobId("ftjob-20240429172916475-fb7r9"); +// queryFineTuningJobRequest.setLimit(1); +// queryFineTuningJobRequest.setAfter(1); + QueryFineTuningJobApiResponse queryFineTuningJobApiResponse = client.retrieveFineTuningJobs(queryFineTuningJobRequest); + logger.info("model output: {}", mapper.writeValueAsString(queryFineTuningJobApiResponse)); + } + + + /** + * 微调V4-查询微调任务 + */ + @Test + public void testFueryFineTuningJobsEvents() throws JsonProcessingException { + QueryFineTuningJobRequest queryFineTuningJobRequest = new QueryFineTuningJobRequest(); + queryFineTuningJobRequest.setJobId("ftjob-20240429172916475-fb7r9"); + + QueryFineTuningEventApiResponse queryFineTuningEventApiResponse = client.queryFineTuningJobsEvents(queryFineTuningJobRequest); + logger.info("model output: {}", mapper.writeValueAsString(queryFineTuningEventApiResponse)); + } + + + /** + * testQueryPersonalFineTuningJobs V4-查询个人微调任务 + */ + @Test + public void testQueryPersonalFineTuningJobs() throws JsonProcessingException { + QueryPersonalFineTuningJobRequest queryPersonalFineTuningJobRequest = new QueryPersonalFineTuningJobRequest(); + queryPersonalFineTuningJobRequest.setLimit(1); + QueryPersonalFineTuningJobApiResponse queryPersonalFineTuningJobApiResponse = client.queryPersonalFineTuningJobs(queryPersonalFineTuningJobRequest); + logger.info("model output: {}", mapper.writeValueAsString(queryPersonalFineTuningJobApiResponse)); + } + + + @Test + public void testBatchesCreate() { + BatchCreateParams batchCreateParams = new BatchCreateParams( + "24h", + "/v4/chat/completions", + "20240514_ea19d21b-d256-4586-b0df-e80a45e3c286", + new HashMap() {{ + put("key1", "value1"); + put("key2", "value2"); + }} + ); + + BatchResponse batchResponse = client.batchesCreate(batchCreateParams); + logger.info("output: {}", batchResponse); +// output: BatchResponse(code=200, msg=调用成功, success=true, data=Batch(id=batch_1791021399316246528, completionWindow=24h, createdAt=1715847751822, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=null, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=null, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null)) + } + + @Test + public void testDeleteFineTuningJob() { + FineTuningJobIdRequest request = FineTuningJobIdRequest.builder().jobId("test").build(); + QueryFineTuningJobApiResponse queryFineTuningJobApiResponse = client.deleteFineTuningJob(request); + logger.info("output: {}", queryFineTuningJobApiResponse); + + } + + @Test + public void testCancelFineTuningJob() { + FineTuningJobIdRequest request = FineTuningJobIdRequest.builder().jobId("test").build(); + QueryFineTuningJobApiResponse queryFineTuningJobApiResponse = client.cancelFineTuningJob(request); + logger.info("output: {}", queryFineTuningJobApiResponse); + + } + + @Test + public void testBatchesRetrieve() { + BatchResponse batchResponse = client.batchesRetrieve("batch_1791021399316246528"); + logger.info("output: {}", batchResponse); + + } + + @Test + public void testDeleteFineTuningModel() { + FineTuningJobModelRequest request = FineTuningJobModelRequest.builder().fineTunedModel("test").build(); + + FineTunedModelsStatusResponse fineTunedModelsStatusResponse = client.deleteFineTuningModel(request); + logger.info("output: {}", fineTunedModelsStatusResponse); +// output: BatchResponse(code=200, msg=调用成功, success=true, data=Batch(id=batch_1791021399316246528, completionWindow=24h, createdAt=1715847752000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null)) + + } + + @Test + public void testBatchesList() { + QueryBatchRequest queryBatchRequest = new QueryBatchRequest(); + queryBatchRequest.setLimit(10); + QueryBatchResponse queryBatchResponse = client.batchesList(queryBatchRequest); + logger.info("output: {}", queryBatchResponse); +// output: QueryBatchResponse(code=200, msg=调用成功, success=true, data=BatchPage(object=list, data=[Batch(id=batch_1790291013237211136, completionWindow=24h, createdAt=1715673614000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=completed, cancelledAt=null, cancellingAt=1715673699000, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={description=job test}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null), Batch(id=batch_1790292763050508288, completionWindow=24h, createdAt=1715674031000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=completed, cancelledAt=null, cancellingAt=null, completedAt=1715766416000, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=1715754569000, inProgressAt=null, metadata={description=job test}, outputFileId=1715766415_e5a77222855a406ca8a082de28549c99, requestCounts=BatchRequestCounts(completed=2, failed=0, total=2), error=null), Batch(id=batch_1791021114887909376, completionWindow=24h, createdAt=1715847684000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null), Batch(id=batch_1791021399316246528, completionWindow=24h, createdAt=1715847752000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null)], error=null)) + + } + + @Test + public void testBatchesCancel() throws JsonProcessingException { + getAsyncTaskId(); + } + + private static String getAsyncTaskId() throws JsonProcessingException { + List messages = new ArrayList<>(); + ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大"); + messages.add(chatMessage); + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + // 函数调用参数构建部分 + List chatToolList = new ArrayList<>(); + ChatTool chatTool = new ChatTool(); + chatTool.setType(ChatToolType.FUNCTION.value()); + ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters(); + chatFunctionParameters.setType("object"); + Map properties = new HashMap<>(); + properties.put("location", new HashMap() {{ + put("type", "string"); + put("description", "城市,如:北京"); + }}); + properties.put("unit", new HashMap() {{ + put("type", "string"); + put("enum", new ArrayList() {{ + add("celsius"); + add("fahrenheit"); + }}); + }}); + chatFunctionParameters.setProperties(properties); + ChatFunction chatFunction = ChatFunction.builder() + .name("get_weather") + .description("Get the current weather of a location") + .parameters(chatFunctionParameters) + .build(); + chatTool.setFunction(chatFunction); + chatToolList.add(chatTool); + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(Constants.ModelChatGLM4) + .stream(Boolean.FALSE) + .invokeMethod(Constants.invokeMethodAsync) + .messages(messages) + .requestId(requestId) + .tools(chatToolList) + .toolChoice("auto") + .build(); + ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest); + logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp)); + return invokeModelApiResp.getData().getId(); + } + + + private static void testQueryResult(String taskId) throws JsonProcessingException { + QueryModelResultRequest request = new QueryModelResultRequest(); + request.setTaskId(taskId); + QueryModelResultResponse queryResultResp = client.queryModelResult(request); + logger.info("model output {}", mapper.writeValueAsString(queryResultResp)); + } + + public static Flowable mapStreamToAccumulator(Flowable flowable) { + return flowable.map(chunk -> { + return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId()); + }); + } +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/WebSearchToolsTest.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/WebSearchToolsTest.java new file mode 100644 index 00000000..eca586fa --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/WebSearchToolsTest.java @@ -0,0 +1,246 @@ +package org.ruoyi.common.chat.demo.zhipu; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.zhipu.oapi.ClientV4; +import com.zhipu.oapi.Constants; +import com.zhipu.oapi.service.v4.tools.*; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + + + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.zhipu.oapi.core.response.HttpxBinaryResponseContent; +import com.zhipu.oapi.service.v4.batchs.BatchCreateParams; +import com.zhipu.oapi.service.v4.batchs.BatchResponse; +import com.zhipu.oapi.service.v4.batchs.QueryBatchResponse; +import com.zhipu.oapi.service.v4.embedding.EmbeddingApiResponse; +import com.zhipu.oapi.service.v4.embedding.EmbeddingRequest; +import com.zhipu.oapi.service.v4.file.*; +import com.zhipu.oapi.service.v4.fine_turning.*; +import com.zhipu.oapi.service.v4.image.CreateImageRequest; +import com.zhipu.oapi.service.v4.image.ImageApiResponse; +import com.zhipu.oapi.service.v4.model.*; +import io.reactivex.Flowable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + + +public class WebSearchToolsTest { + + private final static Logger logger = LoggerFactory.getLogger(WebSearchToolsTest.class); + private static final String API_SECRET_KEY = "xx"; + + private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY) + .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS) + .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS)) + .build(); + private static final ObjectMapper mapper = new ObjectMapper(); + // 请自定义自己的业务id + private static final String requestIdTemplate = "mycompany-%d"; + + + @Test + public void test1() throws JsonProcessingException { + +// json 转换 ArrayList + String jsonString = "[\n" + + " {\n" + + " \"content\": \"今天武汉天气怎么样\",\n" + + " \"role\": \"user\"\n" + + " }\n" + + " ]"; + + ArrayList messages = new ObjectMapper().readValue(jsonString, new TypeReference>() { + }); + + + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder() + .model("web-search-pro") + .stream(Boolean.TRUE) + .messages(messages) + .requestId(requestId) + .build(); + WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest); + if (webSearchApiResponse.isSuccess()) { + AtomicBoolean isFirst = new AtomicBoolean(true); + List choices = new ArrayList<>(); + AtomicReference lastAccumulator = new AtomicReference<>(); + + webSearchApiResponse.getFlowable().map(result -> result) + .doOnNext(accumulator -> { + { + if (isFirst.getAndSet(false)) { + logger.info("Response: "); + } + ChoiceDelta delta = accumulator.getChoices().get(0).getDelta(); + if (delta != null && delta.getToolCalls() != null) { + logger.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls())); + } + choices.add(delta); + lastAccumulator.set(accumulator); + + } + }) + .doOnComplete(() -> System.out.println("Stream completed.")) + .doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors + .blockingSubscribe();// Use blockingSubscribe instead of blockingGet() + + WebSearchPro chatMessageAccumulator = lastAccumulator.get(); + + webSearchApiResponse.setFlowable(null);// 打印前置空 + webSearchApiResponse.setData(chatMessageAccumulator); + } + logger.info("model output: {}", mapper.writeValueAsString(webSearchApiResponse)); + client.getConfig().getHttpClient().dispatcher().executorService().shutdown(); + + client.getConfig().getHttpClient().connectionPool().evictAll(); + // List all active threads + for (Thread t : Thread.getAllStackTraces().keySet()) { + logger.info("Thread: " + t.getName() + " State: " + t.getState()); + } + + } + + + @Test + public void test2() throws JsonProcessingException { + +// json 转换 ArrayList + String jsonString = "[\n" + + " {\n" + + " \"content\": \"今天天气怎么样\",\n" + + " \"role\": \"user\"\n" + + " }\n" + + " ]"; + + ArrayList messages = new ObjectMapper().readValue(jsonString, new TypeReference>() { + }); + + + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder() + .model("web-search-pro") + .stream(Boolean.FALSE) + .messages(messages) + .requestId(requestId) + .build(); + WebSearchApiResponse webSearchApiResponse = client.invokeWebSearchPro(chatCompletionRequest); + + logger.info("model output: {}", mapper.writeValueAsString(webSearchApiResponse)); + + } + + + @Test + public void testFunctionSSE() throws JsonProcessingException { + List messages = new ArrayList<>(); + ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "成都到北京要多久,天气如何"); + messages.add(chatMessage); + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + // 函数调用参数构建部分 + List chatToolList = new ArrayList<>(); + ChatTool chatTool = new ChatTool(); + + chatTool.setType(ChatToolType.FUNCTION.value()); + ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters(); + chatFunctionParameters.setType("object"); + Map properties = new HashMap<>(); + properties.put("location", new HashMap() {{ + put("type", "string"); + put("description", "城市,如:北京"); + }}); + properties.put("unit", new HashMap() {{ + put("type", "string"); + put("enum", new ArrayList() {{ + add("celsius"); + add("fahrenheit"); + }}); + }}); + chatFunctionParameters.setProperties(properties); + ChatFunction chatFunction = ChatFunction.builder() + .name("get_weather") + .description("Get the current weather of a location") + .parameters(chatFunctionParameters) + .build(); + chatTool.setFunction(chatFunction); + chatToolList.add(chatTool); + HashMap extraJson = new HashMap<>(); + extraJson.put("temperature", 0.5); + extraJson.put("max_tokens", 50); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(Constants.ModelChatGLM4) + .stream(Boolean.TRUE) + .messages(messages) + .requestId(requestId) + .tools(chatToolList) + .toolChoice("auto") + .extraJson(extraJson) + .build(); + ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest); + if (sseModelApiResp.isSuccess()) { + AtomicBoolean isFirst = new AtomicBoolean(true); + List choices = new ArrayList<>(); + ChatMessageAccumulator chatMessageAccumulator = mapStreamToAccumulator(sseModelApiResp.getFlowable()) + .doOnNext(accumulator -> { + { + if (isFirst.getAndSet(false)) { + logger.info("Response: "); + } + if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) { + String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls()); + logger.info("tool_calls: {}", jsonString); + } + if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) { + logger.info(accumulator.getDelta().getContent()); + } + choices.add(accumulator.getChoice()); + } + }) + .doOnComplete(System.out::println) + .lastElement() + .blockingGet(); + + + ModelData data = new ModelData(); + data.setChoices(choices); + data.setUsage(chatMessageAccumulator.getUsage()); + data.setId(chatMessageAccumulator.getId()); + data.setCreated(chatMessageAccumulator.getCreated()); + data.setRequestId(chatCompletionRequest.getRequestId()); + sseModelApiResp.setFlowable(null);// 打印前置空 + sseModelApiResp.setData(data); + } + logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp)); + } + + public static Flowable mapStreamToAccumulator(Flowable flowable) { + return flowable.map(chunk -> { + return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId()); + }); + } + +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/chat/Message.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/chat/Message.java index 9637569c..86faa475 100644 --- a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/chat/Message.java +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/chat/Message.java @@ -2,12 +2,11 @@ package org.ruoyi.common.chat.entity.chat; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; -import lombok.AllArgsConstructor; import lombok.Data; -import lombok.Getter; +import org.ruoyi.common.chat.entity.chat.tool.ToolCalls; import java.io.Serializable; +import java.util.List; /** * 描述: @@ -18,21 +17,10 @@ import java.io.Serializable; @Data @JsonInclude(JsonInclude.Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) -public class Message implements Serializable { - - /** - * 目前支持四个中角色参考官网,进行情景输入: - * https://platform.openai.com/docs/guides/chat/introduction - */ - private String role; +public class Message extends BaseMessage implements Serializable { private Object content; - private String name; - - @JsonProperty("function_call") - private FunctionCall functionCall; - public static Builder builder() { return new Builder(); } @@ -41,44 +29,37 @@ public class Message implements Serializable { * 构造函数 * * @param role 角色 - * @param content 描述主题信息 * @param name name + * @param content content * @param functionCall functionCall */ - public Message(String role, String content, String name, FunctionCall functionCall) { - this.role = role; + public Message(String role, String name, String content, List toolCalls, String toolCallId, FunctionCall functionCall) { this.content = content; - this.name = name; - this.functionCall = functionCall; + super.setRole(role); + super.setName(name); + super.setToolCalls(toolCalls); + super.setToolCallId(toolCallId); + super.setFunctionCall(functionCall); } public Message() { } private Message(Builder builder) { - setRole(builder.role); setContent(builder.content); - setName(builder.name); - setFunctionCall(builder.functionCall); - } - - - @Getter - @AllArgsConstructor - public enum Role { - - SYSTEM("system"), - USER("user"), - ASSISTANT("assistant"), - FUNCTION("function"), - ; - private String name; + super.setRole(builder.role); + super.setName(builder.name); + super.setFunctionCall(builder.functionCall); + super.setToolCalls(builder.toolCalls); + super.setToolCallId(builder.toolCallId); } public static final class Builder { private String role; private String content; private String name; + private String toolCallId; + private List toolCalls; private FunctionCall functionCall; public Builder() { @@ -109,6 +90,16 @@ public class Message implements Serializable { return this; } + public Builder toolCalls(List toolCalls) { + this.toolCalls = toolCalls; + return this; + } + + public Builder toolCallId(String toolCallId) { + this.toolCallId = toolCallId; + return this; + } + public Message build() { return new Message(this); } diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/chat/ResponseFormat.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/chat/ResponseFormat.java index 5c154844..d1c691b6 100644 --- a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/chat/ResponseFormat.java +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/chat/ResponseFormat.java @@ -24,5 +24,6 @@ public class ResponseFormat { TEXT("text"), ; private final String name; + } } diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/listener/WebSocketEventListener.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/listener/WebSocketEventListener.java index 2cc3529b..08088c47 100644 --- a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/listener/WebSocketEventListener.java +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/listener/WebSocketEventListener.java @@ -75,6 +75,8 @@ public class WebSocketEventListener extends EventSourceListener { return; } ResponseBody body = response.body(); + + if (Objects.nonNull(body)) { // 返回非流式回复内容 if(response.code() == OpenAIConst.SUCCEED_CODE){ diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiClient.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiClient.java index 0ee97662..ccc52147 100644 --- a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiClient.java +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiClient.java @@ -2,6 +2,7 @@ package org.ruoyi.common.chat.openai; import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; import io.reactivex.Single; import lombok.Getter; import lombok.extern.slf4j.Slf4j; @@ -12,9 +13,7 @@ import okhttp3.RequestBody; import org.ruoyi.common.chat.constant.OpenAIConst; import org.ruoyi.common.chat.entity.billing.BillingUsage; import org.ruoyi.common.chat.entity.billing.Subscription; -import org.ruoyi.common.chat.entity.chat.ChatCompletion; -import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse; -import org.ruoyi.common.chat.entity.chat.Message; +import org.ruoyi.common.chat.entity.chat.*; import org.ruoyi.common.chat.entity.common.DeleteResponse; import org.ruoyi.common.chat.entity.common.OpenAiResponse; import org.ruoyi.common.chat.entity.completions.Completion; @@ -43,6 +42,8 @@ import org.ruoyi.common.chat.openai.function.KeyStrategyFunction; import org.ruoyi.common.chat.openai.interceptor.DefaultOpenAiAuthInterceptor; import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor; import org.ruoyi.common.chat.openai.interceptor.OpenAiAuthInterceptor; +import org.ruoyi.common.chat.openai.plugin.PluginAbstract; +import org.ruoyi.common.chat.openai.plugin.PluginParam; import org.ruoyi.common.core.exception.base.BaseException; import org.jetbrains.annotations.NotNull; import retrofit2.Retrofit; @@ -696,6 +697,90 @@ public class OpenAiClient { return whisperResponse.blockingGet(); } + /** + * 插件问答简易版 + * 默认取messages最后一个元素构建插件对话 + * 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613 + * + * @param chatCompletion 参数 + * @param plugin 插件 + * @param 插件自定义函数的请求值 + * @param 插件自定义函数的返回值 + * @return ChatCompletionResponse + */ + public ChatCompletionResponse chatCompletionWithPlugin(ChatCompletion chatCompletion, PluginAbstract plugin) { + if (Objects.isNull(plugin)) { + return this.chatCompletion(chatCompletion); + } + if (CollectionUtil.isEmpty(chatCompletion.getMessages())) { + throw new BaseException(CommonError.MESSAGE_NOT_NUL.msg()); + } + List messages = chatCompletion.getMessages(); + Functions functions = Functions.builder() + .name(plugin.getFunction()) + .description(plugin.getDescription()) + .parameters(plugin.getParameters()) + .build(); + //没有值,设置默认值 + if (Objects.isNull(chatCompletion.getFunctionCall())) { + chatCompletion.setFunctionCall("auto"); + } + //tip: 覆盖自己设置的functions参数,使用plugin构造的functions + chatCompletion.setFunctions(Collections.singletonList(functions)); + //调用OpenAi + ChatCompletionResponse functionCallChatCompletionResponse = this.chatCompletion(chatCompletion); + ChatChoice chatChoice = functionCallChatCompletionResponse.getChoices().get(0); + log.debug("构造的方法值:{}", chatChoice.getMessage().getFunctionCall()); + + R realFunctionParam = (R) JSONUtil.toBean(chatChoice.getMessage().getFunctionCall().getArguments(), plugin.getR()); + T tq = plugin.func(realFunctionParam); + + FunctionCall functionCall = FunctionCall.builder() + .arguments(chatChoice.getMessage().getFunctionCall().getArguments()) + .name(plugin.getFunction()) + .build(); + messages.add(Message.builder().role(Message.Role.ASSISTANT).content("function_call").functionCall(functionCall).build()); + messages.add(Message.builder().role(Message.Role.FUNCTION).name(plugin.getFunction()).content(plugin.content(tq)).build()); + //设置第二次,请求的参数 + chatCompletion.setFunctionCall(null); + chatCompletion.setFunctions(null); + + ChatCompletionResponse chatCompletionResponse = this.chatCompletion(chatCompletion); + log.debug("自定义的方法返回值:{}", chatCompletionResponse.getChoices()); + return chatCompletionResponse; + } + + /** + * 插件问答简易版 + * 默认取messages最后一个元素构建插件对话 + * 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613 + * + * @param messages 问答参数 + * @param plugin 插件 + * @param 插件自定义函数的请求值 + * @param 插件自定义函数的返回值 + * @return ChatCompletionResponse + */ + public ChatCompletionResponse chatCompletionWithPlugin(List messages, PluginAbstract plugin) { + return chatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), plugin); + } + + /** + * 插件问答简易版 + * 默认取messages最后一个元素构建插件对话 + * + * @param messages 问答参数 + * @param model 模型 + * @param plugin 插件 + * @param 插件自定义函数的请求值 + * @param 插件自定义函数的返回值 + * @return ChatCompletionResponse + */ + public ChatCompletionResponse chatCompletionWithPlugin(List messages, String model, PluginAbstract plugin) { + ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).model(model).build(); + return this.chatCompletionWithPlugin(chatCompletion, plugin); + } + /** * 简易版 语音翻译:目前仅支持翻译为英文 * diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiStreamClient.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiStreamClient.java index cad1bbbf..e7992022 100644 --- a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiStreamClient.java +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/OpenAiStreamClient.java @@ -3,6 +3,7 @@ package org.ruoyi.common.chat.openai; import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.http.ContentType; +import cn.hutool.json.JSONUtil; import com.fasterxml.jackson.databind.ObjectMapper; import io.reactivex.Single; import lombok.Getter; @@ -17,10 +18,7 @@ import org.ruoyi.common.chat.entity.Tts.TextToSpeech; import org.ruoyi.common.chat.entity.billing.BillingUsage; import org.ruoyi.common.chat.entity.billing.KeyInfo; import org.ruoyi.common.chat.entity.billing.Subscription; -import org.ruoyi.common.chat.entity.chat.BaseChatCompletion; -import org.ruoyi.common.chat.entity.chat.ChatCompletion; -import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse; -import org.ruoyi.common.chat.entity.chat.ChatCompletionWithPicture; +import org.ruoyi.common.chat.entity.chat.*; import org.ruoyi.common.chat.entity.embeddings.Embedding; import org.ruoyi.common.chat.entity.embeddings.EmbeddingResponse; import org.ruoyi.common.chat.entity.files.UploadFileResponse; @@ -29,6 +27,7 @@ import org.ruoyi.common.chat.entity.images.ImageResponse; import org.ruoyi.common.chat.entity.models.Model; import org.ruoyi.common.chat.entity.models.ModelResponse; import org.ruoyi.common.chat.entity.whisper.Transcriptions; +import org.ruoyi.common.chat.entity.whisper.Translations; import org.ruoyi.common.chat.entity.whisper.WhisperResponse; import org.ruoyi.common.chat.openai.exception.CommonError; import org.ruoyi.common.chat.openai.function.KeyRandomStrategy; @@ -36,6 +35,10 @@ import org.ruoyi.common.chat.openai.function.KeyStrategyFunction; import org.ruoyi.common.chat.openai.interceptor.DefaultOpenAiAuthInterceptor; import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor; import org.ruoyi.common.chat.openai.interceptor.OpenAiAuthInterceptor; +import org.ruoyi.common.chat.openai.plugin.PluginAbstract; +import org.ruoyi.common.chat.openai.plugin.PluginParam; +import org.ruoyi.common.chat.sse.DefaultPluginListener; +import org.ruoyi.common.chat.sse.PluginListener; import org.ruoyi.common.core.exception.base.BaseException; import org.jetbrains.annotations.NotNull; import retrofit2.Call; @@ -186,6 +189,93 @@ public class OpenAiStreamClient { } } + /** + * 插件问答简易版 + * 默认取messages最后一个元素构建插件对话 + * 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613 + * + * @param chatCompletion 参数 + * @param eventSourceListener sse监听器 + * @param pluginEventSourceListener 插件sse监听器,收集function call返回信息 + * @param plugin 插件 + * @param 插件自定义函数的请求值 + * @param 插件自定义函数的返回值 + */ + public void streamChatCompletionWithPlugin(ChatCompletion chatCompletion, EventSourceListener eventSourceListener, PluginListener pluginEventSourceListener, PluginAbstract plugin) { + if (Objects.isNull(plugin)) { + this.streamChatCompletion(chatCompletion, eventSourceListener); + return; + } + if (CollectionUtil.isEmpty(chatCompletion.getMessages())) { + throw new BaseException(CommonError.MESSAGE_NOT_NUL.msg()); + } + Functions functions = Functions.builder() + .name(plugin.getFunction()) + .description(plugin.getDescription()) + .parameters(plugin.getParameters()) + .build(); + //没有值,设置默认值 + if (Objects.isNull(chatCompletion.getFunctionCall())) { + chatCompletion.setFunctionCall("auto"); + } + //tip: 覆盖自己设置的functions参数,使用plugin构造的functions + chatCompletion.setFunctions(Collections.singletonList(functions)); + //调用OpenAi + if (Objects.isNull(pluginEventSourceListener)) { + pluginEventSourceListener = new DefaultPluginListener(this, eventSourceListener, plugin, chatCompletion); + } + this.streamChatCompletion(chatCompletion, pluginEventSourceListener); + } + + + /** + * 插件问答简易版 + * 默认取messages最后一个元素构建插件对话 + * 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613 + * + * @param chatCompletion 参数 + * @param eventSourceListener sse监听器 + * @param plugin 插件 + * @param 插件自定义函数的请求值 + * @param 插件自定义函数的返回值 + */ + public void streamChatCompletionWithPlugin(ChatCompletion chatCompletion, EventSourceListener eventSourceListener, PluginAbstract plugin) { + PluginListener pluginEventSourceListener = new DefaultPluginListener(this, eventSourceListener, plugin, chatCompletion); + this.streamChatCompletionWithPlugin(chatCompletion, eventSourceListener, pluginEventSourceListener, plugin); + } + + + /** + * 插件问答简易版 + * 默认取messages最后一个元素构建插件对话 + * 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613 + * + * @param messages 问答参数 + * @param eventSourceListener sse监听器 + * @param plugin 插件 + * @param 插件自定义函数的请求值 + * @param 插件自定义函数的返回值 + */ + public void streamChatCompletionWithPlugin(List messages, EventSourceListener eventSourceListener, PluginAbstract plugin) { + this.streamChatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), eventSourceListener, plugin); + } + + /** + * 插件问答简易版 + * 默认取messages最后一个元素构建插件对话 + * + * @param messages 问答参数 + * @param model 模型 + * @param eventSourceListener eventSourceListener + * @param plugin 插件 + * @param 插件自定义函数的请求值 + * @param 插件自定义函数的返回值 + */ + public void streamChatCompletionWithPlugin(List messages, String model, EventSourceListener eventSourceListener, PluginAbstract plugin) { + ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).model(model).build(); + this.streamChatCompletionWithPlugin(chatCompletion, eventSourceListener, plugin); + } + /** * 根据描述生成图片 @@ -418,6 +508,95 @@ public class OpenAiStreamClient { } } + /** + * 插件问答简易版 + * 默认取messages最后一个元素构建插件对话 + * 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613 + * + * @param chatCompletion 参数 + * @param plugin 插件 + * @param 插件自定义函数的请求值 + * @param 插件自定义函数的返回值 + * @return ChatCompletionResponse + */ + public ChatCompletionResponse chatCompletionWithPlugin(ChatCompletion chatCompletion, PluginAbstract plugin) { + if (Objects.isNull(plugin)) { + return this.chatCompletion(chatCompletion); + } + if (CollectionUtil.isEmpty(chatCompletion.getMessages())) { + throw new BaseException(CommonError.MESSAGE_NOT_NUL.msg()); + } + List messages = chatCompletion.getMessages(); + Functions functions = Functions.builder() + .name(plugin.getFunction()) + .description(plugin.getDescription()) + .parameters(plugin.getParameters()) + .build(); + //没有值,设置默认值 + if (Objects.isNull(chatCompletion.getFunctionCall())) { + chatCompletion.setFunctionCall("auto"); + } + //tip: 覆盖自己设置的functions参数,使用plugin构造的functions + chatCompletion.setFunctions(Collections.singletonList(functions)); + //调用OpenAi + ChatCompletionResponse functionCallChatCompletionResponse = this.chatCompletion(chatCompletion); + ChatChoice chatChoice = functionCallChatCompletionResponse.getChoices().get(0); + log.debug("构造的方法值:{}", chatChoice.getMessage().getFunctionCall()); + + R realFunctionParam = (R) JSONUtil.toBean(chatChoice.getMessage().getFunctionCall().getArguments(), plugin.getR()); + T tq = plugin.func(realFunctionParam); + + FunctionCall functionCall = FunctionCall.builder() + .arguments(chatChoice.getMessage().getFunctionCall().getArguments()) + .name(plugin.getFunction()) + .build(); + messages.add(Message.builder().role(Message.Role.ASSISTANT).content("function_call").functionCall(functionCall).build()); + messages.add(Message.builder().role(Message.Role.FUNCTION).name(plugin.getFunction()).content(plugin.content(tq)).build()); + //设置第二次,请求的参数 + chatCompletion.setFunctionCall(null); + chatCompletion.setFunctions(null); + + ChatCompletionResponse chatCompletionResponse = this.chatCompletion(chatCompletion); + log.debug("自定义的方法返回值:{}", chatCompletionResponse.getChoices()); + return chatCompletionResponse; + } + + /** + * 插件问答简易版 + * 默认取messages最后一个元素构建插件对话 + * 默认模型:ChatCompletion.Model.GPT_3_5_TURBO_16K_0613 + * + * @param messages 问答参数 + * @param plugin 插件 + * @param 插件自定义函数的请求值 + * @param 插件自定义函数的返回值 + * @return ChatCompletionResponse + */ + public ChatCompletionResponse chatCompletionWithPlugin(List messages, PluginAbstract plugin) { + return chatCompletionWithPlugin(messages, ChatCompletion.Model.GPT_3_5_TURBO_16K_0613.getName(), plugin); + } + + /** + * 插件问答简易版 + * 默认取messages最后一个元素构建插件对话 + * + * @param messages 问答参数 + * @param model 模型 + * @param plugin 插件 + * @param 插件自定义函数的请求值 + * @param 插件自定义函数的返回值 + * @return ChatCompletionResponse + */ + public ChatCompletionResponse chatCompletionWithPlugin(List messages, String model, PluginAbstract plugin) { + ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).model(model).build(); + return this.chatCompletionWithPlugin(chatCompletion, plugin); + } + + + + + + /** * 构造 diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/exception/CommonError.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/exception/CommonError.java index 9bd65b82..e59e9cdc 100644 --- a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/exception/CommonError.java +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/exception/CommonError.java @@ -7,6 +7,7 @@ package org.ruoyi.common.chat.openai.exception; * 2023-02-11 */ public enum CommonError implements IError { + MESSAGE_NOT_NUL(500, "Message 不能为空"), API_KEYS_NOT_NUL(500, "API KEYS 不能为空"), NO_ACTIVE_API_KEYS(500, "没有可用的API KEYS"), SYS_ERROR(500, "系统繁忙"), @@ -19,8 +20,8 @@ public enum CommonError implements IError { ; - private int code; - private String msg; + private final int code; + private final String msg; CommonError(int code, String msg) { this.code = code; diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/plugin/PluginAbstract.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/plugin/PluginAbstract.java new file mode 100644 index 00000000..ddb2045e --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/plugin/PluginAbstract.java @@ -0,0 +1,88 @@ +package org.ruoyi.common.chat.openai.plugin; + +import cn.hutool.core.collection.CollectionUtil; +import cn.hutool.json.JSONObject; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.AllArgsConstructor; +import lombok.Data; +import org.ruoyi.common.chat.entity.chat.Parameters; + +import java.util.List; +import java.util.stream.Collectors; + +@Data +@AllArgsConstructor +public abstract class PluginAbstract { + + private Class R; + + private String name; + + private String function; + + private String description; + + private List args; + + private List required; + + private Parameters parameters; + + public PluginAbstract(Class r) { + R = r; + } + + public void setRequired(List required) { + if (CollectionUtil.isEmpty(required)) { + this.required = this.getArgs().stream().filter(e -> e.isRequired()).map(Arg::getName).collect(Collectors.toList()); + return; + } + this.required = required; + } + + private void setRequired() { + if (CollectionUtil.isEmpty(required)) { + this.required = this.getArgs().stream().filter(e -> e.isRequired()).map(Arg::getName).collect(Collectors.toList()); + } + } + + private void setParameters() { + JSONObject properties = new JSONObject(); + args.forEach(e -> { + JSONObject param = new JSONObject(); + param.putOpt("type", e.getType()); + param.putOpt("enum", e.getEnumDictValue()); + param.putOpt("description", e.getDescription()); + properties.putOpt(e.getName(), param); + }); + this.parameters = Parameters.builder() + .type("object") + .properties(properties) + .required(this.getRequired()) + .build(); + } + + public void setArgs(List args) { + this.args = args; + setRequired(); + setParameters(); + } + + @Data + public static class Arg { + private String name; + private String type; + private String description; + @JsonIgnore + private boolean enumDict; + @JsonProperty("enum") + private List enumDictValue; + @JsonIgnore + private boolean required; + } + + public abstract T func(R args); + + public abstract String content(T t); +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/plugin/PluginParam.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/plugin/PluginParam.java new file mode 100644 index 00000000..c5a5909d --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/openai/plugin/PluginParam.java @@ -0,0 +1,7 @@ +package org.ruoyi.common.chat.openai.plugin; + +import lombok.Data; + +@Data +public class PluginParam { +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/CmdPlugin.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/CmdPlugin.java new file mode 100644 index 00000000..428d6e0e --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/CmdPlugin.java @@ -0,0 +1,36 @@ +package org.ruoyi.common.chat.plugin; + +import org.ruoyi.common.chat.openai.plugin.PluginAbstract; + +import java.io.IOException; + +public class CmdPlugin extends PluginAbstract { + + public CmdPlugin(Class r) { + super(r); + } + + @Override + public CmdResp func(CmdReq args) { + try { + if("计算器".equals(args.getCmd())){ + Runtime.getRuntime().exec("calc"); + }else if("记事本".equals(args.getCmd())){ + Runtime.getRuntime().exec("notepad"); + }else if("命令行".equals(args.getCmd())){ + String [] cmd={"cmd","/C","start copy exel exe2"}; + Runtime.getRuntime().exec(cmd); + } + } catch (IOException e) { + throw new RuntimeException("指令执行失败"); + } + CmdResp resp = new CmdResp(); + resp.setResult(args.getCmd()+"指令执行成功!"); + return resp; + } + + @Override + public String content(CmdResp resp) { + return resp.getResult(); + } +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/CmdReq.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/CmdReq.java new file mode 100644 index 00000000..a275150b --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/CmdReq.java @@ -0,0 +1,13 @@ +package org.ruoyi.common.chat.plugin; + + +import lombok.Data; +import org.ruoyi.common.chat.openai.plugin.PluginParam; + +@Data +public class CmdReq extends PluginParam { + /** + * 指令 + */ + private String cmd; +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/CmdResp.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/CmdResp.java new file mode 100644 index 00000000..4e101393 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/CmdResp.java @@ -0,0 +1,12 @@ +package org.ruoyi.common.chat.plugin; + +import lombok.Data; + +@Data +public class CmdResp { + + /** + * 返回结果 + */ + private String result; +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/SqlPlugin.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/SqlPlugin.java new file mode 100644 index 00000000..a40734d7 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/SqlPlugin.java @@ -0,0 +1,88 @@ +package org.ruoyi.common.chat.plugin; + +import org.ruoyi.common.chat.openai.plugin.PluginAbstract; + +import java.sql.*; + +/** + * @author ageer + */ +public class SqlPlugin extends PluginAbstract { + + public SqlPlugin(Class r) { + super(r); + } + + + + @Override + public SqlResp func(SqlReq args) { + SqlResp resp = new SqlResp(); + resp.setUserBalance(getBalance(args.getUsername())); + return resp; + } + + @Override + public String content(SqlResp resp) { + return "用户余额:"+resp.getUserBalance(); + } + + + public String getBalance(String userName) { + // MySQL 8.0 以下版本 - JDBC 驱动名及数据库 URL + String JDBC_DRIVER = "com.mysql.cj.jdbc.Driver"; + String DB_URL = "jdbc:mysql://43.139.70.230:3306/ry-vue"; + // 数据库的用户名与密码,需要根据自己的设置 + String USER = "ry-vue"; + String PASS = "BXZiGsY35K523Xfx"; + Connection conn = null; + Statement stmt = null; + String balance = "0.1"; + + try{ + // 注册 JDBC 驱动 + Class.forName(JDBC_DRIVER); + + // 打开链接 + System.out.println("连接数据库..."); + conn = DriverManager.getConnection(DB_URL,USER,PASS); + + // 执行查询 + System.out.println(" 实例化Statement对象..."); + stmt = conn.createStatement(); + String sql; + sql = "SELECT user_balance FROM sys_user where user_name ='" + userName + "'"; + ResultSet rs = stmt.executeQuery(sql); + // 展开结果集数据库 + while(rs.next()){ + // 通过字段检索 + balance = rs.getString("user_balance"); + // 输出数据 + System.out.print("余额: " + balance); + System.out.print("\n"); + } + // 完成后关闭 + rs.close(); + stmt.close(); + conn.close(); + }catch(SQLException se){ + // 处理 JDBC 错误 + se.printStackTrace(); + }catch(Exception e){ + // 处理 Class.forName 错误 + e.printStackTrace(); + }finally{ + // 关闭资源 + try{ + if(stmt!=null) stmt.close(); + }catch(SQLException se2){ + }// 什么都不做 + try{ + if(conn!=null) conn.close(); + }catch(SQLException se){ + se.printStackTrace(); + } + } + return balance; + } +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/SqlReq.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/SqlReq.java new file mode 100644 index 00000000..481ba72c --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/SqlReq.java @@ -0,0 +1,13 @@ +package org.ruoyi.common.chat.plugin; + + +import lombok.Data; +import org.ruoyi.common.chat.openai.plugin.PluginParam; + +@Data +public class SqlReq extends PluginParam { + /** + * 用户名称 + */ + private String username; +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/SqlResp.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/SqlResp.java new file mode 100644 index 00000000..b84b555b --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/plugin/SqlResp.java @@ -0,0 +1,12 @@ +package org.ruoyi.common.chat.plugin; + +import lombok.Data; + +@Data +public class SqlResp { + + /** + * 用户余额 + */ + private String userBalance; +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/sse/ConsoleEventSourceListener.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/sse/ConsoleEventSourceListener.java new file mode 100644 index 00000000..c05fa19f --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/sse/ConsoleEventSourceListener.java @@ -0,0 +1,56 @@ +package org.ruoyi.common.chat.sse; + +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; + +import java.util.Objects; + +/** + * 描述: sse + * + * @author https:www.unfbx.com + * 2023-02-28 + */ +@Slf4j +public class ConsoleEventSourceListener extends EventSourceListener { + + @Override + public void onOpen(EventSource eventSource, Response response) { + log.info("OpenAI建立sse连接..."); + } + + @Override + public void onEvent(EventSource eventSource, String id, String type, String data) { + log.info("OpenAI返回数据:{}", data); + if ("[DONE]".equals(data)) { + log.info("OpenAI返回数据结束了"); + return; + } + } + + @Override + public void onClosed(EventSource eventSource) { + log.info("OpenAI关闭sse连接..."); + } + + @SneakyThrows + @Override + public void onFailure(EventSource eventSource, Throwable t, Response response) { + if(Objects.isNull(response)){ + log.error("OpenAI sse连接异常:{}", t); + eventSource.cancel(); + return; + } + ResponseBody body = response.body(); + if (Objects.nonNull(body)) { + log.error("OpenAI sse连接异常data:{},异常:{}", body.string(), t); + } else { + log.error("OpenAI sse连接异常data:{},异常:{}", response, t); + } + eventSource.cancel(); + } +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/sse/DefaultPluginListener.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/sse/DefaultPluginListener.java new file mode 100644 index 00000000..ab6fcf1e --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/sse/DefaultPluginListener.java @@ -0,0 +1,22 @@ +package org.ruoyi.common.chat.sse; + +import lombok.extern.slf4j.Slf4j; + +import okhttp3.sse.EventSourceListener; +import org.ruoyi.common.chat.entity.chat.ChatCompletion; +import org.ruoyi.common.chat.openai.OpenAiStreamClient; +import org.ruoyi.common.chat.openai.plugin.PluginAbstract; + +/** + * 描述: 插件开发返回信息收集sse监听器 + * + * @author https:www.unfbx.com + * 2023-08-18 + */ +@Slf4j +public class DefaultPluginListener extends PluginListener { + + public DefaultPluginListener(OpenAiStreamClient client, EventSourceListener eventSourceListener, PluginAbstract plugin, ChatCompletion chatCompletion) { + super(client, eventSourceListener, plugin, chatCompletion); + } +} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/sse/PluginListener.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/sse/PluginListener.java new file mode 100644 index 00000000..6701a251 --- /dev/null +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/sse/PluginListener.java @@ -0,0 +1,126 @@ +package org.ruoyi.common.chat.sse; + +import cn.hutool.json.JSONUtil; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import org.jetbrains.annotations.NotNull; +import org.ruoyi.common.chat.constant.OpenAIConst; +import org.ruoyi.common.chat.entity.chat.ChatCompletion; +import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse; +import org.ruoyi.common.chat.entity.chat.FunctionCall; +import org.ruoyi.common.chat.entity.chat.Message; +import org.ruoyi.common.chat.openai.OpenAiStreamClient; +import org.ruoyi.common.chat.openai.plugin.PluginAbstract; +import org.ruoyi.common.chat.openai.plugin.PluginParam; + +import java.util.Objects; + +/** + * 描述: 插件开发返回信息收集sse监听器 + * + * @author https:www.unfbx.com + * 2023-08-18 + */ +@Slf4j +public abstract class PluginListener extends EventSourceListener { + /** + * openAi插件构建的参数 + */ + private String arguments = ""; + + /** + * 获取openAi插件构建的参数 + * + * @return arguments + */ + private String getArguments() { + return this.arguments; + } + + private OpenAiStreamClient client; + private EventSourceListener eventSourceListener; + private PluginAbstract plugin; + private ChatCompletion chatCompletion; + + /** + * 构造方法必备四个元素 + * + * @param client OpenAiStreamClient + * @param eventSourceListener 处理真实第二次sse请求的自定义监听 + * @param plugin 插件信息 + * @param chatCompletion 请求参数 + */ + public PluginListener(OpenAiStreamClient client, EventSourceListener eventSourceListener, PluginAbstract plugin, ChatCompletion chatCompletion) { + this.client = client; + this.eventSourceListener = eventSourceListener; + this.plugin = plugin; + this.chatCompletion = chatCompletion; + } + + /** + * sse关闭后处理,第二次请求方法 + */ + public void onClosedAfter() { + log.debug("构造的方法值:{}", getArguments()); + + R realFunctionParam = (R) JSONUtil.toBean(getArguments(), plugin.getR()); + T tq = plugin.func(realFunctionParam); + + FunctionCall functionCall = FunctionCall.builder() + .arguments(getArguments()) + .name(plugin.getFunction()) + .build(); + chatCompletion.getMessages().add(Message.builder().role(Message.Role.ASSISTANT).content("function_call").functionCall(functionCall).build()); + chatCompletion.getMessages().add(Message.builder().role(Message.Role.FUNCTION).name(plugin.getFunction()).content(plugin.content(tq)).build()); + //设置第二次,请求的参数 + chatCompletion.setFunctionCall(null); + chatCompletion.setFunctions(null); + client.streamChatCompletion(chatCompletion, eventSourceListener); + } + + @SneakyThrows + @Override + public final void onEvent(@NotNull EventSource eventSource, String id, String type, String data) { + log.debug("插件开发返回信息收集sse监听器返回数据:{}", data); + if ("[DONE]".equals(data)) { + log.debug("插件开发返回信息收集sse监听器返回数据结束了"); + return; + } + ChatCompletionResponse chatCompletionResponse = JSONUtil.toBean(data, ChatCompletionResponse.class); + if (Objects.nonNull(chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall())) { + this.arguments += chatCompletionResponse.getChoices().get(0).getDelta().getFunctionCall().getArguments(); + } + } + + @Override + public final void onClosed(EventSource eventSource) { + log.debug("插件开发返回信息收集sse监听器关闭连接..."); + this.onClosedAfter(); + } + + @Override + public void onOpen(EventSource eventSource, Response response) { + log.debug("插件开发返回信息收集sse监听器建立连接..."); + } + + @SneakyThrows + @Override + public void onFailure(EventSource eventSource, Throwable t, Response response) { + if (Objects.isNull(response)) { + log.error("插件开发返回信息收集sse监听器,连接异常:{}", t); + eventSource.cancel(); + return; + } + ResponseBody body = response.body(); + if (Objects.nonNull(body)) { + log.error("插件开发返回信息收集sse监听器,连接异常data:{},异常:{}", body.string(), t); + } else { + log.error("插件开发返回信息收集sse监听器,连接异常data:{},异常:{}", response, t); + } + eventSource.cancel(); + } +} diff --git a/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/service/UserService.java b/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/service/UserService.java index 5e354b7f..152a418d 100644 --- a/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/service/UserService.java +++ b/ruoyi-common/ruoyi-common-core/src/main/java/org/ruoyi/common/core/service/UserService.java @@ -15,4 +15,11 @@ public interface UserService { */ String selectUserNameById(Long userId); + /** + * 通过用户名称查询余额 + * + * @param userName + * @return + */ + String selectUserByName(String userName); } diff --git a/ruoyi-modules/ruoyi-knowledge/pom.xml b/ruoyi-modules/ruoyi-knowledge/pom.xml index dcba1990..46576f82 100644 --- a/ruoyi-modules/ruoyi-knowledge/pom.xml +++ b/ruoyi-modules/ruoyi-knowledge/pom.xml @@ -83,6 +83,9 @@ runtime true + + + com.mysql mysql-connector-j diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/MilvusVectorStore.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/MilvusVectorStore.java index 82cd8960..c083ef47 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/MilvusVectorStore.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/MilvusVectorStore.java @@ -41,13 +41,13 @@ public class MilvusVectorStore implements VectorStore{ @Resource private ConfigService configService; - // @PostConstruct + @PostConstruct public void loadConfig() { this.dimension = Integer.parseInt(configService.getConfigValue("milvus", "dimension")); this.collectionName = configService.getConfigValue("milvus", "collection"); } - //@PostConstruct + @PostConstruct public void init(){ String milvusHost = configService.getConfigValue("milvus", "host"); String milvausPort = configService.getConfigValue("milvus", "port"); diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStore.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStore.java index 6852cfd7..6be022d4 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStore.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStore.java @@ -6,11 +6,16 @@ import java.util.List; * 向量存储 */ public interface VectorStore { - void storeEmbeddings(List chunkList,List> vectorList, String kid, String docId,List fidList); - void removeByDocId(String kid,String docId); + + void storeEmbeddings(List chunkList, List> vectorList, String kid, String docId, List fidList); + + void removeByDocId(String kid, String docId); + void removeByKid(String kid); - List nearest(List queryVector,String kid); - List nearest(String query,String kid); + + List nearest(List queryVector, String kid); + + List nearest(String query, String kid); void newSchema(String kid); diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreFactory.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreFactory.java index 1cfb6b3a..5acf28b6 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreFactory.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreFactory.java @@ -1,32 +1,35 @@ package org.ruoyi.knowledge.chain.vectorstore; +import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Value; +import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo; +import org.ruoyi.knowledge.mapper.KnowledgeInfoMapper; import org.springframework.stereotype.Component; @Component @Slf4j public class VectorStoreFactory { - private final String type = "weaviate"; - private final WeaviateVectorStore weaviateVectorStore; private final MilvusVectorStore milvusVectorStore; + @Resource + private KnowledgeInfoMapper knowledgeInfoMapper; + public VectorStoreFactory(WeaviateVectorStore weaviateVectorStore, MilvusVectorStore milvusVectorStore) { this.weaviateVectorStore = weaviateVectorStore; this.milvusVectorStore = milvusVectorStore; } public VectorStore getVectorStore(String kid){ -// if ("weaviate".equals(type)){ -// return weaviateVectorStore; -// }else if ("milvus".equals(type)){ -// return milvusVectorStore; -// } -// -// return null; - return weaviateVectorStore; + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoMapper.selectVoById(Long.valueOf(kid)); + String vectorModel = knowledgeInfoVo.getVector(); + if ("weaviate".equals(vectorModel)){ + return weaviateVectorStore; + }else if ("milvus".equals(vectorModel)){ + return milvusVectorStore; + } + return null; } } diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreWrapper.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreWrapper.java index eb6b1f4d..e1daa619 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreWrapper.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreWrapper.java @@ -11,19 +11,20 @@ import java.util.List; @Slf4j @Primary @AllArgsConstructor -public class VectorStoreWrapper implements VectorStore{ +public class VectorStoreWrapper implements VectorStore { private final VectorStoreFactory vectorStoreFactory; + @Override public void storeEmbeddings(List chunkList, List> vectorList, String kid, String docId, List fidList) { VectorStore vectorStore = vectorStoreFactory.getVectorStore(kid); - vectorStore.storeEmbeddings(chunkList, vectorList, kid, docId, fidList); + vectorStore.storeEmbeddings(chunkList, vectorList, kid, docId, fidList); } @Override public void removeByDocId(String kid, String docId) { VectorStore vectorStore = vectorStoreFactory.getVectorStore(kid); - vectorStore.removeByDocId(kid,docId); + vectorStore.removeByDocId(kid, docId); } @Override @@ -35,7 +36,7 @@ public class VectorStoreWrapper implements VectorStore{ @Override public List nearest(List queryVector, String kid) { VectorStore vectorStore = vectorStoreFactory.getVectorStore(kid); - return vectorStore.nearest(queryVector,kid); + return vectorStore.nearest(queryVector, kid); } @Override diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/domain/KnowledgeInfo.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/domain/KnowledgeInfo.java index fb8407df..292bdb0f 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/domain/KnowledgeInfo.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/domain/KnowledgeInfo.java @@ -45,7 +45,7 @@ public class KnowledgeInfo implements Serializable { private String kname; /** - * 知识库名称 + * 是否公开知识库(0 否 1是) */ private String share; diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/IKnowledgeAttachService.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/IKnowledgeAttachService.java index 04698880..6ed856ee 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/IKnowledgeAttachService.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/IKnowledgeAttachService.java @@ -49,8 +49,6 @@ public interface IKnowledgeAttachService { /** * 删除知识附件 - * - * @return */ - void removeKnowledgeAttach(String kid); + void removeKnowledgeAttach(String docId); } diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/IKnowledgeInfoService.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/IKnowledgeInfoService.java index 121f8fb2..14e8e972 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/IKnowledgeInfoService.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/IKnowledgeInfoService.java @@ -2,13 +2,10 @@ package org.ruoyi.knowledge.service; import org.ruoyi.common.mybatis.core.page.PageQuery; import org.ruoyi.common.mybatis.core.page.TableDataInfo; -import org.ruoyi.knowledge.domain.KnowledgeAttach; -import org.ruoyi.knowledge.domain.bo.KnowledgeAttachBo; import org.ruoyi.knowledge.domain.bo.KnowledgeInfoBo; import org.ruoyi.knowledge.domain.req.KnowledgeInfoUploadRequest; import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo; -import java.util.Collection; import java.util.List; /** @@ -40,7 +37,6 @@ public interface IKnowledgeInfoService { */ Boolean updateByBo(KnowledgeInfoBo bo); - /** * 新增知识库 */ diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/EmbeddingServiceImpl.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/EmbeddingServiceImpl.java index 73503a45..96e568a2 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/EmbeddingServiceImpl.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/EmbeddingServiceImpl.java @@ -40,8 +40,7 @@ public class EmbeddingServiceImpl implements EmbeddingService { @Override public List getQueryVector(String query, String kid) { - List queryVector = vectorization.singleVectorization(query,kid); - return queryVector; + return vectorization.singleVectorization(query,kid); } @Override diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/KnowledgeAttachServiceImpl.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/KnowledgeAttachServiceImpl.java index 9bea5dd4..8ae5340b 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/KnowledgeAttachServiceImpl.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/KnowledgeAttachServiceImpl.java @@ -126,13 +126,9 @@ public class KnowledgeAttachServiceImpl implements IKnowledgeAttachService { } @Override - public void removeKnowledgeAttach(String kid) { - LoginUser loginUser = LoginHelper.getLoginUser(); + public void removeKnowledgeAttach(String docId) { Map map = new HashMap<>(); - map.put("kid",kid); - List knowledgeInfoList = knowledgeInfoMapper.selectVoByMap(map); - knowledgeInfoService.check(knowledgeInfoList); - + map.put("doc_id",docId); baseMapper.deleteByMap(map); fragmentMapper.deleteByMap(map); } diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/KnowledgeInfoServiceImpl.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/KnowledgeInfoServiceImpl.java index 3bd9390d..23f255b0 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/KnowledgeInfoServiceImpl.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/service/impl/KnowledgeInfoServiceImpl.java @@ -1,14 +1,10 @@ package org.ruoyi.knowledge.service.impl; +import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.RandomUtil; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.toolkit.Wrappers; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; -import io.github.ollama4j.OllamaAPI; -import io.github.ollama4j.models.chat.OllamaChatMessageRole; -import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; -import io.github.ollama4j.models.chat.OllamaChatRequestModel; -import io.github.ollama4j.models.chat.OllamaChatResult; import lombok.RequiredArgsConstructor; import org.ruoyi.common.core.domain.model.LoginUser; import org.ruoyi.common.core.utils.MapstructUtils; @@ -30,6 +26,7 @@ import org.ruoyi.knowledge.mapper.KnowledgeInfoMapper; import org.ruoyi.knowledge.service.EmbeddingService; import org.ruoyi.knowledge.service.IKnowledgeInfoService; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; @@ -41,8 +38,8 @@ import java.util.*; * @author Lion Li * @date 2024-10-21 */ -@RequiredArgsConstructor @Service +@RequiredArgsConstructor public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { private final KnowledgeInfoMapper baseMapper; @@ -110,9 +107,8 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { //TODO 做一些数据校验,如唯一约束 } - - @Override + @Transactional(rollbackFor = Exception.class) public void saveOne(KnowledgeInfoBo bo) { KnowledgeInfo knowledgeInfo = MapstructUtils.convert(bo, KnowledgeInfo.class); if (StringUtils.isBlank(bo.getKid())){ @@ -122,7 +118,7 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { knowledgeInfo.setUid(LoginHelper.getLoginUser().getUserId()); } baseMapper.insert(knowledgeInfo); - embeddingService.createSchema(kid); + embeddingService.createSchema(String.valueOf(knowledgeInfo.getId())); }else { baseMapper.updateById(knowledgeInfo); } @@ -148,19 +144,23 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { try { content = resourceLoader.getContent(file.getInputStream()); chunkList = resourceLoader.getChunkList(content, kid); - for (int i = 0; i < chunkList.size(); i++) { - String fid = RandomUtil.randomString(16); - fids.add(fid); - KnowledgeFragment knowledgeFragment = new KnowledgeFragment(); - knowledgeFragment.setKid(kid); - knowledgeFragment.setDocId(docId); - knowledgeFragment.setFid(fid); - knowledgeFragment.setIdx(i); - // String text = convertTextBlockToPretrainData(chunkList.get(i)); - knowledgeFragment.setContent(chunkList.get(i)); - knowledgeFragment.setCreateTime(new Date()); - fragmentMapper.insert(knowledgeFragment); + List knowledgeFragmentList = new ArrayList<>(); + if (CollUtil.isNotEmpty(chunkList)) { + for (int i = 0; i < chunkList.size(); i++) { + String fid = RandomUtil.randomString(16); + fids.add(fid); + KnowledgeFragment knowledgeFragment = new KnowledgeFragment(); + knowledgeFragment.setKid(kid); + knowledgeFragment.setDocId(docId); + knowledgeFragment.setFid(fid); + knowledgeFragment.setIdx(i); + // String text = convertTextBlockToPretrainData(chunkList.get(i)); + knowledgeFragment.setContent(chunkList.get(i)); + knowledgeFragment.setCreateTime(new Date()); + knowledgeFragmentList.add(knowledgeFragment); + } } + fragmentMapper.insertBatch(knowledgeFragmentList); } catch (IOException e) { e.printStackTrace(); } @@ -171,19 +171,21 @@ public class KnowledgeInfoServiceImpl implements IKnowledgeInfoService { } @Override + @Transactional(rollbackFor = Exception.class) public void removeKnowledge(String id) { - Map map = new HashMap<>(); map.put("kid",id); List knowledgeInfoList = baseMapper.selectVoByMap(map); check(knowledgeInfoList); - // 删除知识库 - baseMapper.deleteByMap(map); + // 删除向量库信息 + knowledgeInfoList.forEach(knowledgeInfoVo -> { + embeddingService.removeByKid(String.valueOf(knowledgeInfoVo.getId())); + }); // 删除附件和知识片段 fragmentMapper.deleteByMap(map); attachMapper.deleteByMap(map); - // 删除向量库信息 - embeddingService.removeByKid(id); + // 删除知识库 + baseMapper.deleteByMap(map); } @Override diff --git a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/controller/system/SysProfileController.java b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/controller/system/SysProfileController.java index aee284d0..c6d7e3c2 100644 --- a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/controller/system/SysProfileController.java +++ b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/controller/system/SysProfileController.java @@ -12,6 +12,7 @@ import org.ruoyi.common.log.enums.BusinessType; import org.ruoyi.common.satoken.utils.LoginHelper; import org.ruoyi.common.web.core.BaseController; import org.ruoyi.system.domain.bo.SysUserBo; +import org.ruoyi.system.domain.bo.SysUserPasswordBo; import org.ruoyi.system.domain.bo.SysUserProfileBo; import org.ruoyi.system.domain.vo.AvatarVo; import org.ruoyi.system.domain.vo.ProfileVo; @@ -75,23 +76,20 @@ public class SysProfileController extends BaseController { /** * 重置密码 - * - * @param newPassword 旧密码 - * @param oldPassword 新密码 */ @Log(title = "个人信息", businessType = BusinessType.UPDATE) @PutMapping("/updatePwd") - public R updatePwd(String oldPassword, String newPassword) { + public R updatePwd(@Validated @RequestBody SysUserPasswordBo bo) { SysUserVo user = userService.selectUserById(LoginHelper.getUserId()); String password = user.getPassword(); - if (!BCrypt.checkpw(oldPassword, password)) { + if (!BCrypt.checkpw(bo.getOldPassword(), password)) { return R.fail("修改密码失败,旧密码错误"); } - if (BCrypt.checkpw(newPassword, password)) { + if (BCrypt.checkpw(bo.getNewPassword(), password)) { return R.fail("新密码不能与旧密码相同"); } - if (userService.resetUserPwd(user.getUserId(), BCrypt.hashpw(newPassword)) > 0) { + if (userService.resetUserPwd(user.getUserId(), BCrypt.hashpw(bo.getNewPassword())) > 0) { return R.ok(); } return R.fail("修改密码异常,请联系管理员"); diff --git a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/domain/bo/SysUserPasswordBo.java b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/domain/bo/SysUserPasswordBo.java new file mode 100644 index 00000000..493ba0aa --- /dev/null +++ b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/domain/bo/SysUserPasswordBo.java @@ -0,0 +1,32 @@ +package org.ruoyi.system.domain.bo; +import jakarta.validation.constraints.NotBlank; +import lombok.Data; + +import java.io.Serial; +import java.io.Serializable; + + +/** + * 描述:用户密码修改bo + * + * @author ageerle@163.com + * date 2025/3/9 + */ +@Data +public class SysUserPasswordBo implements Serializable { + + @Serial + private static final long serialVersionUID = 1L; + + /** + * 旧密码 + */ + @NotBlank(message = "旧密码不能为空") + private String oldPassword; + + /** + * 新密码 + */ + @NotBlank(message = "新密码不能为空") + private String newPassword; +} diff --git a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/listener/SSEEventSourceListener.java b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/listener/SSEEventSourceListener.java index 787cba4c..bdf4f347 100644 --- a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/listener/SSEEventSourceListener.java +++ b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/listener/SSEEventSourceListener.java @@ -65,7 +65,7 @@ public class SSEEventSourceListener extends EventSourceListener { @Override public void onEvent(@NotNull EventSource eventSource, String id, String type, String data) { try { - if (data.equals("[DONE]")) { + if ("[DONE]".equals(data)) { //成功响应 emitter.complete(); if(StringUtils.isNotEmpty(modelName)){ diff --git a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/plugin/WebSearchPlugin.java b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/plugin/WebSearchPlugin.java new file mode 100644 index 00000000..63b290a3 --- /dev/null +++ b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/plugin/WebSearchPlugin.java @@ -0,0 +1,212 @@ +package org.ruoyi.system.plugin; + + +import cn.hutool.json.JSONUtil; +import com.alibaba.fastjson.JSONObject; +import lombok.Builder; +import lombok.Data; +import okhttp3.OkHttpClient; +import okhttp3.logging.HttpLoggingInterceptor; +import org.junit.Before; +import org.junit.Test; +import org.ruoyi.common.chat.demo.ConsoleEventSourceListenerV3; +import org.ruoyi.common.chat.entity.chat.ChatCompletion; +import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse; +import org.ruoyi.common.chat.entity.chat.Message; +import org.ruoyi.common.chat.entity.chat.Parameters; +import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction; +import org.ruoyi.common.chat.entity.chat.tool.ToolCalls; +import org.ruoyi.common.chat.entity.chat.tool.Tools; +import org.ruoyi.common.chat.entity.chat.tool.ToolsFunction; +import org.ruoyi.common.chat.openai.OpenAiClient; +import org.ruoyi.common.chat.openai.OpenAiStreamClient; +import org.ruoyi.common.chat.openai.function.KeyRandomStrategy; +import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor; +import org.ruoyi.common.chat.openai.interceptor.OpenAILogger; +import org.ruoyi.common.chat.openai.interceptor.OpenAiResponseInterceptor; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class WebSearchPlugin { + + private OpenAiClient openAiClient; + private OpenAiStreamClient openAiStreamClient; + + @Before + public void before() { + //可以为null +// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)); + HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger()); + //!!!!千万别再生产或者测试环境打开BODY级别日志!!!! + //!!!生产或者测试环境建议设置为这三种级别:NONE,BASIC,HEADERS,!!! + httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS); + OkHttpClient okHttpClient = new OkHttpClient + .Builder() +// .proxy(proxy) + .addInterceptor(httpLoggingInterceptor) + .addInterceptor(new OpenAiResponseInterceptor()) + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(30, TimeUnit.SECONDS) + .readTimeout(30, TimeUnit.SECONDS) + .build(); + openAiClient = OpenAiClient.builder() + //支持多key传入,请求时候随机选择 + .apiKey(Arrays.asList("xx")) + //自定义key的获取策略:默认KeyRandomStrategy + //.keyStrategy(new KeyRandomStrategy()) + .keyStrategy(new KeyRandomStrategy()) + .okHttpClient(okHttpClient) + //自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址) + .apiHost("https://open.bigmodel.cn/") + .build(); + + openAiStreamClient = OpenAiStreamClient.builder() + //支持多key传入,请求时候随机选择 + .apiKey(Arrays.asList("xx")) + //自定义key的获取策略:默认KeyRandomStrategy + .keyStrategy(new KeyRandomStrategy()) + .authInterceptor(new DynamicKeyOpenAiAuthInterceptor()) + .okHttpClient(okHttpClient) + //自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址) + .apiHost("https://open.bigmodel.cn/") + .build(); + } + + @Test + public void test() { + Message message = Message.builder().role(Message.Role.USER).content("今天武汉天气怎么样").build(); + ChatCompletion chatCompletion = ChatCompletion + .builder() + .messages(Collections.singletonList(message)) +// .tools(Collections.singletonList(tools)) + .model("web-search-pro") + .build(); + ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion); + + System.out.printf("chatCompletionResponse=%s\n", JSONUtil.toJsonStr(chatCompletionResponse)); + } + + @Test + public void streamToolsChat() { + + CountDownLatch countDownLatch = new CountDownLatch(1); + ConsoleEventSourceListenerV3 eventSourceListener = new ConsoleEventSourceListenerV3(countDownLatch); + + Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build(); + //属性一 + JSONObject wordLength = new JSONObject(); + wordLength.put("type", "number"); + wordLength.put("description", "词语的长度"); + //属性二 + JSONObject language = new JSONObject(); + language.put("type", "string"); + language.put("enum", Arrays.asList("zh", "en")); + language.put("description", "语言类型,例如:zh代表中文、en代表英语"); + //参数 + JSONObject properties = new JSONObject(); + properties.put("wordLength", wordLength); + properties.put("language", language); + Parameters parameters = Parameters.builder() + .type("object") + .properties(properties) + .required(Collections.singletonList("wordLength")).build(); + Tools tools = Tools.builder() + .type(Tools.Type.FUNCTION.getName()) + .function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build()) + .build(); + + ChatCompletion chatCompletion = ChatCompletion + .builder() + .messages(Collections.singletonList(message)) + .tools(Collections.singletonList(tools)) + .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) + .build(); + openAiStreamClient.streamChatCompletion(chatCompletion, eventSourceListener); + + try { + countDownLatch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + ToolCalls openAiReturnToolCalls = eventSourceListener.getToolCalls(); + WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class); + String oneWord = getOneWord(wordParam); + + + ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build(); + ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build(); + //构造tool call + Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build(); + String content + = "{ " + + "\"wordLength\": \"3\", " + + "\"language\": \"zh\", " + + "\"word\": \"" + oneWord + "\"," + + "\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" + + "}"; + Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build(); + List messageList = Arrays.asList(message, message2, message3); + ChatCompletion chatCompletionV2 = ChatCompletion + .builder() + .messages(messageList) + .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) + .build(); + + + CountDownLatch countDownLatch1 = new CountDownLatch(1); + openAiStreamClient.streamChatCompletion(chatCompletionV2, new ConsoleEventSourceListenerV3(countDownLatch)); + try { + countDownLatch1.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + try { + countDownLatch1.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + } + + @Data + @Builder + static class WordParam { + private int wordLength; + @Builder.Default + private String language = "zh"; + } + + + /** + * 获取一个词语(根据语言和字符长度查询) + * @param wordParam + * @return + */ + public String getOneWord(WordParam wordParam) { + + List zh = Arrays.asList("大香蕉", "哈密瓜", "苹果"); + List en = Arrays.asList("apple", "banana", "cantaloupe"); + if (wordParam.getLanguage().equals("zh")) { + for (String e : zh) { + if (e.length() == wordParam.getWordLength()) { + return e; + } + } + } + if (wordParam.getLanguage().equals("en")) { + for (String e : en) { + if (e.length() == wordParam.getWordLength()) { + return e; + } + } + } + return "西瓜"; + } + + +} diff --git a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SseServiceImpl.java b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SseServiceImpl.java index e2f02b58..a2cfeffe 100644 --- a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SseServiceImpl.java +++ b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SseServiceImpl.java @@ -3,16 +3,11 @@ package org.ruoyi.system.service.impl; import cn.dev33.satoken.stp.StpUtil; import cn.hutool.core.collection.CollectionUtil; import com.alibaba.fastjson.JSONObject; -import com.azure.ai.openai.OpenAIClient; -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.ai.openai.models.*; -import com.azure.core.credential.AzureKeyCredential; import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.models.chat.OllamaChatMessageRole; import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; import io.github.ollama4j.models.chat.OllamaChatRequestModel; import io.github.ollama4j.models.generate.OllamaStreamHandler; -import io.github.ollama4j.utils.Options; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -33,6 +28,12 @@ import org.ruoyi.common.chat.entity.images.Item; import org.ruoyi.common.chat.entity.images.ResponseFormat; import org.ruoyi.common.chat.entity.whisper.WhisperResponse; import org.ruoyi.common.chat.openai.OpenAiStreamClient; +import org.ruoyi.common.chat.openai.plugin.PluginAbstract; +import org.ruoyi.common.chat.plugin.CmdPlugin; +import org.ruoyi.common.chat.plugin.CmdReq; +import org.ruoyi.common.chat.plugin.SqlPlugin; +import org.ruoyi.common.chat.plugin.SqlReq; +import org.ruoyi.common.chat.sse.ConsoleEventSourceListener; import org.ruoyi.common.chat.utils.TikTokensUtil; import org.ruoyi.common.core.domain.model.LoginUser; import org.ruoyi.common.core.exception.base.BaseException; @@ -43,12 +44,10 @@ import org.ruoyi.system.domain.bo.ChatMessageBo; import org.ruoyi.system.domain.bo.SysModelBo; import org.ruoyi.system.domain.request.translation.TranslationRequest; import org.ruoyi.system.domain.vo.SysModelVo; -import org.ruoyi.system.domain.vo.SysUserVo; import org.ruoyi.system.listener.SSEEventSourceListener; import org.ruoyi.system.service.*; import org.springframework.core.io.InputStreamResource; import org.springframework.core.io.Resource; -import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Service; @@ -63,10 +62,10 @@ import java.net.URLEncoder; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; -import io.github.ollama4j.utils.OptionsBuilder; @Service @Slf4j @@ -89,9 +88,6 @@ public class SseServiceImpl implements ISseService { static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build(); - private final ISysPackagePlanService sysPackagePlanService; - - @Override public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) { openAiStreamClient = chatConfig.getOpenAiStreamClient(); @@ -101,12 +97,7 @@ public class SseServiceImpl implements ISseService { List messages = chatRequest.getMessages(); try { if (StpUtil.isLogin()) { - SysUserVo sysUserVo = userService.selectUserById(getUserId()); -// if (!checkModel(sysUserVo.getUserPlan(), chatRequest.getModel())) { -// throw new BaseException("当前套餐不支持此模型!"); -// } LocalCache.CACHE.put("userId", getUserId()); - Object content = messages.get(messages.size() - 1).getContent(); String chatString = ""; @@ -161,36 +152,23 @@ public class SseServiceImpl implements ISseService { } } } - -// else { -// -// // 初始请求次数 -// int number = 1; -// // 获取请求IP -// String realIp = getClientIpAddress(request); -// // 根据IP获取次数 -// Integer requestNumber = RedisUtils.getCacheObject(realIp); -// if (requestNumber == null) { -// // 记录ip使用次数 -// RedisUtils.setCacheObject(realIp, number); -// } else { -// String configValue = configService.getConfigValue("mail", "free"); -// if (requestNumber > Integer.parseInt(configValue)) { -// throw new BaseException("剩余次数不足,请充值后使用"); -// } -// RedisUtils.setCacheObject(realIp, requestNumber + 1); -// } -// -// } - ChatCompletion completion = ChatCompletion - .builder() - .messages(messages) - .model(chatRequest.getModel()) - .temperature(chatRequest.getTemperature()) - .topP(chatRequest.getTop_p()) - .stream(true) - .build(); - openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener); + if("openCmd".equals(chatRequest.getModel())) { + sseEmitter.send(cmdPlugin(messages)); + sseEmitter.complete(); + }else if ("sqlPlugin".equals(chatRequest.getModel())){ + sseEmitter.send(sqlPlugin(messages)); + sseEmitter.complete(); + } else { + ChatCompletion completion = ChatCompletion + .builder() + .messages(messages) + .model(chatRequest.getModel()) + .temperature(chatRequest.getTemperature()) + .topP(chatRequest.getTop_p()) + .stream(true) + .build(); + openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener); + } } catch (Exception e) { String message = e.getMessage(); sendErrorEvent(sseEmitter, message); @@ -199,32 +177,51 @@ public class SseServiceImpl implements ISseService { return sseEmitter; } + public String cmdPlugin(List messages) { + CmdPlugin plugin = new CmdPlugin(CmdReq.class); + // 插件名称 + plugin.setName("命令行工具"); + // 方法名称 + plugin.setFunction("openCmd"); + // 方法说明 + plugin.setDescription("提供一个命令行指令,比如<记事本>,指令使用中文"); -// /** -// * 查当前用户是否可以调用此模型 -// * -// * @param planId -// * @return -// */ -// public Boolean checkModel(String planId, String modelName) { -// SysPackagePlanBo sysPackagePlanBo = new SysPackagePlanBo(); -// if (modelName.startsWith("gpt-4-gizmo")) { -// modelName = "gpt-4-gizmo"; -// } -// if (StringUtils.isEmpty(planId)) { -// sysPackagePlanBo.setName("Visitor"); -// } else if ("Visitor".equals(planId) || "Free".equals(planId)) { -// sysPackagePlanBo.setName(planId); -// } else { -// // sysPackagePlanBo.setId(Long.valueOf(planId)); -// return true; -// } -// -// SysPackagePlanVo sysPackagePlanVo = sysPackagePlanService.queryList(sysPackagePlanBo).get(0); -// // 将字符串转换为数组 -// String[] array = sysPackagePlanVo.getPlanDetail().split(","); -// return Arrays.asList(array).contains(modelName); -// } + PluginAbstract.Arg arg = new PluginAbstract.Arg(); + // 参数名称 + arg.setName("cmd"); + // 参数说明 + arg.setDescription("命令行指令"); + // 参数类型 + arg.setType("string"); + arg.setRequired(true); + plugin.setArgs(Collections.singletonList(arg)); + //有四个重载方法,都可以使用 + ChatCompletionResponse response = openAiStreamClient.chatCompletionWithPlugin(messages,"gpt-4o-mini",plugin); + return response.getChoices().get(0).getMessage().getContent().toString(); + } + + public String sqlPlugin(List messages) { + SqlPlugin plugin = new SqlPlugin(SqlReq.class); + // 插件名称 + plugin.setName("数据库查询插件"); + // 方法名称 + plugin.setFunction("sqlPlugin"); + // 方法说明 + plugin.setDescription("提供一个用户名称查询余额信息"); + + PluginAbstract.Arg arg = new PluginAbstract.Arg(); + // 参数名称 + arg.setName("username"); + // 参数说明 + arg.setDescription("用户名称"); + // 参数类型 + arg.setType("string"); + arg.setRequired(true); + plugin.setArgs(Collections.singletonList(arg)); + //有四个重载方法,都可以使用 + ChatCompletionResponse response = openAiStreamClient.chatCompletionWithPlugin(messages,"gpt-4o-mini",plugin); + return response.getChoices().get(0).getMessage().getContent().toString(); + } /** * 根据次数扣除余额 @@ -295,25 +292,6 @@ public class SseServiceImpl implements ISseService { @Override public String chat(ChatRequest chatRequest, String userId) { -// chatService.deductUserBalance(Long.valueOf(userId), 0.01); -// // 保存消息记录 -// ChatMessageBo chatMessageBo = new ChatMessageBo(); -// chatMessageBo.setUserId(Long.valueOf(userId)); -// chatMessageBo.setModelName(ChatCompletion.Model.GPT_3_5_TURBO.getName()); -// chatMessageBo.setContent(chatRequest.getPrompt()); -// chatMessageBo.setDeductCost(0.01); -// chatMessageBo.setTotalTokens(0); -// chatMessageService.insertByBo(chatMessageBo); -// -// openAiStreamClient = chatConfig.getOpenAiStreamClient(); -// Message message = Message.builder().role(Message.Role.USER).content(chatRequest.getPrompt()).build(); -// ChatCompletion chatCompletion = ChatCompletion -// .builder() -// .messages(Collections.singletonList(message)) -// .model(chatRequest.getModel()) -// .build(); -// ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion); -// return chatCompletionResponse.getChoices().get(0).getMessage().getContent(); return null; } @@ -540,7 +518,8 @@ public class SseServiceImpl implements ISseService { @Override public String translation(TranslationRequest translationRequest) { - + // 翻译模型固定为gpt-4o-mini + translationRequest.setModel("gpt-4o-mini"); ChatMessageBo chatMessageBo = new ChatMessageBo(); chatMessageBo.setUserId(getUserId()); chatMessageBo.setModelName(translationRequest.getModel()); @@ -557,17 +536,12 @@ public class SseServiceImpl implements ISseService { "\n" + "请将用户输入词语翻译成{" + translationRequest.getTargetLanguage() + "}\n" + "\n" + - "让我们一步一步来思考\n" + "==示例输出==\n" + + "**原文** : <这里显示要翻译的原文信息>\n" + "**翻译** : <这里显示翻译成英语的结果>\n" + - "\n" + - "**造句** : What's the weather like today? Use the 'Weather Query' plugin to find out instantly! <造一个英语句子>\n" + - "\n" + - "**同义词** : Add-on、Extension、Module <这里显示1-3个英文的同义词>\n" + - "\n" + "==示例结束==\n" + "\n" + - "注意:请严格按示例进行输出").build(); + "注意:请严格按示例进行输出,返回markdown格式").build(); messageList.add(sysMessage); Message message = Message.builder().role(Message.Role.USER).content(translationRequest.getPrompt()).build(); messageList.add(message); @@ -646,4 +620,6 @@ public class SseServiceImpl implements ISseService { ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion); return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString(); } + + } diff --git a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SysUserServiceImpl.java b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SysUserServiceImpl.java index 39e51ea1..7b895da4 100644 --- a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SysUserServiceImpl.java +++ b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SysUserServiceImpl.java @@ -557,4 +557,11 @@ public class SysUserServiceImpl implements ISysUserService, UserService { .select(SysUser::getUserName).eq(SysUser::getUserId, userId)); return ObjectUtil.isNull(sysUser) ? null : sysUser.getUserName(); } + + @Override + public String selectUserByName(String userName) { + SysUser sysUser = baseMapper.selectOne(new LambdaQueryWrapper() + .eq(SysUser::getUserName, userName)); + return ObjectUtil.isNull(sysUser) ? null : sysUser.getUserBalance().toString(); + } } diff --git a/script/sql/update/update20250307.sql b/script/sql/update/update20250307.sql index ba83d699..65c6d8cf 100644 --- a/script/sql/update/update20250307.sql +++ b/script/sql/update/update20250307.sql @@ -8,8 +8,10 @@ ADD COLUMN `vector` varchar(50) NULL COMMENT '向量库' AFTER `text_block_size` ADD COLUMN `vector_model` varchar(50) NULL COMMENT '向量模型' AFTER `vector`; - -INSERT INTO `chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412050, 'weaviate', 'protocol', 'http', '协议', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0); -INSERT INTO `chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412051, 'weaviate', 'host', '127.0.0.1:6038', '地址', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0); -INSERT INTO `chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412052, 'weaviate', 'classname', 'LocalKnowledge', '分类名称', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0); - +INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412050, 'weaviate', 'protocol', 'http', '协议', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0); +INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412051, 'weaviate', 'host', '127.0.0.1:6038', '地址', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0); +INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412052, 'weaviate', 'classname', 'LocalKnowledge', '分类名称', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0); +INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412053, 'milvus', 'host', '127.0.0.1', '地址', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0); +INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412054, 'milvus', 'port', '19530', '端口', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0); +INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412055, 'milvus', 'dimension', '1536', '维度', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0); +INSERT INTO `ruoyi-ai`.`chat_config` (`id`, `category`, `config_name`, `config_value`, `config_dict`, `create_dept`, `create_time`, `create_by`, `update_by`, `update_time`, `remark`, `version`, `del_flag`, `update_ip`, `tenant_id`) VALUES (1897610056458412056, 'milvus', 'collection', 'LocalKnowledge', '分类名称', 103, '2025-03-06 21:10:02', '1', '1', '2025-03-06 21:10:31', NULL, NULL, '0', NULL, 0);