Merge branch 'main' into chat_context

# Conflicts:
#	mallchat-common/pom.xml
#	mallchat-custom-server/pom.xml
#	mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chatai/handler/GPTChatAIHandler.java
This commit is contained in:
zhaoyuhang
2023-07-10 22:46:20 +08:00
81 changed files with 1870 additions and 455 deletions

View File

@@ -84,7 +84,7 @@ public class ChatController {
@GetMapping("/public/msg/page")
@ApiOperation("消息列表")
@FrequencyControl(time = 120, count = 20, target = FrequencyControl.Target.IP)
public ApiResult<CursorPageBaseResp<ChatMessageResp>> getMsgPage1(@Valid ChatMessagePageReq request) {
public ApiResult<CursorPageBaseResp<ChatMessageResp>> getMsgPage(@Valid ChatMessagePageReq request) {
// black(request);
CursorPageBaseResp<ChatMessageResp> msgPage = chatService.getMsgPage(request, RequestHolder.get().getUid());
filterBlackMsg(msgPage);
@@ -94,7 +94,6 @@ public class ChatController {
private void filterBlackMsg(CursorPageBaseResp<ChatMessageResp> memberPage) {
Set<String> blackMembers = getBlackUidSet();
memberPage.getList().removeIf(a -> blackMembers.contains(a.getFromUser().getUid().toString()));
System.out.println(1);
}
@PostMapping("/msg")

View File

@@ -1,6 +1,5 @@
package com.abin.mallchat.custom.chat.domain.vo.request;
import com.abin.mallchat.common.chat.domain.enums.MessageTypeEnum;
import io.swagger.annotations.ApiModelProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
@@ -30,7 +29,7 @@ public class ChatMessageReq {
@ApiModelProperty("消息类型")
@NotNull
private Integer msgType = MessageTypeEnum.TEXT.getType();
private Integer msgType;
@ApiModelProperty("消息内容类型不同传值不同见https://www.yuque.com/snab/mallcaht/rkb2uz5k1qqdmcmd")
@NotNull

View File

@@ -7,7 +7,6 @@ import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Date;
import java.util.Map;
/**
* Description: 消息
@@ -27,16 +26,8 @@ public class ChatMessageResp {
@Data
public static class UserInfo {
@ApiModelProperty("用户名称")
private String username;
@ApiModelProperty("用户id")
private Long uid;
@ApiModelProperty("头像")
private String avatar;
@ApiModelProperty("归属地")
private String locPlace;
@ApiModelProperty("徽章标识如果没有展示null")
private Badge badge;
}
@Data
@@ -45,36 +36,12 @@ public class ChatMessageResp {
private Long id;
@ApiModelProperty("消息发送时间")
private Date sendTime;
@ApiModelProperty("消息内容-废弃")
@Deprecated
private String content;
@ApiModelProperty("消息链接映射-废弃")
@Deprecated
private Map<String, String> urlTitleMap;
@ApiModelProperty("消息类型 1正常文本 2.撤回消息")
private Integer type;
@ApiModelProperty("消息内容不同的消息类型内容体不同见https://www.yuque.com/snab/mallcaht/rkb2uz5k1qqdmcmd")
private Object body;
@ApiModelProperty("消息标记")
private MessageMark messageMark;
@ApiModelProperty("父消息如果没有父消息返回的是null")
private ReplyMsg reply;
}
@Data
@Deprecated
public static class ReplyMsg {
@ApiModelProperty("消息id")
private Long id;
@ApiModelProperty("用户名称")
private String username;
@ApiModelProperty("消息内容")
private String content;
@ApiModelProperty("是否可消息跳转 0否 1是")
private Integer canCallback;
@ApiModelProperty("跳转间隔的消息条数")
private Integer gapCount;
}
@Data
@@ -88,12 +55,4 @@ public class ChatMessageResp {
@ApiModelProperty("该用户是否已经举报 0否 1是")
private Integer userDislike;
}
@Data
public static class Badge {
@ApiModelProperty("徽章图像")
private String img;
@ApiModelProperty("徽章说明")
private String describe;
}
}

View File

@@ -1,5 +1,6 @@
package com.abin.mallchat.custom.chat.domain.vo.response.msg;
import com.abin.mallchat.common.common.utils.discover.domain.UrlInfo;
import io.swagger.annotations.ApiModelProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
@@ -22,7 +23,7 @@ public class TextMsgResp {
@ApiModelProperty("消息内容")
private String content;
@ApiModelProperty("消息链接映射")
private Map<String, String> urlTitleMap;
private Map<String, UrlInfo> urlContentMap;
@ApiModelProperty("艾特的uid")
private List<Long> atUidList;
@ApiModelProperty("父消息如果没有父消息返回的是null")

View File

@@ -0,0 +1,14 @@
package com.abin.mallchat.custom.chat.service;
import java.util.List;
public interface WeChatMsgOperationService {
/**
* 向被at的用户微信推送群聊消息
*
* @param senderUid senderUid
* @param receiverUidList receiverUidList
* @param msg msg
*/
void publishChatMsgToWeChatUser(long senderUid, List<Long> receiverUidList, String msg);
}

View File

@@ -3,14 +3,9 @@ package com.abin.mallchat.custom.chat.service.adapter;
import cn.hutool.core.bean.BeanUtil;
import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.chat.domain.entity.MessageMark;
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
import com.abin.mallchat.common.chat.domain.enums.MessageMarkTypeEnum;
import com.abin.mallchat.common.chat.domain.enums.MessageStatusEnum;
import com.abin.mallchat.common.common.domain.enums.YesOrNoEnum;
import com.abin.mallchat.common.user.domain.entity.IpDetail;
import com.abin.mallchat.common.user.domain.entity.IpInfo;
import com.abin.mallchat.common.user.domain.entity.ItemConfig;
import com.abin.mallchat.common.user.domain.entity.User;
import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq;
import com.abin.mallchat.custom.chat.domain.vo.response.ChatMessageResp;
import com.abin.mallchat.custom.chat.service.strategy.msg.AbstractMsgHandler;
@@ -38,37 +33,25 @@ public class MessageAdapter {
}
public static List<ChatMessageResp> buildMsgResp(List<Message> messages, Map<Long, Message> replyMap, Map<Long, User> userMap, List<MessageMark> msgMark, Long receiveUid, Map<Long, ItemConfig> itemMap) {
public static List<ChatMessageResp> buildMsgResp(List<Message> messages, Map<Long, Message> replyMap, List<MessageMark> msgMark, Long receiveUid) {
Map<Long, List<MessageMark>> markMap = msgMark.stream().collect(Collectors.groupingBy(MessageMark::getMsgId));
return messages.stream().map(a -> {
ChatMessageResp resp = new ChatMessageResp();
resp.setFromUser(buildFromUser(userMap.get(a.getFromUid()), itemMap));
resp.setMessage(buildMessage(a, replyMap, userMap, markMap.getOrDefault(a.getId(), new ArrayList<>()), receiveUid));
resp.setFromUser(buildFromUser(a.getFromUid()));
resp.setMessage(buildMessage(a, replyMap, markMap.getOrDefault(a.getId(), new ArrayList<>()), receiveUid));
return resp;
})
.sorted(Comparator.comparing(a -> a.getMessage().getSendTime()))//帮前端排好序,更方便它展示
.collect(Collectors.toList());
}
private static ChatMessageResp.Message buildMessage(Message message, Map<Long, Message> replyMap, Map<Long, User> userMap, List<MessageMark> marks, Long receiveUid) {
private static ChatMessageResp.Message buildMessage(Message message, Map<Long, Message> replyMap, List<MessageMark> marks, Long receiveUid) {
ChatMessageResp.Message messageVO = new ChatMessageResp.Message();
BeanUtil.copyProperties(message, messageVO);
messageVO.setSendTime(message.getCreateTime());
AbstractMsgHandler msgHandler = MsgHandlerFactory.getStrategyNoNull(message.getType());
messageVO.setBody(msgHandler.showMsg(message));
messageVO.setUrlTitleMap(Optional.ofNullable(message.getExtra()).map(MessageExtra::getUrlTitleMap).orElse(null));
Message replyMessage = replyMap.get(message.getReplyMsgId());
//回复消息
if (Objects.nonNull(replyMessage)) {
ChatMessageResp.ReplyMsg replyMsgVO = new ChatMessageResp.ReplyMsg();
replyMsgVO.setId(replyMessage.getId());
replyMsgVO.setContent(replyMessage.getContent());
User replyUser = userMap.get(replyMessage.getFromUid());
replyMsgVO.setUsername(replyUser.getName());
replyMsgVO.setCanCallback(YesOrNoEnum.toStatus(Objects.nonNull(message.getGapCount()) && message.getGapCount() <= CAN_CALLBACK_GAP_COUNT));
replyMsgVO.setGapCount(message.getGapCount());
messageVO.setReply(replyMsgVO);
if (Objects.nonNull(msgHandler)) {
messageVO.setBody(msgHandler.showMsg(message));
}
//消息标记
messageVO.setMessageMark(buildMsgMark(marks, receiveUid));
@@ -87,19 +70,9 @@ public class MessageAdapter {
return mark;
}
private static ChatMessageResp.UserInfo buildFromUser(User fromUser, Map<Long, ItemConfig> itemMap) {
private static ChatMessageResp.UserInfo buildFromUser(Long fromUid) {
ChatMessageResp.UserInfo userInfo = new ChatMessageResp.UserInfo();
userInfo.setUsername(fromUser.getName());
userInfo.setAvatar(fromUser.getAvatar());
userInfo.setLocPlace(Optional.ofNullable(fromUser.getIpInfo()).map(IpInfo::getUpdateIpDetail).map(IpDetail::getCity).orElse(null));
userInfo.setUid(fromUser.getId());
if (Objects.nonNull(fromUser.getItemId())) {
ChatMessageResp.Badge badge = new ChatMessageResp.Badge();
ItemConfig itemConfig = itemMap.get(fromUser.getItemId());
badge.setImg(itemConfig.getImg());
badge.setDescribe(itemConfig.getDescribe());
userInfo.setBadge(badge);
}
userInfo.setUid(fromUid);
return userInfo;
}

View File

@@ -19,8 +19,6 @@ import com.abin.mallchat.common.common.domain.vo.response.CursorPageBaseResp;
import com.abin.mallchat.common.common.event.MessageSendEvent;
import com.abin.mallchat.common.common.utils.AssertUtil;
import com.abin.mallchat.common.user.dao.UserDao;
import com.abin.mallchat.common.user.domain.entity.ItemConfig;
import com.abin.mallchat.common.user.domain.entity.User;
import com.abin.mallchat.common.user.domain.enums.ChatActiveStatusEnum;
import com.abin.mallchat.common.user.domain.enums.RoleEnum;
import com.abin.mallchat.common.user.service.IRoleService;
@@ -49,7 +47,6 @@ import org.springframework.transaction.annotation.Transactional;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Description: 消息处理类
@@ -226,21 +223,14 @@ public class ChatServiceImpl implements ChatService {
return new ArrayList<>();
}
Map<Long, Message> replyMap = new HashMap<>();
Map<Long, User> userMap;
Map<Long, ItemConfig> itemMap;
//批量查出回复的消息
List<Long> replyIds = messages.stream().map(Message::getReplyMsgId).filter(Objects::nonNull).distinct().collect(Collectors.toList());
if (CollectionUtil.isNotEmpty(replyIds)) {
replyMap = messageDao.listByIds(replyIds).stream().collect(Collectors.toMap(Message::getId, Function.identity()));
}
//批量查询消息关联用户
Set<Long> uidSet = Stream.concat(replyMap.values().stream().map(Message::getFromUid), messages.stream().map(Message::getFromUid)).collect(Collectors.toSet());
userMap = userCache.getUserInfoBatch(uidSet);
//批量查询item信息
itemMap = userMap.values().stream().map(User::getItemId).distinct().filter(Objects::nonNull).map(itemCache::getById).collect(Collectors.toMap(ItemConfig::getId, Function.identity()));
//查询消息标志
List<MessageMark> msgMark = messageMarkDao.getValidMarkByMsgIdBatch(messages.stream().map(Message::getId).collect(Collectors.toList()));
return MessageAdapter.buildMsgResp(messages, replyMap, userMap, msgMark, receiveUid, itemMap);
return MessageAdapter.buildMsgResp(messages, replyMap, msgMark, receiveUid);
}
}

View File

@@ -0,0 +1,115 @@
package com.abin.mallchat.custom.chat.service.impl;
import cn.hutool.core.thread.NamedThreadFactory;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.exception.FrequencyControlException;
import com.abin.mallchat.common.common.handler.GlobalUncaughtExceptionHandler;
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
import com.abin.mallchat.common.user.domain.entity.User;
import com.abin.mallchat.common.user.service.cache.UserCache;
import com.abin.mallchat.custom.chat.service.WeChatMsgOperationService;
import lombok.extern.slf4j.Slf4j;
import me.chanjar.weixin.mp.api.WxMpService;
import me.chanjar.weixin.mp.bean.template.WxMpTemplateData;
import me.chanjar.weixin.mp.bean.template.WxMpTemplateMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
@Slf4j
@Component
public class WeChatMsgOperationServiceImpl implements WeChatMsgOperationService {
private static final ExecutorService executor = new ThreadPoolExecutor(1, 10, 3000L,
TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(20),
new NamedThreadFactory("wechat-operation-thread", null, false,
new GlobalUncaughtExceptionHandler()));
// at消息的微信推送模板id
private final String atMsgPublishTemplateId = "Xd7sWPZsuWa0UmpvLaZPvaJVjNj1KjEa0zLOm5_Z7IU";
private final String WE_CHAT_MSG_COLOR = "#A349A4";
@Autowired
private UserCache userCache;
@Autowired
WxMpService wxMpService;
@Override
public void publishChatMsgToWeChatUser(long senderUid, List<Long> receiverUidList, String msg) {
User sender = userCache.getUserInfo(senderUid);
Set uidSet = new HashSet();
uidSet.addAll(receiverUidList);
Map<Long, User> userMap = userCache.getUserInfoBatch(uidSet);
userMap.values().forEach(user -> {
if (Objects.nonNull(user.getOpenId())) {
executor.execute(() -> {
WxMpTemplateMessage msgTemplate = getAtMsgTemplate(sender, user.getOpenId(), msg);
publishTemplateMsgCheckLimit(msgTemplate);
});
}
});
}
private void publishTemplateMsgCheckLimit(WxMpTemplateMessage msgTemplate) {
try {
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
frequencyControlDTO.setKey("TemplateMsg:" + msgTemplate.getToUser());
frequencyControlDTO.setUnit(TimeUnit.HOURS);
frequencyControlDTO.setCount(1);
frequencyControlDTO.setTime(1);
FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO,
() -> publishTemplateMsg(msgTemplate));
} catch (FrequencyControlException e) {
log.info("wx push limit openid:{}", msgTemplate.getToUser());
} catch (Throwable e) {
log.error("wx push error openid:{}", msgTemplate.getToUser());
}
}
/*
* 构造微信模板消息
*/
private WxMpTemplateMessage getAtMsgTemplate(User sender, String openId, String msg) {
return WxMpTemplateMessage.builder()
.toUser(openId)
.templateId(atMsgPublishTemplateId)
.data(generateAtMsgData(sender, msg))
.build();
}
/*
* 构造微信消息模板的数据
*/
private List<WxMpTemplateData> generateAtMsgData(User sender, String msg) {
List dataList = new ArrayList<WxMpTemplateData>();
// todo: 没有消息模板,暂不实现
dataList.add(new WxMpTemplateData("name", sender.getName(), WE_CHAT_MSG_COLOR));
dataList.add(new WxMpTemplateData("content", msg, WE_CHAT_MSG_COLOR));
return dataList;
}
/**
* 推送微信模板消息
*
* @param templateMsg 微信模板消息
*/
protected void publishTemplateMsg(WxMpTemplateMessage templateMsg) {
// WxMpTemplateMsgService wxMpTemplateMsgService = wxMpService.getTemplateMsgService();todo 等审核通过
// try {
// wxMpTemplateMsgService.sendTemplateMsg(templateMsg);
// } catch (WxErrorException e) {
// log.error("publish we chat msg failed! open id is {}, msg is {}.",
// templateMsg.getToUser(), JsonUtils.toStr(templateMsg.getData()));
// }
}
}

View File

@@ -0,0 +1,57 @@
package com.abin.mallchat.custom.chat.service.strategy.msg;
import cn.hutool.core.bean.BeanUtil;
import com.abin.mallchat.common.chat.dao.MessageDao;
import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.chat.domain.entity.msg.EmojisMsgDTO;
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
import com.abin.mallchat.common.chat.domain.enums.MessageTypeEnum;
import com.abin.mallchat.common.common.utils.AssertUtil;
import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.Optional;
/**
* Description:表情消息
* Author: <a href="https://github.com/zongzibinbin">abin</a>
* Date: 2023-06-04
*/
@Component
public class EmojisMsgHandler extends AbstractMsgHandler {
@Autowired
private MessageDao messageDao;
@Override
MessageTypeEnum getMsgTypeEnum() {
return MessageTypeEnum.EMOJI;
}
@Override
public void checkMsg(ChatMessageReq request, Long uid) {
EmojisMsgDTO body = BeanUtil.toBean(request.getBody(), EmojisMsgDTO.class);
AssertUtil.allCheckValidateThrow(body);
}
@Override
public void saveMsg(Message msg, ChatMessageReq request) {
EmojisMsgDTO body = BeanUtil.toBean(request.getBody(), EmojisMsgDTO.class);
MessageExtra extra = Optional.ofNullable(msg.getExtra()).orElse(new MessageExtra());
Message update = new Message();
update.setId(msg.getId());
update.setExtra(extra);
extra.setEmojisMsgDTO(body);
messageDao.updateById(update);
}
@Override
public Object showMsg(Message msg) {
return msg.getExtra().getEmojisMsgDTO();
}
@Override
public Object showReplyMsg(Message msg) {
return "表情";
}
}

View File

@@ -10,8 +10,9 @@ import com.abin.mallchat.common.chat.domain.enums.MessageTypeEnum;
import com.abin.mallchat.common.chat.service.cache.MsgCache;
import com.abin.mallchat.common.common.domain.enums.YesOrNoEnum;
import com.abin.mallchat.common.common.utils.AssertUtil;
import com.abin.mallchat.common.common.utils.SensitiveWordUtils;
import com.abin.mallchat.common.common.utils.discover.PrioritizedUrlTitleDiscover;
import com.abin.mallchat.common.common.utils.discover.PrioritizedUrlDiscover;
import com.abin.mallchat.common.common.utils.discover.domain.UrlInfo;
import com.abin.mallchat.common.common.utils.sensitiveWord.SensitiveWordBs;
import com.abin.mallchat.common.user.domain.entity.User;
import com.abin.mallchat.common.user.domain.enums.RoleEnum;
import com.abin.mallchat.common.user.service.IRoleService;
@@ -46,8 +47,10 @@ public class TextMsgHandler extends AbstractMsgHandler {
private UserInfoCache userInfoCache;
@Autowired
private IRoleService iRoleService;
private static final PrioritizedUrlTitleDiscover URL_TITLE_DISCOVER = new PrioritizedUrlTitleDiscover();
@Autowired
private SensitiveWordBs sensitiveWordBs;
private static final PrioritizedUrlDiscover URL_TITLE_DISCOVER = new PrioritizedUrlDiscover();
@Override
MessageTypeEnum getMsgTypeEnum() {
@@ -81,7 +84,7 @@ public class TextMsgHandler extends AbstractMsgHandler {
MessageExtra extra = Optional.ofNullable(msg.getExtra()).orElse(new MessageExtra());
Message update = new Message();
update.setId(msg.getId());
update.setContent(SensitiveWordUtils.filter(body.getContent()));
update.setContent(sensitiveWordBs.filter(body.getContent()));
update.setExtra(extra);
//如果有回复消息
if (Objects.nonNull(body.getReplyMsgId())) {
@@ -91,8 +94,8 @@ public class TextMsgHandler extends AbstractMsgHandler {
}
//判断消息url跳转
Map<String, String> urlTitleMap = URL_TITLE_DISCOVER.getContentTitleMap(body.getContent());
extra.setUrlTitleMap(urlTitleMap);
Map<String, UrlInfo> urlContentMap = URL_TITLE_DISCOVER.getUrlContentMap(body.getContent());
extra.setUrlContentMap(urlContentMap);
//艾特功能
if (CollectionUtil.isNotEmpty(body.getAtUidList())) {
extra.setAtUidList(body.getAtUidList());
@@ -106,7 +109,7 @@ public class TextMsgHandler extends AbstractMsgHandler {
public Object showMsg(Message msg) {
TextMsgResp resp = new TextMsgResp();
resp.setContent(msg.getContent());
resp.setUrlTitleMap(Optional.ofNullable(msg.getExtra()).map(MessageExtra::getUrlTitleMap).orElse(null));
resp.setUrlContentMap(Optional.ofNullable(msg.getExtra()).map(MessageExtra::getUrlContentMap).orElse(null));
resp.setAtUidList(Optional.ofNullable(msg.getExtra()).map(MessageExtra::getAtUidList).orElse(null));
//回复消息
Optional<Message> reply = Optional.ofNullable(msg.getReplyMsgId())

View File

@@ -0,0 +1,21 @@
package com.abin.mallchat.custom.chatai.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class GPTRequestDTO {
/**
* 聊天内容
*/
private String content;
/**
* 用户Id
*/
private Long uid;
}

View File

@@ -4,12 +4,17 @@ import cn.hutool.http.HttpResponse;
import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.chat.domain.entity.msg.MessageExtra;
import com.abin.mallchat.common.common.constant.RedisKey;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.exception.FrequencyControlException;
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
import com.abin.mallchat.common.common.utils.RedisUtils;
import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO;
import com.abin.mallchat.custom.chatai.properties.ChatGLM2Properties;
import com.abin.mallchat.custom.chatai.utils.ChatGLM2Utils;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.Nullable;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@@ -21,10 +26,15 @@ import java.util.Random;
import java.util.concurrent.TimeUnit;
import static com.abin.mallchat.common.common.constant.RedisKey.USER_GLM2_TIME_LAST;
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
@Slf4j
@Component
public class ChatGLM2Handler extends AbstractChatAIHandler {
/**
* ChatGLM2Handler限流前缀
*/
private static final String CHAT_GLM2_FREQUENCY_PREFIX = "ChatGLM2Handler";
private static final List<String> ERROR_MSG = Arrays.asList(
"还摸鱼呢?你不下班我还要下班呢。。。。",
@@ -74,27 +84,42 @@ public class ChatGLM2Handler extends AbstractChatAIHandler {
protected String doChat(Message message) {
String content = message.getContent().replace("@" + AI_NAME, "").trim();
Long uid = message.getFromUid();
Long minute;
try {
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
frequencyControlDTO.setKey(CHAT_GLM2_FREQUENCY_PREFIX + ":" + uid);
frequencyControlDTO.setUnit(TimeUnit.MINUTES);
frequencyControlDTO.setCount(1);
frequencyControlDTO.setTime(glm2Properties.getMinute().intValue());
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid)));
} catch (FrequencyControlException e) {
return "你太快了亲爱的~过一会再来找人家~";
} catch (Throwable e) {
return "系统开小差啦~~";
}
}
/**
* TODO
*
* @param gptRequestDTO
* @return
*/
@Nullable
private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) {
String content = gptRequestDTO.getContent();
String text;
if ((minute = userMinutesLater(uid)) > 0) {
text = "你太快了 " + minute + "分钟后重试";
} else {
HttpResponse response = null;
try {
response = ChatGLM2Utils
.create()
.url(glm2Properties.getUrl())
.prompt(content)
.timeout(glm2Properties.getTimeout())
.send();
text = ChatGLM2Utils.parseText(response);
} catch (Exception e) {
log.warn("glm2 doChat warn:", e);
return getErrorText();
}
if (StringUtils.isNotBlank(text)) {
RedisUtils.set(RedisKey.getKey(USER_GLM2_TIME_LAST, uid), new Date(), glm2Properties.getMinute(), TimeUnit.MINUTES);
}
HttpResponse response = null;
try {
response = ChatGLM2Utils
.create()
.url(glm2Properties.getUrl())
.prompt(content)
.timeout(glm2Properties.getTimeout())
.send();
text = ChatGLM2Utils.parseText(response);
} catch (Exception e) {
log.warn("glm2 doChat warn:", e);
return getErrorText();
}
return text;
}

View File

@@ -9,6 +9,10 @@ import com.abin.mallchat.custom.chatai.domain.ChatGPTContext;
import com.abin.mallchat.custom.chatai.domain.ChatGPTMsg;
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTContextBuilder;
import com.abin.mallchat.custom.chatai.domain.builder.ChatGPTMsgBuilder;
import com.abin.mallchat.common.common.domain.dto.FrequencyControlDTO;
import com.abin.mallchat.common.common.exception.FrequencyControlException;
import com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlUtil;
import com.abin.mallchat.custom.chatai.dto.GPTRequestDTO;
import com.abin.mallchat.custom.chatai.properties.ChatGPTProperties;
import com.abin.mallchat.custom.chatai.utils.ChatGPTUtils;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp;
@@ -24,9 +28,15 @@ import java.util.concurrent.TimeUnit;
import static com.abin.mallchat.common.common.constant.RedisKey.USER_CHAT_CONTEXT;
import static com.abin.mallchat.common.common.service.frequencycontrol.FrequencyControlStrategyFactory.TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER;
@Slf4j
@Component
public class GPTChatAIHandler extends AbstractChatAIHandler {
/**
* GPTChatAIHandler限流前缀
*/
private static final String CHAT_FREQUENCY_PREFIX = "GPTChatAIHandler";
@Autowired
private ChatGPTProperties chatGPTProperties;
@@ -36,16 +46,18 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
@Override
protected void init() {
super.init();
UserInfoResp userInfo = userService.getUserInfo(chatGPTProperties.getAIUserId());
if (userInfo == null) {
log.error("根据AIUserId:{} 找不到用户信息", chatGPTProperties.getAIUserId());
throw new RuntimeException("根据AIUserId: " + chatGPTProperties.getAIUserId() + " 找不到用户信息");
if (isUse()) {
UserInfoResp userInfo = userService.getUserInfo(chatGPTProperties.getAIUserId());
if (userInfo == null) {
log.error("根据AIUserId:{} 找不到用户信息", chatGPTProperties.getAIUserId());
throw new RuntimeException("根据AIUserId: " + chatGPTProperties.getAIUserId() + " 找不到用户信息");
}
if (StringUtils.isBlank(userInfo.getName())) {
log.warn("根据AIUserId:{} 找到的用户信息没有name", chatGPTProperties.getAIUserId());
throw new RuntimeException("根据AIUserId: " + chatGPTProperties.getAIUserId() + " 找到的用户没有名字");
}
AI_NAME = userInfo.getName();
}
if (StringUtils.isBlank(userInfo.getName())) {
log.warn("根据AIUserId:{} 找到的用户信息没有name", chatGPTProperties.getAIUserId());
throw new RuntimeException("根据AIUserId: " + chatGPTProperties.getAIUserId() + " 找到的用户没有名字");
}
AI_NAME = userInfo.getName();
}
@Override
@@ -58,37 +70,41 @@ public class GPTChatAIHandler extends AbstractChatAIHandler {
return chatGPTProperties.getAIUserId();
}
@Override
protected String doChat(Message message) {
String prompt = message.getContent().replace("@" + AI_NAME, "").trim();
String content = message.getContent().replace("@" + AI_NAME, "").trim();
Long uid = message.getFromUid();
try {
FrequencyControlDTO frequencyControlDTO = new FrequencyControlDTO();
frequencyControlDTO.setKey(CHAT_FREQUENCY_PREFIX + ":" + uid);
frequencyControlDTO.setUnit(TimeUnit.HOURS);
frequencyControlDTO.setCount(chatGPTProperties.getLimit());
frequencyControlDTO.setTime(24);
return FrequencyControlUtil.executeWithFrequencyControl(TOTAL_COUNT_WITH_IN_FIX_TIME_FREQUENCY_CONTROLLER, frequencyControlDTO, () -> sendRequestToGPT(new GPTRequestDTO(content, uid)));
} catch (FrequencyControlException e) {
return "亲爱的,你今天找我聊了" + chatGPTProperties.getLimit() + "次了~人家累了~明天见";
} catch (Throwable e) {
return "系统开小差啦~~";
}
}
private String sendRequestToGPT(GPTRequestDTO gptRequestDTO) {
String content = gptRequestDTO.getContent();
Long roomId = message.getRoomId();
Long chatNum;
String text;
if ((chatNum = getUserChatNum(uid)) > chatGPTProperties.getLimit()) {
text = "你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧";
} else {
try {
ChatGPTContext context = buildContext(message, prompt);// 构建上下文
context = tailorContext(context);// 裁剪上下文
log.info("prompt = {}" , prompt);
Response response = ChatGPTUtils.create(chatGPTProperties.getKey())
.proxyUrl(chatGPTProperties.getProxyUrl())
.model(chatGPTProperties.getModelName())
.timeout(chatGPTProperties.getTimeout())
.maxTokens(chatGPTProperties.getMaxTokens())
.message(context.getMsg())
.send();
text = ChatGPTUtils.parseText(response);
ChatGPTMsg chatGPTMsg = ChatGPTMsgBuilder.assistantMsg(text);
context.addMsg(chatGPTMsg);
RedisUtils.set(RedisKey.getKey(USER_CHAT_CONTEXT, uid, roomId), context, 1L, TimeUnit.HOURS);
userChatNumInrc(uid);
} catch (Exception e) {
log.warn("gpt doChat warn:", e);
text = "我累了,明天再聊吧";
}
HttpResponse response = null;
try {
response = ChatGPTUtils.create(chatGPTProperties.getKey())
.proxyUrl(chatGPTProperties.getProxyUrl())
.model(chatGPTProperties.getModelName())
.timeout(chatGPTProperties.getTimeout())
.prompt(content)
.send();
text = ChatGPTUtils.parseText(response);
} catch (Exception e) {
log.warn("gpt doChat warn:", e);
text = "我累了,明天再聊吧";
}
return text;
}

View File

@@ -5,6 +5,8 @@ import com.abin.mallchat.common.chat.domain.entity.Message;
import com.abin.mallchat.common.common.event.MessageSendEvent;
import com.abin.mallchat.custom.chat.domain.vo.response.ChatMessageResp;
import com.abin.mallchat.custom.chat.service.ChatService;
import com.abin.mallchat.custom.chat.service.WeChatMsgOperationService;
import com.abin.mallchat.custom.chat.service.impl.WeChatMsgOperationServiceImpl;
import com.abin.mallchat.custom.chatai.service.IChatAIService;
import com.abin.mallchat.custom.user.service.WebSocketService;
import com.abin.mallchat.custom.user.service.adapter.WSAdapter;
@@ -15,6 +17,8 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import org.springframework.transaction.event.TransactionalEventListener;
import java.util.Objects;
/**
* 消息发送监听器
*
@@ -32,6 +36,9 @@ public class MessageSendListener {
@Autowired
private IChatAIService openAIService;
@Autowired
WeChatMsgOperationService weChatMsgOperationService;
@Async
@TransactionalEventListener(classes = MessageSendEvent.class, fallbackExecution = true)
public void notifyAllOnline(MessageSendEvent event) {
@@ -46,4 +53,12 @@ public class MessageSendListener {
openAIService.chat(message);
}
@TransactionalEventListener(classes = MessageSendEvent.class, fallbackExecution = true)
public void publishChatToWechat(@NotNull MessageSendEvent event) {
Message message = messageDao.getById(event.getMsgId());
if (Objects.nonNull(message.getExtra().getAtUidList())) {
weChatMsgOperationService.publishChatMsgToWeChatUser(message.getFromUid(), message.getExtra().getAtUidList(),
message.getContent());
}
}
}

View File

@@ -45,6 +45,11 @@ public class TokenInterceptor implements HandlerInterceptor {
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
MDC.remove(MDCKey.UID);
}
/**
* 判断是不是公共方法,可以未登录访问的
*

View File

@@ -0,0 +1,77 @@
package com.abin.mallchat.custom.user.controller;
import com.abin.mallchat.common.common.domain.vo.request.IdReqVO;
import com.abin.mallchat.common.common.domain.vo.response.ApiResult;
import com.abin.mallchat.common.common.domain.vo.response.IdRespVO;
import com.abin.mallchat.common.common.utils.RequestHolder;
import com.abin.mallchat.custom.user.domain.vo.request.user.UserEmojiReq;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserEmojiResp;
import com.abin.mallchat.custom.user.service.UserEmojiService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import org.springframework.web.bind.annotation.*;
import javax.annotation.Resource;
import javax.validation.Valid;
import java.util.List;
/**
* 用户表情包
*
* @author: WuShiJie
* @createTime: 2023/7/3 14:21
*/
@RestController
@RequestMapping("/capi/user/emoji")
@Api(tags = "用户表情包管理相关接口")
public class UserEmojiController {
/**
* 用户表情包 Service
*/
@Resource
private UserEmojiService emojiService;
/**
* 表情包列表
*
* @return 表情包列表
* @author WuShiJie
* @createTime 2023/7/3 14:46
**/
@GetMapping("/list")
@ApiOperation("表情包列表")
public ApiResult<List<UserEmojiResp>> getEmojisPage() {
return ApiResult.success(emojiService.list(RequestHolder.get().getUid()));
}
/**
* 新增表情包
*
* @param req 用户表情包
* @return 表情包
* @author WuShiJie
* @createTime 2023/7/3 14:46
**/
@PostMapping()
@ApiOperation("新增表情包")
public ApiResult<IdRespVO> insertEmojis(@Valid @RequestBody UserEmojiReq req) {
return emojiService.insert(req, RequestHolder.get().getUid());
}
/**
* 删除表情包
*
* @return 删除结果
* @author WuShiJie
* @createTime 2023/7/3 14:46
**/
@DeleteMapping()
@ApiOperation("删除表情包")
public ApiResult<Void> deleteEmojis(@Valid @RequestBody IdReqVO reqVO) {
emojiService.remove(reqVO.getId(), RequestHolder.get().getUid());
return ApiResult.success();
}
}

View File

@@ -17,6 +17,7 @@ import java.util.stream.Collectors;
@Getter
public enum OssSceneEnum {
CHAT(1, "聊天", "/chat"),
EMOJI(2, "表情包", "/emoji"),
;
private final Integer type;

View File

@@ -23,7 +23,7 @@ public class UploadUrlReq {
@ApiModelProperty(value = "文件名(带后缀)")
@NotBlank
private String fileName;
@ApiModelProperty(value = "上传场景1.聊天室")
@ApiModelProperty(value = "上传场景1.聊天室,2.表情包")
@NotNull
private Integer scene;
}

View File

@@ -0,0 +1,14 @@
package com.abin.mallchat.custom.user.domain.vo.request.user;
import com.abin.mallchat.common.common.domain.vo.request.CursorPageBaseReq;
import lombok.Data;
/**
* 描述此类的作用
*
* @author: WuShiJie
* @createTime: 2023/7/3 14:52
*/
@Data
public class EmojisPageReq extends CursorPageBaseReq {
}

View File

@@ -0,0 +1,26 @@
package com.abin.mallchat.custom.user.domain.vo.request.user;
import io.swagger.annotations.ApiModelProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* Description: 表情包反参
* Author: <a href="https://github.com/zongzibinbin">abin</a>
* Date: 2023-07-09
*/
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class UserEmojiReq {
/**
* 表情地址
*/
@ApiModelProperty(value = "新增的表情url")
private String expressionUrl;
}

View File

@@ -0,0 +1,32 @@
package com.abin.mallchat.custom.user.domain.vo.response.user;
import io.swagger.annotations.ApiModelProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* Description: 表情包反参
* Author: <a href="https://github.com/zongzibinbin">abin</a>
* Date: 2023-07-09
*/
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class UserEmojiResp {
/**
* id
*/
@ApiModelProperty(value = "id")
private Long id;
/**
* 表情地址
*/
@ApiModelProperty(value = "表情url")
private String expressionUrl;
}

View File

@@ -38,4 +38,5 @@ public interface LoginService {
* @return
*/
Long getValidUid(String token);
}

View File

@@ -0,0 +1,45 @@
package com.abin.mallchat.custom.user.service;
import com.abin.mallchat.common.common.domain.vo.response.ApiResult;
import com.abin.mallchat.common.common.domain.vo.response.IdRespVO;
import com.abin.mallchat.custom.user.domain.vo.request.user.UserEmojiReq;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserEmojiResp;
import java.util.List;
/**
* 用户表情包 Service
*
* @author: WuShiJie
* @createTime: 2023/7/3 14:22
*/
public interface UserEmojiService {
/**
* 表情包列表
*
* @return 表情包列表
* @author WuShiJie
* @createTime 2023/7/3 14:46
**/
List<UserEmojiResp> list(Long uid);
/**
* 新增表情包
*
* @param emojis 用户表情包
* @param uid 用户ID
* @return 表情包
* @author WuShiJie
* @createTime 2023/7/3 14:46
**/
ApiResult<IdRespVO> insert(UserEmojiReq emojis, Long uid);
/**
* 删除表情包
*
* @param id
* @param uid
*/
void remove(Long id, Long uid);
}

View File

@@ -17,7 +17,6 @@ import org.springframework.stereotype.Service;
*/
@Service
public class OssServiceImpl implements OssService {
private static final String BUCKET_NAME = "mallchat";
@Autowired
private MinIOTemplate minIOTemplate;

View File

@@ -0,0 +1,75 @@
package com.abin.mallchat.custom.user.service.impl;
import com.abin.mallchat.common.common.annotation.RedissonLock;
import com.abin.mallchat.common.common.domain.vo.response.ApiResult;
import com.abin.mallchat.common.common.domain.vo.response.IdRespVO;
import com.abin.mallchat.common.common.utils.AssertUtil;
import com.abin.mallchat.common.user.dao.UserEmojiDao;
import com.abin.mallchat.common.user.domain.entity.UserEmoji;
import com.abin.mallchat.custom.user.domain.vo.request.user.UserEmojiReq;
import com.abin.mallchat.custom.user.domain.vo.response.user.UserEmojiResp;
import com.abin.mallchat.custom.user.service.UserEmojiService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.stream.Collectors;
/**
* 用户表情包 ServiceImpl
*
* @author: WuShiJie
* @createTime: 2023/7/3 14:23
*/
@Service
@Slf4j
public class UserEmojiServiceImpl implements UserEmojiService {
@Autowired
private UserEmojiDao userEmojiDao;
@Override
public List<UserEmojiResp> list(Long uid) {
return userEmojiDao.listByUid(uid).
stream()
.map(a -> UserEmojiResp.builder()
.id(a.getId())
.expressionUrl(a.getExpressionUrl())
.build())
.collect(Collectors.toList());
}
/**
* 新增表情包
*
* @param uid 用户ID
* @return 表情包
* @author WuShiJie
* @createTime 2023/7/3 14:46
**/
@Override
@RedissonLock(key = "#uid")
public ApiResult<IdRespVO> insert(UserEmojiReq req, Long uid) {
//校验表情数量是否超过30
int count = userEmojiDao.countByUid(uid);
AssertUtil.isFalse(count > 30, "最多只能添加30个表情哦~~");
//校验表情是否存在
Integer existsCount = userEmojiDao.lambdaQuery()
.eq(UserEmoji::getExpressionUrl, req.getExpressionUrl())
.eq(UserEmoji::getUid, uid)
.count();
AssertUtil.isFalse(existsCount > 0, "当前表情已存在哦~~");
UserEmoji insert = UserEmoji.builder().uid(uid).expressionUrl(req.getExpressionUrl()).build();
userEmojiDao.save(insert);
return ApiResult.success(IdRespVO.id(insert.getId()));
}
@Override
public void remove(Long id, Long uid) {
UserEmoji userEmoji = userEmojiDao.getById(id);
AssertUtil.isNotEmpty(userEmoji, "表情不能为空");
AssertUtil.equal(userEmoji.getUid(), uid, "小黑子,别人表情不是你能删的");
userEmojiDao.removeById(id);
}
}

View File

@@ -4,7 +4,7 @@ import cn.hutool.core.util.StrUtil;
import com.abin.mallchat.common.common.event.UserBlackEvent;
import com.abin.mallchat.common.common.event.UserRegisterEvent;
import com.abin.mallchat.common.common.utils.AssertUtil;
import com.abin.mallchat.common.common.utils.SensitiveWordUtils;
import com.abin.mallchat.common.common.utils.sensitiveWord.SensitiveWordBs;
import com.abin.mallchat.common.user.dao.BlackDao;
import com.abin.mallchat.common.user.dao.ItemConfigDao;
import com.abin.mallchat.common.user.dao.UserBackpackDao;
@@ -63,6 +63,8 @@ public class UserServiceImpl implements UserService {
private BlackDao blackDao;
@Autowired
private UserSummaryCache userSummaryCache;
@Autowired
private SensitiveWordBs sensitiveWordBs;
@Override
public UserInfoResp getUserInfo(Long uid) {
@@ -76,7 +78,7 @@ public class UserServiceImpl implements UserService {
public void modifyName(Long uid, ModifyNameReq req) {
//判断名字是不是重复
String newName = req.getName();
AssertUtil.isFalse(SensitiveWordUtils.hasSensitiveWord(newName), "名字中包含敏感词,请重新输入"); // 判断名字中有没有敏感词
AssertUtil.isFalse(sensitiveWordBs.hasSensitiveWord(newName), "名字中包含敏感词,请重新输入"); // 判断名字中有没有敏感词
User oldUser = userDao.getByName(newName);
AssertUtil.isEmpty(oldUser, "名字已经被抢占了,请换一个哦~~");
//判断改名卡够不够

View File

@@ -2,7 +2,6 @@ package com.abin.mallchat.custom.user.service.impl;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.json.JSONUtil;
import com.abin.mallchat.common.common.annotation.FrequencyControl;
import com.abin.mallchat.common.common.config.ThreadPoolConfig;
@@ -20,6 +19,8 @@ import com.abin.mallchat.custom.user.service.LoginService;
import com.abin.mallchat.custom.user.service.WebSocketService;
import com.abin.mallchat.custom.user.service.adapter.WSAdapter;
import com.abin.mallchat.custom.user.websocket.NettyUtil;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.SneakyThrows;
@@ -32,9 +33,13 @@ import org.springframework.context.ApplicationEventPublisher;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
@@ -47,12 +52,18 @@ import java.util.concurrent.locks.ReentrantLock;
@Slf4j
public class WebSocketServiceImpl implements WebSocketService {
private static final Duration EXPIRE_TIME = Duration.ofHours(1);
private static final Long MAX_MUM_SIZE = 10000L;
private static final AtomicInteger CODE = new AtomicInteger();
/**
* 所有请求登录的code与channel关系
* todo 有可能有人请求了二维码,就是不登录,留个坑,之后处理
*/
private static final ConcurrentHashMap<Integer, Channel> WAIT_LOGIN_MAP = new ConcurrentHashMap<>();
private static final Cache<Integer, Channel> WAIT_LOGIN_MAP = Caffeine.newBuilder()
.expireAfterWrite(EXPIRE_TIME)
.maximumSize(MAX_MUM_SIZE)
.build();
/**
* 所有已连接的websocket连接列表和一些额外参数
*/
@@ -66,7 +77,6 @@ public class WebSocketServiceImpl implements WebSocketService {
return ONLINE_WS_MAP;
}
public static final int EXPIRE_SECONDS = 60 * 60;
@Autowired
private WxMpService wxMpService;
@Autowired
@@ -95,7 +105,7 @@ public class WebSocketServiceImpl implements WebSocketService {
//生成随机不重复的登录码
Integer code = generateLoginCode(channel);
//请求微信接口,获取登录码地址
WxMpQrCodeTicket wxMpQrCodeTicket = wxMpService.getQrcodeService().qrCodeCreateTmpTicket(code, EXPIRE_SECONDS);
WxMpQrCodeTicket wxMpQrCodeTicket = wxMpService.getQrcodeService().qrCodeCreateTmpTicket(code, (int) EXPIRE_TIME.getSeconds());
//返回给前端
sendMsg(channel, WSAdapter.buildLoginResp(wxMpQrCodeTicket));
}
@@ -107,12 +117,11 @@ public class WebSocketServiceImpl implements WebSocketService {
* @return
*/
private Integer generateLoginCode(Channel channel) {
int code;
do {
code = RandomUtil.randomInt(Integer.MAX_VALUE);
} while (WAIT_LOGIN_MAP.contains(code)
|| Objects.nonNull(WAIT_LOGIN_MAP.putIfAbsent(code, channel)));
return code;
CODE.getAndIncrement();
} while (WAIT_LOGIN_MAP.asMap().containsKey(CODE.get())
|| Objects.isNull(WAIT_LOGIN_MAP.get(CODE.get(), c -> channel)));
return CODE.get();
}
/**
@@ -199,12 +208,12 @@ public class WebSocketServiceImpl implements WebSocketService {
@Override
public Boolean scanLoginSuccess(Integer loginCode, User user, String token) {
//发送消息
Channel channel = WAIT_LOGIN_MAP.get(loginCode);
Channel channel = WAIT_LOGIN_MAP.getIfPresent(loginCode);
if (Objects.isNull(channel)) {
return Boolean.FALSE;
}
//移除code
WAIT_LOGIN_MAP.remove(loginCode);
WAIT_LOGIN_MAP.invalidate(loginCode);
//用户登录
loginSuccess(channel, user, token);
return true;
@@ -212,7 +221,7 @@ public class WebSocketServiceImpl implements WebSocketService {
@Override
public Boolean scanSuccess(Integer loginCode) {
Channel channel = WAIT_LOGIN_MAP.get(loginCode);
Channel channel = WAIT_LOGIN_MAP.getIfPresent(loginCode);
if (Objects.isNull(channel)) {
return Boolean.FALSE;
}
@@ -287,4 +296,6 @@ public class WebSocketServiceImpl implements WebSocketService {
reentrantLock.unlock();
Thread.sleep(1000);
}
}

View File

@@ -1,5 +1,6 @@
package com.abin.mallchat.custom.user.websocket;
import cn.hutool.core.net.url.UrlBuilder;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpRequest;
@@ -7,13 +8,23 @@ import io.netty.handler.codec.http.HttpHeaders;
import org.apache.commons.lang3.StringUtils;
import java.net.InetSocketAddress;
import java.util.Optional;
public class HttpHeadersHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
HttpHeaders headers = ((FullHttpRequest) msg).headers();
FullHttpRequest request = (FullHttpRequest) msg;
UrlBuilder urlBuilder = UrlBuilder.ofHttp(request.uri());
// 获取token参数
String token = Optional.ofNullable(urlBuilder.getQuery()).map(k->k.get("token")).map(CharSequence::toString).orElse("");
NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token);
// 获取请求路径
request.setUri(urlBuilder.getPath().toString());
HttpHeaders headers = request.headers();
String ip = headers.get("X-Real-IP");
if (StringUtils.isEmpty(ip)) {//如果没经过nginx就直接获取远端地址
InetSocketAddress address = (InetSocketAddress) ctx.channel().remoteAddress();
@@ -21,7 +32,10 @@ public class HttpHeadersHandler extends ChannelInboundHandlerAdapter {
}
NettyUtil.setAttr(ctx.channel(), NettyUtil.IP, ip);
ctx.pipeline().remove(this);
ctx.fireChannelRead(request);
}else
{
ctx.fireChannelRead(msg);
}
ctx.fireChannelRead(msg);
}
}

View File

@@ -1,6 +1,7 @@
package com.abin.mallchat.custom.user.websocket;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
@@ -15,6 +16,7 @@ public class NettyUtil {
public static AttributeKey<String> TOKEN = AttributeKey.valueOf("token");
public static AttributeKey<String> IP = AttributeKey.valueOf("ip");
public static AttributeKey<Long> UID = AttributeKey.valueOf("uid");
public static AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY = AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");
public static <T> void setAttr(Channel channel, AttributeKey<T> attributeKey, T data) {
Attribute<T> attr = channel.attr(attributeKey);

View File

@@ -10,6 +10,7 @@ import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
@@ -88,7 +89,7 @@ public class NettyWebSocketServer {
* 4. WebSocketServerProtocolHandler 核心功能是把 http协议升级为 ws 协议,保持长连接;
* 是通过一个状态码 101 来切换的
*/
pipeline.addLast(new WebSocketHandshakeHandler());
pipeline.addLast(new WebSocketServerProtocolHandler("/"));
// 自定义handler ,处理业务逻辑
pipeline.addLast(new NettyWebSocketServerHandler());
}

View File

@@ -19,10 +19,12 @@ import lombok.extern.slf4j.Slf4j;
@Slf4j
public class NettyWebSocketServerHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
private WebSocketService webSocketService;
// 当web客户端连接后触发该方法
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
// getService().connect(ctx.channel());
this.webSocketService = getService();
}
// 客户端离线
@@ -45,7 +47,7 @@ public class NettyWebSocketServerHandler extends SimpleChannelInboundHandler<Tex
}
private void userOffLine(ChannelHandlerContext ctx) {
getService().removed(ctx.channel());
this.webSocketService.removed(ctx.channel());
ctx.channel().close();
}
@@ -65,11 +67,11 @@ public class NettyWebSocketServerHandler extends SimpleChannelInboundHandler<Tex
// 关闭用户的连接
userOffLine(ctx);
}
} else if (evt == WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) {
getService().connect(ctx.channel());
} else if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
this.webSocketService.connect(ctx.channel());
String token = NettyUtil.getAttr(ctx.channel(), NettyUtil.TOKEN);
if (StrUtil.isNotBlank(token)) {
getService().authorize(ctx.channel(), new WSAuthorize(token));
this.webSocketService.authorize(ctx.channel(), new WSAuthorize(token));
}
}
super.userEventTriggered(ctx, evt);
@@ -93,13 +95,13 @@ public class NettyWebSocketServerHandler extends SimpleChannelInboundHandler<Tex
WSReqTypeEnum wsReqTypeEnum = WSReqTypeEnum.of(wsBaseReq.getType());
switch (wsReqTypeEnum) {
case LOGIN:
getService().handleLoginReq(ctx.channel());
this.webSocketService.handleLoginReq(ctx.channel());
log.info("请求二维码 = " + msg.text());
break;
case HEARTBEAT:
break;
case AUTHORIZE:
getService().authorize(ctx.channel(), JSONUtil.toBean(wsBaseReq.getData(), WSAuthorize.class));
this.webSocketService.authorize(ctx.channel(), JSONUtil.toBean(wsBaseReq.getData(), WSAuthorize.class));
log.info("主动认证 = " + msg.text());
break;
default:

View File

@@ -1,42 +0,0 @@
package com.abin.mallchat.custom.user.websocket;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
public class WebSocketHandshakeHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
FullHttpRequest request = (FullHttpRequest) msg;
String token = request.headers().get("Sec-Websocket-Protocol");
NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token);
// 构建WebSocket握手处理器
WebSocketServerHandshakerFactory handshakeFactory = new WebSocketServerHandshakerFactory(
request.uri(), token, false);
WebSocketServerHandshaker handshake = handshakeFactory.newHandshaker(request);
final ChannelFuture handshakeFuture = handshake.handshake(ctx.channel(), request);
ctx.pipeline().remove(this);
handshakeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
if (!future.isSuccess()) {
ctx.fireExceptionCaught(future.cause());
} else {
// 手动触发WebSocket握手状态事件
ctx.fireUserEventTriggered(
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
}
}
});
} else {
super.channelRead(ctx, msg);
}
}
}