diff --git a/mallchat-common/src/test/java/com/abin/mallchat/common/common/algorithm/ac/CreateTokenTest.java b/mallchat-common/src/test/java/com/abin/mallchat/common/common/algorithm/ac/CreateTokenTest.java index 34d4035..06c1845 100644 --- a/mallchat-common/src/test/java/com/abin/mallchat/common/common/algorithm/ac/CreateTokenTest.java +++ b/mallchat-common/src/test/java/com/abin/mallchat/common/common/algorithm/ac/CreateTokenTest.java @@ -31,5 +31,4 @@ public class CreateTokenTest { log.info("decode error,token:{}", token, e); } } - } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/HttpHeadersHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/HttpHeadersHandler.java index 30f6cdd..b6cdd43 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/HttpHeadersHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/HttpHeadersHandler.java @@ -1,5 +1,6 @@ package com.abin.mallchat.custom.user.websocket; +import cn.hutool.core.net.url.UrlBuilder; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.codec.http.FullHttpRequest; @@ -13,7 +14,16 @@ public class HttpHeadersHandler extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof FullHttpRequest) { - HttpHeaders headers = ((FullHttpRequest) msg).headers(); + FullHttpRequest request = (FullHttpRequest) msg; + UrlBuilder urlBuilder = UrlBuilder.ofHttp(request.uri()); + + // 获取token参数 + String token = urlBuilder.getQuery().get("token").toString(); + NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token); + + // 获取请求路径 + request.setUri(urlBuilder.getPath().toString()); + HttpHeaders headers = request.headers(); String ip = headers.get("X-Real-IP"); if (StringUtils.isEmpty(ip)) {//如果没经过nginx,就直接获取远端地址 InetSocketAddress address = (InetSocketAddress) ctx.channel().remoteAddress(); @@ -21,7 +31,10 @@ public class HttpHeadersHandler extends ChannelInboundHandlerAdapter { } NettyUtil.setAttr(ctx.channel(), NettyUtil.IP, ip); ctx.pipeline().remove(this); + ctx.fireChannelRead(request); + }else + { + ctx.fireChannelRead(msg); } - ctx.fireChannelRead(msg); } } \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/NettyWebSocketServer.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/NettyWebSocketServer.java index 3448528..4ebf470 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/NettyWebSocketServer.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/NettyWebSocketServer.java @@ -10,6 +10,7 @@ import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LoggingHandler; import io.netty.handler.stream.ChunkedWriteHandler; @@ -88,7 +89,7 @@ public class NettyWebSocketServer { * 4. WebSocketServerProtocolHandler 核心功能是把 http协议升级为 ws 协议,保持长连接; * 是通过一个状态码 101 来切换的 */ - pipeline.addLast(new NettyWebSocketServerProtocolHandler("/")); + pipeline.addLast(new WebSocketServerProtocolHandler("/")); // 自定义handler ,处理业务逻辑 pipeline.addLast(new NettyWebSocketServerHandler()); } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/NettyWebSocketServerHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/NettyWebSocketServerHandler.java index ecbbb24..dbc62f9 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/NettyWebSocketServerHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/NettyWebSocketServerHandler.java @@ -71,7 +71,7 @@ public class NettyWebSocketServerHandler extends SimpleChannelInboundHandler out) throws Exception { - if (this.webSocketServerProtocolConfig.handleCloseFrames() && frame instanceof CloseWebSocketFrame) { - WebSocketServerHandshaker handshaker = NettyUtil.getAttr(ctx.channel(),NettyUtil.HANDSHAKER_ATTR_KEY); - if (handshaker != null) { - frame.retain(); - ChannelPromise promise = ctx.newPromise(); - Method closeSent = ReflectUtil.getMethod(super.getClass(), "closeSent", ChannelPromise.class); - closeSent.setAccessible(true); - closeSent.invoke(this,promise); - handshaker.close(ctx, (CloseWebSocketFrame)frame, promise); - } else { - ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); - } - } else { - super.decode(ctx, frame, out); - } - } -} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/WebSocketHandshakeHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/WebSocketHandshakeHandler.java deleted file mode 100644 index 8054c90..0000000 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/user/websocket/WebSocketHandshakeHandler.java +++ /dev/null @@ -1,163 +0,0 @@ -package com.abin.mallchat.custom.user.websocket; - - -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelPipeline; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpObject; -import io.netty.handler.codec.http.HttpRequest; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpUtil; -import io.netty.handler.codec.http.HttpVersion; -import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakeException; -import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; -import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory; -import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig; -import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; -import io.netty.handler.ssl.SslHandler; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.concurrent.Future; -import io.netty.util.concurrent.FutureListener; -import io.netty.util.internal.ObjectUtil; - -import java.util.concurrent.TimeUnit; - -public class WebSocketHandshakeHandler extends ChannelInboundHandlerAdapter { - - private final WebSocketServerProtocolConfig serverConfig; - private ChannelHandlerContext ctx; - private ChannelPromise handshakePromise; - private boolean isWebSocketPath; - - public WebSocketHandshakeHandler(WebSocketServerProtocolConfig serverConfig) { - this.serverConfig = (WebSocketServerProtocolConfig)ObjectUtil.checkNotNull(serverConfig, "serverConfig"); - } - - @Override - public void handlerAdded(ChannelHandlerContext ctx) { - this.ctx = ctx; - this.handshakePromise = ctx.newPromise(); - } - - @Override - public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { - HttpObject httpObject = (HttpObject)msg; - if (httpObject instanceof HttpRequest) { - final HttpRequest req = (HttpRequest)httpObject; - this.isWebSocketPath = this.isWebSocketPath(req); - if (!this.isWebSocketPath) { - ctx.fireChannelRead(msg); - return; - } - try { - if (HttpMethod.GET.equals(req.method())) { - final String token = req.headers().get("Sec-Websocket-Protocol"); - WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(ctx.pipeline(), req, this.serverConfig.websocketPath()), token, this.serverConfig.decoderConfig()); - final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req); - NettyUtil.setAttr(ctx.channel(),NettyUtil.HANDSHAKER_ATTR_KEY,handshaker); - final ChannelPromise localHandshakePromise = this.handshakePromise; - if (handshaker == null) { - WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); - } else { - - ctx.pipeline().remove(this); - ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req); - handshakeFuture.addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) { - if (!future.isSuccess()) { - localHandshakePromise.tryFailure(future.cause()); - ctx.fireExceptionCaught(future.cause()); - } else { - localHandshakePromise.trySuccess(); - NettyUtil.setAttr(ctx.channel(), NettyUtil.TOKEN, token); - ctx.fireUserEventTriggered(WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE); - } - } - }); - this.applyHandshakeTimeout(); - } - - return; - } - - sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN, ctx.alloc().buffer(0))); - } finally { - ReferenceCountUtil.release(req); - } - - return; - } else if (!this.isWebSocketPath) { - ctx.fireChannelRead(msg); - } else { - ReferenceCountUtil.release(msg); - } - - } - - private boolean isWebSocketPath(HttpRequest req) { - String websocketPath = this.serverConfig.websocketPath(); - String uri = req.uri(); - boolean checkStartUri = uri.startsWith(websocketPath); - boolean checkNextUri = "/".equals(websocketPath) || this.checkNextUri(uri, websocketPath); - return this.serverConfig.checkStartsWith() ? checkStartUri && checkNextUri : uri.equals(websocketPath); - } - - private boolean checkNextUri(String uri, String websocketPath) { - int len = websocketPath.length(); - if (uri.length() <= len) { - return true; - } else { - char nextUri = uri.charAt(len); - return nextUri == '/' || nextUri == '?'; - } - } - - private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) { - ChannelFuture f = ctx.channel().writeAndFlush(res); - if (!HttpUtil.isKeepAlive(req) || res.status().code() != 200) { - f.addListener(ChannelFutureListener.CLOSE); - } - - } - - private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) { - String protocol = "ws"; - if (cp.get(SslHandler.class) != null) { - protocol = "wss"; - } - - String host = req.headers().get(HttpHeaderNames.HOST); - return protocol + "://" + host + path; - } - - private void applyHandshakeTimeout() { - final ChannelPromise localHandshakePromise = this.handshakePromise; - long handshakeTimeoutMillis = this.serverConfig.handshakeTimeoutMillis(); - if (handshakeTimeoutMillis > 0L && !localHandshakePromise.isDone()) { - final Future timeoutFuture = this.ctx.executor().schedule(new Runnable() { - @Override - public void run() { - if (!localHandshakePromise.isDone() && localHandshakePromise.tryFailure(new WebSocketServerHandshakeException("handshake timed out"))) { - WebSocketHandshakeHandler.this.ctx.flush().fireUserEventTriggered(WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT).close(); - } - - } - }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); - localHandshakePromise.addListener(new FutureListener() { - @Override - public void operationComplete(Future f) { - timeoutFuture.cancel(false); - } - }); - } - } - -}