From d8cc001eb849fc69adbcf20020ff901c90906c16 Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Mon, 10 Jul 2023 23:11:41 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E4=B8=8A=E4=B8=8B=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../custom/chatai/domain/ChatGPTMsg.java | 2 + .../chatai/handler/GPTChatAIHandler.java | 37 ++++++++++-------- .../custom/chatai/utils/ChatGPTUtils.java | 39 +++++++++---------- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java index 056657d..2d913f9 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java @@ -2,11 +2,13 @@ package com.abin.mallchat.custom.chatai.domain; import lombok.Getter; import lombok.Setter; +import lombok.ToString; import java.io.Serializable; @Getter @Setter +@ToString public class ChatGPTMsg implements Serializable { private String role; diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java index d4b5b55..33107bf 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java @@ -3,16 +3,15 @@ package com.abin.mallchat.custom.chatai.handler; import com.abin.mallchat.common.chat.domain.entity.Message; import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra; import com.abin.mallchat.common.common.constant.RedisKey; +import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; +import com.abin.mallchat.common.common.exception.FrequencyControlException; +import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil; import com.abin.mallchat.common.common.utils.DateUtils; import com.abin.mallchat.common.common.utils.RedisUtils; import com.abin.mallchat.custom.chatai.domain.ChatGPTContext; import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg; import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTContextBuilder; import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTMsgBuilder; -import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; -import com.abin.mallchat.common.common.exception.FrequencyControlException; -import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil; -import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO; import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties; import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils; import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; @@ -27,7 +26,6 @@ import java.util.List; import java.util.concurrent.TimeUnit; import static com.abin.mallchat.common.common.constant.RedisKey.USER_CHAT_CONTEXT; - import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER; @Slf4j @@ -72,15 +70,16 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { @Override protected String doChat(Message message) { - String content = message.getContent().replace("@" + AI_NAME, "").trim(); Long uid = message.getFromUid(); try { FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO(); - frequencyControlDTO.setKey(CHAT_FREQUENCY_PREFIX + ":" + uid); + frequencyControlDTO.setKey(RedisKey.getKey(CHAT_FREQUENCY_PREFIX) + ":" + uid); frequencyControlDTO.setUnit(TimeUnit.HOURS); frequencyControlDTO.setCount(chatGPTProperties.getLimit()); frequencyControlDTO.setTime(24); - return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid))); + return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, + frequencyControlDTO, // 限流参数 + () -> sendRequestToGPT(message)); } catch (FrequencyControlException e) { return "亲爱的,你今天找我聊了" + chatGPTProperties.getLimit() + "次了~人家累了~明天见"; } catch (Throwable e) { @@ -88,20 +87,24 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { } } - private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) { - String content = gptRequestDTO.getContent(); - Long roomId = message.getRoomId(); - Long chatNum; + + private String sendRequestToGPT(Message message) { + ChatGPTContext context = buildContext(message);// 构建上下文 + context = tailorContext(context);// 裁剪上下文 + log.info("context = {}", context); String text; - HttpResponse response = null; try { - response = ChatGPTUtils.create(chatGPTProperties.getKey()) + Response response = ChatGPTUtils.create(chatGPTProperties.getKey()) .proxyUrl(chatGPTProperties.getProxyUrl()) .model(chatGPTProperties.getModelName()) .timeout(chatGPTProperties.getTimeout()) - .prompt(content) + .maxTokens(chatGPTProperties.getMaxTokens()) + .message(context.getMsg()) .send(); text = ChatGPTUtils.parseText(response); + ChatGPTMsg chatGPTMsg = ChatGPTMsgBuilder.assistantMsg(text); + context.addMsg(chatGPTMsg); + RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, message.getFromUid(), message.getRoomId()), context, 1L, TimeUnit.HOURS); } catch (Exception e) { log.warn("gpt doChat warn:", e); text = "我累了,明天再聊吧"; @@ -119,13 +122,15 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { return tailorContext(context); } - private ChatGPTContext buildContext(Message message, String prompt) { + private ChatGPTContext buildContext(Message message) { + String prompt = message.getContent().replace("@" + AI_NAME, "").trim(); Long uid = message.getFromUid(); Long roomId = message.getRoomId(); ChatGPTContext chatGPTContext = RedisUtils.get(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), ChatGPTContext.class); if (chatGPTContext == null) { chatGPTContext = ChatGPTContextBuilder.initContext(uid, roomId); } + RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), chatGPTContext, 1L, TimeUnit.HOURS); chatGPTContext.addMsg(ChatGPTMsgBuilder.userMsg(prompt)); return chatGPTContext; } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java index 13adbff..b049bda 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java @@ -2,7 +2,6 @@ package com.abin.mallchat.custom.chatai.utils; import com.abin.mallchat.common.common.exception.BusinessException; import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg; -import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.knuddels.jtokkit.Encodings; import com.knuddels.jtokkit.api.Encoding; @@ -13,10 +12,9 @@ import okhttp3.*; import org.apache.commons.lang3.StringUtils; import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; @Slf4j public class ChatGPTUtils { @@ -85,17 +83,23 @@ public class ChatGPTUtils { return parseText(response.body().string()); } + public static String parseText(String body) { - log.info("body >>> " + body); - JSONObject jsonObject = JSONObject.parseObject(body); - JSONObject error = jsonObject.getJSONObject("error"); - if (error != null) { - log.error("error >>> " + error); +// log.info("body >>> " + body); + try { + return Arrays.stream(body.split("data:")) + .map(String::trim) + .filter(x -> StringUtils.isNotBlank(x) && !"[DONE]".endsWith(x)) + .map(x -> JSONObject.parseObject(x) + .getJSONArray("choices") + .getJSONObject(0) + .getJSONObject("delta") + .getString("content") + ).filter(Objects::nonNull).collect(Collectors.joining()); + } catch (Exception e) { + log.error("parseText error e:", e); return "闹脾气了,等会再试试吧~"; } - JSONArray choices = JSONObject.parseArray(jsonObject.getString("choices")); - JSONObject choice = choices.getJSONObject(0); - return choice.getJSONObject("message").getString("content"); } public ChatGPTUtils model(String model) { @@ -144,15 +148,6 @@ public class ChatGPTUtils { } public Response send() throws IOException { -// cn.hutool.json.JSONObject param = new cn.hutool.json.JSONObject(); -// param.set("model", model); -// param.set("messages", messages); -// param.set("max_tokens", maxTokens); -// param.set("temperature", temperature); -// param.set("top_p", topP); -// param.set("frequency_penalty", frequencyPenalty); -// param.set("presence_penalty", presencePenalty); -// log.info("headers >>> " + headers); OkHttpClient okHttpClient = new OkHttpClient() .newBuilder() .connectTimeout(10, TimeUnit.SECONDS) @@ -167,7 +162,9 @@ public class ChatGPTUtils { paramMap.put("top_p", topP); paramMap.put("frequency_penalty", frequencyPenalty); paramMap.put("presence_penalty", presencePenalty); + paramMap.put("stream", true); + log.info("paramMap >>> " + JSONObject.toJSONString(paramMap)); Request request = new Request.Builder() .url(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL) .addHeader("Content-Type", "application/json")