diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/aspect/FrequencyControlAspect.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/aspect/FrequencyControlAspect.java index 25418a6..f0290ab 100644 --- a/mallchat-common/src/main/java/com/abin/mallchat/common/common/aspect/FrequencyControlAspect.java +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/aspect/FrequencyControlAspect.java @@ -2,9 +2,8 @@ package com.abin.mallchat.common.common.aspect; import cn.hutool.core.util.StrUtil; import com.abin.mallchat.common.common.annotation.FrequencyControl; -import com.abin.mallchat.common.common.exception.BusinessException; -import com.abin.mallchat.common.common.exception.CommonErrorEnum; -import com.abin.mallchat.common.common.utils.RedisUtils; +import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; +import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil; import com.abin.mallchat.common.common.utils.RequestHolder; import com.abin.mallchat.common.common.utils.SpElUtils; import lombok.extern.slf4j.Slf4j; @@ -15,7 +14,12 @@ import org.aspectj.lang.reflect.MethodSignature; import org.springframework.stereotype.Component; import java.lang.reflect.Method; -import java.util.*; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER; /** * Description: 频控实现 @@ -48,25 +52,25 @@ public class FrequencyControlAspect { } keyMap.put(prefix + ":" + key, frequencyControl); } - //批量获取redis统计的值 - ArrayList keyList = new ArrayList<>(keyMap.keySet()); - List countList = RedisUtils.mget(keyList, Integer.class); - for (int i = 0; i < keyList.size(); i++) { - String key = keyList.get(i); - Integer count = countList.get(i); - FrequencyControl frequencyControl = keyMap.get(key); - if (Objects.nonNull(count) && count >= frequencyControl.count()) {//频率超过了 - log.warn("frequencyControl limit key:{},count:{}", key, count); - throw new BusinessException(CommonErrorEnum.FREQUENCY_LIMIT); - } - } - try { - return joinPoint.proceed(); - } finally { - //不管成功还是失败,都增加次数 - keyMap.forEach((k, v) -> { - RedisUtils.inc(k, v.time(), v.unit()); - }); - } + // 将注解的参数转换为编程式调用需要的参数 + List frequencyControlDTOS = keyMap.entrySet().stream().map(entrySet -> buildFrequencyControlDTO(entrySet.getKey(), entrySet.getValue())).collect(Collectors.toList()); + // 调用编程式注解 + return FrequencyControlUtil.executeWithFrequencyControlList(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTOS, joinPoint::proceed); + } + + /** + * 将注解参数转换为编程式调用所需要的参数 + * + * @param key 频率控制Key + * @param frequencyControl 注解 + * @return 编程式调用所需要的参数-FrequencyControlDTO + */ + private FrequencyControlDTO buildFrequencyControlDTO(String key, FrequencyControl frequencyControl) { + FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO(); + frequencyControlDTO.setCount(frequencyControl.count()); + frequencyControlDTO.setTime(frequencyControl.time()); + frequencyControlDTO.setUnit(frequencyControl.unit()); + frequencyControlDTO.setKey(key); + return frequencyControlDTO; } } diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/domain/dto/FrequencyControlDTO.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/domain/dto/FrequencyControlDTO.java new file mode 100644 index 0000000..3399773 --- /dev/null +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/domain/dto/FrequencyControlDTO.java @@ -0,0 +1,42 @@ +package com.abin.mallchat.common.common.domain.dto; + +import lombok.*; + +import java.util.concurrent.TimeUnit; + +@Data +@ToString +@Builder +@NoArgsConstructor +@AllArgsConstructor +/** 限流策略定义 + * @author linzhihan + * @date 2023/07/03 + * + */ +public class FrequencyControlDTO { + /** + * 代表频控的Key 如果target为Key的话 这里要传值用于构建redis的Key target为Ip或者UID的话会从上下文取值 Key字段无需传值 + */ + private String key; + /** + * 频控时间范围,默认单位秒 + * + * @return 时间范围 + */ + private Integer time; + + /** + * 频控时间单位,默认秒 + * + * @return 单位 + */ + private TimeUnit unit; + + /** + * 单位时间内最大访问次数 + * + * @return 次数 + */ + private Integer count; +} diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/exception/FrequencyControlException.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/exception/FrequencyControlException.java new file mode 100644 index 0000000..fbcaff7 --- /dev/null +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/exception/FrequencyControlException.java @@ -0,0 +1,39 @@ +package com.abin.mallchat.common.common.exception; + +import lombok.Data; + +/** + * 自定义限流异常 + * + * @author linzhihan + * @date 2023/07/034 + */ +@Data +public class FrequencyControlException extends RuntimeException { + private static final long serialVersionUID = 1L; + + /** + *  错误码 + */ + protected Integer errorCode; + + /** + *  错误信息 + */ + protected String errorMsg; + + public FrequencyControlException() { + super(); + } + + public FrequencyControlException(String errorMsg) { + super(errorMsg); + this.errorMsg = errorMsg; + } + + public FrequencyControlException(ErrorEnum error) { + super(error.getErrorMsg()); + this.errorCode = error.getErrorCode(); + this.errorMsg = error.getErrorMsg(); + } +} diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/exception/GlobalExceptionHandler.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/exception/GlobalExceptionHandler.java index a5a06eb..066e210 100644 --- a/mallchat-common/src/main/java/com/abin/mallchat/common/common/exception/GlobalExceptionHandler.java +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/exception/GlobalExceptionHandler.java @@ -73,4 +73,12 @@ public class GlobalExceptionHandler { return ApiResult.fail(-1, String.format("不支持'%s'请求", e.getMethod())); } + /** + * 限流异常 + */ + @ExceptionHandler(value = FrequencyControlException.class) + public ApiResult frequencyControlExceptionHandler(FrequencyControlException e) { + log.info("frequencyControl exception!The reason is:{}", e.getMessage(), e); + return ApiResult.fail(e.getErrorCode(), e.getMessage()); + } } diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/AbstractFrequencyControlService.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/AbstractFrequencyControlService.java new file mode 100644 index 0000000..5b884c5 --- /dev/null +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/AbstractFrequencyControlService.java @@ -0,0 +1,113 @@ +package com.abin.mallchat.common.common.service.frequencycontrol; + +import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; +import com.abin.mallchat.common.common.exception.CommonErrorEnum; +import com.abin.mallchat.common.common.exception.FrequencyControlException; +import com.abin.mallchat.common.common.utils.AssertUtil; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ObjectUtils; + +import javax.annotation.PostConstruct; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * 抽象类频控服务 其他类如果要实现限流服务 直接注入使用通用限流类 + * 后期会通过继承此类实现令牌桶等算法 + * + * @author linzhihan + * @date 2023/07/03 + * @see TotalCountWithInFixTimeFrequencyController 通用限流类 + */ +@Slf4j +public abstract class AbstractFrequencyControlService { + + @PostConstruct + protected void registerMyselfToFactory() { + FrequencyControlStrategyFactory.registerFrequencyController(getStrategyName(), this); + } + + /** + * @param frequencyControlMap 定义的注解频控 Map中的Key-对应redis的单个频控的Key Map中的Value-对应redis的单个频控的Key限制的Value + * @param supplier 函数式入参-代表每个频控方法执行的不同的业务逻辑 + * @return 业务方法执行的返回值 + * @throws Throwable + */ + private T executeWithFrequencyControlMap(Map frequencyControlMap, SupplierThrowWithoutParam supplier) throws Throwable { + if (reachRateLimit(frequencyControlMap)) { + throw new FrequencyControlException(CommonErrorEnum.FREQUENCY_LIMIT); + } + try { + return supplier.get(); + } finally { + //不管成功还是失败,都增加次数 + addFrequencyControlStatisticsCount(frequencyControlMap); + } + } + + + /** + * 多限流策略的编程式调用方法 无参的调用方法 + * + * @param frequencyControlList 频控列表 包含每一个频率控制的定义以及顺序 + * @param supplier 函数式入参-代表每个频控方法执行的不同的业务逻辑 + * @return 业务方法执行的返回值 + * @throws Throwable 被限流或者限流策略定义错误 + */ + @SuppressWarnings("unchecked") + public T executeWithFrequencyControlList(List frequencyControlList, SupplierThrowWithoutParam supplier) throws Throwable { + boolean existsFrequencyControlHasNullKey = frequencyControlList.stream().anyMatch(frequencyControl -> ObjectUtils.isEmpty(frequencyControl.getKey())); + AssertUtil.isFalse(existsFrequencyControlHasNullKey, "限流策略的Key字段不允许出现空值"); + Map frequencyControlDTOMap = frequencyControlList.stream().collect(Collectors.groupingBy(FrequencyControlDTO::getKey, Collectors.collectingAndThen(Collectors.toList(), list -> list.get(0)))); + return executeWithFrequencyControlMap((Map) frequencyControlDTOMap, supplier); + } + + /** + * 单限流策略的调用方法-编程式调用 + * + * @param frequencyControl 单个频控对象 + * @param supplier 服务提供着 + * @return 业务方法执行结果 + * @throws Throwable + */ + public T executeWithFrequencyControl(K frequencyControl, SupplierThrowWithoutParam supplier) throws Throwable { + return executeWithFrequencyControlList(Collections.singletonList(frequencyControl), supplier); + } + + + @FunctionalInterface + public interface SupplierThrowWithoutParam { + + /** + * Gets a result. + * + * @return a result + */ + T get() throws Throwable; + } + + /** + * 是否达到限流阈值 子类实现 每个子类都可以自定义自己的限流逻辑判断 + * + * @param frequencyControlMap 定义的注解频控 Map中的Key-对应redis的单个频控的Key Map中的Value-对应redis的单个频控的Key限制的Value + * @return true-方法被限流 false-方法没有被限流 + */ + protected abstract boolean reachRateLimit(Map frequencyControlMap); + + /** + * 增加限流统计次数 子类实现 每个子类都可以自定义自己的限流统计信息增加的逻辑 + * + * @param frequencyControlMap 定义的注解频控 Map中的Key-对应redis的单个频控的Key Map中的Value-对应redis的单个频控的Key限制的Value + */ + protected abstract void addFrequencyControlStatisticsCount(Map frequencyControlMap); + + /** + * 获取策略名称 + * + * @return 策略名称 + */ + protected abstract String getStrategyName(); + +} diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/FrequencyControlStrategyFactory.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/FrequencyControlStrategyFactory.java new file mode 100644 index 0000000..ec211a8 --- /dev/null +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/FrequencyControlStrategyFactory.java @@ -0,0 +1,51 @@ +package com.abin.mallchat.common.common.service.frequencycontrol; + +import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * 限流策略工厂 + * + * @author linzhihan + * @date 2023/07/03 + */ +public class FrequencyControlStrategyFactory { + /** + * 指定时间内总次数限流 + */ + public static final String TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER = "TotalCountWithInFixTime"; + /** + * 限流策略集合 + */ + static Map> frequencyControlServiceStrategyMap = new ConcurrentHashMap<>(8); + + /** + * 将策略类放入工厂 + * + * @param strategyName 策略名称 + * @param abstractFrequencyControlService 策略类 + */ + public static void registerFrequencyController(String strategyName, AbstractFrequencyControlService abstractFrequencyControlService) { + frequencyControlServiceStrategyMap.put(strategyName, abstractFrequencyControlService); + } + + /** + * 根据名称获取策略类 + * + * @param strategyName 策略名称 + * @return 对应的限流策略类 + */ + @SuppressWarnings("unchecked") + public static AbstractFrequencyControlService getFrequencyControllerByName(String strategyName) { + return (AbstractFrequencyControlService) frequencyControlServiceStrategyMap.get(strategyName); + } + + /** + * 构造器私有 + */ + private FrequencyControlStrategyFactory() { + + } +} diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/FrequencyControlUtil.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/FrequencyControlUtil.java new file mode 100644 index 0000000..aa9a30e --- /dev/null +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/FrequencyControlUtil.java @@ -0,0 +1,55 @@ +package com.abin.mallchat.common.common.service.frequencycontrol; + +import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; +import com.abin.mallchat.common.common.utils.AssertUtil; +import org.apache.commons.lang3.ObjectUtils; + +import java.util.List; + +/** + * 限流工具类 提供编程式的限流调用方法 + * + * @author linzhihan + * @date 2023/07/03 + */ +public class FrequencyControlUtil { + + /** + * 单限流策略的调用方法-编程式调用 + * + * @param strategyName 策略名称 + * @param frequencyControl 单个频控对象 + * @param supplier 服务提供着 + * @return 业务方法执行结果 + * @throws Throwable + */ + public static T executeWithFrequencyControl(String strategyName, K frequencyControl, AbstractFrequencyControlService.SupplierThrowWithoutParam supplier) throws Throwable { + AbstractFrequencyControlService frequencyController = FrequencyControlStrategyFactory.getFrequencyControllerByName(strategyName); + return frequencyController.executeWithFrequencyControl(frequencyControl, supplier); + } + + + /** + * 多限流策略的编程式调用方法调用方法 + * + * @param strategyName 策略名称 + * @param frequencyControlList 频控列表 包含每一个频率控制的定义以及顺序 + * @param supplier 函数式入参-代表每个频控方法执行的不同的业务逻辑 + * @return 业务方法执行的返回值 + * @throws Throwable 被限流或者限流策略定义错误 + */ + public static T executeWithFrequencyControlList(String strategyName, List frequencyControlList, AbstractFrequencyControlService.SupplierThrowWithoutParam supplier) throws Throwable { + boolean existsFrequencyControlHasNullKey = frequencyControlList.stream().anyMatch(frequencyControl -> ObjectUtils.isEmpty(frequencyControl.getKey())); + AssertUtil.isFalse(existsFrequencyControlHasNullKey, "限流策略的Key字段不允许出现空值"); + AbstractFrequencyControlService frequencyController = FrequencyControlStrategyFactory.getFrequencyControllerByName(strategyName); + return frequencyController.executeWithFrequencyControlList(frequencyControlList, supplier); + } + + /** + * 构造器私有 + */ + private FrequencyControlUtil() { + + } + +} diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/TotalCountWithInFixTimeFrequencyController.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/TotalCountWithInFixTimeFrequencyController.java new file mode 100644 index 0000000..531e36c --- /dev/null +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/service/frequencycontrol/TotalCountWithInFixTimeFrequencyController.java @@ -0,0 +1,64 @@ +package com.abin.mallchat.common.common.service.frequencycontrol; + +import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; +import com.abin.mallchat.common.common.utils.RedisUtils; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER; + +/** + * 抽象类频控服务 -使用redis实现 固定时间内不超过固定次数的限流类 + * + * @author linzhihan + * @date 2023/07/03 + */ +@Slf4j +@Service +public class TotalCountWithInFixTimeFrequencyController extends AbstractFrequencyControlService { + + + /** + * 是否达到限流阈值 子类实现 每个子类都可以自定义自己的限流逻辑判断 + * + * @param frequencyControlMap 定义的注解频控 Map中的Key-对应redis的单个频控的Key Map中的Value-对应redis的单个频控的Key限制的Value + * @return true-方法被限流 false-方法没有被限流 + */ + @Override + protected boolean reachRateLimit(Map frequencyControlMap) { + //批量获取redis统计的值 + List frequencyKeys = new ArrayList<>(frequencyControlMap.keySet()); + List countList = RedisUtils.mget(frequencyKeys, Integer.class); + for (int i = 0; i < frequencyKeys.size(); i++) { + String key = frequencyKeys.get(i); + Integer count = countList.get(i); + int frequencyControlCount = frequencyControlMap.get(key).getCount(); + if (Objects.nonNull(count) && count >= frequencyControlCount) { + //频率超过了 + log.warn("frequencyControl limit key:{},count:{}", key, count); + return true; + } + } + return false; + } + + /** + * 增加限流统计次数 子类实现 每个子类都可以自定义自己的限流统计信息增加的逻辑 + * + * @param frequencyControlMap 定义的注解频控 Map中的Key-对应redis的单个频控的Key Map中的Value-对应redis的单个频控的Key限制的Value + */ + @Override + protected void addFrequencyControlStatisticsCount(Map frequencyControlMap) { + frequencyControlMap.forEach((k, v) -> RedisUtils.inc(k, v.getTime(), v.getUnit())); + } + + @Override + protected String getStrategyName() { + return TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER; + } +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/dto/GPTRequestDTO.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/dto/GPTRequestDTO.java new file mode 100644 index 0000000..cb72432 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/dto/GPTRequestDTO.java @@ -0,0 +1,21 @@ +package com.abin.mallchat.custom.chatai.dto; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class GPTRequestDTO { + /** + * 聊天内容 + */ + private String content; + /** + * 用户Id + */ + private Long uid; +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java index 0f7ff1b..83f665a 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/ChatGLM2Handler.java @@ -4,12 +4,17 @@ import cn.hutool.http.HttpResponse; import com.abin.mallchat.common.chat.domain.entity.Message; import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra; import com.abin.mallchat.common.common.constant.RedisKey; +import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; +import com.abin.mallchat.common.common.exception.FrequencyControlException; +import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil; import com.abin.mallchat.common.common.utils.RedisUtils; +import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO; import com.abin.mallchat.custom.chatai.properties.ChatGLM2Properties; import com.abin.mallchat.custom.chatai.utils.ChatGLM2Utils; import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import org.jetbrains.annotations.Nullable; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; @@ -21,10 +26,15 @@ import java.util.Random; import java.util.concurrent.TimeUnit; import static com.abin.mallchat.common.common.constant.RedisKey.USER_GLM2_TIME_LAST; +import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER; @Slf4j @Component public class ChatGLM2Handler extends AbstractChatAIHandler { + /** + * ChatGLM2Handler限流前缀 + */ + private static final String CHAT_GLM2_FREQUENCY_PREFIX = "ChatGLM2Handler"; private static final List ERROR_MSG = Arrays.asList( "还摸鱼呢?你不下班我还要下班呢。。。。", @@ -74,27 +84,42 @@ public class ChatGLM2Handler extends AbstractChatAIHandler { protected String doChat(Message message) { String content = message.getContent().replace("@" + AI_NAME, "").trim(); Long uid = message.getFromUid(); - Long minute; + try { + FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO(); + frequencyControlDTO.setKey(CHAT_GLM2_FREQUENCY_PREFIX + ":" + uid); + frequencyControlDTO.setUnit(TimeUnit.MINUTES); + frequencyControlDTO.setCount(1); + frequencyControlDTO.setTime(glm2Properties.getMinute().intValue()); + return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid))); + } catch (FrequencyControlException e) { + return "你太快了亲爱的~过一会再来找人家~"; + } catch (Throwable e) { + return "系统开小差啦~~"; + } + } + + /** + * TODO + * + * @param gptRequestDTO + * @return + */ + @Nullable + private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) { + String content = gptRequestDTO.getContent(); String text; - if ((minute = userMinutesLater(uid)) > 0) { - text = "你太快了 " + minute + "分钟后重试"; - } else { - HttpResponse response = null; - try { - response = ChatGLM2Utils - .create() - .url(glm2Properties.getUrl()) - .prompt(content) - .timeout(glm2Properties.getTimeout()) - .send(); - text = ChatGLM2Utils.parseText(response); - } catch (Exception e) { - log.warn("glm2 doChat warn:", e); - return getErrorText(); - } - if (StringUtils.isNotBlank(text)) { - RedisUtils.set(RedisKey.getKey(USER_GLM2_TIME_LAST, uid), new Date(), glm2Properties.getMinute(), TimeUnit.MINUTES); - } + HttpResponse response = null; + try { + response = ChatGLM2Utils + .create() + .url(glm2Properties.getUrl()) + .prompt(content) + .timeout(glm2Properties.getTimeout()) + .send(); + text = ChatGLM2Utils.parseText(response); + } catch (Exception e) { + log.warn("glm2 doChat warn:", e); + return getErrorText(); } return text; } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java index 440c1d9..fd8a594 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java @@ -3,9 +3,10 @@ package com.abin.mallchat.custom.chatai.handler; import cn.hutool.http.HttpResponse; import com.abin.mallchat.common.chat.domain.entity.Message; import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra; -import com.abin.mallchat.common.common.constant.RedisKey; -import com.abin.mallchat.common.common.utils.DateUtils; -import com.abin.mallchat.common.common.utils.RedisUtils; +import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO; +import com.abin.mallchat.common.common.exception.FrequencyControlException; +import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil; +import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO; import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties; import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils; import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; @@ -17,9 +18,15 @@ import org.springframework.util.CollectionUtils; import java.util.concurrent.TimeUnit; +import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER; + @Slf4j @Component public class GPTChatAIHandler extends AbstractChatAIHandler { + /** + * GPTChatAIHandler限流前缀 + */ + private static final String CHAT_FREQUENCY_PREFIX = "GPTChatAIHandler"; @Autowired private ChatGPTProperties chatGPTProperties; @@ -51,43 +58,44 @@ public class GPTChatAIHandler extends AbstractChatAIHandler { return chatGPTProperties.getAIUserId(); } - @Override protected String doChat(Message message) { String content = message.getContent().replace("@" + AI_NAME, "").trim(); Long uid = message.getFromUid(); - Long chatNum; + try { + FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO(); + frequencyControlDTO.setKey(CHAT_FREQUENCY_PREFIX + ":" + uid); + frequencyControlDTO.setUnit(TimeUnit.HOURS); + frequencyControlDTO.setCount(chatGPTProperties.getLimit()); + frequencyControlDTO.setTime(24); + return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid))); + } catch (FrequencyControlException e) { + return "亲爱的,你今天找我聊了" + chatGPTProperties.getLimit() + "次了~人家累了~明天见"; + } catch (Throwable e) { + return "系统开小差啦~~"; + } + } + + private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) { + String content = gptRequestDTO.getContent(); String text; - if ((chatNum = getUserChatNum(uid)) > chatGPTProperties.getLimit()) { - text = "你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧"; - } else { - HttpResponse response = null; - try { - response = ChatGPTUtils.create(chatGPTProperties.getKey()) - .proxyUrl(chatGPTProperties.getProxyUrl()) - .model(chatGPTProperties.getModelName()) - .timeout(chatGPTProperties.getTimeout()) - .prompt(content) - .send(); - text = ChatGPTUtils.parseText(response); - userChatNumInrc(uid); - } catch (Exception e) { - log.warn("gpt doChat warn:", e); - text = "我累了,明天再聊吧"; - } + HttpResponse response = null; + try { + response = ChatGPTUtils.create(chatGPTProperties.getKey()) + .proxyUrl(chatGPTProperties.getProxyUrl()) + .model(chatGPTProperties.getModelName()) + .timeout(chatGPTProperties.getTimeout()) + .prompt(content) + .send(); + text = ChatGPTUtils.parseText(response); + } catch (Exception e) { + log.warn("gpt doChat warn:", e); + text= "我累了,明天再聊吧"; } return text; } - private Long userChatNumInrc(Long uid) { - return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS); - } - private Long getUserChatNum(Long uid) { - Long num = RedisUtils.get(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), Long.class); - return num == null ? 0 : num; - - } @Override diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/service/impl/WebSocketServiceImpl.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/service/impl/WebSocketServiceImpl.java index 8515083..04efa6d 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/service/impl/WebSocketServiceImpl.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/service/impl/WebSocketServiceImpl.java @@ -2,7 +2,6 @@ package com.abin.mallchat.custom.user.service.impl; import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.util.ObjectUtil; -import cn.hutool.core.util.RandomUtil; import cn.hutool.json.JSONUtil; import com.abin.mallchat.common.common.annotation.FrequencyControl; import com.abin.mallchat.common.common.config.ThreadPoolConfig; @@ -20,6 +19,8 @@ import com.abin.mallchat.custom.user.service.LoginService; import com.abin.mallchat.custom.user.service.WebSocketService; import com.abin.mallchat.custom.user.service.adapter.WSAdapter; import com.abin.mallchat.custom.user.websocket.NettyUtil; +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; import io.netty.channel.Channel; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import lombok.SneakyThrows; @@ -32,13 +33,13 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.stereotype.Component; -import java.util.Date; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; + +import java.time.Duration; +import java.util.*; + import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; @@ -51,12 +52,18 @@ import java.util.concurrent.locks.ReentrantLock; @Slf4j public class WebSocketServiceImpl implements WebSocketService { + private static final Duration EXPIRE_TIME = Duration.ofHours(1); + private static final Long MAX_MUM_SIZE = 10000L; + + private static final AtomicInteger CODE = new AtomicInteger(); /** * 所有请求登录的code与channel关系 - * todo 有可能有人请求了二维码,就是不登录,留个坑,之后处理 */ - private static final ConcurrentHashMap WAIT_LOGIN_MAP = new ConcurrentHashMap<>(); + private static final Cache WAIT_LOGIN_MAP = Caffeine.newBuilder() + .expireAfterWrite(EXPIRE_TIME) + .maximumSize(MAX_MUM_SIZE) + .build(); /** * 所有已连接的websocket连接列表和一些额外参数 */ @@ -70,7 +77,6 @@ public class WebSocketServiceImpl implements WebSocketService { return ONLINE_WS_MAP; } - public static final int EXPIRE_SECONDS = 60 * 60; @Autowired private WxMpService wxMpService; @Autowired @@ -99,7 +105,7 @@ public class WebSocketServiceImpl implements WebSocketService { //生成随机不重复的登录码 Integer code = generateLoginCode(channel); //请求微信接口,获取登录码地址 - WxMpQrCodeTicket wxMpQrCodeTicket = wxMpService.getQrcodeService().qrCodeCreateTmpTicket(code, EXPIRE_SECONDS); + WxMpQrCodeTicket wxMpQrCodeTicket = wxMpService.getQrcodeService().qrCodeCreateTmpTicket(code, (int) EXPIRE_TIME.getSeconds()); //返回给前端 sendMsg(channel, WSAdapter.buildLoginResp(wxMpQrCodeTicket)); } @@ -111,12 +117,11 @@ public class WebSocketServiceImpl implements WebSocketService { * @return */ private Integer generateLoginCode(Channel channel) { - int code; do { - code = RandomUtil.randomInt(Integer.MAX_VALUE); - } while (WAIT_LOGIN_MAP.contains(code) - || Objects.nonNull(WAIT_LOGIN_MAP.putIfAbsent(code, channel))); - return code; + CODE.getAndIncrement(); + } while (WAIT_LOGIN_MAP.asMap().containsKey(CODE.get()) + || Objects.isNull(WAIT_LOGIN_MAP.get(CODE.get(), c -> channel))); + return CODE.get(); } /** @@ -203,12 +208,12 @@ public class WebSocketServiceImpl implements WebSocketService { @Override public Boolean scanLoginSuccess(Integer loginCode, User user, String token) { //发送消息 - Channel channel = WAIT_LOGIN_MAP.get(loginCode); + Channel channel = WAIT_LOGIN_MAP.getIfPresent(loginCode); if (Objects.isNull(channel)) { return Boolean.FALSE; } //移除code - WAIT_LOGIN_MAP.remove(loginCode); + WAIT_LOGIN_MAP.invalidate(loginCode); //用户登录 loginSuccess(channel, user, token); return true; @@ -216,7 +221,7 @@ public class WebSocketServiceImpl implements WebSocketService { @Override public Boolean scanSuccess(Integer loginCode) { - Channel channel = WAIT_LOGIN_MAP.get(loginCode); + Channel channel = WAIT_LOGIN_MAP.getIfPresent(loginCode); if (Objects.isNull(channel)) { return Boolean.FALSE; } @@ -291,4 +296,6 @@ public class WebSocketServiceImpl implements WebSocketService { reentrantLock.unlock(); Thread.sleep(1000); } + + } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/HttpHeadersHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/HttpHeadersHandler.java index 5434151..6752060 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/HttpHeadersHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/HttpHeadersHandler.java @@ -8,7 +8,7 @@ import io.netty.handler.codec.http.HttpHeaders; import org.apache.commons.lang3.StringUtils; import java.net.InetSocketAddress; -import java.util.Objects; +import java.util.Optional; public class HttpHeadersHandler extends ChannelInboundHandlerAdapter { @@ -19,11 +19,8 @@ public class HttpHeadersHandler extends ChannelInboundHandlerAdapter { UrlBuilder urlBuilder = UrlBuilder.ofHttp(request.uri()); // 获取token参数 - CharSequence sequence = urlBuilder.getQuery().get("token"); - if (Objects.nonNull(sequence)) { - String token = sequence.toString(); - NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token); - } + String token = Optional.ofNullable(urlBuilder.getQuery()).map(k->k.get("token")).map(CharSequence::toString).orElse(""); + NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token); // 获取请求路径 request.setUri(urlBuilder.getPath().toString());