From fb90e98a4e13c18871a3882fcd8fb82ca100da26 Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Thu, 29 Jun 2023 23:52:48 +0800 Subject: [PATCH 1/5] openAI --- .../common/common/constant/RedisKey.java | 5 + .../common/common/utils/DateUtils.java | 16 ++ .../src/main/resources/application.yml | 13 +- .../chat/service/impl/ChatServiceImpl.java | 2 + .../service/strategy/msg/TextMsgHandler.java | 2 +- .../custom/openai/enums/OpenAIModelEnums.java | 92 ++++++++++ .../custom/openai/event/OpenAIEvent.java | 14 ++ .../openai/event/listener/OpenAIListener.java | 63 +++++++ .../custom/openai/service/IOpenAIService.java | 11 ++ .../service/impl/OpenAIServiceImpl.java | 171 ++++++++++++++++++ .../custom/openai/utils/OpenAIUtils.java | 158 ++++++++++++++++ 11 files changed, 542 insertions(+), 5 deletions(-) create mode 100644 mallchat-common/src/main/java/com/abin/mallchat/common/common/utils/DateUtils.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java index faa6bd2..a86d1c4 100644 --- a/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java @@ -36,6 +36,11 @@ public class RedisKey { */ public static final String USER_SUMMARY_STRING = "userSummary:uid_%d"; + /** + * 用户AI聊天次数 + */ + public static final String USER_CHAT_NUM = "userAIChatNum:uid_%d"; + public static String getKey(String key, Object... objects) { return BASE_KEY + String.format(key, objects); } diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/utils/DateUtils.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/utils/DateUtils.java new file mode 100644 index 0000000..ecc2a59 --- /dev/null +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/utils/DateUtils.java @@ -0,0 +1,16 @@ +package com.abin.mallchat.common.common.utils; + +import java.util.Calendar; +import java.util.Date; + +public class DateUtils extends org.apache.commons.lang3.time.DateUtils { + public static Long getEndTimeByToday() { + Calendar instance = Calendar.getInstance(); + Date now = new Date(); + instance.setTime(now); + instance.set(Calendar.HOUR_OF_DAY, 23); + instance.set(Calendar.MINUTE, 59); + instance.set(Calendar.SECOND, 59); + return instance.getTime().getTime() - now.getTime(); + } +} diff --git a/mallchat-common/src/main/resources/application.yml b/mallchat-common/src/main/resources/application.yml index 53b6b9b..0293bb6 100644 --- a/mallchat-common/src/main/resources/application.yml +++ b/mallchat-common/src/main/resources/application.yml @@ -12,7 +12,7 @@ mybatis-plus: spring: profiles: #运行的环境 - active: my-prod + active: test application: name: mallchat datasource: @@ -37,8 +37,8 @@ spring: database: 0 # 连接超时时间 timeout: 1800000 - # 设置密码 - password: ${mallchat.redis.password} +# # 设置密码 +# password: ${mallchat.redis.password} lettuce: pool: # 最大阻塞等待时间,负数表示没有限制 @@ -62,4 +62,9 @@ wx: - appId: ${mallchat.wx.appId} # 第一个公众号的appid secret: ${mallchat.wx.secret} # 公众号的appsecret token: ${mallchat.wx.token} # 接口配置里的Token值 - aesKey: ${mallchat.wx.aesKey} # 接口配置里的EncodingAESKey值 \ No newline at end of file + aesKey: ${mallchat.wx.aesKey} # 接口配置里的EncodingAESKey值 +openai: + use-openai: true + ai-user-id: xxxxx + key: xxxxxxx + proxy-url: https://xxxxxxx \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java index adbb1d3..177dbfa 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java @@ -38,6 +38,7 @@ import com.abin.mallchat.custom.chat.service.strategy.mark.MsgMarkFactory; import com.abin.mallchat.custom.chat.service.strategy.msg.AbstractMsgHandler; import com.abin.mallchat.custom.chat.service.strategy.msg.MsgHandlerFactory; import com.abin.mallchat.custom.chat.service.strategy.msg.RecallMsgHandler; +import com.abin.mallchat.custom.openai.event.OpenAIEvent; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -95,6 +96,7 @@ public class ChatServiceImpl implements ChatService { msgHandler.saveMsg(insert, request); //发布消息发送事件 applicationEventPublisher.publishEvent(new MessageSendEvent(this, insert.getId())); + applicationEventPublisher.publishEvent(new OpenAIEvent(this, insert.getId())); return insert.getId(); } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java index 461ed88..5bf8b67 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java @@ -66,7 +66,7 @@ public class TextMsgHandler extends AbstractMsgHandler { AssertUtil.equal(replyMsg.getRoomId(), request.getRoomId(), "只能回复相同会话内的消息"); } if (CollectionUtil.isNotEmpty(body.getAtUidList())) { - AssertUtil.isTrue(body.getAtUidList().size() > 10, "一次别艾特这么多人"); + AssertUtil.isFalse(body.getAtUidList().size() > 10, "一次别艾特这么多人"); List atUidList = body.getAtUidList(); Map batch = userInfoCache.getBatch(atUidList); AssertUtil.equal(atUidList.size(), batch.values().size(), "@用户不存在"); diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java new file mode 100644 index 0000000..194c3df --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java @@ -0,0 +1,92 @@ +package com.abin.mallchat.custom.openai.enums; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.util.Arrays; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +@AllArgsConstructor +@Getter +public enum OpenAIModelEnums { + // chat + GPT_35_TURBO("gpt-3.5-turbo", 3, 40000), + GPT_35_TURBO_0301("gpt-3.5-turbo-0301", 3, 40000), + GPT_35_TURBO_0613("gpt-3.5-turbo-0613", 3, 40000), + GPT_35_TURBO_16K("gpt-3.5-turbo-16k", 3, 40000), + GPT_35_TURBO_16K_0613("gpt-3.5-turbo-16k-0613", 3, 40000), + // text + ADA("ada", 60, 150000), + ADA_CODE_SEARCH_CODE("ada-code-search-code", 60, 150000), + ADA_CODE_SEARCH_TEXT("ada-code-search-text", 60, 150000), + ADA_SEARCH_DOCUMENT("ada-search-document", 60, 150000), + ADA_SEARCH_QUERY("ada-search-query", 60, 150000), + ADA_SIMILARITY("ada-similarity", 60, 150000), + BABBAGE("babbage", 60, 150000), + BABBAGE_CODE_SEARCH_CODE("babbage-code-search-code", 60, 150000), + BABBAGE_CODE_SEARCH_TEXT("babbage-code-search-text", 60, 150000), + BABBAGE_SEARCH_DOCUMENT("babbage-search-document", 60, 150000), + BABBAGE_SEARCH_QUERY("babbage-search-query", 60, 150000), + BABBAGE_SIMILARITY("babbage-similarity", 60, 150000), + CODE_DAVINCI_EDIT_001("code-davinci-edit-001", 20, 150000), + CODE_SEARCH_ADA_CODE_001("code-search-ada-code-001", 60, 150000), + CODE_SEARCH_ADA_TEXT_001("code-search-ada-text-001", 60, 150000), + CODE_SEARCH_BABBAGE_CODE_001("code-search-babbage-code-001", 60, 150000), + CODE_SEARCH_BABBAGE_TEXT_001("code-search-babbage-text-001", 60, 150000), + CURIE("curie", 60, 150000), + CURIE_INSTRUCT_BETA("curie-instruct-beta", 60, 150000), + CURIE_SEARCH_DOCUMENT("curie-search-document", 60, 150000), + CURIE_SEARCH_QUERY("curie-search-query", 60, 150000), + CURIE_SIMILARITY("curie-similarity", 60, 150000), + DAVINCI("davinci", 60, 150000), + DAVINCI_INSTRUCT_BETA("davinci-instruct-beta", 60, 150000), + DAVINCI_SEARCH_DOCUMENT("davinci-search-document", 60, 150000), + DAVINCI_SEARCH_QUERY("davinci-search-query", 60, 150000), + DAVINCI_SIMILARITY("davinci-similarity", 60, 150000), + TEXT_ADA_001("text-ada-001", 60, 150000), + TEXT_BABBAGE_001("text-babbage-001", 60, 150000), + TEXT_CURIE_001("text-curie-001", 60, 150000), + TEXT_DAVINCI_001("text-davinci-001", 60, 150000), + TEXT_DAVINCI_002("text-davinci-002", 60, 150000), + TEXT_DAVINCI_003("text-davinci-003", 60, 150000), + TEXT_DAVINCI_EDIT_001("text-davinci-edit-001", 20, 150000), + TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", 60, 150000), + TEXT_SEARCH_ADA_DOC_001("text-search-ada-doc-001", 60, 150000), + TEXT_SEARCH_ADA_QUERY_001("text-search-ada-query-001", 60, 150000), + TEXT_SEARCH_BABBAGE_DOC_001("text-search-babbage-doc-001", 60, 150000), + TEXT_SEARCH_BABBAGE_QUERY_001("text-search-babbage-query-001", 60, 150000), + TEXT_SEARCH_CURIE_DOC_001("text-search-curie-doc-001", 60, 150000), + TEXT_SEARCH_CURIE_QUERY_001("text-search-curie-query-001", 60, 150000), + TEXT_SEARCH_DAVINCI_DOC_001("text-search-davinci-doc-001", 60, 150000), + TEXT_SEARCH_DAVINCI_QUERY_001("text-search-davinci-query-001", 60, 150000), + TEXT_SIMILARITY_ADA_001("text-similarity-ada-001", 60, 150000), + TEXT_SIMILARITY_BABBAGE_001("text-similarity-babbage-001", 60, 150000), + TEXT_SIMILARITY_CURIE_001("text-similarity-curie-001", 60, 150000), + TEXT_SIMILARITY_DAVINCI_001("text-similarity-davinci-001", 60, 150000); + + /** + * 名字 + */ + private final String name; + /** + * 每分钟请求数 + */ + private final Integer RPM; + /** + * 每分钟令牌数 + */ + private final Integer TPM; + + private static final Map cache; + + static { + cache = Arrays.stream(OpenAIModelEnums.values()).collect(Collectors.toMap(OpenAIModelEnums::getName, Function.identity())); + } + + public static OpenAIModelEnums of(String name) { + return cache.get(name); + } + +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java new file mode 100644 index 0000000..493d54e --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java @@ -0,0 +1,14 @@ +package com.abin.mallchat.custom.openai.event; + +import lombok.Getter; +import org.springframework.context.ApplicationEvent; + +@Getter +public class OpenAIEvent extends ApplicationEvent { + private Long msgId; + + public OpenAIEvent(Object source, Long msgId) { + super(source); + this.msgId = msgId; + } +} \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java new file mode 100644 index 0000000..160a1f1 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java @@ -0,0 +1,63 @@ +package com.abin.mallchat.custom.openai.event.listener; + +import com.abin.mallchat.common.chat.dao.MessageDao; +import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.custom.openai.event.OpenAIEvent; +import com.abin.mallchat.custom.openai.service.IOpenAIService; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.jetbrains.annotations.NotNull; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; +import org.springframework.transaction.event.TransactionalEventListener; + +import static com.abin.mallchat.custom.openai.service.impl.OpenAIServiceImpl.MALL_CHAT_AI_NAME; + +/** + * 是否AI回复监听器 + * + * @author zhaoyuhang + * @date 2023/06/29 + */ +@Slf4j +@Component +public class OpenAIListener { + @Autowired + private IOpenAIService openAIService; + @Autowired + private MessageDao messageDao; + + @TransactionalEventListener(classes = OpenAIEvent.class, fallbackExecution = true) + public void notifyAllOnline(@NotNull OpenAIEvent event) { + Message message = messageDao.getById(event.getMsgId()); + if (ATedAI(message)) { + openAIService.chat(message); + } + } + + /** + * @return boolean + * @了AI + */ + private boolean ATedAI(Message message) { + /* 前端传@信息后取消注释 */ + +// MessageExtra extra = message.getExtra(); +// if (extra == null) { +// return false; +// } +// if (CollectionUtils.isEmpty(extra.getAtUidList())) { +// return false; +// } +// if (!extra.getAtUidList().contains(OpenAIServiceImpl.AI_USER_ID)) { +// return false; +// } + + if (StringUtils.isBlank(message.getContent())) { + return false; + } + return StringUtils.contains(message.getContent(), "@" + MALL_CHAT_AI_NAME) + && StringUtils.isNotBlank(message.getContent().replace(MALL_CHAT_AI_NAME, "").trim()); + } + +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java new file mode 100644 index 0000000..0216a0b --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java @@ -0,0 +1,11 @@ +package com.abin.mallchat.custom.openai.service; + +import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq; + +public interface IOpenAIService { + + + void chat(ChatMessageReq chatMessageReq, Long uid); + void chat(Message message); +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java new file mode 100644 index 0000000..4924963 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java @@ -0,0 +1,171 @@ +package com.abin.mallchat.custom.openai.service.impl; + +import cn.hutool.core.bean.BeanUtil; +import cn.hutool.core.thread.NamedThreadFactory; +import cn.hutool.http.HttpResponse; +import cn.hutool.http.HttpUtil; +import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.common.chat.domain.enums.MessageTypeEnum; +import com.abin.mallchat.common.common.constant.RedisKey; +import com.abin.mallchat.common.common.exception.BusinessException; +import com.abin.mallchat.common.common.handler.GlobalUncaughtExceptionHandler; +import com.abin.mallchat.common.common.utils.DateUtils; +import com.abin.mallchat.common.common.utils.RedisUtils; +import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq; +import com.abin.mallchat.custom.chat.domain.vo.request.msg.TextMsgReq; +import com.abin.mallchat.custom.chat.service.ChatService; +import com.abin.mallchat.custom.openai.enums.OpenAIModelEnums; +import com.abin.mallchat.custom.openai.service.IOpenAIService; +import com.abin.mallchat.custom.openai.utils.OpenAIUtils; +import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; +import com.abin.mallchat.custom.user.service.UserService; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Description; +import org.springframework.context.annotation.Lazy; +import org.springframework.stereotype.Service; + +import java.util.Collections; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +@Slf4j +@Service +public class OpenAIServiceImpl implements IOpenAIService, DisposableBean, InitializingBean { + private static ExecutorService EXECUTOR; + + @Value("${openai.use-openai:false}") + private boolean USE_OPENAI; + @Value("${openai.ai-user-id}") + public Long AI_USER_ID; + + @Value("${openai.model.name:text-davinci-003}") + private String modelName; + @Value("${openai.key}") + private String key; + @Value("${openai.proxy-url:}") + private String proxyUrl; + + @Value("${openai.limit:5}") + private Integer limit; + + @Autowired + private UserService userService; + @Lazy + @Autowired + private ChatService chatService; + + public static String MALL_CHAT_AI_NAME; + + /** + * 聊天 + * + * @param chatMessageReq 提示词 + * @param uid 用户id + */ + @Deprecated + @Override + public void chat(ChatMessageReq chatMessageReq, Long uid) { + TextMsgReq body = BeanUtil.toBean(chatMessageReq.getBody(), TextMsgReq.class); + String content = body.getContent().replace(MALL_CHAT_AI_NAME, "").trim(); + EXECUTOR.execute(() -> { + Long chatNum; + if ((chatNum = userChatNumInrc(uid)) > limit) { + answerMsg("你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧", chatMessageReq.getRoomId(), uid); + } else { + chat(content, chatMessageReq.getRoomId(), uid); + } + }); + + } + + @Override + public void chat(Message message) { + String content = message.getContent().replace(MALL_CHAT_AI_NAME, "").trim(); + Long roomId = message.getRoomId(); + Long uid = message.getFromUid(); + EXECUTOR.execute(() -> { + Long chatNum; + if ((chatNum = userChatNumInrc(uid)) > limit) { + answerMsg("你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧", roomId, uid); + } else { + chat(content, roomId, uid); + } + }); + + } + + private Long userChatNumInrc(Long uid) { + //todo:白名单 + return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS); + } + + private void chat(String content, Long roomId, Long uid) { + HttpResponse response = OpenAIUtils.create(key) + .proxyUrl(proxyUrl) + .model(modelName) + .prompt(content) + .send(); + String text = OpenAIUtils.parseText(response); + answerMsg(text, roomId, uid); + } + + private void answerMsg(String text, Long roomId, Long uid) { + ChatMessageReq answerReq = new ChatMessageReq(); + answerReq.setRoomId(roomId); + answerReq.setMsgType(MessageTypeEnum.TEXT.getType()); + UserInfoResp userInfo = userService.getUserInfo(uid); + TextMsgReq textMsgReq = new TextMsgReq(); + textMsgReq.setContent("@" + userInfo.getName() + " " + text); + textMsgReq.setAtUidList(Collections.singletonList(uid)); + answerReq.setBody(textMsgReq); + chatService.sendMsg(answerReq, AI_USER_ID); + } + + + @Override + public void afterPropertiesSet() { + if (!USE_OPENAI) { + return; + } + if (StringUtils.isNotBlank(proxyUrl) && !HttpUtil.isHttp(proxyUrl) && !HttpUtil.isHttps(proxyUrl)) { + throw new BusinessException("openai.proxy-url 配置错误"); + } + OpenAIModelEnums modelEnum = OpenAIModelEnums.of(modelName); + if (modelEnum == null) { + throw new BusinessException("openai.model.name 配置错误"); + } + Integer rpm = modelEnum.getRPM(); + EXECUTOR = new ThreadPoolExecutor(10, 10, + 0L, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue<>(rpm), + new NamedThreadFactory("openAI-chat-gpt", + null, + false, + new GlobalUncaughtExceptionHandler()), + (r, executor) -> { + throw new BusinessException("别问的太快了,我的脑子不够用了"); + }); + UserInfoResp userInfo = userService.getUserInfo(AI_USER_ID); + if (userInfo == null) { + throw new BusinessException("openai.ai-user-id 配置错误"); + } + MALL_CHAT_AI_NAME = userInfo.getName(); + } + + @Override + public void destroy() throws Exception { + EXECUTOR.shutdown(); + if (!EXECUTOR.awaitTermination(30, TimeUnit.SECONDS)) { //最多等30秒,处理不完就拉倒 + if (log.isErrorEnabled()) { + log.error("Timed out while waiting for executor [{}] to terminate", EXECUTOR); + } + } + } +} \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java new file mode 100644 index 0000000..dee0570 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java @@ -0,0 +1,158 @@ +package com.abin.mallchat.custom.openai.utils; + +import cn.hutool.http.HttpResponse; +import cn.hutool.http.HttpUtil; +import cn.hutool.json.JSONArray; +import cn.hutool.json.JSONObject; +import com.abin.mallchat.common.common.exception.BusinessException; +import org.apache.commons.lang3.StringUtils; + +import java.util.HashMap; +import java.util.Map; + +public class OpenAIUtils { + + private static final String URL = "https://api.openai.com/v1/completions"; + + private String model = "text-davinci-003"; + + private final Map headers; + /** + * 超时30秒 + */ + private Integer timeout = 30 * 1000; + /** + * 参数用于指定生成文本的最大长度。 + * 它表示生成的文本中最多包含多少个 token。一个 token 可以是一个单词、一个标点符号或一个空格。 + */ + private int maxTokens = 2048; + /** + * 用于控制生成文本的多样性。 + * 较高的温度会导致更多的随机性和多样性,但可能会降低生成文本的质量。默认值为 1,建议在 0.7 到 1.3 之间调整。 + */ + private Object temperature = 1; + /** + * 用于控制生成文本的多样性。 + * 它会根据概率选择最高的几个单词,而不是选择概率最高的单词。默认值为 1,建议在 0.7 到 0.9 之间调整。 + */ + private Object topP = 0.9; + /** + * 用于控制生成文本中重复单词的数量。 + * 较高的惩罚值会导致更少的重复单词,但可能会降低生成文本的流畅性。默认值为 0,建议在 0 到 2 之间调整。 + */ + private Object frequencyPenalty = 0.0; + /** + * 用于控制生成文本中出现特定单词的数量。 + * 较高的惩罚值会导致更少的特定单词,但可能会降低生成文本的流畅性。默认值为 0,建议在 0 到 2 之间调整。 + */ + private Object presencePenalty = 0.6; + + /** + * 提示词 + */ + private String prompt; + + private String proxyUrl; + + public OpenAIUtils(String key) { + HashMap _headers_ = new HashMap<>(); + _headers_.put("Content-Type", "application/json"); + if (StringUtils.isBlank(key)) { + throw new BusinessException("openAi key is blank"); + } + _headers_.put("Authorization", "Bearer " + key); + this.headers = _headers_; + } + + public static OpenAIUtils create(String key) { + return new OpenAIUtils(key); + } + + public static String parseText(HttpResponse response) { + return parseText(response.body()); + } + + public static String parseText(String body) { + JSONObject jsonObj = new JSONObject(body); + JSONArray choicesArr = jsonObj.getJSONArray("choices"); + JSONObject choiceObj = choicesArr.getJSONObject(0); + return choiceObj.getStr("text"); + } + + public OpenAIUtils model(String model) { + this.model = model; + return this; + } + + public OpenAIUtils timeout(int timeout) { + this.timeout = timeout; + return this; + } + + public OpenAIUtils maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public OpenAIUtils temperature(int temperature) { + this.temperature = temperature; + return this; + } + + public OpenAIUtils topP(int topP) { + this.topP = topP; + return this; + } + + public OpenAIUtils frequencyPenalty(int frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public OpenAIUtils presencePenalty(int presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public OpenAIUtils prompt(String prompt) { + this.prompt = prompt; + return this; + } + + public OpenAIUtils proxyUrl(String proxyUrl) { + this.proxyUrl = proxyUrl; + return this; + } + + public HttpResponse send() { + JSONObject param = new JSONObject(); + param.set("model", model); + param.set("prompt", prompt); + param.set("max_tokens", maxTokens); + param.set("temperature", temperature); + param.set("top_p", topP); + param.set("frequency_penalty", frequencyPenalty); + param.set("presence_penalty", presencePenalty); + return HttpUtil.createPost(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL) + .addHeaders(headers) + .body(param.toString()) + .timeout(timeout) + .execute(); + } + + public static void main(String[] args) { + HttpResponse send = OpenAIUtils.create("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") + .timeout(30 * 1000) + .prompt("Spring的启动流程是什么") + .send(); + System.out.println("send = " + send); + // JSON 数据 + // JSON 数据 + JSONObject jsonObj = new JSONObject(send.body()); + JSONArray choicesArr = jsonObj.getJSONArray("choices"); + JSONObject choiceObj = choicesArr.getJSONObject(0); + String text = choiceObj.getStr("text"); + System.out.println("text = " + text); + + } +} \ No newline at end of file From c06934bc896aadfdce5d32d2e9386092001d7772 Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Fri, 30 Jun 2023 00:03:21 +0800 Subject: [PATCH 2/5] openAI --- .../abin/mallchat/custom/openai/utils/OpenAIUtils.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java index dee0570..f01c8a8 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java @@ -5,11 +5,13 @@ import cn.hutool.http.HttpUtil; import cn.hutool.json.JSONArray; import cn.hutool.json.JSONObject; import com.abin.mallchat.common.common.exception.BusinessException; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import java.util.HashMap; import java.util.Map; +@Slf4j public class OpenAIUtils { private static final String URL = "https://api.openai.com/v1/completions"; @@ -73,7 +75,13 @@ public class OpenAIUtils { } public static String parseText(String body) { + log.info("body >>> " + body); JSONObject jsonObj = new JSONObject(body); + JSONObject error = jsonObj.getJSONObject("error"); + if (error != null) { + log.error("error >>> " + error); + return "闹脾气了,等会再试试吧~"; + } JSONArray choicesArr = jsonObj.getJSONArray("choices"); JSONObject choiceObj = choicesArr.getJSONObject(0); return choiceObj.getStr("text"); From 440159e8a06feb3bcfe5f781e7b8451380a8ff9c Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Sat, 1 Jul 2023 13:42:10 +0800 Subject: [PATCH 3/5] chatAI-handler --- .../common/common/constant/RedisKey.java | 9 +- .../src/main/resources/application.yml | 15 +- .../domain/vo/request/ChatMessageReq.java | 4 + .../chat/service/impl/ChatServiceImpl.java | 2 - .../enums/ChatGPTModelEnum.java} | 10 +- .../chatai/handler/AbstractChatAIHandler.java | 130 +++++++++++++ .../chatai/handler/ChatAIHandlerFactory.java | 42 +++++ .../chatai/handler/ChatGLM2Handler.java | 136 ++++++++++++++ .../chatai/handler/GPTChatAIHandler.java | 86 +++++++++ .../chatai/properties/ChatGLM2Properties.java | 43 +++++ .../chatai/properties/ChatGPTProperties.java | 44 +++++ .../custom/chatai/service/IChatAIService.java | 8 + .../service/impl/ChatAIServiceImpl.java | 26 +++ .../custom/chatai/utils/ChatGLM2Utils.java | 107 +++++++++++ .../utils/ChatGPTUtils.java} | 30 +-- .../event/listener/MessageSendListener.java | 10 + .../custom/openai/event/OpenAIEvent.java | 14 -- .../openai/event/listener/OpenAIListener.java | 63 ------- .../custom/openai/service/IOpenAIService.java | 11 -- .../service/impl/OpenAIServiceImpl.java | 171 ------------------ 20 files changed, 673 insertions(+), 288 deletions(-) rename mallchat-custom-server/src/main/java/com/abin/mallchat/custom/{openai/enums/OpenAIModelEnums.java => chatai/enums/ChatGPTModelEnum.java} (92%) create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/AbstractChatAIHandler.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatAIHandlerFactory.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/IChatAIService.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/impl/ChatAIServiceImpl.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGLM2Utils.java rename mallchat-custom-server/src/main/java/com/abin/mallchat/custom/{openai/utils/OpenAIUtils.java => chatai/utils/ChatGPTUtils.java} (85%) delete mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java delete mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java delete mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java delete mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java index a86d1c4..0cc9bd0 100644 --- a/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java @@ -37,9 +37,14 @@ public class RedisKey { public static final String USER_SUMMARY_STRING = "userSummary:uid_%d"; /** - * 用户AI聊天次数 + * 用户GPT聊天次数 */ - public static final String USER_CHAT_NUM = "userAIChatNum:uid_%d"; + public static final String USER_CHAT_NUM = "useChatGPTNum:uid_%d"; + + /** + * 用户上次使用GLM使用时间 + */ + public static final String USER_GLM2_TIME_LAST = "userGLM2UseTime:uid_%d"; public static String getKey(String key, Object... objects) { return BASE_KEY + String.format(key, objects); diff --git a/mallchat-common/src/main/resources/application.yml b/mallchat-common/src/main/resources/application.yml index 0293bb6..e0f925e 100644 --- a/mallchat-common/src/main/resources/application.yml +++ b/mallchat-common/src/main/resources/application.yml @@ -63,8 +63,13 @@ wx: secret: ${mallchat.wx.secret} # 公众号的appsecret token: ${mallchat.wx.token} # 接口配置里的Token值 aesKey: ${mallchat.wx.aesKey} # 接口配置里的EncodingAESKey值 -openai: - use-openai: true - ai-user-id: xxxxx - key: xxxxxxx - proxy-url: https://xxxxxxx \ No newline at end of file +chatai: + chatgpt: + use: true + AIUserId: 10450 + key: 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 + chatglm2: + use: true + url: http://vastmiao.natapp1.cc + minute: 3 # 每个用户每3分钟可以请求一次 + AIUserId: 10451 diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/domain/vo/request/ChatMessageReq.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/domain/vo/request/ChatMessageReq.java index a6a6095..2fd2040 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/domain/vo/request/ChatMessageReq.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/domain/vo/request/ChatMessageReq.java @@ -11,9 +11,13 @@ import javax.validation.constraints.NotNull; /** + * 聊天信息点播 * Description: 消息发送请求体 * Author: abin * Date: 2023-03-23 + * + * @author zhaoyuhang + * @date 2023/06/30 */ @Data @Builder diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java index 177dbfa..adbb1d3 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java @@ -38,7 +38,6 @@ import com.abin.mallchat.custom.chat.service.strategy.mark.MsgMarkFactory; import com.abin.mallchat.custom.chat.service.strategy.msg.AbstractMsgHandler; import com.abin.mallchat.custom.chat.service.strategy.msg.MsgHandlerFactory; import com.abin.mallchat.custom.chat.service.strategy.msg.RecallMsgHandler; -import com.abin.mallchat.custom.openai.event.OpenAIEvent; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -96,7 +95,6 @@ public class ChatServiceImpl implements ChatService { msgHandler.saveMsg(insert, request); //发布消息发送事件 applicationEventPublisher.publishEvent(new MessageSendEvent(this, insert.getId())); - applicationEventPublisher.publishEvent(new OpenAIEvent(this, insert.getId())); return insert.getId(); } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/enums/ChatGPTModelEnum.java similarity index 92% rename from mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java rename to mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/enums/ChatGPTModelEnum.java index 194c3df..6b73b0f 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/enums/ChatGPTModelEnum.java @@ -1,4 +1,4 @@ -package com.abin.mallchat.custom.openai.enums; +package com.abin.mallchat.custom.chatai.enums; import lombok.AllArgsConstructor; import lombok.Getter; @@ -10,7 +10,7 @@ import java.util.stream.Collectors; @AllArgsConstructor @Getter -public enum OpenAIModelEnums { +public enum ChatGPTModelEnum { // chat GPT_35_TURBO("gpt-3.5-turbo", 3, 40000), GPT_35_TURBO_0301("gpt-3.5-turbo-0301", 3, 40000), @@ -79,13 +79,13 @@ public enum OpenAIModelEnums { */ private final Integer TPM; - private static final Map cache; + private static final Map cache; static { - cache = Arrays.stream(OpenAIModelEnums.values()).collect(Collectors.toMap(OpenAIModelEnums::getName, Function.identity())); + cache = Arrays.stream(ChatGPTModelEnum.values()).collect(Collectors.toMap(ChatGPTModelEnum::getName, Function.identity())); } - public static OpenAIModelEnums of(String name) { + public static ChatGPTModelEnum of(String name) { return cache.get(name); } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/AbstractChatAIHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/AbstractChatAIHandler.java new file mode 100644 index 0000000..04e5a6e --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/AbstractChatAIHandler.java @@ -0,0 +1,130 @@ +package com.abin.mallchat.custom.chatai.handler; + +import cn.hutool.core.thread.NamedThreadFactory; +import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.common.chat.domain.enums.MessageTypeEnum; +import com.abin.mallchat.common.common.exception.BusinessException; +import com.abin.mallchat.common.common.handler.GlobalUncaughtExceptionHandler; +import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq; +import com.abin.mallchat.custom.chat.domain.vo.request.msg.TextMsgReq; +import com.abin.mallchat.custom.chat.service.ChatService; +import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; +import com.abin.mallchat.custom.user.service.UserService; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import javax.annotation.PostConstruct; +import java.util.Collections; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +@Slf4j +@Component +public abstract class AbstractChatAIHandler implements DisposableBean, InitializingBean { + public static ExecutorService EXECUTOR; + + @Autowired + protected ChatService chatService; + @Autowired + protected UserService userService; + + @PostConstruct + private void init() { + ChatAIHandlerFactory.register(getChatAIUserId(), getChatAIName(), this); + } + // 获取机器人id + public abstract Long getChatAIUserId(); + // 获取机器人名称 + public abstract String getChatAIName(); + + public void chat(Message message) { + if (!supports(message)) { + return; + } + EXECUTOR.execute(() -> { + String text = doChat(message); + if (StringUtils.isNotBlank(text)) { + answerMsg(text, message.getRoomId(), message.getFromUid()); + } + }); + } + + /** + * 支持 + * + * @param message 消息 + * @return boolean true 支持 false 不支持 + */ + protected abstract boolean supports(Message message); + + /** + * 执行聊天 + * + * @param message 消息 + * @return {@link String} AI回答的内容 + */ + protected abstract String doChat(Message message); + + + protected void answerMsg(String text, Long roomId, Long uid) { + UserInfoResp userInfo = userService.getUserInfo(uid); + text = "@" + userInfo.getName() + " " + text; + if (text.length() < 450) { + save(text, roomId, uid); + }else { + int maxLen = 450; + int len = text.length(); + int count = (len + maxLen - 1) / maxLen; + + for (int i = 0; i < count; i++) { + int start = i * maxLen; + int end = Math.min(start + maxLen, len); + save(text.substring(start, end), roomId, uid); + } + } + } + + private void save(String text, Long roomId, Long uid) { + ChatMessageReq answerReq = new ChatMessageReq(); + answerReq.setRoomId(roomId); + answerReq.setMsgType(MessageTypeEnum.TEXT.getType()); + TextMsgReq textMsgReq = new TextMsgReq(); + textMsgReq.setContent(text); + textMsgReq.setAtUidList(Collections.singletonList(uid)); + answerReq.setBody(textMsgReq); + chatService.sendMsg(answerReq, getChatAIUserId()); + } + + @Override + public void afterPropertiesSet() { + EXECUTOR = new ThreadPoolExecutor( + 10, + 10, + 0L, + TimeUnit.MILLISECONDS, + new LinkedBlockingQueue<>(15), + new NamedThreadFactory("openAI-chat-gpt", + null, + false, + new GlobalUncaughtExceptionHandler()), + (r, executor) -> { + throw new BusinessException("别问的太快了,我的脑子不够用了"); + }); + } + + @Override + public void destroy() throws Exception { + EXECUTOR.shutdown(); + if (!EXECUTOR.awaitTermination(30, TimeUnit.SECONDS)) { //最多等30秒,处理不完就拉倒 + if (log.isErrorEnabled()) { + log.error("Timed out while waiting for executor [{}] to terminate", EXECUTOR); + } + } + } +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatAIHandlerFactory.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatAIHandlerFactory.java new file mode 100644 index 0000000..8d6356d --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatAIHandlerFactory.java @@ -0,0 +1,42 @@ +package com.abin.mallchat.custom.chatai.handler; + +import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; +import com.baomidou.mybatisplus.core.toolkit.StringUtils; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class ChatAIHandlerFactory { + private static final Map CHATAI_ID_MAP = new ConcurrentHashMap<>(); + private static final Map CHATAI_NAME_MAP = new ConcurrentHashMap<>(); + + public static void register(Long aIUserId, String name, AbstractChatAIHandler chatAIHandler) { + CHATAI_ID_MAP.put(aIUserId, chatAIHandler); + CHATAI_NAME_MAP.put(name, chatAIHandler); + } + + public static AbstractChatAIHandler getChatAIHandlerById(List userIds) { + if (CollectionUtils.isEmpty(userIds)) { + return null; + } + for (Long userId : userIds) { + AbstractChatAIHandler chatAIHandler = CHATAI_ID_MAP.get(userId); + if (chatAIHandler != null) { + return chatAIHandler; + } + } + return null; + } + public static AbstractChatAIHandler getChatAIHandlerByName(String text) { + if (StringUtils.isBlank(text)) { + return null; + } + for (Map.Entry entry : CHATAI_NAME_MAP.entrySet()) { + if (text.contains("@"+entry.getKey())) { + return entry.getValue(); + } + } + return null; + } +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java new file mode 100644 index 0000000..ca529f3 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java @@ -0,0 +1,136 @@ +package com.abin.mallchat.custom.chatai.handler; + +import cn.hutool.http.HttpResponse; +import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.common.common.constant.RedisKey; +import com.abin.mallchat.common.common.utils.RedisUtils; +import com.abin.mallchat.custom.chatai.properties.ChatGLM2Properties; +import com.abin.mallchat.custom.chatai.utils.ChatGLM2Utils; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.Arrays; +import java.util.Date; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +import static com.abin.mallchat.common.common.constant.RedisKey.USER_GLM2_TIME_LAST; + +@Slf4j +@Component +public class ChatGLM2Handler extends AbstractChatAIHandler { + + private static final List ERROR_MSG = Arrays.asList( + "还摸鱼呢?你不下班我还要下班呢。。。。", + "没给钱,矿工了。。。。", + "服务器被你们玩儿坏了。。。。", + "你们这群人,我都不想理你们了。。。。", + "还艾特我呢?那是另外的价钱。。。。", + "得加钱"); + + + private static final Random RANDOM = new Random(); + + @Autowired + private ChatGLM2Properties glm2Properties; + + @Override + public Long getChatAIUserId() { + return glm2Properties.getAIUserId(); + } + + @Override + public String getChatAIName() { + if (StringUtils.isNotBlank(glm2Properties.getAIUserName())) { + return glm2Properties.getAIUserName(); + } + String name = userService.getUserInfo(glm2Properties.getAIUserId()).getName(); + glm2Properties.setAIUserName(name); + return name; + } + + @Override + protected String doChat(Message message) { + String content = message.getContent().replace("@" +glm2Properties.getAIUserName(), "").trim(); + Long uid = message.getFromUid(); + Long minute; + String text; + if ((minute = userMinutesLater(uid)) > 0) { + text = "你太快了 " + minute + "分钟后重试"; + } else { + HttpResponse response = null; + try { + response = ChatGLM2Utils + .create() + .url(glm2Properties.getUrl()) + .prompt(content) + .send(); + } catch (Exception e) { + e.printStackTrace(); + return getErrorText(); + } + text = ChatGLM2Utils.parseText(response); + if (StringUtils.isNotBlank(text)) { + RedisUtils.set(RedisKey.getKey(USER_GLM2_TIME_LAST, uid), new Date(), glm2Properties.getMinute(), TimeUnit.MINUTES); + } + } + return text; + } + + private static String getErrorText() { + int index = RANDOM.nextInt(ERROR_MSG.size()); + return ERROR_MSG.get(index); + } + + /** + * 用户多少分钟后才能再次聊天 + * + * @param uid + * @return + */ + private Long userMinutesLater(Long uid) { + // 获取用户最后聊天时间 + Date lastChatTime = RedisUtils.get(RedisKey.getKey(USER_GLM2_TIME_LAST, uid), Date.class); + if (lastChatTime == null) { + // 如果没有聊天记录,则可以立即聊天 + return 0L; + } + // 计算当前时间和上次聊天时间之间的时间差 + long now = System.currentTimeMillis(); + long lastChatTimeMillis = lastChatTime.getTime(); + long durationMillis = now - lastChatTimeMillis; + long minutes = TimeUnit.MILLISECONDS.toMinutes(durationMillis); + // 计算剩余等待时间 + long remainingMinutes = glm2Properties.getMinute() - minutes; + return remainingMinutes > 0 ? remainingMinutes : 0L; + } + + + @Override + protected boolean supports(Message message) { + if (!glm2Properties.isUse()) { + return false; + } + /* 前端传@信息后取消注释 */ + +// MessageExtra extra = message.getExtra(); +// if (extra == null) { +// return false; +// } +// if (CollectionUtils.isEmpty(extra.getAtUidList())) { +// return false; +// } +// if (!extra.getAtUidList().contains(ChatAIServiceImpl.AI_USER_ID)) { +// return false; +// } + + if (StringUtils.isBlank(message.getContent())) { + return false; + } + return StringUtils.contains(message.getContent(), "@" + glm2Properties.getAIUserName()) + && StringUtils.isNotBlank(message.getContent().replace(glm2Properties.getAIUserName(), "").trim()); + } +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java new file mode 100644 index 0000000..e981fa7 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java @@ -0,0 +1,86 @@ +package com.abin.mallchat.custom.chatai.handler; + +import cn.hutool.http.HttpResponse; +import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.common.common.constant.RedisKey; +import com.abin.mallchat.common.common.utils.DateUtils; +import com.abin.mallchat.common.common.utils.RedisUtils; +import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties; +import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.concurrent.TimeUnit; + +@Component +public class GPTChatAIHandler extends AbstractChatAIHandler { + + @Autowired + private ChatGPTProperties chatGPTProperties; + + @Override + public Long getChatAIUserId() { + return chatGPTProperties.getAIUserId(); + } + + @Override + public String getChatAIName() { + if (StringUtils.isNotBlank(chatGPTProperties.getAIUserName())) { + return chatGPTProperties.getAIUserName(); + } + String name = userService.getUserInfo(chatGPTProperties.getAIUserId()).getName(); + chatGPTProperties.setAIUserName(name); + return name; + } + + @Override + protected String doChat(Message message) { + String content = message.getContent().replace("@" +chatGPTProperties.getAIUserName(), "").trim(); + Long uid = message.getFromUid(); + Long chatNum; + String text; + if ((chatNum = userChatNumInrc(uid)) > chatGPTProperties.getLimit()) { + text = "你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧"; + } else { + HttpResponse response = ChatGPTUtils.create(chatGPTProperties.getKey()) + .proxyUrl(chatGPTProperties.getProxyUrl()) + .model(chatGPTProperties.getModelName()) + .prompt(content) + .send(); + text = ChatGPTUtils.parseText(response); + } + return text; + } + + private Long userChatNumInrc(Long uid) { + //todo:白名单 + return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS); + } + + + @Override + protected boolean supports(Message message) { + if (!chatGPTProperties.isUse()) { + return false; + } + /* 前端传@信息后取消注释 */ + +// MessageExtra extra = message.getExtra(); +// if (extra == null) { +// return false; +// } +// if (CollectionUtils.isEmpty(extra.getAtUidList())) { +// return false; +// } +// if (!extra.getAtUidList().contains(ChatAIServiceImpl.AI_USER_ID)) { +// return false; +// } + + if (StringUtils.isBlank(message.getContent())) { + return false; + } + return StringUtils.contains(message.getContent(), "@" + chatGPTProperties.getAIUserName()) + && StringUtils.isNotBlank(message.getContent().replace(chatGPTProperties.getAIUserName(), "").trim()); + } +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java new file mode 100644 index 0000000..582ce4a --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java @@ -0,0 +1,43 @@ +package com.abin.mallchat.custom.chatai.properties; + +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; + +/** + * ChatGLM2 配置文件 + * + * @author zhaoyuhang + * @date 2023/06/30 + */ +@Data +@Component +@ConfigurationProperties(prefix = "chatai.chatglm2") +public class ChatGLM2Properties { + + /** + * 使用 + */ + private boolean use; + + /** + * url + */ + private String url; + + /** + * 机器人 id + */ + private Long AIUserId; + + /** + * 机器人名称 + */ + private String AIUserName; + + /** + * 每个用户每3分钟可以请求一次 + */ + private Long minute = 3L; + +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java new file mode 100644 index 0000000..b7a5c56 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGPTProperties.java @@ -0,0 +1,44 @@ +package com.abin.mallchat.custom.chatai.properties; + +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; + + +@Data +@Component +@ConfigurationProperties(prefix = "chatai.chatgpt") +public class ChatGPTProperties { + + /** + * 是否使用openAI + */ + private boolean use; + /** + * 机器人 id + */ + private Long AIUserId; + + /** + * 机器人名称 + */ + private String AIUserName; + /** + * 模型名称 + */ + private String modelName = "text-davinci-003"; + /** + * openAI key + */ + private String key; + /** + * 代理地址 + */ + private String proxyUrl; + + /** + * 用户每天条数限制 + */ + private Integer limit = 5; + +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/IChatAIService.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/IChatAIService.java new file mode 100644 index 0000000..8cded22 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/IChatAIService.java @@ -0,0 +1,8 @@ +package com.abin.mallchat.custom.chatai.service; + +import com.abin.mallchat.common.chat.domain.entity.Message; + +public interface IChatAIService { + + void chat(Message message); +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/impl/ChatAIServiceImpl.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/impl/ChatAIServiceImpl.java new file mode 100644 index 0000000..37af88e --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/impl/ChatAIServiceImpl.java @@ -0,0 +1,26 @@ +package com.abin.mallchat.custom.chatai.service.impl; + +import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra; +import com.abin.mallchat.custom.chatai.handler.AbstractChatAIHandler; +import com.abin.mallchat.custom.chatai.handler.ChatAIHandlerFactory; +import com.abin.mallchat.custom.chatai.service.IChatAIService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +@Slf4j +@Service +public class ChatAIServiceImpl implements IChatAIService { + @Override + public void chat(Message message) { + MessageExtra extra = message.getExtra(); + if (extra == null) { + return; + } + AbstractChatAIHandler chatAI = ChatAIHandlerFactory.getChatAIHandlerByName(message.getContent()); +// AbstractChatAIHandler chatAI = ChatAIHandlerFactory.getChatAIHandlerById(extra.getAtUidList()); + if (chatAI != null) { + chatAI.chat(message); + } + } +} \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGLM2Utils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGLM2Utils.java new file mode 100644 index 0000000..2729e0a --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGLM2Utils.java @@ -0,0 +1,107 @@ +package com.abin.mallchat.custom.chatai.utils; + +import cn.hutool.http.HttpResponse; +import cn.hutool.http.HttpUtil; +import cn.hutool.json.JSONObject; +import lombok.extern.slf4j.Slf4j; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@Slf4j +public class ChatGLM2Utils { + + + private final Map headers; + /** + * 超时30秒 + */ + private Integer timeout = 60 * 1000; + + private String url; + /** + * 提示词 + */ + private String prompt; + + /** + * 历史 + */ + private List history; + + + public ChatGLM2Utils() { + HashMap _headers_ = new HashMap<>(); + _headers_.put("Content-Type", "application/json"); + this.headers = _headers_; + } + + public static ChatGLM2Utils create() { + return new ChatGLM2Utils(); + } + + + + + public ChatGLM2Utils url(String url) { + this.url = url; + return this; + } + + + public ChatGLM2Utils timeout(int timeout) { + this.timeout = timeout; + return this; + } + + public ChatGLM2Utils prompt(String prompt) { + this.prompt = prompt; + return this; + } + public HttpResponse send() { + JSONObject param = new JSONObject(); + param.set("prompt", prompt); + log.info("headers >>> " + headers); + log.info("param >>> " + param); + return HttpUtil.createPost(url) + .addHeaders(headers) + .body(param.toString()) + .timeout(timeout) + .execute(); + } + + public static String parseText(String body) { + log.info("body >>> " + body); + JSONObject jsonObj = new JSONObject(body); + if (200 != jsonObj.getInt("status")) { + log.error("status >>> " + jsonObj.getInt("status")); + return "闹脾气了,等会再试试吧~"; + } + return jsonObj.getStr("response"); + } + + public static String parseText(HttpResponse response) { + return parseText(response.body()); + } + + + public static void main(String[] args) { + HttpResponse send = null; + try { + send = ChatGLM2Utils + .create() + .url("http://vastmiao.natapp1.cc") + .timeout(60 * 1000) + .prompt("Spring的启动流程是什么") + .send(); + } catch (Exception e) { + throw new RuntimeException(e); + } + System.out.println("send = " + send); + + System.out.println("parseText(send) = " + parseText(send)); + } + + +} \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java similarity index 85% rename from mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java rename to mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java index f01c8a8..c846492 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java @@ -1,4 +1,4 @@ -package com.abin.mallchat.custom.openai.utils; +package com.abin.mallchat.custom.chatai.utils; import cn.hutool.http.HttpResponse; import cn.hutool.http.HttpUtil; @@ -12,7 +12,7 @@ import java.util.HashMap; import java.util.Map; @Slf4j -public class OpenAIUtils { +public class ChatGPTUtils { private static final String URL = "https://api.openai.com/v1/completions"; @@ -56,7 +56,7 @@ public class OpenAIUtils { private String proxyUrl; - public OpenAIUtils(String key) { + public ChatGPTUtils(String key) { HashMap _headers_ = new HashMap<>(); _headers_.put("Content-Type", "application/json"); if (StringUtils.isBlank(key)) { @@ -66,8 +66,8 @@ public class OpenAIUtils { this.headers = _headers_; } - public static OpenAIUtils create(String key) { - return new OpenAIUtils(key); + public static ChatGPTUtils create(String key) { + return new ChatGPTUtils(key); } public static String parseText(HttpResponse response) { @@ -87,47 +87,47 @@ public class OpenAIUtils { return choiceObj.getStr("text"); } - public OpenAIUtils model(String model) { + public ChatGPTUtils model(String model) { this.model = model; return this; } - public OpenAIUtils timeout(int timeout) { + public ChatGPTUtils timeout(int timeout) { this.timeout = timeout; return this; } - public OpenAIUtils maxTokens(int maxTokens) { + public ChatGPTUtils maxTokens(int maxTokens) { this.maxTokens = maxTokens; return this; } - public OpenAIUtils temperature(int temperature) { + public ChatGPTUtils temperature(int temperature) { this.temperature = temperature; return this; } - public OpenAIUtils topP(int topP) { + public ChatGPTUtils topP(int topP) { this.topP = topP; return this; } - public OpenAIUtils frequencyPenalty(int frequencyPenalty) { + public ChatGPTUtils frequencyPenalty(int frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; return this; } - public OpenAIUtils presencePenalty(int presencePenalty) { + public ChatGPTUtils presencePenalty(int presencePenalty) { this.presencePenalty = presencePenalty; return this; } - public OpenAIUtils prompt(String prompt) { + public ChatGPTUtils prompt(String prompt) { this.prompt = prompt; return this; } - public OpenAIUtils proxyUrl(String proxyUrl) { + public ChatGPTUtils proxyUrl(String proxyUrl) { this.proxyUrl = proxyUrl; return this; } @@ -149,7 +149,7 @@ public class OpenAIUtils { } public static void main(String[] args) { - HttpResponse send = OpenAIUtils.create("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") + HttpResponse send = ChatGPTUtils.create("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") .timeout(30 * 1000) .prompt("Spring的启动流程是什么") .send(); diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/common/event/listener/MessageSendListener.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/common/event/listener/MessageSendListener.java index 15b7cab..ba8bcf0 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/common/event/listener/MessageSendListener.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/common/event/listener/MessageSendListener.java @@ -5,9 +5,11 @@ import com.abin.mallchat.common.chat.domain.entity.Message; import com.abin.mallchat.common.common.event.MessageSendEvent; import com.abin.mallchat.custom.chat.domain.vo.response.ChatMessageResp; import com.abin.mallchat.custom.chat.service.ChatService; +import com.abin.mallchat.custom.chatai.service.IChatAIService; import com.abin.mallchat.custom.user.service.WebSocketService; import com.abin.mallchat.custom.user.service.adapter.WSAdapter; import lombok.extern.slf4j.Slf4j; +import org.jetbrains.annotations.NotNull; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Component; @@ -27,6 +29,8 @@ public class MessageSendListener { private ChatService chatService; @Autowired private MessageDao messageDao; + @Autowired + private IChatAIService openAIService; @Async @TransactionalEventListener(classes = MessageSendEvent.class, fallbackExecution = true) @@ -36,4 +40,10 @@ public class MessageSendListener { webSocketService.sendToAllOnline(WSAdapter.buildMsgSend(msgResp), message.getFromUid()); } + @TransactionalEventListener(classes = MessageSendEvent.class, fallbackExecution = true) + public void handlerMsg(@NotNull MessageSendEvent event) { + Message message = messageDao.getById(event.getMsgId()); + openAIService.chat(message); + } + } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java deleted file mode 100644 index 493d54e..0000000 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.abin.mallchat.custom.openai.event; - -import lombok.Getter; -import org.springframework.context.ApplicationEvent; - -@Getter -public class OpenAIEvent extends ApplicationEvent { - private Long msgId; - - public OpenAIEvent(Object source, Long msgId) { - super(source); - this.msgId = msgId; - } -} \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java deleted file mode 100644 index 160a1f1..0000000 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java +++ /dev/null @@ -1,63 +0,0 @@ -package com.abin.mallchat.custom.openai.event.listener; - -import com.abin.mallchat.common.chat.dao.MessageDao; -import com.abin.mallchat.common.chat.domain.entity.Message; -import com.abin.mallchat.custom.openai.event.OpenAIEvent; -import com.abin.mallchat.custom.openai.service.IOpenAIService; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.jetbrains.annotations.NotNull; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Component; -import org.springframework.transaction.event.TransactionalEventListener; - -import static com.abin.mallchat.custom.openai.service.impl.OpenAIServiceImpl.MALL_CHAT_AI_NAME; - -/** - * 是否AI回复监听器 - * - * @author zhaoyuhang - * @date 2023/06/29 - */ -@Slf4j -@Component -public class OpenAIListener { - @Autowired - private IOpenAIService openAIService; - @Autowired - private MessageDao messageDao; - - @TransactionalEventListener(classes = OpenAIEvent.class, fallbackExecution = true) - public void notifyAllOnline(@NotNull OpenAIEvent event) { - Message message = messageDao.getById(event.getMsgId()); - if (ATedAI(message)) { - openAIService.chat(message); - } - } - - /** - * @return boolean - * @了AI - */ - private boolean ATedAI(Message message) { - /* 前端传@信息后取消注释 */ - -// MessageExtra extra = message.getExtra(); -// if (extra == null) { -// return false; -// } -// if (CollectionUtils.isEmpty(extra.getAtUidList())) { -// return false; -// } -// if (!extra.getAtUidList().contains(OpenAIServiceImpl.AI_USER_ID)) { -// return false; -// } - - if (StringUtils.isBlank(message.getContent())) { - return false; - } - return StringUtils.contains(message.getContent(), "@" + MALL_CHAT_AI_NAME) - && StringUtils.isNotBlank(message.getContent().replace(MALL_CHAT_AI_NAME, "").trim()); - } - -} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java deleted file mode 100644 index 0216a0b..0000000 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.abin.mallchat.custom.openai.service; - -import com.abin.mallchat.common.chat.domain.entity.Message; -import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq; - -public interface IOpenAIService { - - - void chat(ChatMessageReq chatMessageReq, Long uid); - void chat(Message message); -} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java deleted file mode 100644 index 4924963..0000000 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java +++ /dev/null @@ -1,171 +0,0 @@ -package com.abin.mallchat.custom.openai.service.impl; - -import cn.hutool.core.bean.BeanUtil; -import cn.hutool.core.thread.NamedThreadFactory; -import cn.hutool.http.HttpResponse; -import cn.hutool.http.HttpUtil; -import com.abin.mallchat.common.chat.domain.entity.Message; -import com.abin.mallchat.common.chat.domain.enums.MessageTypeEnum; -import com.abin.mallchat.common.common.constant.RedisKey; -import com.abin.mallchat.common.common.exception.BusinessException; -import com.abin.mallchat.common.common.handler.GlobalUncaughtExceptionHandler; -import com.abin.mallchat.common.common.utils.DateUtils; -import com.abin.mallchat.common.common.utils.RedisUtils; -import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq; -import com.abin.mallchat.custom.chat.domain.vo.request.msg.TextMsgReq; -import com.abin.mallchat.custom.chat.service.ChatService; -import com.abin.mallchat.custom.openai.enums.OpenAIModelEnums; -import com.abin.mallchat.custom.openai.service.IOpenAIService; -import com.abin.mallchat.custom.openai.utils.OpenAIUtils; -import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; -import com.abin.mallchat.custom.user.service.UserService; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.DisposableBean; -import org.springframework.beans.factory.InitializingBean; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.context.annotation.Description; -import org.springframework.context.annotation.Lazy; -import org.springframework.stereotype.Service; - -import java.util.Collections; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; - -@Slf4j -@Service -public class OpenAIServiceImpl implements IOpenAIService, DisposableBean, InitializingBean { - private static ExecutorService EXECUTOR; - - @Value("${openai.use-openai:false}") - private boolean USE_OPENAI; - @Value("${openai.ai-user-id}") - public Long AI_USER_ID; - - @Value("${openai.model.name:text-davinci-003}") - private String modelName; - @Value("${openai.key}") - private String key; - @Value("${openai.proxy-url:}") - private String proxyUrl; - - @Value("${openai.limit:5}") - private Integer limit; - - @Autowired - private UserService userService; - @Lazy - @Autowired - private ChatService chatService; - - public static String MALL_CHAT_AI_NAME; - - /** - * 聊天 - * - * @param chatMessageReq 提示词 - * @param uid 用户id - */ - @Deprecated - @Override - public void chat(ChatMessageReq chatMessageReq, Long uid) { - TextMsgReq body = BeanUtil.toBean(chatMessageReq.getBody(), TextMsgReq.class); - String content = body.getContent().replace(MALL_CHAT_AI_NAME, "").trim(); - EXECUTOR.execute(() -> { - Long chatNum; - if ((chatNum = userChatNumInrc(uid)) > limit) { - answerMsg("你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧", chatMessageReq.getRoomId(), uid); - } else { - chat(content, chatMessageReq.getRoomId(), uid); - } - }); - - } - - @Override - public void chat(Message message) { - String content = message.getContent().replace(MALL_CHAT_AI_NAME, "").trim(); - Long roomId = message.getRoomId(); - Long uid = message.getFromUid(); - EXECUTOR.execute(() -> { - Long chatNum; - if ((chatNum = userChatNumInrc(uid)) > limit) { - answerMsg("你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧", roomId, uid); - } else { - chat(content, roomId, uid); - } - }); - - } - - private Long userChatNumInrc(Long uid) { - //todo:白名单 - return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS); - } - - private void chat(String content, Long roomId, Long uid) { - HttpResponse response = OpenAIUtils.create(key) - .proxyUrl(proxyUrl) - .model(modelName) - .prompt(content) - .send(); - String text = OpenAIUtils.parseText(response); - answerMsg(text, roomId, uid); - } - - private void answerMsg(String text, Long roomId, Long uid) { - ChatMessageReq answerReq = new ChatMessageReq(); - answerReq.setRoomId(roomId); - answerReq.setMsgType(MessageTypeEnum.TEXT.getType()); - UserInfoResp userInfo = userService.getUserInfo(uid); - TextMsgReq textMsgReq = new TextMsgReq(); - textMsgReq.setContent("@" + userInfo.getName() + " " + text); - textMsgReq.setAtUidList(Collections.singletonList(uid)); - answerReq.setBody(textMsgReq); - chatService.sendMsg(answerReq, AI_USER_ID); - } - - - @Override - public void afterPropertiesSet() { - if (!USE_OPENAI) { - return; - } - if (StringUtils.isNotBlank(proxyUrl) && !HttpUtil.isHttp(proxyUrl) && !HttpUtil.isHttps(proxyUrl)) { - throw new BusinessException("openai.proxy-url 配置错误"); - } - OpenAIModelEnums modelEnum = OpenAIModelEnums.of(modelName); - if (modelEnum == null) { - throw new BusinessException("openai.model.name 配置错误"); - } - Integer rpm = modelEnum.getRPM(); - EXECUTOR = new ThreadPoolExecutor(10, 10, - 0L, TimeUnit.MILLISECONDS, - new LinkedBlockingQueue<>(rpm), - new NamedThreadFactory("openAI-chat-gpt", - null, - false, - new GlobalUncaughtExceptionHandler()), - (r, executor) -> { - throw new BusinessException("别问的太快了,我的脑子不够用了"); - }); - UserInfoResp userInfo = userService.getUserInfo(AI_USER_ID); - if (userInfo == null) { - throw new BusinessException("openai.ai-user-id 配置错误"); - } - MALL_CHAT_AI_NAME = userInfo.getName(); - } - - @Override - public void destroy() throws Exception { - EXECUTOR.shutdown(); - if (!EXECUTOR.awaitTermination(30, TimeUnit.SECONDS)) { //最多等30秒,处理不完就拉倒 - if (log.isErrorEnabled()) { - log.error("Timed out while waiting for executor [{}] to terminate", EXECUTOR); - } - } - } -} \ No newline at end of file From a944d1df14109ac9329cfc6bad61d6813b7a0fd2 Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Sat, 1 Jul 2023 15:24:52 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E4=BC=98=E5=8C=96chatAI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/main/resources/application.yml | 10 +++--- .../chatai/handler/ChatGLM2Handler.java | 5 +-- .../chatai/handler/GPTChatAIHandler.java | 32 +++++++++++++------ .../chatai/properties/ChatGLM2Properties.java | 7 +++- .../custom/chatai/utils/ChatGLM2Utils.java | 3 +- .../custom/chatai/utils/ChatGPTUtils.java | 2 ++ 6 files changed, 39 insertions(+), 20 deletions(-) diff --git a/mallchat-common/src/main/resources/application.yml b/mallchat-common/src/main/resources/application.yml index e0f925e..2df7f5b 100644 --- a/mallchat-common/src/main/resources/application.yml +++ b/mallchat-common/src/main/resources/application.yml @@ -65,11 +65,11 @@ wx: aesKey: ${mallchat.wx.aesKey} # 接口配置里的EncodingAESKey值 chatai: chatgpt: - use: true - AIUserId: 10450 - key: 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 + use: false + AIUserId: 10452 + key: xxxxx chatglm2: - use: true - url: http://vastmiao.natapp1.cc + use: false + url: xxxxx minute: 3 # 每个用户每3分钟可以请求一次 AIUserId: 10451 diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java index ca529f3..13aecfd 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java @@ -28,7 +28,7 @@ public class ChatGLM2Handler extends AbstractChatAIHandler { "没给钱,矿工了。。。。", "服务器被你们玩儿坏了。。。。", "你们这群人,我都不想理你们了。。。。", - "还艾特我呢?那是另外的价钱。。。。", + "艾特我那是另外的价钱。。。。", "得加钱"); @@ -67,6 +67,7 @@ public class ChatGLM2Handler extends AbstractChatAIHandler { .create() .url(glm2Properties.getUrl()) .prompt(content) + .timeout(glm2Properties.getTimeout()) .send(); } catch (Exception e) { e.printStackTrace(); @@ -123,7 +124,7 @@ public class ChatGLM2Handler extends AbstractChatAIHandler { // if (CollectionUtils.isEmpty(extra.getAtUidList())) { // return false; // } -// if (!extra.getAtUidList().contains(ChatAIServiceImpl.AI_USER_ID)) { +// if (!extra.getAtUidList().contains(glm2Properties.getAIUserId())) { // return false; // } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java index e981fa7..f814713 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java @@ -36,28 +36,40 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { @Override protected String doChat(Message message) { - String content = message.getContent().replace("@" +chatGPTProperties.getAIUserName(), "").trim(); + String content = message.getContent().replace("@" + chatGPTProperties.getAIUserName(), "").trim(); Long uid = message.getFromUid(); Long chatNum; String text; - if ((chatNum = userChatNumInrc(uid)) > chatGPTProperties.getLimit()) { + if ((chatNum = getUserChatNum(uid)) > chatGPTProperties.getLimit()) { text = "你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧"; } else { - HttpResponse response = ChatGPTUtils.create(chatGPTProperties.getKey()) - .proxyUrl(chatGPTProperties.getProxyUrl()) - .model(chatGPTProperties.getModelName()) - .prompt(content) - .send(); - text = ChatGPTUtils.parseText(response); + HttpResponse response = null; + try { + response = ChatGPTUtils.create(chatGPTProperties.getKey()) + .proxyUrl(chatGPTProperties.getProxyUrl()) + .model(chatGPTProperties.getModelName()) + .prompt(content) + .send(); + text = ChatGPTUtils.parseText(response); + userChatNumInrc(uid); + } catch (Exception e) { + e.printStackTrace(); + text = "我累了,明天再聊吧"; + } } return text; } private Long userChatNumInrc(Long uid) { - //todo:白名单 return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS); } + private Long getUserChatNum(Long uid) { + Long num = RedisUtils.get(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), Long.class); + return num == null ? 0 : num; + + } + @Override protected boolean supports(Message message) { @@ -73,7 +85,7 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { // if (CollectionUtils.isEmpty(extra.getAtUidList())) { // return false; // } -// if (!extra.getAtUidList().contains(ChatAIServiceImpl.AI_USER_ID)) { +// if (!extra.getAtUidList().contains(chatGPTProperties.getAIUserId())) { // return false; // } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java index 582ce4a..2ccc99b 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/properties/ChatGLM2Properties.java @@ -36,8 +36,13 @@ public class ChatGLM2Properties { private String AIUserName; /** - * 每个用户每3分钟可以请求一次 + * 每个用户每?分钟可以请求一次 */ private Long minute = 3L; + /** + * 超时 + */ + private Integer timeout = 60*1000; + } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGLM2Utils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGLM2Utils.java index 2729e0a..7ac9e1a 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGLM2Utils.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGLM2Utils.java @@ -42,8 +42,6 @@ public class ChatGLM2Utils { } - - public ChatGLM2Utils url(String url) { this.url = url; return this; @@ -59,6 +57,7 @@ public class ChatGLM2Utils { this.prompt = prompt; return this; } + public HttpResponse send() { JSONObject param = new JSONObject(); param.set("prompt", prompt); diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java index c846492..094cd39 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/utils/ChatGPTUtils.java @@ -141,6 +141,8 @@ public class ChatGPTUtils { param.set("top_p", topP); param.set("frequency_penalty", frequencyPenalty); param.set("presence_penalty", presencePenalty); + log.info("headers >>> " + headers); + log.info("param >>> " + param); return HttpUtil.createPost(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL) .addHeaders(headers) .body(param.toString()) From 824d3b78c745b3f176cfff7d0faa8962961ee3ee Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Sat, 1 Jul 2023 21:13:49 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E6=94=B9=E4=B8=BA=E4=BD=BF=E7=94=A8id?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E6=9C=BA=E5=99=A8=E4=BA=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/main/resources/application.yml | 4 +-- .../chatai/handler/ChatGLM2Handler.java | 25 ++++++++++--------- .../service/impl/ChatAIServiceImpl.java | 4 +-- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/mallchat-common/src/main/resources/application.yml b/mallchat-common/src/main/resources/application.yml index 2df7f5b..6f13a61 100644 --- a/mallchat-common/src/main/resources/application.yml +++ b/mallchat-common/src/main/resources/application.yml @@ -65,11 +65,11 @@ wx: aesKey: ${mallchat.wx.aesKey} # 接口配置里的EncodingAESKey值 chatai: chatgpt: - use: false + use: true AIUserId: 10452 key: xxxxx chatglm2: - use: false + use: true url: xxxxx minute: 3 # 每个用户每3分钟可以请求一次 AIUserId: 10451 diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java index 13aecfd..11350ca 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java @@ -2,6 +2,7 @@ package com.abin.mallchat.custom.chatai.handler; import cn.hutool.http.HttpResponse; import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra; import com.abin.mallchat.common.common.constant.RedisKey; import com.abin.mallchat.common.common.utils.RedisUtils; import com.abin.mallchat.custom.chatai.properties.ChatGLM2Properties; @@ -10,6 +11,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; import java.util.Arrays; import java.util.Date; @@ -69,11 +71,11 @@ public class ChatGLM2Handler extends AbstractChatAIHandler { .prompt(content) .timeout(glm2Properties.getTimeout()) .send(); + text = ChatGLM2Utils.parseText(response); } catch (Exception e) { e.printStackTrace(); return getErrorText(); } - text = ChatGLM2Utils.parseText(response); if (StringUtils.isNotBlank(text)) { RedisUtils.set(RedisKey.getKey(USER_GLM2_TIME_LAST, uid), new Date(), glm2Properties.getMinute(), TimeUnit.MINUTES); } @@ -116,17 +118,16 @@ public class ChatGLM2Handler extends AbstractChatAIHandler { return false; } /* 前端传@信息后取消注释 */ - -// MessageExtra extra = message.getExtra(); -// if (extra == null) { -// return false; -// } -// if (CollectionUtils.isEmpty(extra.getAtUidList())) { -// return false; -// } -// if (!extra.getAtUidList().contains(glm2Properties.getAIUserId())) { -// return false; -// } + MessageExtra extra = message.getExtra(); + if (extra == null) { + return false; + } + if (CollectionUtils.isEmpty(extra.getAtUidList())) { + return false; + } + if (!extra.getAtUidList().contains(glm2Properties.getAIUserId())) { + return false; + } if (StringUtils.isBlank(message.getContent())) { return false; diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/impl/ChatAIServiceImpl.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/impl/ChatAIServiceImpl.java index 37af88e..885c227 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/impl/ChatAIServiceImpl.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/service/impl/ChatAIServiceImpl.java @@ -17,8 +17,8 @@ public class ChatAIServiceImpl implements IChatAIService { if (extra == null) { return; } - AbstractChatAIHandler chatAI = ChatAIHandlerFactory.getChatAIHandlerByName(message.getContent()); -// AbstractChatAIHandler chatAI = ChatAIHandlerFactory.getChatAIHandlerById(extra.getAtUidList()); +// AbstractChatAIHandler chatAI = ChatAIHandlerFactory.getChatAIHandlerByName(message.getContent()); + AbstractChatAIHandler chatAI = ChatAIHandlerFactory.getChatAIHandlerById(extra.getAtUidList()); if (chatAI != null) { chatAI.chat(message); }