编程式限流

This commit is contained in:
linzhihan
2023-07-06 21:54:04 +08:00
parent 89d81fc1ed
commit 58daaef47c
12 changed files with 506 additions and 75 deletions

View File

@@ -0,0 +1,21 @@
package com.abin.mallchat.custom.chatai.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class GPTRequestDTO {
/**
* 聊天内容
*/
private String content;
/**
* 用户Id
*/
private Long uid;
}

View File

@@ -4,12 +4,17 @@ import cn.hutool.http.HttpResponse;
import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
import com.abin.mallchat.common.common.constant.RedisKey;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.exception.FrequencyControlException;
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
import com.abin.mallchat.common.common.utils.RedisUtils;
import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO;
import com.abin.mallchat.custom.chatai.properties.ChatGLM2Properties;
import com.abin.mallchat.custom.chatai.utils.ChatGLM2Utils;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.Nullable;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@@ -21,10 +26,15 @@ import java.util.Random;
import java.util.concurrent.TimeUnit;
import static com.abin.mallchat.common.common.constant.RedisKey.USER_GLM2_TIME_LAST;
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
@Slf4j
@Component
public class ChatGLM2Handler extends AbstractChatAIHandler {
/**
* ChatGLM2Handler限流前缀
*/
private static final String CHAT_GLM2_FREQUENCY_PREFIX = "ChatGLM2Handler";
private static final List<String> ERROR_MSG = Arrays.asList(
"还摸鱼呢?你不下班我还要下班呢。。。。",
@@ -74,27 +84,42 @@ public class ChatGLM2Handler extends AbstractChatAIHandler {
protected String doChat(Message message) {
String content = message.getContent().replace("@" + AI_NAME, "").trim();
Long uid = message.getFromUid();
Long minute;
try {
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
frequencyControlDTO.setKey(CHAT_GLM2_FREQUENCY_PREFIX + ":" + uid);
frequencyControlDTO.setUnit(TimeUnit.MINUTES);
frequencyControlDTO.setCount(1);
frequencyControlDTO.setTime(glm2Properties.getMinute().intValue());
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid)));
} catch (FrequencyControlException e) {
return "你太快了亲爱的~过一会再来找人家~";
} catch (Throwable e) {
return "系统开小差啦~~";
}
}
/**
* TODO
*
* @param gptRequestDTO
* @return
*/
@Nullable
private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) {
String content = gptRequestDTO.getContent();
String text;
if ((minute = userMinutesLater(uid)) > 0) {
text = "你太快了 " + minute + "分钟后重试";
} else {
HttpResponse response = null;
try {
response = ChatGLM2Utils
.create()
.url(glm2Properties.getUrl())
.prompt(content)
.timeout(glm2Properties.getTimeout())
.send();
text = ChatGLM2Utils.parseText(response);
} catch (Exception e) {
log.warn("glm2 doChat warn:", e);
return getErrorText();
}
if (StringUtils.isNotBlank(text)) {
RedisUtils.set(RedisKey.getKey(USER_GLM2_TIME_LAST, uid), new Date(), glm2Properties.getMinute(), TimeUnit.MINUTES);
}
HttpResponse response = null;
try {
response = ChatGLM2Utils
.create()
.url(glm2Properties.getUrl())
.prompt(content)
.timeout(glm2Properties.getTimeout())
.send();
text = ChatGLM2Utils.parseText(response);
} catch (Exception e) {
log.warn("glm2 doChat warn:", e);
return getErrorText();
}
return text;
}

View File

@@ -3,9 +3,10 @@ package com.abin.mallchat.custom.chatai.handler;
import cn.hutool.http.HttpResponse;
import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
import com.abin.mallchat.common.common.constant.RedisKey;
import com.abin.mallchat.common.common.utils.DateUtils;
import com.abin.mallchat.common.common.utils.RedisUtils;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.exception.FrequencyControlException;
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO;
import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties;
import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
@@ -17,9 +18,15 @@ import org.springframework.util.CollectionUtils;
import java.util.concurrent.TimeUnit;
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
@Slf4j
@Component
public class GPTChatAIHandler extends AbstractChatAIHandler {
/**
* GPTChatAIHandler限流前缀
*/
private static final String CHAT_FREQUENCY_PREFIX = "GPTChatAIHandler";
@Autowired
private ChatGPTProperties chatGPTProperties;
@@ -51,43 +58,44 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
return chatGPTProperties.getAIUserId();
}
@Override
protected String doChat(Message message) {
String content = message.getContent().replace("@" + AI_NAME, "").trim();
Long uid = message.getFromUid();
Long chatNum;
try {
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
frequencyControlDTO.setKey(CHAT_FREQUENCY_PREFIX + ":" + uid);
frequencyControlDTO.setUnit(TimeUnit.HOURS);
frequencyControlDTO.setCount(chatGPTProperties.getLimit());
frequencyControlDTO.setTime(24);
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid)));
} catch (FrequencyControlException e) {
return "亲爱的,你今天找我聊了" + chatGPTProperties.getLimit() + "次了~人家累了~明天见";
} catch (Throwable e) {
return "系统开小差啦~~";
}
}
private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) {
String content = gptRequestDTO.getContent();
String text;
if ((chatNum = getUserChatNum(uid)) > chatGPTProperties.getLimit()) {
text = "你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧";
} else {
HttpResponse response = null;
try {
response = ChatGPTUtils.create(chatGPTProperties.getKey())
.proxyUrl(chatGPTProperties.getProxyUrl())
.model(chatGPTProperties.getModelName())
.timeout(chatGPTProperties.getTimeout())
.prompt(content)
.send();
text = ChatGPTUtils.parseText(response);
userChatNumInrc(uid);
} catch (Exception e) {
log.warn("gpt doChat warn:", e);
text = "我累了,明天再聊吧";
}
HttpResponse response = null;
try {
response = ChatGPTUtils.create(chatGPTProperties.getKey())
.proxyUrl(chatGPTProperties.getProxyUrl())
.model(chatGPTProperties.getModelName())
.timeout(chatGPTProperties.getTimeout())
.prompt(content)
.send();
text = ChatGPTUtils.parseText(response);
} catch (Exception e) {
log.warn("gpt doChat warn:", e);
text= "我累了,明天再聊吧";
}
return text;
}
private Long userChatNumInrc(Long uid) {
return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS);
}
private Long getUserChatNum(Long uid) {
Long num = RedisUtils.get(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), Long.class);
return num == null ? 0 : num;
}
@Override

View File

@@ -8,6 +8,7 @@ import io.netty.handler.codec.http.HttpHeaders;
import org.apache.commons.lang3.StringUtils;
import java.net.InetSocketAddress;
import java.util.Optional;
public class HttpHeadersHandler extends ChannelInboundHandlerAdapter {
@@ -18,7 +19,7 @@ public class HttpHeadersHandler extends ChannelInboundHandlerAdapter {
UrlBuilder urlBuilder = UrlBuilder.ofHttp(request.uri());
// 获取token参数
String token = urlBuilder.getQuery().get("token").toString();
String token = Optional.ofNullable(urlBuilder.getQuery()).map(k->k.get("token")).map(CharSequence::toString).orElse("");
NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token);
// 获取请求路径