Merge branch 'main' into main

This commit is contained in:
lm
2023-07-06 22:15:22 +08:00
committed by GitHub
13 changed files with 532 additions and 98 deletions

View File

@@ -2,9 +2,8 @@ package com.abin.mallchat.common.common.aspect;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import com.abin.mallchat.common.common.annotation.FrequencyControl; import com.abin.mallchat.common.common.annotation.FrequencyControl;
import com.abin.mallchat.common.common.exception.BusinessException; import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.exception.CommonErrorEnum; import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
import com.abin.mallchat.common.common.utils.RedisUtils;
import com.abin.mallchat.common.common.utils.RequestHolder; import com.abin.mallchat.common.common.utils.RequestHolder;
import com.abin.mallchat.common.common.utils.SpElUtils; import com.abin.mallchat.common.common.utils.SpElUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -15,7 +14,12 @@ import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.*; import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
/** /**
* Description: 频控实现 * Description: 频控实现
@@ -48,25 +52,25 @@ public class FrequencyControlAspect {
} }
keyMap.put(prefix + ":" + key, frequencyControl); keyMap.put(prefix + ":" + key, frequencyControl);
} }
//批量获取redis统计的值 // 将注解的参数转换为编程式调用需要的参数
ArrayList<String> keyList = new ArrayList<>(keyMap.keySet()); List<FrequencyControlDTO> frequencyControlDTOS = keyMap.entrySet().stream().map(entrySet -> buildFrequencyControlDTO(entrySet.getKey(), entrySet.getValue())).collect(Collectors.toList());
List<Integer> countList = RedisUtils.mget(keyList, Integer.class); // 调用编程式注解
for (int i = 0; i < keyList.size(); i++) { return FrequencyControlUtil.executeWithFrequencyControlList(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTOS, joinPoint::proceed);
String key = keyList.get(i); }
Integer count = countList.get(i);
FrequencyControl frequencyControl = keyMap.get(key); /**
if (Objects.nonNull(count) && count >= frequencyControl.count()) {//频率超过了 * 将注解参数转换为编程式调用所需要的参数
log.warn("frequencyControl limit key:{},count:{}", key, count); *
throw new BusinessException(CommonErrorEnum.FREQUENCY_LIMIT); * @param key 频率控制Key
} * @param frequencyControl 注解
} * @return 编程式调用所需要的参数-FrequencyControlDTO
try { */
return joinPoint.proceed(); private FrequencyControlDTO buildFrequencyControlDTO(String key, FrequencyControl frequencyControl) {
} finally { FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
//不管成功还是失败,都增加次数 frequencyControlDTO.setCount(frequencyControl.count());
keyMap.forEach((k, v) -> { frequencyControlDTO.setTime(frequencyControl.time());
RedisUtils.inc(k, v.time(), v.unit()); frequencyControlDTO.setUnit(frequencyControl.unit());
}); frequencyControlDTO.setKey(key);
} return frequencyControlDTO;
} }
} }

View File

@@ -0,0 +1,42 @@
package com.abin.mallchat.common.common.domain.dto;
import lombok.*;
import java.util.concurrent.TimeUnit;
@Data
@ToString
@Builder
@NoArgsConstructor
@AllArgsConstructor
/** 限流策略定义
* @author linzhihan
* @date 2023/07/03
*
*/
public class FrequencyControlDTO {
/**
* 代表频控的Key 如果target为Key的话 这里要传值用于构建redis的Key target为Ip或者UID的话会从上下文取值 Key字段无需传值
*/
private String key;
/**
* 频控时间范围,默认单位秒
*
* @return 时间范围
*/
private Integer time;
/**
* 频控时间单位,默认秒
*
* @return 单位
*/
private TimeUnit unit;
/**
* 单位时间内最大访问次数
*
* @return 次数
*/
private Integer count;
}

View File

@@ -0,0 +1,39 @@
package com.abin.mallchat.common.common.exception;
import lombok.Data;
/**
* 自定义限流异常
*
* @author linzhihan
* @date 2023/07/034
*/
@Data
public class FrequencyControlException extends RuntimeException {
private static final long serialVersionUID = 1L;
/**
*  错误码
*/
protected Integer errorCode;
/**
*  错误信息
*/
protected String errorMsg;
public FrequencyControlException() {
super();
}
public FrequencyControlException(String errorMsg) {
super(errorMsg);
this.errorMsg = errorMsg;
}
public FrequencyControlException(ErrorEnum error) {
super(error.getErrorMsg());
this.errorCode = error.getErrorCode();
this.errorMsg = error.getErrorMsg();
}
}

View File

@@ -73,4 +73,12 @@ public class GlobalExceptionHandler {
return ApiResult.fail(-1, String.format("不支持'%s'请求", e.getMethod())); return ApiResult.fail(-1, String.format("不支持'%s'请求", e.getMethod()));
} }
/**
* 限流异常
*/
@ExceptionHandler(value = FrequencyControlException.class)
public ApiResult frequencyControlExceptionHandler(FrequencyControlException e) {
log.info("frequencyControl exceptionThe reason is{}", e.getMessage(), e);
return ApiResult.fail(e.getErrorCode(), e.getMessage());
}
} }

View File

@@ -0,0 +1,113 @@
package com.abin.mallchat.common.common.service.frequencycontrol;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.exception.CommonErrorEnum;
import com.abin.mallchat.common.common.exception.FrequencyControlException;
import com.abin.mallchat.common.common.utils.AssertUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import javax.annotation.PostConstruct;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* 抽象类频控服务 其他类如果要实现限流服务 直接注入使用通用限流类
* 后期会通过继承此类实现令牌桶等算法
*
* @author linzhihan
* @date 2023/07/03
* @see TotalCountWithInFixTimeFrequencyController 通用限流类
*/
@Slf4j
public abstract class AbstractFrequencyControlService<K extends FrequencyControlDTO> {
@PostConstruct
protected void registerMyselfToFactory() {
FrequencyControlStrategyFactory.registerFrequencyController(getStrategyName(), this);
}
/**
* @param frequencyControlMap 定义的注解频控 Map中的Key-对应redis的单个频控的Key Map中的Value-对应redis的单个频控的Key限制的Value
* @param supplier 函数式入参-代表每个频控方法执行的不同的业务逻辑
* @return 业务方法执行的返回值
* @throws Throwable
*/
private <T> T executeWithFrequencyControlMap(Map<String, K> frequencyControlMap, SupplierThrowWithoutParam<T> supplier) throws Throwable {
if (reachRateLimit(frequencyControlMap)) {
throw new FrequencyControlException(CommonErrorEnum.FREQUENCY_LIMIT);
}
try {
return supplier.get();
} finally {
//不管成功还是失败,都增加次数
addFrequencyControlStatisticsCount(frequencyControlMap);
}
}
/**
* 多限流策略的编程式调用方法 无参的调用方法
*
* @param frequencyControlList 频控列表 包含每一个频率控制的定义以及顺序
* @param supplier 函数式入参-代表每个频控方法执行的不同的业务逻辑
* @return 业务方法执行的返回值
* @throws Throwable 被限流或者限流策略定义错误
*/
@SuppressWarnings("unchecked")
public <T> T executeWithFrequencyControlList(List<K> frequencyControlList, SupplierThrowWithoutParam<T> supplier) throws Throwable {
boolean existsFrequencyControlHasNullKey = frequencyControlList.stream().anyMatch(frequencyControl -> ObjectUtils.isEmpty(frequencyControl.getKey()));
AssertUtil.isFalse(existsFrequencyControlHasNullKey, "限流策略的Key字段不允许出现空值");
Map<String, FrequencyControlDTO> frequencyControlDTOMap = frequencyControlList.stream().collect(Collectors.groupingBy(FrequencyControlDTO::getKey, Collectors.collectingAndThen(Collectors.toList(), list -> list.get(0))));
return executeWithFrequencyControlMap((Map<String, K>) frequencyControlDTOMap, supplier);
}
/**
* 单限流策略的调用方法-编程式调用
*
* @param frequencyControl 单个频控对象
* @param supplier 服务提供着
* @return 业务方法执行结果
* @throws Throwable
*/
public <T> T executeWithFrequencyControl(K frequencyControl, SupplierThrowWithoutParam<T> supplier) throws Throwable {
return executeWithFrequencyControlList(Collections.singletonList(frequencyControl), supplier);
}
@FunctionalInterface
public interface SupplierThrowWithoutParam<T> {
/**
* Gets a result.
*
* @return a result
*/
T get() throws Throwable;
}
/**
* 是否达到限流阈值 子类实现 每个子类都可以自定义自己的限流逻辑判断
*
* @param frequencyControlMap 定义的注解频控 Map中的Key-对应redis的单个频控的Key Map中的Value-对应redis的单个频控的Key限制的Value
* @return true-方法被限流 false-方法没有被限流
*/
protected abstract boolean reachRateLimit(Map<String, K> frequencyControlMap);
/**
* 增加限流统计次数 子类实现 每个子类都可以自定义自己的限流统计信息增加的逻辑
*
* @param frequencyControlMap 定义的注解频控 Map中的Key-对应redis的单个频控的Key Map中的Value-对应redis的单个频控的Key限制的Value
*/
protected abstract void addFrequencyControlStatisticsCount(Map<String, K> frequencyControlMap);
/**
* 获取策略名称
*
* @return 策略名称
*/
protected abstract String getStrategyName();
}

View File

@@ -0,0 +1,51 @@
package com.abin.mallchat.common.common.service.frequencycontrol;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* 限流策略工厂
*
* @author linzhihan
* @date 2023/07/03
*/
public class FrequencyControlStrategyFactory {
/**
* 指定时间内总次数限流
*/
public static final String TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER = "TotalCountWithInFixTime";
/**
* 限流策略集合
*/
static Map<String, AbstractFrequencyControlService<?>> frequencyControlServiceStrategyMap = new ConcurrentHashMap<>(8);
/**
* 将策略类放入工厂
*
* @param strategyName 策略名称
* @param abstractFrequencyControlService 策略类
*/
public static <K extends FrequencyControlDTO> void registerFrequencyController(String strategyName, AbstractFrequencyControlService<K> abstractFrequencyControlService) {
frequencyControlServiceStrategyMap.put(strategyName, abstractFrequencyControlService);
}
/**
* 根据名称获取策略类
*
* @param strategyName 策略名称
* @return 对应的限流策略类
*/
@SuppressWarnings("unchecked")
public static <K extends FrequencyControlDTO> AbstractFrequencyControlService<K> getFrequencyControllerByName(String strategyName) {
return (AbstractFrequencyControlService<K>) frequencyControlServiceStrategyMap.get(strategyName);
}
/**
* 构造器私有
*/
private FrequencyControlStrategyFactory() {
}
}

View File

@@ -0,0 +1,55 @@
package com.abin.mallchat.common.common.service.frequencycontrol;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.utils.AssertUtil;
import org.apache.commons.lang3.ObjectUtils;
import java.util.List;
/**
* 限流工具类 提供编程式的限流调用方法
*
* @author linzhihan
* @date 2023/07/03
*/
public class FrequencyControlUtil {
/**
* 单限流策略的调用方法-编程式调用
*
* @param strategyName 策略名称
* @param frequencyControl 单个频控对象
* @param supplier 服务提供着
* @return 业务方法执行结果
* @throws Throwable
*/
public static <T, K extends FrequencyControlDTO> T executeWithFrequencyControl(String strategyName, K frequencyControl, AbstractFrequencyControlService.SupplierThrowWithoutParam<T> supplier) throws Throwable {
AbstractFrequencyControlService<K> frequencyController = FrequencyControlStrategyFactory.getFrequencyControllerByName(strategyName);
return frequencyController.executeWithFrequencyControl(frequencyControl, supplier);
}
/**
* 多限流策略的编程式调用方法调用方法
*
* @param strategyName 策略名称
* @param frequencyControlList 频控列表 包含每一个频率控制的定义以及顺序
* @param supplier 函数式入参-代表每个频控方法执行的不同的业务逻辑
* @return 业务方法执行的返回值
* @throws Throwable 被限流或者限流策略定义错误
*/
public static <T, K extends FrequencyControlDTO> T executeWithFrequencyControlList(String strategyName, List<K> frequencyControlList, AbstractFrequencyControlService.SupplierThrowWithoutParam<T> supplier) throws Throwable {
boolean existsFrequencyControlHasNullKey = frequencyControlList.stream().anyMatch(frequencyControl -> ObjectUtils.isEmpty(frequencyControl.getKey()));
AssertUtil.isFalse(existsFrequencyControlHasNullKey, "限流策略的Key字段不允许出现空值");
AbstractFrequencyControlService<K> frequencyController = FrequencyControlStrategyFactory.getFrequencyControllerByName(strategyName);
return frequencyController.executeWithFrequencyControlList(frequencyControlList, supplier);
}
/**
* 构造器私有
*/
private FrequencyControlUtil() {
}
}

View File

@@ -0,0 +1,64 @@
package com.abin.mallchat.common.common.service.frequencycontrol;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.utils.RedisUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
/**
* 抽象类频控服务 -使用redis实现 固定时间内不超过固定次数的限流类
*
* @author linzhihan
* @date 2023/07/03
*/
@Slf4j
@Service
public class TotalCountWithInFixTimeFrequencyController extends AbstractFrequencyControlService<FrequencyControlDTO> {
/**
* 是否达到限流阈值 子类实现 每个子类都可以自定义自己的限流逻辑判断
*
* @param frequencyControlMap 定义的注解频控 Map中的Key-对应redis的单个频控的Key Map中的Value-对应redis的单个频控的Key限制的Value
* @return true-方法被限流 false-方法没有被限流
*/
@Override
protected boolean reachRateLimit(Map<String, FrequencyControlDTO> frequencyControlMap) {
//批量获取redis统计的值
List<String> frequencyKeys = new ArrayList<>(frequencyControlMap.keySet());
List<Integer> countList = RedisUtils.mget(frequencyKeys, Integer.class);
for (int i = 0; i < frequencyKeys.size(); i++) {
String key = frequencyKeys.get(i);
Integer count = countList.get(i);
int frequencyControlCount = frequencyControlMap.get(key).getCount();
if (Objects.nonNull(count) && count >= frequencyControlCount) {
//频率超过了
log.warn("frequencyControl limit key:{},count:{}", key, count);
return true;
}
}
return false;
}
/**
* 增加限流统计次数 子类实现 每个子类都可以自定义自己的限流统计信息增加的逻辑
*
* @param frequencyControlMap 定义的注解频控 Map中的Key-对应redis的单个频控的Key Map中的Value-对应redis的单个频控的Key限制的Value
*/
@Override
protected void addFrequencyControlStatisticsCount(Map<String, FrequencyControlDTO> frequencyControlMap) {
frequencyControlMap.forEach((k, v) -> RedisUtils.inc(k, v.getTime(), v.getUnit()));
}
@Override
protected String getStrategyName() {
return TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
}
}

View File

@@ -0,0 +1,21 @@
package com.abin.mallchat.custom.chatai.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class GPTRequestDTO {
/**
* 聊天内容
*/
private String content;
/**
* 用户Id
*/
private Long uid;
}

View File

@@ -4,12 +4,17 @@ import cn.hutool.http.HttpResponse;
import com.abin.mallchat.common.chat.domain.entity.Message; import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra; import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
import com.abin.mallchat.common.common.constant.RedisKey; import com.abin.mallchat.common.common.constant.RedisKey;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.exception.FrequencyControlException;
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
import com.abin.mallchat.common.common.utils.RedisUtils; import com.abin.mallchat.common.common.utils.RedisUtils;
import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO;
import com.abin.mallchat.custom.chatai.properties.ChatGLM2Properties; import com.abin.mallchat.custom.chatai.properties.ChatGLM2Properties;
import com.abin.mallchat.custom.chatai.utils.ChatGLM2Utils; import com.abin.mallchat.custom.chatai.utils.ChatGLM2Utils;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.Nullable;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@@ -21,10 +26,15 @@ import java.util.Random;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static com.abin.mallchat.common.common.constant.RedisKey.USER_GLM2_TIME_LAST; import static com.abin.mallchat.common.common.constant.RedisKey.USER_GLM2_TIME_LAST;
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
@Slf4j @Slf4j
@Component @Component
public class ChatGLM2Handler extends AbstractChatAIHandler { public class ChatGLM2Handler extends AbstractChatAIHandler {
/**
* ChatGLM2Handler限流前缀
*/
private static final String CHAT_GLM2_FREQUENCY_PREFIX = "ChatGLM2Handler";
private static final List<String> ERROR_MSG = Arrays.asList( private static final List<String> ERROR_MSG = Arrays.asList(
"还摸鱼呢?你不下班我还要下班呢。。。。", "还摸鱼呢?你不下班我还要下班呢。。。。",
@@ -74,11 +84,30 @@ public class ChatGLM2Handler extends AbstractChatAIHandler {
protected String doChat(Message message) { protected String doChat(Message message) {
String content = message.getContent().replace("@" + AI_NAME, "").trim(); String content = message.getContent().replace("@" + AI_NAME, "").trim();
Long uid = message.getFromUid(); Long uid = message.getFromUid();
Long minute; try {
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
frequencyControlDTO.setKey(CHAT_GLM2_FREQUENCY_PREFIX + ":" + uid);
frequencyControlDTO.setUnit(TimeUnit.MINUTES);
frequencyControlDTO.setCount(1);
frequencyControlDTO.setTime(glm2Properties.getMinute().intValue());
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid)));
} catch (FrequencyControlException e) {
return "你太快了亲爱的~过一会再来找人家~";
} catch (Throwable e) {
return "系统开小差啦~~";
}
}
/**
* TODO
*
* @param gptRequestDTO
* @return
*/
@Nullable
private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) {
String content = gptRequestDTO.getContent();
String text; String text;
if ((minute = userMinutesLater(uid)) > 0) {
text = "你太快了 " + minute + "分钟后重试";
} else {
HttpResponse response = null; HttpResponse response = null;
try { try {
response = ChatGLM2Utils response = ChatGLM2Utils
@@ -92,10 +121,6 @@ public class ChatGLM2Handler extends AbstractChatAIHandler {
log.warn("glm2 doChat warn:", e); log.warn("glm2 doChat warn:", e);
return getErrorText(); return getErrorText();
} }
if (StringUtils.isNotBlank(text)) {
RedisUtils.set(RedisKey.getKey(USER_GLM2_TIME_LAST, uid), new Date(), glm2Properties.getMinute(), TimeUnit.MINUTES);
}
}
return text; return text;
} }

View File

@@ -3,9 +3,10 @@ package com.abin.mallchat.custom.chatai.handler;
import cn.hutool.http.HttpResponse; import cn.hutool.http.HttpResponse;
import com.abin.mallchat.common.chat.domain.entity.Message; import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra; import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
import com.abin.mallchat.common.common.constant.RedisKey; import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.utils.DateUtils; import com.abin.mallchat.common.common.exception.FrequencyControlException;
import com.abin.mallchat.common.common.utils.RedisUtils; import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO;
import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties; import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties;
import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils; import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
@@ -17,9 +18,15 @@ import org.springframework.util.CollectionUtils;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
@Slf4j @Slf4j
@Component @Component
public class GPTChatAIHandler extends AbstractChatAIHandler { public class GPTChatAIHandler extends AbstractChatAIHandler {
/**
* GPTChatAIHandler限流前缀
*/
private static final String CHAT_FREQUENCY_PREFIX = "GPTChatAIHandler";
@Autowired @Autowired
private ChatGPTProperties chatGPTProperties; private ChatGPTProperties chatGPTProperties;
@@ -51,16 +58,27 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
return chatGPTProperties.getAIUserId(); return chatGPTProperties.getAIUserId();
} }
@Override @Override
protected String doChat(Message message) { protected String doChat(Message message) {
String content = message.getContent().replace("@" + AI_NAME, "").trim(); String content = message.getContent().replace("@" + AI_NAME, "").trim();
Long uid = message.getFromUid(); Long uid = message.getFromUid();
Long chatNum; try {
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
frequencyControlDTO.setKey(CHAT_FREQUENCY_PREFIX + ":" + uid);
frequencyControlDTO.setUnit(TimeUnit.HOURS);
frequencyControlDTO.setCount(chatGPTProperties.getLimit());
frequencyControlDTO.setTime(24);
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid)));
} catch (FrequencyControlException e) {
return "亲爱的,你今天找我聊了" + chatGPTProperties.getLimit() + "次了~人家累了~明天见";
} catch (Throwable e) {
return "系统开小差啦~~";
}
}
private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) {
String content = gptRequestDTO.getContent();
String text; String text;
if ((chatNum = getUserChatNum(uid)) > chatGPTProperties.getLimit()) {
text = "你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧";
} else {
HttpResponse response = null; HttpResponse response = null;
try { try {
response = ChatGPTUtils.create(chatGPTProperties.getKey()) response = ChatGPTUtils.create(chatGPTProperties.getKey())
@@ -70,24 +88,14 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
.prompt(content) .prompt(content)
.send(); .send();
text = ChatGPTUtils.parseText(response); text = ChatGPTUtils.parseText(response);
userChatNumInrc(uid);
} catch (Exception e) { } catch (Exception e) {
log.warn("gpt doChat warn:", e); log.warn("gpt doChat warn:", e);
text= "我累了,明天再聊吧"; text= "我累了,明天再聊吧";
} }
}
return text; return text;
} }
private Long userChatNumInrc(Long uid) {
return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS);
}
private Long getUserChatNum(Long uid) {
Long num = RedisUtils.get(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), Long.class);
return num == null ? 0 : num;
}
@Override @Override

View File

@@ -2,7 +2,6 @@ package com.abin.mallchat.custom.user.service.impl;
import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.abin.mallchat.common.common.annotation.FrequencyControl; import com.abin.mallchat.common.common.annotation.FrequencyControl;
import com.abin.mallchat.common.common.config.ThreadPoolConfig; import com.abin.mallchat.common.common.config.ThreadPoolConfig;
@@ -20,6 +19,8 @@ import com.abin.mallchat.custom.user.service.LoginService;
import com.abin.mallchat.custom.user.service.WebSocketService; import com.abin.mallchat.custom.user.service.WebSocketService;
import com.abin.mallchat.custom.user.service.adapter.WSAdapter; import com.abin.mallchat.custom.user.service.adapter.WSAdapter;
import com.abin.mallchat.custom.user.websocket.NettyUtil; import com.abin.mallchat.custom.user.websocket.NettyUtil;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.SneakyThrows; import lombok.SneakyThrows;
@@ -32,13 +33,13 @@ import org.springframework.context.ApplicationEventPublisher;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.Date;
import java.util.Map; import java.time.Duration;
import java.util.Objects; import java.util.*;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
@@ -51,12 +52,18 @@ import java.util.concurrent.locks.ReentrantLock;
@Slf4j @Slf4j
public class WebSocketServiceImpl implements WebSocketService { public class WebSocketServiceImpl implements WebSocketService {
private static final Duration EXPIRE_TIME = Duration.ofHours(1);
private static final Long MAX_MUM_SIZE = 10000L;
private static final AtomicInteger CODE = new AtomicInteger();
/** /**
* 所有请求登录的code与channel关系 * 所有请求登录的code与channel关系
* todo 有可能有人请求了二维码,就是不登录,留个坑,之后处理
*/ */
private static final ConcurrentHashMap<Integer, Channel> WAIT_LOGIN_MAP = new ConcurrentHashMap<>(); private static final Cache<Integer, Channel> WAIT_LOGIN_MAP = Caffeine.newBuilder()
.expireAfterWrite(EXPIRE_TIME)
.maximumSize(MAX_MUM_SIZE)
.build();
/** /**
* 所有已连接的websocket连接列表和一些额外参数 * 所有已连接的websocket连接列表和一些额外参数
*/ */
@@ -70,7 +77,6 @@ public class WebSocketServiceImpl implements WebSocketService {
return ONLINE_WS_MAP; return ONLINE_WS_MAP;
} }
public static final int EXPIRE_SECONDS = 60 * 60;
@Autowired @Autowired
private WxMpService wxMpService; private WxMpService wxMpService;
@Autowired @Autowired
@@ -99,7 +105,7 @@ public class WebSocketServiceImpl implements WebSocketService {
//生成随机不重复的登录码 //生成随机不重复的登录码
Integer code = generateLoginCode(channel); Integer code = generateLoginCode(channel);
//请求微信接口,获取登录码地址 //请求微信接口,获取登录码地址
WxMpQrCodeTicket wxMpQrCodeTicket = wxMpService.getQrcodeService().qrCodeCreateTmpTicket(code, EXPIRE_SECONDS); WxMpQrCodeTicket wxMpQrCodeTicket = wxMpService.getQrcodeService().qrCodeCreateTmpTicket(code, (int) EXPIRE_TIME.getSeconds());
//返回给前端 //返回给前端
sendMsg(channel, WSAdapter.buildLoginResp(wxMpQrCodeTicket)); sendMsg(channel, WSAdapter.buildLoginResp(wxMpQrCodeTicket));
} }
@@ -111,12 +117,11 @@ public class WebSocketServiceImpl implements WebSocketService {
* @return * @return
*/ */
private Integer generateLoginCode(Channel channel) { private Integer generateLoginCode(Channel channel) {
int code;
do { do {
code = RandomUtil.randomInt(Integer.MAX_VALUE); CODE.getAndIncrement();
} while (WAIT_LOGIN_MAP.contains(code) } while (WAIT_LOGIN_MAP.asMap().containsKey(CODE.get())
|| Objects.nonNull(WAIT_LOGIN_MAP.putIfAbsent(code, channel))); || Objects.isNull(WAIT_LOGIN_MAP.get(CODE.get(), c -> channel)));
return code; return CODE.get();
} }
/** /**
@@ -203,12 +208,12 @@ public class WebSocketServiceImpl implements WebSocketService {
@Override @Override
public Boolean scanLoginSuccess(Integer loginCode, User user, String token) { public Boolean scanLoginSuccess(Integer loginCode, User user, String token) {
//发送消息 //发送消息
Channel channel = WAIT_LOGIN_MAP.get(loginCode); Channel channel = WAIT_LOGIN_MAP.getIfPresent(loginCode);
if (Objects.isNull(channel)) { if (Objects.isNull(channel)) {
return Boolean.FALSE; return Boolean.FALSE;
} }
//移除code //移除code
WAIT_LOGIN_MAP.remove(loginCode); WAIT_LOGIN_MAP.invalidate(loginCode);
//用户登录 //用户登录
loginSuccess(channel, user, token); loginSuccess(channel, user, token);
return true; return true;
@@ -216,7 +221,7 @@ public class WebSocketServiceImpl implements WebSocketService {
@Override @Override
public Boolean scanSuccess(Integer loginCode) { public Boolean scanSuccess(Integer loginCode) {
Channel channel = WAIT_LOGIN_MAP.get(loginCode); Channel channel = WAIT_LOGIN_MAP.getIfPresent(loginCode);
if (Objects.isNull(channel)) { if (Objects.isNull(channel)) {
return Boolean.FALSE; return Boolean.FALSE;
} }
@@ -291,4 +296,6 @@ public class WebSocketServiceImpl implements WebSocketService {
reentrantLock.unlock(); reentrantLock.unlock();
Thread.sleep(1000); Thread.sleep(1000);
} }
} }

View File

@@ -8,7 +8,7 @@ import io.netty.handler.codec.http.HttpHeaders;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.Objects; import java.util.Optional;
public class HttpHeadersHandler extends ChannelInboundHandlerAdapter { public class HttpHeadersHandler extends ChannelInboundHandlerAdapter {
@@ -19,11 +19,8 @@ public class HttpHeadersHandler extends ChannelInboundHandlerAdapter {
UrlBuilder urlBuilder = UrlBuilder.ofHttp(request.uri()); UrlBuilder urlBuilder = UrlBuilder.ofHttp(request.uri());
// 获取token参数 // 获取token参数
CharSequence sequence = urlBuilder.getQuery().get("token"); String token = Optional.ofNullable(urlBuilder.getQuery()).map(k->k.get("token")).map(CharSequence::toString).orElse("");
if (Objects.nonNull(sequence)) {
String token = sequence.toString();
NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token); NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token);
}
// 获取请求路径 // 获取请求路径
request.setUri(urlBuilder.getPath().toString()); request.setUri(urlBuilder.getPath().toString());