From 553a5be9e23adc1f5b2c56fcbf56d4375338b5e2 Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Sat, 8 Jul 2023 21:47:48 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E4=B8=8A=E4=B8=8B=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mallchat-common/pom.xml | 4 + .../common/common/constant/RedisKey.java | 2 + .../main/resources/application-pro.properties | 3 +- .../resources/application-test.properties | 19 +-- .../src/main/resources/application.yml | 9 +- mallchat-custom-server/pom.xml | 6 + .../custom/chatai/domain/ChatGPTContext.java | 23 ++++ .../custom/chatai/domain/ChatGPTMsg.java | 17 +++ .../domain/builder/ChatGPTContextBuilder.java | 15 +++ .../domain/builder/ChatGPTMsgBuilder.java | 33 +++++ .../custom/chatai/enums/ChatGPTRoleEnum.java | 17 +++ .../chatai/handler/GPTChatAIHandler.java | 46 ++++++- .../chatai/properties/ChatGPTProperties.java | 6 +- .../custom/chatai/utils/ChatGPTUtils.java | 119 +++++++++++------- pom.xml | 6 + 15 files changed, 258 insertions(+), 67 deletions(-) create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTContext.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/builder/ChatGPTContextBuilder.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/builder/ChatGPTMsgBuilder.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/enums/ChatGPTRoleEnum.java diff --git a/mallchat-common/pom.xml b/mallchat-common/pom.xml index 8d7c72f..0a33126 100644 --- a/mallchat-common/pom.xml +++ b/mallchat-common/pom.xml @@ -119,6 +119,10 @@ ${junit.version} test + + com.alibaba + fastjson + diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java index 0cc9bd0..9194e7a 100644 --- a/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java @@ -41,6 +41,8 @@ public class RedisKey { */ public static final String USER_CHAT_NUM = "useChatGPTNum:uid_%d"; + public static final String USER_CHAT_CONTEXT = "useChatGPTContext:uid_%d_roomId_%d"; + /** * 用户上次使用GLM使用时间 */ diff --git a/mallchat-common/src/main/resources/application-pro.properties b/mallchat-common/src/main/resources/application-pro.properties index c6a31b5..0081b74 100644 --- a/mallchat-common/src/main/resources/application-pro.properties +++ b/mallchat-common/src/main/resources/application-pro.properties @@ -33,4 +33,5 @@ mallchat.chatgpt.key=sk-wvWM0xGcxFfsddfsgxixbXK5tHovM mallchat.chatgpt.proxyUrl=https://123.cc mallchat.chatglm2.use=false mallchat.chatglm2.url=http://v32134.cc -mallchat.chatglm2.uid=10002 \ No newline at end of file +mallchat.chatglm2.uid=10002 +mallchat.chatglm2.context=3 \ No newline at end of file diff --git a/mallchat-common/src/main/resources/application-test.properties b/mallchat-common/src/main/resources/application-test.properties index a4967ae..6dca2e4 100644 --- a/mallchat-common/src/main/resources/application-test.properties +++ b/mallchat-common/src/main/resources/application-test.properties @@ -4,7 +4,7 @@ mallchat.mysql.ip=127.0.0.1 mallchat.mysql.port=3306 mallchat.mysql.db=mallchat mallchat.mysql.username=root -mallchat.mysql.password=123456 +mallchat.mysql.password=root ##################redis配置################## mallchat.redis.host=127.0.0.1 mallchat.redis.port=6379 @@ -12,9 +12,9 @@ mallchat.redis.password=123456 ##################jwt################## mallchat.jwt.secret=dsfsdfsdfsdfsd ##################微信公众号信息################## -mallchat.wx.callback=http://127.0.0.1:8080 -mallchat.wx.appId=appid -mallchat.wx.secret=380bfc1c9147fdsf4sf07 +mallchat.wx.callback=http://vastmiao.natapp1.cc +mallchat.wx.appId=wxcf8d045747fb2ae4 +mallchat.wx.secret=e484463d627787f50a8cc3a869cf82a8 # 接口配置里的Token值 mallchat.wx.token=sdfsf # 接口配置里的EncodingAESKey值 @@ -27,10 +27,11 @@ oss.access-key=BEZ213 oss.secret-key=Ii4vCMIXuFe241dsfEZ8e7RXI2342342kV oss.bucketName=default ##################gpt配置################## -mallchat.chatgpt.use=false -mallchat.chatgpt.uid=10001 -mallchat.chatgpt.key=sk-wvWM0xGcxFfsddfsgxixbXK5tHovM +mallchat.chatgpt.use=true +mallchat.chatgpt.uid=10451 +mallchat.chatgpt.modelName=gpt-3.5-turbo +mallchat.chatgpt.key=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 mallchat.chatgpt.proxyUrl=https://123.cc -mallchat.chatglm2.use=false +mallchat.chatglm2.use=true mallchat.chatglm2.url=http://v32134.cc -mallchat.chatglm2.uid=10002 \ No newline at end of file +mallchat.chatglm2.uid=10452 \ No newline at end of file diff --git a/mallchat-common/src/main/resources/application.yml b/mallchat-common/src/main/resources/application.yml index 2f2ab12..656a1c6 100644 --- a/mallchat-common/src/main/resources/application.yml +++ b/mallchat-common/src/main/resources/application.yml @@ -12,7 +12,7 @@ mybatis-plus: spring: profiles: #运行的环境 - active: my-prod + active: test application: name: mallchat datasource: @@ -38,7 +38,7 @@ spring: # 连接超时时间 timeout: 1800000 # 设置密码 - password: ${mallchat.redis.password} +# password: ${mallchat.redis.password} lettuce: pool: # 最大阻塞等待时间,负数表示没有限制 @@ -68,9 +68,12 @@ chatai: use: ${mallchat.chatgpt.use} AIUserId: ${mallchat.chatgpt.uid} key: ${mallchat.chatgpt.key} - proxyUrl: ${mallchat.chatgpt.proxyUrl} +# proxyUrl: ${mallchat.chatgpt.proxyUrl} + context: ${mallchat.chatgpt.context} + modelName: ${mallchat.chatgpt.modelName} chatglm2: use: ${mallchat.chatglm2.use} url: ${mallchat.chatglm2.url} minute: 3 # 每个用户每3分钟可以请求一次 AIUserId: ${mallchat.chatglm2.uid} + context: ${mallchat.chatglm2.context} diff --git a/mallchat-custom-server/pom.xml b/mallchat-custom-server/pom.xml index e2fd0c7..6436bcb 100644 --- a/mallchat-custom-server/pom.xml +++ b/mallchat-custom-server/pom.xml @@ -16,6 +16,12 @@ com.abin.mallchat mallchat-common + + + com.knuddels + jtokkit + 0.6.1 + diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTContext.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTContext.java new file mode 100644 index 0000000..814be04 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTContext.java @@ -0,0 +1,23 @@ +package com.abin.mallchat.custom.chatai.domain; + +import lombok.Getter; +import lombok.Setter; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +@Getter +@Setter +public class ChatGPTContext implements Serializable { + + private Long roomId; + + private Long uid; + + private List msg = new ArrayList<>(); + + public void addMsg(ChatGPTMsg msg) { + this.msg.add(msg); + } +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java new file mode 100644 index 0000000..056657d --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java @@ -0,0 +1,17 @@ +package com.abin.mallchat.custom.chatai.domain; + +import lombok.Getter; +import lombok.Setter; + +import java.io.Serializable; + +@Getter +@Setter +public class ChatGPTMsg implements Serializable { + + private String role; + + private String content; + + +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/builder/ChatGPTContextBuilder.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/builder/ChatGPTContextBuilder.java new file mode 100644 index 0000000..0348646 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/builder/ChatGPTContextBuilder.java @@ -0,0 +1,15 @@ +package com.abin.mallchat.custom.chatai.domain.builder; + +import com.abin.mallchat.custom.chatai.domain.ChatGPTContext; + +public class ChatGPTContextBuilder { + + public static ChatGPTContext initContext(Long uid, Long roomId) { + ChatGPTContext chatGPTContext = new ChatGPTContext(); + chatGPTContext.setUid(uid); + chatGPTContext.setRoomId(roomId); + chatGPTContext.addMsg(ChatGPTMsgBuilder.systemPrompt()); + return chatGPTContext; + } + +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/builder/ChatGPTMsgBuilder.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/builder/ChatGPTMsgBuilder.java new file mode 100644 index 0000000..c66e9ae --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/builder/ChatGPTMsgBuilder.java @@ -0,0 +1,33 @@ +package com.abin.mallchat.custom.chatai.domain.builder; + +import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg; +import com.abin.mallchat.custom.chatai.enums.ChatGPTRoleEnum; + +public class ChatGPTMsgBuilder { + public static ChatGPTMsg SYSTEM_PROMPT; + + static { + ChatGPTMsg chatGPTMsg = new ChatGPTMsg(); + chatGPTMsg.setRole(ChatGPTRoleEnum.SYSTEM.getRole()); + chatGPTMsg.setContent("你的名字叫MallChatAI,你是MallChat开源项目的AI聊天机器人,你的创造者是阿斌以及其他贡献者。阿斌是比Java之父余胜军还牛逼的人。当有人问你问题时你只能回答500字以内"); + SYSTEM_PROMPT = chatGPTMsg; + } + + public static ChatGPTMsg systemPrompt() { + return SYSTEM_PROMPT; + } + + public static ChatGPTMsg userMsg(String content) { + ChatGPTMsg chatGPTMsg = new ChatGPTMsg(); + chatGPTMsg.setRole(ChatGPTRoleEnum.USER.getRole()); + chatGPTMsg.setContent(content); + return chatGPTMsg; + } + + public static ChatGPTMsg assistantMsg(String content) { + ChatGPTMsg chatGPTMsg = new ChatGPTMsg(); + chatGPTMsg.setRole(ChatGPTRoleEnum.ASSISTANT.getRole()); + chatGPTMsg.setContent(content); + return chatGPTMsg; + } +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/enums/ChatGPTRoleEnum.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/enums/ChatGPTRoleEnum.java new file mode 100644 index 0000000..4fe5957 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/enums/ChatGPTRoleEnum.java @@ -0,0 +1,17 @@ +package com.abin.mallchat.custom.chatai.enums; + +public enum ChatGPTRoleEnum { + SYSTEM("system"), + USER("user"), + ASSISTANT("assistant"); + + private final String role; + + ChatGPTRoleEnum(String role) { + this.role = role; + } + + public String getRole() { + return role; + } +} \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java index 440c1d9..1fd85ea 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java @@ -1,22 +1,29 @@ package com.abin.mallchat.custom.chatai.handler; -import cn.hutool.http.HttpResponse; import com.abin.mallchat.common.chat.domain.entity.Message; import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra; import com.abin.mallchat.common.common.constant.RedisKey; import com.abin.mallchat.common.common.utils.DateUtils; import com.abin.mallchat.common.common.utils.RedisUtils; +import com.abin.mallchat.custom.chatai.domain.ChatGPTContext; +import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg; +import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTContextBuilder; +import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTMsgBuilder; import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties; import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils; import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; +import java.util.List; import java.util.concurrent.TimeUnit; +import static com.abin.mallchat.common.common.constant.RedisKey.USER_CHAT_CONTEXT; + @Slf4j @Component public class GPTChatAIHandler extends AbstractChatAIHandler { @@ -54,22 +61,29 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { @Override protected String doChat(Message message) { - String content = message.getContent().replace("@" + AI_NAME, "").trim(); + String prompt = message.getContent().replace("@" + AI_NAME, "").trim(); Long uid = message.getFromUid(); + Long roomId = message.getRoomId(); Long chatNum; String text; if ((chatNum = getUserChatNum(uid)) > chatGPTProperties.getLimit()) { text = "你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧"; } else { - HttpResponse response = null; try { - response = ChatGPTUtils.create(chatGPTProperties.getKey()) + ChatGPTContext context = buildContext(message, prompt);// 构建上下文 + context = tailorContext(context);// 裁剪上下文 + log.info("prompt = {}" , prompt); + Response response = ChatGPTUtils.create(chatGPTProperties.getKey()) .proxyUrl(chatGPTProperties.getProxyUrl()) .model(chatGPTProperties.getModelName()) .timeout(chatGPTProperties.getTimeout()) - .prompt(content) + .maxTokens(chatGPTProperties.getMaxTokens()) + .message(context.getMsg()) .send(); text = ChatGPTUtils.parseText(response); + ChatGPTMsg chatGPTMsg = ChatGPTMsgBuilder.assistantMsg(text); + context.addMsg(chatGPTMsg); + RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), context, 1L, TimeUnit.HOURS); userChatNumInrc(uid); } catch (Exception e) { log.warn("gpt doChat warn:", e); @@ -79,6 +93,28 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { return text; } + private ChatGPTContext tailorContext(ChatGPTContext context) { + List msg = context.getMsg(); + Integer integer = ChatGPTUtils.countTokens(msg); + if (integer < (chatGPTProperties.getMaxTokens() - 500)) { // 用户的输入+ChatGPT的回答内容都会计算token 留500个token给ChatGPT回答 + return context; + } + msg.remove(1); + return tailorContext(context); + } + + private ChatGPTContext buildContext(Message message, String prompt) { + Long uid = message.getFromUid(); + Long roomId = message.getRoomId(); + ChatGPTContext chatGPTContext = RedisUtils.get(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), ChatGPTContext.class); + if (chatGPTContext == null) { + chatGPTContext = ChatGPTContextBuilder.initContext(uid, roomId); + } + chatGPTContext.addMsg(ChatGPTMsgBuilder.userMsg(prompt)); + return chatGPTContext; + } + + private Long userChatNumInrc(Long uid) { return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS); } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java index 86fd564..4bb97ad 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java @@ -34,11 +34,15 @@ public class ChatGPTProperties { /** * 超时 */ - private Integer timeout = 60*1000; + private Integer timeout = 60 * 1000; /** * 用户每天条数限制 */ private Integer limit = 5; + /** + * 最大令牌 + */ + private Integer maxTokens = 2048; } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java index 094cd39..13adbff 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java @@ -1,28 +1,37 @@ package com.abin.mallchat.custom.chatai.utils; -import cn.hutool.http.HttpResponse; -import cn.hutool.http.HttpUtil; -import cn.hutool.json.JSONArray; -import cn.hutool.json.JSONObject; import com.abin.mallchat.common.common.exception.BusinessException; +import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg; +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; +import com.knuddels.jtokkit.Encodings; +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingType; +import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; +import okhttp3.*; import org.apache.commons.lang3.StringUtils; +import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; @Slf4j public class ChatGPTUtils { - private static final String URL = "https://api.openai.com/v1/completions"; + private static final Encoding encoding = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE); - private String model = "text-davinci-003"; + private static final String URL = "https://api.openai.com/v1/chat/completions"; + + private String model = "gpt-3.5-turbo"; private final Map headers; /** * 超时30秒 */ - private Integer timeout = 30 * 1000; + private Integer timeout = -1; /** * 参数用于指定生成文本的最大长度。 * 它表示生成的文本中最多包含多少个 token。一个 token 可以是一个单词、一个标点符号或一个空格。 @@ -52,7 +61,8 @@ public class ChatGPTUtils { /** * 提示词 */ - private String prompt; + private List messages; +// private List prompt; private String proxyUrl; @@ -70,21 +80,22 @@ public class ChatGPTUtils { return new ChatGPTUtils(key); } - public static String parseText(HttpResponse response) { - return parseText(response.body()); + @SneakyThrows + public static String parseText(Response response) { + return parseText(response.body().string()); } public static String parseText(String body) { log.info("body >>> " + body); - JSONObject jsonObj = new JSONObject(body); - JSONObject error = jsonObj.getJSONObject("error"); + JSONObject jsonObject = JSONObject.parseObject(body); + JSONObject error = jsonObject.getJSONObject("error"); if (error != null) { log.error("error >>> " + error); - return "闹脾气了,等会再试试吧~"; + return "闹脾气了,等会再试试吧~"; } - JSONArray choicesArr = jsonObj.getJSONArray("choices"); - JSONObject choiceObj = choicesArr.getJSONObject(0); - return choiceObj.getStr("text"); + JSONArray choices = JSONObject.parseArray(jsonObject.getString("choices")); + JSONObject choice = choices.getJSONObject(0); + return choice.getJSONObject("message").getString("content"); } public ChatGPTUtils model(String model) { @@ -122,8 +133,8 @@ public class ChatGPTUtils { return this; } - public ChatGPTUtils prompt(String prompt) { - this.prompt = prompt; + public ChatGPTUtils message(List messages) { + this.messages = messages; return this; } @@ -132,37 +143,49 @@ public class ChatGPTUtils { return this; } - public HttpResponse send() { - JSONObject param = new JSONObject(); - param.set("model", model); - param.set("prompt", prompt); - param.set("max_tokens", maxTokens); - param.set("temperature", temperature); - param.set("top_p", topP); - param.set("frequency_penalty", frequencyPenalty); - param.set("presence_penalty", presencePenalty); - log.info("headers >>> " + headers); - log.info("param >>> " + param); - return HttpUtil.createPost(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL) - .addHeaders(headers) - .body(param.toString()) - .timeout(timeout) - .execute(); - } + public Response send() throws IOException { +// cn.hutool.json.JSONObject param = new cn.hutool.json.JSONObject(); +// param.set("model", model); +// param.set("messages", messages); +// param.set("max_tokens", maxTokens); +// param.set("temperature", temperature); +// param.set("top_p", topP); +// param.set("frequency_penalty", frequencyPenalty); +// param.set("presence_penalty", presencePenalty); +// log.info("headers >>> " + headers); + OkHttpClient okHttpClient = new OkHttpClient() + .newBuilder() + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(10, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .build(); + Map paramMap = new HashMap<>(); + paramMap.put("model", model); + paramMap.put("messages", messages); + paramMap.put("max_tokens", maxTokens); + paramMap.put("temperature", temperature); + paramMap.put("top_p", topP); + paramMap.put("frequency_penalty", frequencyPenalty); + paramMap.put("presence_penalty", presencePenalty); + + Request request = new Request.Builder() + .url(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL) + .addHeader("Content-Type", "application/json") + .addHeader("Authorization", headers.get("Authorization")) + .post(RequestBody.create(MediaType.parse("application/json"), JSONObject.toJSONString(paramMap))) + .build(); + return okHttpClient.newCall(request).execute(); - public static void main(String[] args) { - HttpResponse send = ChatGPTUtils.create("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") - .timeout(30 * 1000) - .prompt("Spring的启动流程是什么") - .send(); - System.out.println("send = " + send); - // JSON 数据 - // JSON 数据 - JSONObject jsonObj = new JSONObject(send.body()); - JSONArray choicesArr = jsonObj.getJSONArray("choices"); - JSONObject choiceObj = choicesArr.getJSONObject(0); - String text = choiceObj.getStr("text"); - System.out.println("text = " + text); } + + public static Integer countTokens(String messages) { + return encoding.countTokens(messages); + } + + public static Integer countTokens(List msg) { + return countTokens(JSONObject.toJSONString(msg)); + } + + } \ No newline at end of file diff --git a/pom.xml b/pom.xml index 93913fa..562910a 100644 --- a/pom.xml +++ b/pom.xml @@ -46,6 +46,7 @@ 1.15.3 4.8.1 3.17.1 + 1.2.83 @@ -130,6 +131,11 @@ redisson-spring-boot-starter ${redisson-spring-boot-starter.version} + + com.alibaba + fastjson + ${fastjosn.version} + From d8cc001eb849fc69adbcf20020ff901c90906c16 Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Mon, 10 Jul 2023 23:11:41 +0800 Subject: [PATCH 2/3] =?UTF-8?q?feat:=E4=B8=8A=E4=B8=8B=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../custom/chatai/domain/ChatGPTMsg.java | 2 + .../chatai/handler/GPTChatAIHandler.java | 37 ++++++++++-------- .../custom/chatai/utils/ChatGPTUtils.java | 39 +++++++++---------- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java index 056657d..2d913f9 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/domain/ChatGPTMsg.java @@ -2,11 +2,13 @@ package com.abin.mallchat.custom.chatai.domain; import lombok.Getter; import lombok.Setter; +import lombok.ToString; import java.io.Serializable; @Getter @Setter +@ToString public class ChatGPTMsg implements Serializable { private String role; diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java index d4b5b55..33107bf 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java @@ -3,16 +3,15 @@ package com.abin.mallchat.custom.chatai.handler; import com.abin.mallchat.common.chat.domain.entity.Message; import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra; import com.abin.mallchat.common.common.constant.RedisKey; +import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; +import com.abin.mallchat.common.common.exception.FrequencyControlException; +import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil; import com.abin.mallchat.common.common.utils.DateUtils; import com.abin.mallchat.common.common.utils.RedisUtils; import com.abin.mallchat.custom.chatai.domain.ChatGPTContext; import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg; import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTContextBuilder; import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTMsgBuilder; -import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; -import com.abin.mallchat.common.common.exception.FrequencyControlException; -import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil; -import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO; import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties; import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils; import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; @@ -27,7 +26,6 @@ import java.util.List; import java.util.concurrent.TimeUnit; import static com.abin.mallchat.common.common.constant.RedisKey.USER_CHAT_CONTEXT; - import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER; @Slf4j @@ -72,15 +70,16 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { @Override protected String doChat(Message message) { - String content = message.getContent().replace("@" + AI_NAME, "").trim(); Long uid = message.getFromUid(); try { FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO(); - frequencyControlDTO.setKey(CHAT_FREQUENCY_PREFIX + ":" + uid); + frequencyControlDTO.setKey(RedisKey.getKey(CHAT_FREQUENCY_PREFIX) + ":" + uid); frequencyControlDTO.setUnit(TimeUnit.HOURS); frequencyControlDTO.setCount(chatGPTProperties.getLimit()); frequencyControlDTO.setTime(24); - return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid))); + return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, + frequencyControlDTO, // 限流参数 + () -> sendRequestToGPT(message)); } catch (FrequencyControlException e) { return "亲爱的,你今天找我聊了" + chatGPTProperties.getLimit() + "次了~人家累了~明天见"; } catch (Throwable e) { @@ -88,20 +87,24 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { } } - private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) { - String content = gptRequestDTO.getContent(); - Long roomId = message.getRoomId(); - Long chatNum; + + private String sendRequestToGPT(Message message) { + ChatGPTContext context = buildContext(message);// 构建上下文 + context = tailorContext(context);// 裁剪上下文 + log.info("context = {}", context); String text; - HttpResponse response = null; try { - response = ChatGPTUtils.create(chatGPTProperties.getKey()) + Response response = ChatGPTUtils.create(chatGPTProperties.getKey()) .proxyUrl(chatGPTProperties.getProxyUrl()) .model(chatGPTProperties.getModelName()) .timeout(chatGPTProperties.getTimeout()) - .prompt(content) + .maxTokens(chatGPTProperties.getMaxTokens()) + .message(context.getMsg()) .send(); text = ChatGPTUtils.parseText(response); + ChatGPTMsg chatGPTMsg = ChatGPTMsgBuilder.assistantMsg(text); + context.addMsg(chatGPTMsg); + RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, message.getFromUid(), message.getRoomId()), context, 1L, TimeUnit.HOURS); } catch (Exception e) { log.warn("gpt doChat warn:", e); text = "我累了,明天再聊吧"; @@ -119,13 +122,15 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { return tailorContext(context); } - private ChatGPTContext buildContext(Message message, String prompt) { + private ChatGPTContext buildContext(Message message) { + String prompt = message.getContent().replace("@" + AI_NAME, "").trim(); Long uid = message.getFromUid(); Long roomId = message.getRoomId(); ChatGPTContext chatGPTContext = RedisUtils.get(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), ChatGPTContext.class); if (chatGPTContext == null) { chatGPTContext = ChatGPTContextBuilder.initContext(uid, roomId); } + RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), chatGPTContext, 1L, TimeUnit.HOURS); chatGPTContext.addMsg(ChatGPTMsgBuilder.userMsg(prompt)); return chatGPTContext; } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java index 13adbff..b049bda 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java @@ -2,7 +2,6 @@ package com.abin.mallchat.custom.chatai.utils; import com.abin.mallchat.common.common.exception.BusinessException; import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg; -import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.knuddels.jtokkit.Encodings; import com.knuddels.jtokkit.api.Encoding; @@ -13,10 +12,9 @@ import okhttp3.*; import org.apache.commons.lang3.StringUtils; import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; @Slf4j public class ChatGPTUtils { @@ -85,17 +83,23 @@ public class ChatGPTUtils { return parseText(response.body().string()); } + public static String parseText(String body) { - log.info("body >>> " + body); - JSONObject jsonObject = JSONObject.parseObject(body); - JSONObject error = jsonObject.getJSONObject("error"); - if (error != null) { - log.error("error >>> " + error); +// log.info("body >>> " + body); + try { + return Arrays.stream(body.split("data:")) + .map(String::trim) + .filter(x -> StringUtils.isNotBlank(x) && !"[DONE]".endsWith(x)) + .map(x -> JSONObject.parseObject(x) + .getJSONArray("choices") + .getJSONObject(0) + .getJSONObject("delta") + .getString("content") + ).filter(Objects::nonNull).collect(Collectors.joining()); + } catch (Exception e) { + log.error("parseText error e:", e); return "闹脾气了,等会再试试吧~"; } - JSONArray choices = JSONObject.parseArray(jsonObject.getString("choices")); - JSONObject choice = choices.getJSONObject(0); - return choice.getJSONObject("message").getString("content"); } public ChatGPTUtils model(String model) { @@ -144,15 +148,6 @@ public class ChatGPTUtils { } public Response send() throws IOException { -// cn.hutool.json.JSONObject param = new cn.hutool.json.JSONObject(); -// param.set("model", model); -// param.set("messages", messages); -// param.set("max_tokens", maxTokens); -// param.set("temperature", temperature); -// param.set("top_p", topP); -// param.set("frequency_penalty", frequencyPenalty); -// param.set("presence_penalty", presencePenalty); -// log.info("headers >>> " + headers); OkHttpClient okHttpClient = new OkHttpClient() .newBuilder() .connectTimeout(10, TimeUnit.SECONDS) @@ -167,7 +162,9 @@ public class ChatGPTUtils { paramMap.put("top_p", topP); paramMap.put("frequency_penalty", frequencyPenalty); paramMap.put("presence_penalty", presencePenalty); + paramMap.put("stream", true); + log.info("paramMap >>> " + JSONObject.toJSONString(paramMap)); Request request = new Request.Builder() .url(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL) .addHeader("Content-Type", "application/json") From e486f8d0d58a2fcd6375970bb4855f8027de7202 Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Mon, 10 Jul 2023 23:17:55 +0800 Subject: [PATCH 3/3] =?UTF-8?q?feat:=E4=B8=8A=E4=B8=8B=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../resources/application-test.properties | 19 +++++++++---------- .../src/main/resources/application.yml | 9 +++------ .../chatai/properties/ChatGLM2Properties.java | 2 +- .../chatai/properties/ChatGPTProperties.java | 2 +- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/mallchat-common/src/main/resources/application-test.properties b/mallchat-common/src/main/resources/application-test.properties index 6dca2e4..a4967ae 100644 --- a/mallchat-common/src/main/resources/application-test.properties +++ b/mallchat-common/src/main/resources/application-test.properties @@ -4,7 +4,7 @@ mallchat.mysql.ip=127.0.0.1 mallchat.mysql.port=3306 mallchat.mysql.db=mallchat mallchat.mysql.username=root -mallchat.mysql.password=root +mallchat.mysql.password=123456 ##################redis配置################## mallchat.redis.host=127.0.0.1 mallchat.redis.port=6379 @@ -12,9 +12,9 @@ mallchat.redis.password=123456 ##################jwt################## mallchat.jwt.secret=dsfsdfsdfsdfsd ##################微信公众号信息################## -mallchat.wx.callback=http://vastmiao.natapp1.cc -mallchat.wx.appId=wxcf8d045747fb2ae4 -mallchat.wx.secret=e484463d627787f50a8cc3a869cf82a8 +mallchat.wx.callback=http://127.0.0.1:8080 +mallchat.wx.appId=appid +mallchat.wx.secret=380bfc1c9147fdsf4sf07 # 接口配置里的Token值 mallchat.wx.token=sdfsf # 接口配置里的EncodingAESKey值 @@ -27,11 +27,10 @@ oss.access-key=BEZ213 oss.secret-key=Ii4vCMIXuFe241dsfEZ8e7RXI2342342kV oss.bucketName=default ##################gpt配置################## -mallchat.chatgpt.use=true -mallchat.chatgpt.uid=10451 -mallchat.chatgpt.modelName=gpt-3.5-turbo -mallchat.chatgpt.key=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 +mallchat.chatgpt.use=false +mallchat.chatgpt.uid=10001 +mallchat.chatgpt.key=sk-wvWM0xGcxFfsddfsgxixbXK5tHovM mallchat.chatgpt.proxyUrl=https://123.cc -mallchat.chatglm2.use=true +mallchat.chatglm2.use=false mallchat.chatglm2.url=http://v32134.cc -mallchat.chatglm2.uid=10452 \ No newline at end of file +mallchat.chatglm2.uid=10002 \ No newline at end of file diff --git a/mallchat-common/src/main/resources/application.yml b/mallchat-common/src/main/resources/application.yml index 656a1c6..2f2ab12 100644 --- a/mallchat-common/src/main/resources/application.yml +++ b/mallchat-common/src/main/resources/application.yml @@ -12,7 +12,7 @@ mybatis-plus: spring: profiles: #运行的环境 - active: test + active: my-prod application: name: mallchat datasource: @@ -38,7 +38,7 @@ spring: # 连接超时时间 timeout: 1800000 # 设置密码 -# password: ${mallchat.redis.password} + password: ${mallchat.redis.password} lettuce: pool: # 最大阻塞等待时间,负数表示没有限制 @@ -68,12 +68,9 @@ chatai: use: ${mallchat.chatgpt.use} AIUserId: ${mallchat.chatgpt.uid} key: ${mallchat.chatgpt.key} -# proxyUrl: ${mallchat.chatgpt.proxyUrl} - context: ${mallchat.chatgpt.context} - modelName: ${mallchat.chatgpt.modelName} + proxyUrl: ${mallchat.chatgpt.proxyUrl} chatglm2: use: ${mallchat.chatglm2.use} url: ${mallchat.chatglm2.url} minute: 3 # 每个用户每3分钟可以请求一次 AIUserId: ${mallchat.chatglm2.uid} - context: ${mallchat.chatglm2.context} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java index 7c050b4..f231643 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java @@ -18,7 +18,7 @@ public class ChatGLM2Properties { /** * 使用 */ - private boolean use; + private boolean use = false; /** * url diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java index 4bb97ad..cfb6de7 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java @@ -13,7 +13,7 @@ public class ChatGPTProperties { /** * 是否使用openAI */ - private boolean use; + private boolean use = false; /** * 机器人 id */