diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/WebSearchToolsTest.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WebSearchToolsTest.java similarity index 89% rename from ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/WebSearchToolsTest.java rename to ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WebSearchToolsTest.java index eca586fa..34871f5c 100644 --- a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/WebSearchToolsTest.java +++ b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/WebSearchToolsTest.java @@ -1,4 +1,4 @@ -package org.ruoyi.common.chat.demo.zhipu; +package org.ruoyi.common.chat.demo; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; @@ -17,34 +17,11 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.PropertyNamingStrategy; -import com.zhipu.oapi.core.response.HttpxBinaryResponseContent; -import com.zhipu.oapi.service.v4.batchs.BatchCreateParams; -import com.zhipu.oapi.service.v4.batchs.BatchResponse; -import com.zhipu.oapi.service.v4.batchs.QueryBatchResponse; -import com.zhipu.oapi.service.v4.embedding.EmbeddingApiResponse; -import com.zhipu.oapi.service.v4.embedding.EmbeddingRequest; -import com.zhipu.oapi.service.v4.file.*; -import com.zhipu.oapi.service.v4.fine_turning.*; -import com.zhipu.oapi.service.v4.image.CreateImageRequest; -import com.zhipu.oapi.service.v4.image.ImageApiResponse; import com.zhipu.oapi.service.v4.model.*; import io.reactivex.Flowable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; public class WebSearchToolsTest { diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/AllToolsTest.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/AllToolsTest.java deleted file mode 100644 index 805402c5..00000000 --- a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/AllToolsTest.java +++ /dev/null @@ -1,122 +0,0 @@ -package org.ruoyi.common.chat.demo.zhipu; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; - -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; - -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.zhipu.oapi.ClientV4; -import com.zhipu.oapi.Constants; -import com.zhipu.oapi.service.v4.deserialize.MessageDeserializeFactory; -import com.zhipu.oapi.service.v4.model.*; -import io.reactivex.Flowable; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -public class AllToolsTest { - - private final static Logger logger = LoggerFactory.getLogger(AllToolsTest.class); - private static final String API_SECRET_KEY = "28550a39d4cfaabbbf38df04dd3931f5.IUvfTThUf0xBF5l0"; - - private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY) - .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS) - .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS)) - .build(); - private static final ObjectMapper mapper = MessageDeserializeFactory.defaultObjectMapper(); - // 请自定义自己的业务id - private static final String requestIdTemplate = "mycompany-%d"; - - - @Test - public void test1() throws JsonProcessingException { - - - List messages = new ArrayList<>(); - ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "帮我查询北京天气"); - messages.add(chatMessage); - String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); - // 函数调用参数构建部分 - List chatToolList = new ArrayList<>(); - ChatTool chatTool = new ChatTool(); - - chatTool.setType("code_interpreter"); - ObjectNode objectNode = mapper.createObjectNode(); - objectNode.put("code", "北京天气"); -// chatTool.set(chatFunction); - chatToolList.add(chatTool); - - - ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .model("glm-4-alltools") - .stream(Boolean.TRUE) - .invokeMethod(Constants.invokeMethod) - .messages(messages) - .tools(chatToolList) - .toolChoice("auto") - .requestId(requestId) - .build(); - ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest); - if (sseModelApiResp.isSuccess()) { - AtomicBoolean isFirst = new AtomicBoolean(true); - List choices = new ArrayList<>(); - AtomicReference lastAccumulator = new AtomicReference<>(); - - mapStreamToAccumulator(sseModelApiResp.getFlowable()) - .doOnNext(accumulator -> { - { - if (isFirst.getAndSet(false)) { - logger.info("Response: "); - } - if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) { - String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls()); - logger.info("tool_calls: {}", jsonString); - } - if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) { - logger.info(accumulator.getDelta().getContent()); - } - choices.add(accumulator.getChoice()); - lastAccumulator.set(accumulator); - - } - }) - .doOnComplete(() -> System.out.println("Stream completed.")) - .doOnError(throwable -> System.err.println("Error: " + throwable)) // Handle errors - .blockingSubscribe();// Use blockingSubscribe instead of blockingGet() - - ChatMessageAccumulator chatMessageAccumulator = lastAccumulator.get(); - ModelData data = new ModelData(); - data.setChoices(choices); - if (chatMessageAccumulator != null) { - data.setUsage(chatMessageAccumulator.getUsage()); - data.setId(chatMessageAccumulator.getId()); - data.setCreated(chatMessageAccumulator.getCreated()); - } - data.setRequestId(chatCompletionRequest.getRequestId()); - sseModelApiResp.setFlowable(null);// 打印前置空 - sseModelApiResp.setData(data); - } - logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp)); - client.getConfig().getHttpClient().dispatcher().executorService().shutdown(); - - client.getConfig().getHttpClient().connectionPool().evictAll(); - // List all active threads - for (Thread t : Thread.getAllStackTraces().keySet()) { - logger.info("Thread: " + t.getName() + " State: " + t.getState()); - } - } - - - - public static Flowable mapStreamToAccumulator(Flowable flowable) { - return flowable.map(chunk -> { - return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId()); - }); - } -} diff --git a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/V4Test.java b/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/V4Test.java deleted file mode 100644 index afd664e8..00000000 --- a/ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/demo/zhipu/V4Test.java +++ /dev/null @@ -1,646 +0,0 @@ -package org.ruoyi.common.chat.demo.zhipu; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.PropertyNamingStrategy; -import com.zhipu.oapi.ClientV4; -import com.zhipu.oapi.Constants; -import com.zhipu.oapi.core.response.HttpxBinaryResponseContent; -import com.zhipu.oapi.service.v4.batchs.BatchCreateParams; -import com.zhipu.oapi.service.v4.batchs.BatchResponse; -import com.zhipu.oapi.service.v4.batchs.QueryBatchResponse; -import com.zhipu.oapi.service.v4.embedding.EmbeddingApiResponse; -import com.zhipu.oapi.service.v4.embedding.EmbeddingRequest; -import com.zhipu.oapi.service.v4.file.*; -import com.zhipu.oapi.service.v4.fine_turning.*; -import com.zhipu.oapi.service.v4.image.CreateImageRequest; -import com.zhipu.oapi.service.v4.image.ImageApiResponse; -import com.zhipu.oapi.service.v4.model.*; -import io.reactivex.Flowable; - -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; - - -public class V4Test { - - private final static Logger logger = LoggerFactory.getLogger(V4Test.class); - private static final String API_SECRET_KEY = "28550a39d4cfaabbbf38df04dd3931f5.IUvfTThUf0xBF5l0"; - - - private static final ClientV4 client = new ClientV4.Builder(API_SECRET_KEY) - .enableTokenCache() - .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS) - .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS)) - .build(); - - // 请自定义自己的业务id - private static final String requestIdTemplate = "mycompany-%d"; - - private static final ObjectMapper mapper = new ObjectMapper(); - - - public static ObjectMapper defaultObjectMapper() { - ObjectMapper mapper = new ObjectMapper(); - mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); - mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); - return mapper; - } - - @Test - public void test() { - - } - - /** - * sse-V4:function调用 - */ - @Test - public void testFunctionSSE() throws JsonProcessingException { - List messages = new ArrayList<>(); - ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "成都到北京要多久,天气如何"); - messages.add(chatMessage); - String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); - // 函数调用参数构建部分 - List chatToolList = new ArrayList<>(); - ChatTool chatTool = new ChatTool(); - - chatTool.setType(ChatToolType.FUNCTION.value()); - ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters(); - chatFunctionParameters.setType("object"); - Map properties = new HashMap<>(); - properties.put("location", new HashMap() {{ - put("type", "string"); - put("description", "城市,如:北京"); - }}); - properties.put("unit", new HashMap() {{ - put("type", "string"); - put("enum", new ArrayList() {{ - add("celsius"); - add("fahrenheit"); - }}); - }}); - chatFunctionParameters.setProperties(properties); - ChatFunction chatFunction = ChatFunction.builder() - .name("get_weather") - .description("Get the current weather of a location") - .parameters(chatFunctionParameters) - .build(); - chatTool.setFunction(chatFunction); - chatToolList.add(chatTool); - HashMap extraJson = new HashMap<>(); - extraJson.put("temperature", 0.5); - extraJson.put("max_tokens", 50); - - ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .model(Constants.ModelChatGLM4) - .stream(Boolean.TRUE) - .messages(messages) - .requestId(requestId) - .tools(chatToolList) - .toolChoice("auto") - .extraJson(extraJson) - .build(); - ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest); - if (sseModelApiResp.isSuccess()) { - AtomicBoolean isFirst = new AtomicBoolean(true); - List choices = new ArrayList<>(); - ChatMessageAccumulator chatMessageAccumulator = mapStreamToAccumulator(sseModelApiResp.getFlowable()) - .doOnNext(accumulator -> { - { - if (isFirst.getAndSet(false)) { - logger.info("Response: "); - } - if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) { - String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls()); - logger.info("tool_calls: {}", jsonString); - } - if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) { - logger.info(accumulator.getDelta().getContent()); - } - choices.add(accumulator.getChoice()); - } - }) - .doOnComplete(System.out::println) - .lastElement() - .blockingGet(); - - - ModelData data = new ModelData(); - data.setChoices(choices); - data.setUsage(chatMessageAccumulator.getUsage()); - data.setId(chatMessageAccumulator.getId()); - data.setCreated(chatMessageAccumulator.getCreated()); - data.setRequestId(chatCompletionRequest.getRequestId()); - sseModelApiResp.setFlowable(null);// 打印前置空 - sseModelApiResp.setData(data); - } - logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp)); - } - - - /** - * sse-V4:非function调用 - */ - @Test - public void testNonFunctionSSE() throws JsonProcessingException { - List messages = new ArrayList<>(); - ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大"); - messages.add(chatMessage); - HashMap extraJson = new HashMap<>(); - extraJson.put("temperature", 0.5); - extraJson.put("max_tokens", 3); - - String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); - ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .model(Constants.ModelChatGLM4) - .stream(Boolean.TRUE) - .messages(messages) - .requestId(requestId) - .extraJson(extraJson) - .build(); - ModelApiResponse sseModelApiResp = client.invokeModelApi(chatCompletionRequest); - // stream 处理方法 - if (sseModelApiResp.isSuccess()) { - AtomicBoolean isFirst = new AtomicBoolean(true); - List choices = new ArrayList<>(); - ChatMessageAccumulator chatMessageAccumulator = mapStreamToAccumulator(sseModelApiResp.getFlowable()) - .doOnNext(accumulator -> { - { - if (isFirst.getAndSet(false)) { - logger.info("Response: "); - } - if (accumulator.getDelta() != null && accumulator.getDelta().getTool_calls() != null) { - String jsonString = mapper.writeValueAsString(accumulator.getDelta().getTool_calls()); - logger.info("tool_calls: {}", jsonString); - } - if (accumulator.getDelta() != null && accumulator.getDelta().getContent() != null) { - logger.info("accumulator.getDelta().getContent(): {}", accumulator.getDelta().getContent()); - } - choices.add(accumulator.getChoice()); - } - }) - .doOnComplete(System.out::println) - .lastElement() - .blockingGet(); - - - ModelData data = new ModelData(); - data.setChoices(choices); - data.setUsage(chatMessageAccumulator.getUsage()); - data.setId(chatMessageAccumulator.getId()); - data.setCreated(chatMessageAccumulator.getCreated()); - data.setRequestId(chatCompletionRequest.getRequestId()); - sseModelApiResp.setFlowable(null);// 打印前置空 - sseModelApiResp.setData(data); - } - logger.info("model output: {}", mapper.writeValueAsString(sseModelApiResp)); - } - - - /** - * V4-同步function调用 - */ - @Test - public void testFunctionInvoke() { - List messages = new ArrayList<>(); - ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "你可以做什么"); - messages.add(chatMessage); - String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); - // 函数调用参数构建部分 - List chatToolList = new ArrayList<>(); - ChatTool chatTool = new ChatTool(); - chatTool.setType(ChatToolType.FUNCTION.value()); - ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters(); - chatFunctionParameters.setType("object"); - Map properties = new HashMap<>(); - properties.put("location", new HashMap() {{ - put("type", "string"); - put("description", "城市,如:北京"); - }}); - properties.put("unit", new HashMap() {{ - put("type", "string"); - put("enum", new ArrayList() {{ - add("celsius"); - add("fahrenheit"); - }}); - }}); - chatFunctionParameters.setProperties(properties); - ChatFunction chatFunction = ChatFunction.builder() - .name("get_weather") - .description("Get the current weather of a location") - .parameters(chatFunctionParameters) - .build(); - chatTool.setFunction(chatFunction); - - - ChatTool chatTool1 = new ChatTool(); - chatTool1.setType(ChatToolType.WEB_SEARCH.value()); - WebSearch webSearch = new WebSearch(); - webSearch.setSearch_query("清华的升学率"); - webSearch.setSearch_result(true); - webSearch.setEnable(false); - chatTool1.setWeb_search(webSearch); - - chatToolList.add(chatTool); - chatToolList.add(chatTool1); - - ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .model(Constants.ModelChatGLM4) - .stream(Boolean.FALSE) - .invokeMethod(Constants.invokeMethod) - .messages(messages) - .requestId(requestId) - .tools(chatToolList) - .toolChoice("auto") - .build(); - ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest); - try { - logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp)); - } catch (JsonProcessingException e) { - logger.error("model output error", e); - } - } - - - /** - * V4-同步非function调用 - */ - @Test - public void testNonFunctionInvoke() throws JsonProcessingException { - List messages = new ArrayList<>(); - ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大"); - messages.add(chatMessage); - String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); - - - HashMap extraJson = new HashMap<>(); - extraJson.put("temperature", 0.5); - extraJson.put("max_tokens", 3); - ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .model(Constants.ModelChatGLM4) - .stream(Boolean.FALSE) - .invokeMethod(Constants.invokeMethod) - .messages(messages) - .requestId(requestId) - .extraJson(extraJson) - .build(); - ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest); - logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp)); - } - - - /** - * V4-同步非function调用 - */ - @Test - public void testCharGlmInvoke() throws JsonProcessingException { - List messages = new ArrayList<>(); - ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大"); - messages.add(chatMessage); - String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); - - - HashMap extraJson = new HashMap<>(); - extraJson.put("temperature", 0.5); - - ChatMeta meta = new ChatMeta(); - meta.setUser_info("我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。我擅长拍摄音乐题材的电影。苏梦远对我的态度是尊敬的,并视我为良师益友。"); - meta.setBot_info("苏梦远,本名苏远心,是一位当红的国内女歌手及演员。在参加选秀节目后,凭借独特的嗓音及出众的舞台魅力迅速成名,进入娱乐圈。她外表美丽动人,但真正的魅力在于她的才华和勤奋。苏梦远是音乐学院毕业的优秀生,善于创作,拥有多首热门原创歌曲。除了音乐方面的成就,她还热衷于慈善事业,积极参加公益活动,用实际行动传递正能量。在工作中,她对待工作非常敬业,拍戏时总是全身心投入角色,赢得了业内人士的赞誉和粉丝的喜爱。虽然在娱乐圈,但她始终保持低调、谦逊的态度,深得同行尊重。在表达时,苏梦远喜欢使用“我们”和“一起”,强调团队精神。"); - meta.setBot_name("苏梦远"); - meta.setUser_name("陆星辰"); - - ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .model(Constants.ModelCharGLM3) - .stream(Boolean.FALSE) - .invokeMethod(Constants.invokeMethod) - .messages(messages) - .requestId(requestId) - .meta(meta) - .extraJson(extraJson) - .build(); - ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest); - logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp)); - } - - /** - * V4异步调用 - */ - @Test - public void testAsyncInvoke() throws JsonProcessingException { - String taskId = getAsyncTaskId(); - testQueryResult(taskId); - } - -// - - /** - * 文生图 - */ - @Test - public void testCreateImage() throws JsonProcessingException { - CreateImageRequest createImageRequest = new CreateImageRequest(); - createImageRequest.setModel(Constants.ModelCogView); - createImageRequest.setPrompt("Futuristic cloud data center, showcasing advanced technologgy and a high-tech atmosp\n" + - "here. The image should depict a spacious, well-lit interior with rows of server racks, glo\n" + - "wing lights, and digital displays. Include abstract representattions of data streams and\n" + - "onnectivity, symbolizing the essence of cloud computing. Thee style should be modern a\n" + - "nd sleek, with a focus on creating a sense of innovaticon and cutting-edge technology\n" + - "The overall ambiance should convey the power and effciency of cloud services in a visu\n" + - "ally engaging way."); - createImageRequest.setRequestId("test11111111111111"); - ImageApiResponse imageApiResponse = client.createImage(createImageRequest); - logger.info("imageApiResponse: {}", mapper.writeValueAsString(imageApiResponse)); - } - -// -// /** -// * 图生文 -// */ -// @Test -// public void testImageToWord() throws JsonProcessingException { -// List messages = new ArrayList<>(); -// List> contentList = new ArrayList<>(); -// Map textMap = new HashMap<>(); -// textMap.put("type", "text"); -// textMap.put("text", "图里有什么"); -// Map typeMap = new HashMap<>(); -// typeMap.put("type", "image_url"); -// Map urlMap = new HashMap<>(); -// urlMap.put("url", "https://sfile.chatglm.cn/testpath/275ae5b6-5390-51ca-a81a-60332d1a7cac_0.png"); -// typeMap.put("image_url", urlMap); -// contentList.add(textMap); -// contentList.add(typeMap); -// ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), contentList); -// messages.add(chatMessage); -// String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); -// -// -// ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() -// .model(Constants.ModelChatGLM4V) -// .stream(Boolean.FALSE) -// .invokeMethod(Constants.invokeMethod) -// .messages(messages) -// .requestId(requestId) -// .build(); -// ModelApiResponse modelApiResponse = client.invokeModelApi(chatCompletionRequest); -// logger.info("model output: {}", mapper.writeValueAsString(modelApiResponse)); -// } -// - - /** - * 向量模型V4 - */ - @Test - public void testEmbeddings() throws JsonProcessingException { - EmbeddingRequest embeddingRequest = new EmbeddingRequest(); - embeddingRequest.setInput("hello world"); - embeddingRequest.setModel(Constants.ModelEmbedding2); - EmbeddingApiResponse apiResponse = client.invokeEmbeddingsApi(embeddingRequest); - logger.info("model output: {}", mapper.writeValueAsString(apiResponse)); - } - - - /** - * V4微调上传数据集 - */ - @Test - public void testUploadFile() throws JsonProcessingException { - String filePath = "demo.jsonl"; - - String path = ClassLoader.getSystemResource(filePath).getPath(); - String purpose = "fine-tune"; - UploadFileRequest request = UploadFileRequest.builder() - .purpose(purpose) - .filePath(path) - .build(); - - FileApiResponse fileApiResponse = client.invokeUploadFileApi(request); - logger.info("model output: {}", mapper.writeValueAsString(fileApiResponse)); - } - - - /** - * 微调V4-查询上传文件列表 - */ - @Test - public void testQueryUploadFileList() throws JsonProcessingException { - QueryFilesRequest queryFilesRequest = new QueryFilesRequest(); - QueryFileApiResponse queryFileApiResponse = client.queryFilesApi(queryFilesRequest); - logger.info("model output: {}", mapper.writeValueAsString(queryFileApiResponse)); - } - - @Test - public void testFileContent() throws IOException { - try { - - HttpxBinaryResponseContent httpxBinaryResponseContent = client.fileContent("20240514_ea19d21b-d256-4586-b0df-e80a45e3c286"); - String filePath = "demo_output.jsonl"; - String resourcePath = V4Test.class.getClassLoader().getResource("").getPath(); - - httpxBinaryResponseContent.streamToFile(resourcePath + "1" + filePath, 1000); - - } catch (IOException e) { - logger.error("file content error", e); - } - } - -//// @Test -//// public void deletedFile() throws IOException { -//// FileDelResponse fileDelResponse = client.deletedFile("20240514_ea19d21b-d256-4586-b0df-e80a45e3c286"); -//// -//// logger.info("model output: {}", mapper.writeValueAsString(fileDelResponse)); -//// -//// } -// -// - - /** - * 微调V4-创建微调任务 - */ - @Test - public void testCreateFineTuningJob() throws JsonProcessingException { - FineTuningJobRequest request = new FineTuningJobRequest(); - String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); - request.setRequestId(requestId); - request.setModel("chatglm3-6b"); - request.setTraining_file("file-20240118082608327-kp8qr"); - CreateFineTuningJobApiResponse createFineTuningJobApiResponse = client.createFineTuningJob(request); - logger.info("model output: {}", mapper.writeValueAsString(createFineTuningJobApiResponse)); - } - - - /** - * 微调V4-查询微调任务 - */ - @Test - public void testRetrieveFineTuningJobs() throws JsonProcessingException { - QueryFineTuningJobRequest queryFineTuningJobRequest = new QueryFineTuningJobRequest(); - queryFineTuningJobRequest.setJobId("ftjob-20240429172916475-fb7r9"); -// queryFineTuningJobRequest.setLimit(1); -// queryFineTuningJobRequest.setAfter(1); - QueryFineTuningJobApiResponse queryFineTuningJobApiResponse = client.retrieveFineTuningJobs(queryFineTuningJobRequest); - logger.info("model output: {}", mapper.writeValueAsString(queryFineTuningJobApiResponse)); - } - - - /** - * 微调V4-查询微调任务 - */ - @Test - public void testFueryFineTuningJobsEvents() throws JsonProcessingException { - QueryFineTuningJobRequest queryFineTuningJobRequest = new QueryFineTuningJobRequest(); - queryFineTuningJobRequest.setJobId("ftjob-20240429172916475-fb7r9"); - - QueryFineTuningEventApiResponse queryFineTuningEventApiResponse = client.queryFineTuningJobsEvents(queryFineTuningJobRequest); - logger.info("model output: {}", mapper.writeValueAsString(queryFineTuningEventApiResponse)); - } - - - /** - * testQueryPersonalFineTuningJobs V4-查询个人微调任务 - */ - @Test - public void testQueryPersonalFineTuningJobs() throws JsonProcessingException { - QueryPersonalFineTuningJobRequest queryPersonalFineTuningJobRequest = new QueryPersonalFineTuningJobRequest(); - queryPersonalFineTuningJobRequest.setLimit(1); - QueryPersonalFineTuningJobApiResponse queryPersonalFineTuningJobApiResponse = client.queryPersonalFineTuningJobs(queryPersonalFineTuningJobRequest); - logger.info("model output: {}", mapper.writeValueAsString(queryPersonalFineTuningJobApiResponse)); - } - - - @Test - public void testBatchesCreate() { - BatchCreateParams batchCreateParams = new BatchCreateParams( - "24h", - "/v4/chat/completions", - "20240514_ea19d21b-d256-4586-b0df-e80a45e3c286", - new HashMap() {{ - put("key1", "value1"); - put("key2", "value2"); - }} - ); - - BatchResponse batchResponse = client.batchesCreate(batchCreateParams); - logger.info("output: {}", batchResponse); -// output: BatchResponse(code=200, msg=调用成功, success=true, data=Batch(id=batch_1791021399316246528, completionWindow=24h, createdAt=1715847751822, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=null, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=null, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null)) - } - - @Test - public void testDeleteFineTuningJob() { - FineTuningJobIdRequest request = FineTuningJobIdRequest.builder().jobId("test").build(); - QueryFineTuningJobApiResponse queryFineTuningJobApiResponse = client.deleteFineTuningJob(request); - logger.info("output: {}", queryFineTuningJobApiResponse); - - } - - @Test - public void testCancelFineTuningJob() { - FineTuningJobIdRequest request = FineTuningJobIdRequest.builder().jobId("test").build(); - QueryFineTuningJobApiResponse queryFineTuningJobApiResponse = client.cancelFineTuningJob(request); - logger.info("output: {}", queryFineTuningJobApiResponse); - - } - - @Test - public void testBatchesRetrieve() { - BatchResponse batchResponse = client.batchesRetrieve("batch_1791021399316246528"); - logger.info("output: {}", batchResponse); - - } - - @Test - public void testDeleteFineTuningModel() { - FineTuningJobModelRequest request = FineTuningJobModelRequest.builder().fineTunedModel("test").build(); - - FineTunedModelsStatusResponse fineTunedModelsStatusResponse = client.deleteFineTuningModel(request); - logger.info("output: {}", fineTunedModelsStatusResponse); -// output: BatchResponse(code=200, msg=调用成功, success=true, data=Batch(id=batch_1791021399316246528, completionWindow=24h, createdAt=1715847752000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null)) - - } - - @Test - public void testBatchesList() { - QueryBatchRequest queryBatchRequest = new QueryBatchRequest(); - queryBatchRequest.setLimit(10); - QueryBatchResponse queryBatchResponse = client.batchesList(queryBatchRequest); - logger.info("output: {}", queryBatchResponse); -// output: QueryBatchResponse(code=200, msg=调用成功, success=true, data=BatchPage(object=list, data=[Batch(id=batch_1790291013237211136, completionWindow=24h, createdAt=1715673614000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=completed, cancelledAt=null, cancellingAt=1715673699000, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={description=job test}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null), Batch(id=batch_1790292763050508288, completionWindow=24h, createdAt=1715674031000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=completed, cancelledAt=null, cancellingAt=null, completedAt=1715766416000, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=1715754569000, inProgressAt=null, metadata={description=job test}, outputFileId=1715766415_e5a77222855a406ca8a082de28549c99, requestCounts=BatchRequestCounts(completed=2, failed=0, total=2), error=null), Batch(id=batch_1791021114887909376, completionWindow=24h, createdAt=1715847684000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null), Batch(id=batch_1791021399316246528, completionWindow=24h, createdAt=1715847752000, endpoint=/v4/chat/completions, inputFileId=20240514_ea19d21b-d256-4586-b0df-e80a45e3c286, object=batch, status=validating, cancelledAt=null, cancellingAt=null, completedAt=null, errorFileId=, errors=null, expiredAt=null, expiresAt=null, failedAt=null, finalizingAt=null, inProgressAt=null, metadata={key1=value1, key2=value2}, outputFileId=, requestCounts=BatchRequestCounts(completed=0, failed=0, total=0), error=null)], error=null)) - - } - - @Test - public void testBatchesCancel() throws JsonProcessingException { - getAsyncTaskId(); - } - - private static String getAsyncTaskId() throws JsonProcessingException { - List messages = new ArrayList<>(); - ChatMessage chatMessage = new ChatMessage(ChatMessageRole.USER.value(), "ChatGLM和你哪个更强大"); - messages.add(chatMessage); - String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); - // 函数调用参数构建部分 - List chatToolList = new ArrayList<>(); - ChatTool chatTool = new ChatTool(); - chatTool.setType(ChatToolType.FUNCTION.value()); - ChatFunctionParameters chatFunctionParameters = new ChatFunctionParameters(); - chatFunctionParameters.setType("object"); - Map properties = new HashMap<>(); - properties.put("location", new HashMap() {{ - put("type", "string"); - put("description", "城市,如:北京"); - }}); - properties.put("unit", new HashMap() {{ - put("type", "string"); - put("enum", new ArrayList() {{ - add("celsius"); - add("fahrenheit"); - }}); - }}); - chatFunctionParameters.setProperties(properties); - ChatFunction chatFunction = ChatFunction.builder() - .name("get_weather") - .description("Get the current weather of a location") - .parameters(chatFunctionParameters) - .build(); - chatTool.setFunction(chatFunction); - chatToolList.add(chatTool); - ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .model(Constants.ModelChatGLM4) - .stream(Boolean.FALSE) - .invokeMethod(Constants.invokeMethodAsync) - .messages(messages) - .requestId(requestId) - .tools(chatToolList) - .toolChoice("auto") - .build(); - ModelApiResponse invokeModelApiResp = client.invokeModelApi(chatCompletionRequest); - logger.info("model output: {}", mapper.writeValueAsString(invokeModelApiResp)); - return invokeModelApiResp.getData().getId(); - } - - - private static void testQueryResult(String taskId) throws JsonProcessingException { - QueryModelResultRequest request = new QueryModelResultRequest(); - request.setTaskId(taskId); - QueryModelResultResponse queryResultResp = client.queryModelResult(request); - logger.info("model output {}", mapper.writeValueAsString(queryResultResp)); - } - - public static Flowable mapStreamToAccumulator(Flowable flowable) { - return flowable.map(chunk -> { - return new ChatMessageAccumulator(chunk.getChoices().get(0).getDelta(), null, chunk.getChoices().get(0), chunk.getUsage(), chunk.getCreated(), chunk.getId()); - }); - } -} diff --git a/ruoyi-modules/ruoyi-fusion/src/main/java/org/ruoyi/fusion/controller/ChatController.java b/ruoyi-modules/ruoyi-fusion/src/main/java/org/ruoyi/fusion/controller/ChatController.java index dc5592f3..713c6f1b 100644 --- a/ruoyi-modules/ruoyi-fusion/src/main/java/org/ruoyi/fusion/controller/ChatController.java +++ b/ruoyi-modules/ruoyi-fusion/src/main/java/org/ruoyi/fusion/controller/ChatController.java @@ -17,7 +17,6 @@ import org.ruoyi.common.core.exception.base.BaseException; import org.ruoyi.common.mybatis.core.page.PageQuery; import org.ruoyi.common.mybatis.core.page.TableDataInfo; import org.ruoyi.common.satoken.utils.LoginHelper; -import org.ruoyi.knowledge.service.EmbeddingService; import org.ruoyi.system.domain.bo.ChatMessageBo; import org.ruoyi.system.domain.request.translation.TranslationRequest; import org.ruoyi.system.domain.vo.ChatMessageVo; @@ -48,7 +47,6 @@ public class ChatController { private final IChatMessageService chatMessageService; - private final EmbeddingService embeddingService; /** * 聊天接口 */ diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/split/CharacterTextSplitter.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/split/CharacterTextSplitter.java index cf86b34a..4b8f4af1 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/split/CharacterTextSplitter.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/split/CharacterTextSplitter.java @@ -2,6 +2,7 @@ package org.ruoyi.knowledge.chain.split; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; +import org.ruoyi.common.core.utils.StringUtils; import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo; import org.ruoyi.knowledge.service.IKnowledgeInfoService; import org.springframework.context.annotation.Lazy; @@ -29,7 +30,7 @@ public class CharacterTextSplitter implements TextSplitter { int textBlockSize = knowledgeInfoVo.getTextBlockSize(); int overlapChar = knowledgeInfoVo.getOverlapChar(); List chunkList = new ArrayList<>(); - if (content.contains(knowledgeSeparator)) { + if (content.contains(knowledgeSeparator) && StringUtils.isNotBlank(knowledgeSeparator)) { // 按自定义分隔符切分 String[] chunks = content.split(knowledgeSeparator); chunkList.addAll(Arrays.asList(chunks)); diff --git a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreFactory.java b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreFactory.java index 5acf28b6..4471fca9 100644 --- a/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreFactory.java +++ b/ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorstore/VectorStoreFactory.java @@ -1,5 +1,6 @@ package org.ruoyi.knowledge.chain.vectorstore; +import cn.hutool.core.util.StrUtil; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo; @@ -23,8 +24,13 @@ public class VectorStoreFactory { } public VectorStore getVectorStore(String kid){ - KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoMapper.selectVoById(Long.valueOf(kid)); - String vectorModel = knowledgeInfoVo.getVector(); + String vectorModel = "weaviate"; + if (StrUtil.isNotEmpty(kid)) { + KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoMapper.selectVoById(Long.valueOf(kid)); + if (knowledgeInfoVo != null && StrUtil.isNotEmpty(knowledgeInfoVo.getVector())) { + vectorModel = knowledgeInfoVo.getVector(); + } + } if ("weaviate".equals(vectorModel)){ return weaviateVectorStore; }else if ("milvus".equals(vectorModel)){ diff --git a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/plugin/WebSearchPlugin.java b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/plugin/WebSearchPlugin.java deleted file mode 100644 index 63b290a3..00000000 --- a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/plugin/WebSearchPlugin.java +++ /dev/null @@ -1,212 +0,0 @@ -package org.ruoyi.system.plugin; - - -import cn.hutool.json.JSONUtil; -import com.alibaba.fastjson.JSONObject; -import lombok.Builder; -import lombok.Data; -import okhttp3.OkHttpClient; -import okhttp3.logging.HttpLoggingInterceptor; -import org.junit.Before; -import org.junit.Test; -import org.ruoyi.common.chat.demo.ConsoleEventSourceListenerV3; -import org.ruoyi.common.chat.entity.chat.ChatCompletion; -import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse; -import org.ruoyi.common.chat.entity.chat.Message; -import org.ruoyi.common.chat.entity.chat.Parameters; -import org.ruoyi.common.chat.entity.chat.tool.ToolCallFunction; -import org.ruoyi.common.chat.entity.chat.tool.ToolCalls; -import org.ruoyi.common.chat.entity.chat.tool.Tools; -import org.ruoyi.common.chat.entity.chat.tool.ToolsFunction; -import org.ruoyi.common.chat.openai.OpenAiClient; -import org.ruoyi.common.chat.openai.OpenAiStreamClient; -import org.ruoyi.common.chat.openai.function.KeyRandomStrategy; -import org.ruoyi.common.chat.openai.interceptor.DynamicKeyOpenAiAuthInterceptor; -import org.ruoyi.common.chat.openai.interceptor.OpenAILogger; -import org.ruoyi.common.chat.openai.interceptor.OpenAiResponseInterceptor; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -public class WebSearchPlugin { - - private OpenAiClient openAiClient; - private OpenAiStreamClient openAiStreamClient; - - @Before - public void before() { - //可以为null -// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890)); - HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger()); - //!!!!千万别再生产或者测试环境打开BODY级别日志!!!! - //!!!生产或者测试环境建议设置为这三种级别:NONE,BASIC,HEADERS,!!! - httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS); - OkHttpClient okHttpClient = new OkHttpClient - .Builder() -// .proxy(proxy) - .addInterceptor(httpLoggingInterceptor) - .addInterceptor(new OpenAiResponseInterceptor()) - .connectTimeout(10, TimeUnit.SECONDS) - .writeTimeout(30, TimeUnit.SECONDS) - .readTimeout(30, TimeUnit.SECONDS) - .build(); - openAiClient = OpenAiClient.builder() - //支持多key传入,请求时候随机选择 - .apiKey(Arrays.asList("xx")) - //自定义key的获取策略:默认KeyRandomStrategy - //.keyStrategy(new KeyRandomStrategy()) - .keyStrategy(new KeyRandomStrategy()) - .okHttpClient(okHttpClient) - //自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址) - .apiHost("https://open.bigmodel.cn/") - .build(); - - openAiStreamClient = OpenAiStreamClient.builder() - //支持多key传入,请求时候随机选择 - .apiKey(Arrays.asList("xx")) - //自定义key的获取策略:默认KeyRandomStrategy - .keyStrategy(new KeyRandomStrategy()) - .authInterceptor(new DynamicKeyOpenAiAuthInterceptor()) - .okHttpClient(okHttpClient) - //自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址) - .apiHost("https://open.bigmodel.cn/") - .build(); - } - - @Test - public void test() { - Message message = Message.builder().role(Message.Role.USER).content("今天武汉天气怎么样").build(); - ChatCompletion chatCompletion = ChatCompletion - .builder() - .messages(Collections.singletonList(message)) -// .tools(Collections.singletonList(tools)) - .model("web-search-pro") - .build(); - ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion); - - System.out.printf("chatCompletionResponse=%s\n", JSONUtil.toJsonStr(chatCompletionResponse)); - } - - @Test - public void streamToolsChat() { - - CountDownLatch countDownLatch = new CountDownLatch(1); - ConsoleEventSourceListenerV3 eventSourceListener = new ConsoleEventSourceListenerV3(countDownLatch); - - Message message = Message.builder().role(Message.Role.USER).content("给我输出一个长度为2的中文词语,并解释下词语对应物品的用途").build(); - //属性一 - JSONObject wordLength = new JSONObject(); - wordLength.put("type", "number"); - wordLength.put("description", "词语的长度"); - //属性二 - JSONObject language = new JSONObject(); - language.put("type", "string"); - language.put("enum", Arrays.asList("zh", "en")); - language.put("description", "语言类型,例如:zh代表中文、en代表英语"); - //参数 - JSONObject properties = new JSONObject(); - properties.put("wordLength", wordLength); - properties.put("language", language); - Parameters parameters = Parameters.builder() - .type("object") - .properties(properties) - .required(Collections.singletonList("wordLength")).build(); - Tools tools = Tools.builder() - .type(Tools.Type.FUNCTION.getName()) - .function(ToolsFunction.builder().name("getOneWord").description("获取一个指定长度和语言类型的词语").parameters(parameters).build()) - .build(); - - ChatCompletion chatCompletion = ChatCompletion - .builder() - .messages(Collections.singletonList(message)) - .tools(Collections.singletonList(tools)) - .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) - .build(); - openAiStreamClient.streamChatCompletion(chatCompletion, eventSourceListener); - - try { - countDownLatch.await(); - } catch (InterruptedException e) { - e.printStackTrace(); - } - - ToolCalls openAiReturnToolCalls = eventSourceListener.getToolCalls(); - WordParam wordParam = JSONUtil.toBean(openAiReturnToolCalls.getFunction().getArguments(), WordParam.class); - String oneWord = getOneWord(wordParam); - - - ToolCallFunction tcf = ToolCallFunction.builder().name("getOneWord").arguments(openAiReturnToolCalls.getFunction().getArguments()).build(); - ToolCalls tc = ToolCalls.builder().id(openAiReturnToolCalls.getId()).type(ToolCalls.Type.FUNCTION.getName()).function(tcf).build(); - //构造tool call - Message message2 = Message.builder().role(Message.Role.ASSISTANT).content("方法参数").toolCalls(Collections.singletonList(tc)).build(); - String content - = "{ " + - "\"wordLength\": \"3\", " + - "\"language\": \"zh\", " + - "\"word\": \"" + oneWord + "\"," + - "\"用途\": [\"直接吃\", \"做沙拉\", \"售卖\"]" + - "}"; - Message message3 = Message.builder().toolCallId(openAiReturnToolCalls.getId()).role(Message.Role.TOOL).name("getOneWord").content(content).build(); - List messageList = Arrays.asList(message, message2, message3); - ChatCompletion chatCompletionV2 = ChatCompletion - .builder() - .messages(messageList) - .model(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName()) - .build(); - - - CountDownLatch countDownLatch1 = new CountDownLatch(1); - openAiStreamClient.streamChatCompletion(chatCompletionV2, new ConsoleEventSourceListenerV3(countDownLatch)); - try { - countDownLatch1.await(); - } catch (InterruptedException e) { - e.printStackTrace(); - } - try { - countDownLatch1.await(); - } catch (InterruptedException e) { - e.printStackTrace(); - } - - } - - @Data - @Builder - static class WordParam { - private int wordLength; - @Builder.Default - private String language = "zh"; - } - - - /** - * 获取一个词语(根据语言和字符长度查询) - * @param wordParam - * @return - */ - public String getOneWord(WordParam wordParam) { - - List zh = Arrays.asList("大香蕉", "哈密瓜", "苹果"); - List en = Arrays.asList("apple", "banana", "cantaloupe"); - if (wordParam.getLanguage().equals("zh")) { - for (String e : zh) { - if (e.length() == wordParam.getWordLength()) { - return e; - } - } - } - if (wordParam.getLanguage().equals("en")) { - for (String e : en) { - if (e.length() == wordParam.getWordLength()) { - return e; - } - } - } - return "西瓜"; - } - - -} diff --git a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/ISysModelService.java b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/ISysModelService.java index 3e0a6c36..ac5182c8 100644 --- a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/ISysModelService.java +++ b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/ISysModelService.java @@ -2,6 +2,7 @@ package org.ruoyi.system.service; import org.ruoyi.common.mybatis.core.page.PageQuery; import org.ruoyi.common.mybatis.core.page.TableDataInfo; +import org.ruoyi.system.domain.SysModel; import org.ruoyi.system.domain.bo.SysModelBo; import org.ruoyi.system.domain.vo.SysModelVo; @@ -45,4 +46,9 @@ public interface ISysModelService { * 校验并批量删除系统模型信息 */ Boolean deleteWithValidByIds(Collection ids, Boolean isValid); + + /** + * 根据模型名称查询模型 + */ + SysModel selectModelByName(String modelName); } diff --git a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SseServiceImpl.java b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SseServiceImpl.java index a2cfeffe..c822e3d2 100644 --- a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SseServiceImpl.java +++ b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SseServiceImpl.java @@ -1,8 +1,10 @@ package org.ruoyi.system.service.impl; import cn.dev33.satoken.stp.StpUtil; -import cn.hutool.core.collection.CollectionUtil; import com.alibaba.fastjson.JSONObject; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.zhipu.oapi.ClientV4; +import com.zhipu.oapi.service.v4.tools.*; import io.github.ollama4j.OllamaAPI; import io.github.ollama4j.models.chat.OllamaChatMessageRole; import io.github.ollama4j.models.chat.OllamaChatRequestBuilder; @@ -17,10 +19,7 @@ import org.ruoyi.common.chat.config.LocalCache; import org.ruoyi.common.chat.domain.request.ChatRequest; import org.ruoyi.common.chat.domain.request.Dall3Request; import org.ruoyi.common.chat.entity.Tts.TextToSpeech; -import org.ruoyi.common.chat.entity.chat.ChatCompletion; -import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse; -import org.ruoyi.common.chat.entity.chat.Content; -import org.ruoyi.common.chat.entity.chat.Message; +import org.ruoyi.common.chat.entity.chat.*; import org.ruoyi.common.chat.entity.files.UploadFileResponse; import org.ruoyi.common.chat.entity.images.Image; import org.ruoyi.common.chat.entity.images.ImageResponse; @@ -33,17 +32,15 @@ import org.ruoyi.common.chat.plugin.CmdPlugin; import org.ruoyi.common.chat.plugin.CmdReq; import org.ruoyi.common.chat.plugin.SqlPlugin; import org.ruoyi.common.chat.plugin.SqlReq; -import org.ruoyi.common.chat.sse.ConsoleEventSourceListener; import org.ruoyi.common.chat.utils.TikTokensUtil; import org.ruoyi.common.core.domain.model.LoginUser; import org.ruoyi.common.core.exception.base.BaseException; import org.ruoyi.common.core.service.ConfigService; import org.ruoyi.common.core.utils.StringUtils; import org.ruoyi.common.satoken.utils.LoginHelper; +import org.ruoyi.system.domain.SysModel; import org.ruoyi.system.domain.bo.ChatMessageBo; -import org.ruoyi.system.domain.bo.SysModelBo; import org.ruoyi.system.domain.request.translation.TranslationRequest; -import org.ruoyi.system.domain.vo.SysModelVo; import org.ruoyi.system.listener.SSEEventSourceListener; import org.ruoyi.system.service.*; import org.springframework.core.io.InputStreamResource; @@ -65,6 +62,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; @Service @@ -76,18 +76,21 @@ public class SseServiceImpl implements ISseService { private final ChatConfig chatConfig; + private final IChatCostService chatService; private final IChatMessageService chatMessageService; private final ISysModelService sysModelService; - private final ISysUserService userService; - private final ConfigService configService; static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build(); + private static final String requestIdTemplate = "mycompany-%d"; + + private static final ObjectMapper mapper = new ObjectMapper(); + @Override public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) { openAiStreamClient = chatConfig.getOpenAiStreamClient(); @@ -96,11 +99,10 @@ public class SseServiceImpl implements ISseService { // 获取对话消息列表 List messages = chatRequest.getMessages(); try { + String chatString = null; if (StpUtil.isLogin()) { LocalCache.CACHE.put("userId", getUserId()); Object content = messages.get(messages.size() - 1).getContent(); - - String chatString = ""; if (content instanceof List listContent) { if (!listContent.isEmpty() && listContent.get(0) instanceof Content) { chatString = ((Content) listContent.get(0)).getText(); @@ -123,39 +125,89 @@ public class SseServiceImpl implements ISseService { throw new BaseException("文本不合规,请修改!"); } } - //根据模型名称查询模型信息 - SysModelBo sysModelBo = new SysModelBo(); + String model = chatRequest.getModel(); // 如果是gpts系列模型 if (chatRequest.getModel().startsWith("gpt-4-gizmo")) { - sysModelBo.setModelName("gpt-4-gizmo"); - } else { - sysModelBo.setModelName(chatRequest.getModel()); + model = "gpt-4-gizmo"; } - List sysModelList = sysModelService.queryList(sysModelBo); - - if (CollectionUtil.isEmpty(sysModelList)) { + SysModel sysModel = sysModelService.selectModelByName(model); + if (sysModel != null) { // 如果模型不存在默认使用token扣费方式 processByToken(chatRequest.getModel(), chatString, chatMessageBo); } else { - openAiStreamClient = chatConfig.createOpenAiStreamClient(sysModelList.get(0).getApiHost(), sysModelList.get(0).getApiKey()); + openAiStreamClient = chatConfig.createOpenAiStreamClient(sysModel.getApiHost(), sysModel.getApiKey()); // 模型设置默认提示词 - SysModelVo firstModel = sysModelList.get(0); - if (StringUtils.isNotEmpty(firstModel.getSystemPrompt())) { - Message sysMessage = Message.builder().content(firstModel.getSystemPrompt()).role(Message.Role.SYSTEM).build(); + + if (StringUtils.isNotEmpty(sysModel.getSystemPrompt())) { + Message sysMessage = Message.builder().content(sysModel.getSystemPrompt()).role(Message.Role.SYSTEM).build(); messages.add(sysMessage); } // 计费类型: 1 token扣费 2 次数扣费 - if ("2".equals(firstModel.getModelType())) { - processByModelPrice(firstModel, chatMessageBo); + if ("2".equals(sysModel.getModelType())) { + processByModelPrice(sysModel, chatMessageBo); } else { - processByToken(chatRequest.getModel(), chatString, chatMessageBo); + processByToken(chatRequest.getModel(), chatString, chatMessageBo); } } } - if("openCmd".equals(chatRequest.getModel())) { + String configValue = configService.getConfigValue("zhipu", "key"); + // 添加联网信息 + if(StringUtils.isNotEmpty(configValue)){ + ClientV4 client = new ClientV4.Builder(configValue) + .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS) + .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS)) + .build(); + + SearchChatMessage jsonNodes = new SearchChatMessage(); + jsonNodes.setRole(Message.Role.USER.getName()); + jsonNodes.setContent(chatString); + + String requestId = String.format(requestIdTemplate, System.currentTimeMillis()); + WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder() + .model("web-search-pro") + .stream(Boolean.TRUE) + .messages(Collections.singletonList(jsonNodes)) + .requestId(requestId) + .build(); + WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest); + List choices = new ArrayList<>(); + if (webSearchApiResponse.isSuccess()) { + AtomicBoolean isFirst = new AtomicBoolean(true); + + AtomicReference lastAccumulator = new AtomicReference<>(); + + webSearchApiResponse.getFlowable().map(result -> result) + .doOnNext(accumulator -> { + { + if (isFirst.getAndSet(false)) { + log.info("Response: "); + } + ChoiceDelta delta = accumulator.getChoices().get(0).getDelta(); + if (delta != null && delta.getToolCalls() != null) { + log.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls())); + } + choices.add(delta); + } + }) + .doOnComplete(() -> System.out.println("Stream completed.")) + .doOnError(throwable -> System.err.println("Error: " + throwable)) + .blockingSubscribe(); + + WebSearchPro chatMessageAccumulator = lastAccumulator.get(); + + webSearchApiResponse.setFlowable(null);// 打印前置空 + webSearchApiResponse.setData(chatMessageAccumulator); + } + + + Message message = Message.builder().role(Message.Role.ASSISTANT).content(choices.get(1).getToolCalls().toString()).build(); + messages.add(message); + } + + if ("openCmd".equals(chatRequest.getModel())) { sseEmitter.send(cmdPlugin(messages)); sseEmitter.complete(); - }else if ("sqlPlugin".equals(chatRequest.getModel())){ + } else if ("sqlPlugin".equals(chatRequest.getModel())) { sseEmitter.send(sqlPlugin(messages)); sseEmitter.complete(); } else { @@ -229,7 +281,7 @@ public class SseServiceImpl implements ISseService { * @param model 模型信息 * @param chatMessageBo 对话信息 */ - private void processByModelPrice(SysModelVo model, ChatMessageBo chatMessageBo) { + private void processByModelPrice(SysModel model, ChatMessageBo chatMessageBo) { double cost = model.getModelPrice(); chatService.deductUserBalance(getUserId(), cost); chatMessageBo.setDeductCost(cost); @@ -316,16 +368,14 @@ public class SseServiceImpl implements ISseService { .style(request.getStyle()) .build(); ImageResponse imageResponse = openAiStreamClient.genImages(image); - SysModelBo sysModelBo = new SysModelBo(); - sysModelBo.setModelName(request.getModel()); - List sysModelList = sysModelService.queryList(sysModelBo); + SysModel sysModel = sysModelService.selectModelByName(request.getModel()); //chatService.deductUserBalance(getUserId(),sysModelList.get(0).getModelPrice()); // 保存消息记录 ChatMessageBo chatMessageBo = new ChatMessageBo(); chatMessageBo.setUserId(getUserId()); chatMessageBo.setModelName(Image.Model.DALL_E_3.getName()); chatMessageBo.setContent(request.getPrompt()); - chatMessageBo.setDeductCost(sysModelList.get(0).getModelPrice()); + chatMessageBo.setDeductCost(sysModel.getModelPrice()); chatMessageBo.setTotalTokens(0); chatMessageService.insertByBo(chatMessageBo); return imageResponse.getData(); @@ -342,16 +392,14 @@ public class SseServiceImpl implements ISseService { .n(1) .build(); ImageResponse imageResponse = openAiStreamClient.genImages(image); - SysModelBo sysModelBo = new SysModelBo(); - sysModelBo.setModelName("dall3"); - List sysModelList = sysModelService.queryList(sysModelBo); + SysModel dall3 = sysModelService.selectModelByName("dall3"); chatService.deductUserBalance(Long.valueOf(userId), 0.3); // 保存消息记录 ChatMessageBo chatMessageBo = new ChatMessageBo(); chatMessageBo.setUserId(getUserId()); chatMessageBo.setModelName(Image.Model.DALL_E_3.getName()); chatMessageBo.setContent(prompt); - chatMessageBo.setDeductCost(sysModelList.get(0).getModelPrice()); + chatMessageBo.setDeductCost(dall3.getModelPrice()); chatMessageBo.setTotalTokens(0); chatMessageService.insertByBo(chatMessageBo); return imageResponse.getData(); @@ -527,12 +575,9 @@ public class SseServiceImpl implements ISseService { chatMessageBo.setDeductCost(0.01); chatMessageBo.setTotalTokens(0); chatMessageService.insertByBo(chatMessageBo); - openAiStreamClient = chatConfig.getOpenAiStreamClient(); - List messageList = new ArrayList<>(); - - Message sysMessage = Message.builder().role(Message.Role.SYSTEM).content("你是一名翻译老师\n" + + Message sysMessage = Message.builder().role(Message.Role.SYSTEM).content("你是一位精通各国语言的翻译大师\n" + "\n" + "请将用户输入词语翻译成{" + translationRequest.getTargetLanguage() + "}\n" + "\n" + @@ -563,25 +608,21 @@ public class SseServiceImpl implements ISseService { @Override public SseEmitter ollamaChat(ChatRequest chatRequest) { + String[] parts = chatRequest.getModel().split("ollama-"); + SysModel sysModel = sysModelService.selectModelByName(parts[1]); final SseEmitter emitter = new SseEmitter(); - String host = "http://localhost:11434/"; - + String host = sysModel.getApiHost(); List msgList = chatRequest.getMessages(); Message message = msgList.get(msgList.size() - 1); - - OllamaAPI ollamaAPI = new OllamaAPI(host); - - ollamaAPI.setRequestTimeoutSeconds(100); - - OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("qwen2.5:7b"); - + OllamaAPI api = new OllamaAPI(host); + api.setRequestTimeoutSeconds(100); + OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(sysModel.getModelName()); OllamaChatRequestModel requestModel = builder .withMessage(OllamaChatMessageRole.USER, message.getContent().toString()) .build(); - - // 异步执行 Ollama API 调用 + // 异步执行 OllAma API 调用 CompletableFuture.runAsync(() -> { try { StringBuilder response = new StringBuilder(); @@ -595,14 +636,12 @@ public class SseServiceImpl implements ISseService { sendErrorEvent(emitter, e.getMessage()); } }; - ollamaAPI.chat(requestModel, streamHandler); + api.chat(requestModel, streamHandler); emitter.complete(); } catch (Exception e) { sendErrorEvent(emitter, e.getMessage()); } }); - - return emitter; } @@ -620,6 +659,4 @@ public class SseServiceImpl implements ISseService { ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion); return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString(); } - - } diff --git a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SysModelServiceImpl.java b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SysModelServiceImpl.java index 107f9e67..9ed3bd1f 100644 --- a/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SysModelServiceImpl.java +++ b/ruoyi-modules/ruoyi-system/src/main/java/org/ruoyi/system/service/impl/SysModelServiceImpl.java @@ -107,4 +107,11 @@ public class SysModelServiceImpl implements ISysModelService { } return baseMapper.deleteBatchIds(ids) > 0; } + + @Override + public SysModel selectModelByName(String modelName) { + return baseMapper.selectOne( + new LambdaQueryWrapper().eq(SysModel::getModelName, modelName) + ); + } }