340 lines
13 KiB
Java
340 lines
13 KiB
Java
/*
|
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
* contributor license agreements. See the NOTICE file distributed with
|
|
* this work for additional information regarding copyright ownership.
|
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
* (the "License"); you may not use this file except in compliance with
|
|
* the License. You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
package org.apache.tomcat.websocket.server;
|
|
|
|
import java.io.IOException;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.util.ArrayList;
|
|
import java.util.Collections;
|
|
import java.util.Enumeration;
|
|
import java.util.LinkedHashMap;
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
import java.util.Map.Entry;
|
|
|
|
import javax.servlet.ServletException;
|
|
import javax.servlet.ServletRequest;
|
|
import javax.servlet.ServletResponse;
|
|
import javax.servlet.http.HttpServletRequest;
|
|
import javax.servlet.http.HttpServletResponse;
|
|
import javax.websocket.Endpoint;
|
|
import javax.websocket.Extension;
|
|
import javax.websocket.HandshakeResponse;
|
|
import javax.websocket.server.ServerEndpointConfig;
|
|
|
|
import org.apache.tomcat.util.codec.binary.Base64;
|
|
import org.apache.tomcat.util.res.StringManager;
|
|
import org.apache.tomcat.util.security.ConcurrentMessageDigest;
|
|
import org.apache.tomcat.websocket.Constants;
|
|
import org.apache.tomcat.websocket.Transformation;
|
|
import org.apache.tomcat.websocket.TransformationFactory;
|
|
import org.apache.tomcat.websocket.Util;
|
|
import org.apache.tomcat.websocket.WsHandshakeResponse;
|
|
import org.apache.tomcat.websocket.pojo.PojoEndpointServer;
|
|
|
|
public class UpgradeUtil {
|
|
|
|
private static final StringManager sm =
|
|
StringManager.getManager(UpgradeUtil.class.getPackage().getName());
|
|
private static final byte[] WS_ACCEPT =
|
|
"258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(
|
|
StandardCharsets.ISO_8859_1);
|
|
|
|
private UpgradeUtil() {
|
|
// Utility class. Hide default constructor.
|
|
}
|
|
|
|
/**
|
|
* Checks to see if this is an HTTP request that includes a valid upgrade
|
|
* request to web socket.
|
|
* <p>
|
|
* Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java
|
|
* WebSocket spec 1.0, section 8.2 implies such a limitation and RFC
|
|
* 6455 section 4.1 requires that a WebSocket Upgrade uses GET.
|
|
* @param request The request to check if it is an HTTP upgrade request for
|
|
* a WebSocket connection
|
|
* @param response The response associated with the request
|
|
* @return <code>true</code> if the request includes an HTTP Upgrade request
|
|
* for the WebSocket protocol, otherwise <code>false</code>
|
|
*/
|
|
public static boolean isWebSocketUpgradeRequest(ServletRequest request,
|
|
ServletResponse response) {
|
|
|
|
return ((request instanceof HttpServletRequest) &&
|
|
(response instanceof HttpServletResponse) &&
|
|
headerContainsToken((HttpServletRequest) request,
|
|
Constants.UPGRADE_HEADER_NAME,
|
|
Constants.UPGRADE_HEADER_VALUE) &&
|
|
"GET".equals(((HttpServletRequest) request).getMethod()));
|
|
}
|
|
|
|
|
|
public static void doUpgrade(WsServerContainer sc, HttpServletRequest req,
|
|
HttpServletResponse resp, ServerEndpointConfig sec,
|
|
Map<String,String> pathParams)
|
|
throws ServletException, IOException {
|
|
|
|
// Validate the rest of the headers and reject the request if that
|
|
// validation fails
|
|
String key;
|
|
String subProtocol = null;
|
|
if (!headerContainsToken(req, Constants.CONNECTION_HEADER_NAME,
|
|
Constants.CONNECTION_HEADER_VALUE)) {
|
|
resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
|
|
return;
|
|
}
|
|
if (!headerContainsToken(req, Constants.WS_VERSION_HEADER_NAME,
|
|
Constants.WS_VERSION_HEADER_VALUE)) {
|
|
resp.setStatus(426);
|
|
resp.setHeader(Constants.WS_VERSION_HEADER_NAME,
|
|
Constants.WS_VERSION_HEADER_VALUE);
|
|
return;
|
|
}
|
|
key = req.getHeader(Constants.WS_KEY_HEADER_NAME);
|
|
if (key == null) {
|
|
resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
|
|
return;
|
|
}
|
|
|
|
|
|
// Origin check
|
|
String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME);
|
|
if (!sec.getConfigurator().checkOrigin(origin)) {
|
|
resp.sendError(HttpServletResponse.SC_FORBIDDEN);
|
|
return;
|
|
}
|
|
// Sub-protocols
|
|
List<String> subProtocols = getTokensFromHeader(req,
|
|
Constants.WS_PROTOCOL_HEADER_NAME);
|
|
subProtocol = sec.getConfigurator().getNegotiatedSubprotocol(
|
|
sec.getSubprotocols(), subProtocols);
|
|
|
|
// Extensions
|
|
// Should normally only be one header but handle the case of multiple
|
|
// headers
|
|
List<Extension> extensionsRequested = new ArrayList<>();
|
|
Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME);
|
|
while (extHeaders.hasMoreElements()) {
|
|
Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement());
|
|
}
|
|
// Negotiation phase 1. By default this simply filters out the
|
|
// extensions that the server does not support but applications could
|
|
// use a custom configurator to do more than this.
|
|
List<Extension> installedExtensions = null;
|
|
if (sec.getExtensions().size() == 0) {
|
|
installedExtensions = Constants.INSTALLED_EXTENSIONS;
|
|
} else {
|
|
installedExtensions = new ArrayList<>();
|
|
installedExtensions.addAll(sec.getExtensions());
|
|
installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS);
|
|
}
|
|
List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions(
|
|
installedExtensions, extensionsRequested);
|
|
|
|
// Negotiation phase 2. Create the Transformations that will be applied
|
|
// to this connection. Note than an extension may be dropped at this
|
|
// point if the client has requested a configuration that the server is
|
|
// unable to support.
|
|
List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1);
|
|
|
|
List<Extension> negotiatedExtensionsPhase2;
|
|
if (transformations.isEmpty()) {
|
|
negotiatedExtensionsPhase2 = Collections.emptyList();
|
|
} else {
|
|
negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size());
|
|
for (Transformation t : transformations) {
|
|
negotiatedExtensionsPhase2.add(t.getExtensionResponse());
|
|
}
|
|
}
|
|
|
|
// Build the transformation pipeline
|
|
Transformation transformation = null;
|
|
StringBuilder responseHeaderExtensions = new StringBuilder();
|
|
boolean first = true;
|
|
for (Transformation t : transformations) {
|
|
if (first) {
|
|
first = false;
|
|
} else {
|
|
responseHeaderExtensions.append(',');
|
|
}
|
|
append(responseHeaderExtensions, t.getExtensionResponse());
|
|
if (transformation == null) {
|
|
transformation = t;
|
|
} else {
|
|
transformation.setNext(t);
|
|
}
|
|
}
|
|
|
|
// Now we have the full pipeline, validate the use of the RSV bits.
|
|
if (transformation != null && !transformation.validateRsvBits(0)) {
|
|
throw new ServletException(sm.getString("upgradeUtil.incompatibleRsv"));
|
|
}
|
|
|
|
// If we got this far, all is good. Accept the connection.
|
|
resp.setHeader(Constants.UPGRADE_HEADER_NAME,
|
|
Constants.UPGRADE_HEADER_VALUE);
|
|
resp.setHeader(Constants.CONNECTION_HEADER_NAME,
|
|
Constants.CONNECTION_HEADER_VALUE);
|
|
resp.setHeader(HandshakeResponse.SEC_WEBSOCKET_ACCEPT,
|
|
getWebSocketAccept(key));
|
|
if (subProtocol != null && subProtocol.length() > 0) {
|
|
// RFC6455 4.2.2 explicitly states "" is not valid here
|
|
resp.setHeader(Constants.WS_PROTOCOL_HEADER_NAME, subProtocol);
|
|
}
|
|
if (!transformations.isEmpty()) {
|
|
resp.setHeader(Constants.WS_EXTENSIONS_HEADER_NAME, responseHeaderExtensions.toString());
|
|
}
|
|
|
|
WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams);
|
|
WsHandshakeResponse wsResponse = new WsHandshakeResponse();
|
|
WsPerSessionServerEndpointConfig perSessionServerEndpointConfig =
|
|
new WsPerSessionServerEndpointConfig(sec);
|
|
sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig,
|
|
wsRequest, wsResponse);
|
|
wsRequest.finished();
|
|
|
|
// Add any additional headers
|
|
for (Entry<String,List<String>> entry :
|
|
wsResponse.getHeaders().entrySet()) {
|
|
for (String headerValue: entry.getValue()) {
|
|
resp.addHeader(entry.getKey(), headerValue);
|
|
}
|
|
}
|
|
|
|
Endpoint ep;
|
|
try {
|
|
Class<?> clazz = sec.getEndpointClass();
|
|
if (Endpoint.class.isAssignableFrom(clazz)) {
|
|
ep = (Endpoint) sec.getConfigurator().getEndpointInstance(
|
|
clazz);
|
|
} else {
|
|
ep = new PojoEndpointServer();
|
|
// Need to make path params available to POJO
|
|
perSessionServerEndpointConfig.getUserProperties().put(
|
|
org.apache.tomcat.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams);
|
|
}
|
|
} catch (InstantiationException e) {
|
|
throw new ServletException(e);
|
|
}
|
|
|
|
WsHttpUpgradeHandler wsHandler =
|
|
req.upgrade(WsHttpUpgradeHandler.class);
|
|
wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest,
|
|
negotiatedExtensionsPhase2, subProtocol, transformation, pathParams,
|
|
req.isSecure());
|
|
|
|
}
|
|
|
|
|
|
private static List<Transformation> createTransformations(
|
|
List<Extension> negotiatedExtensions) {
|
|
|
|
TransformationFactory factory = TransformationFactory.getInstance();
|
|
|
|
LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences =
|
|
new LinkedHashMap<>();
|
|
|
|
// Result will likely be smaller than this
|
|
List<Transformation> result = new ArrayList<>(negotiatedExtensions.size());
|
|
|
|
for (Extension extension : negotiatedExtensions) {
|
|
List<List<Extension.Parameter>> preferences =
|
|
extensionPreferences.get(extension.getName());
|
|
|
|
if (preferences == null) {
|
|
preferences = new ArrayList<>();
|
|
extensionPreferences.put(extension.getName(), preferences);
|
|
}
|
|
|
|
preferences.add(extension.getParameters());
|
|
}
|
|
|
|
for (Map.Entry<String,List<List<Extension.Parameter>>> entry :
|
|
extensionPreferences.entrySet()) {
|
|
Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true);
|
|
if (transformation != null) {
|
|
result.add(transformation);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
private static void append(StringBuilder sb, Extension extension) {
|
|
if (extension == null || extension.getName() == null || extension.getName().length() == 0) {
|
|
return;
|
|
}
|
|
|
|
sb.append(extension.getName());
|
|
|
|
for (Extension.Parameter p : extension.getParameters()) {
|
|
sb.append(';');
|
|
sb.append(p.getName());
|
|
if (p.getValue() != null) {
|
|
sb.append('=');
|
|
sb.append(p.getValue());
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
/*
|
|
* This only works for tokens. Quoted strings need more sophisticated
|
|
* parsing.
|
|
*/
|
|
private static boolean headerContainsToken(HttpServletRequest req,
|
|
String headerName, String target) {
|
|
Enumeration<String> headers = req.getHeaders(headerName);
|
|
while (headers.hasMoreElements()) {
|
|
String header = headers.nextElement();
|
|
String[] tokens = header.split(",");
|
|
for (String token : tokens) {
|
|
if (target.equalsIgnoreCase(token.trim())) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
|
|
/*
|
|
* This only works for tokens. Quoted strings need more sophisticated
|
|
* parsing.
|
|
*/
|
|
private static List<String> getTokensFromHeader(HttpServletRequest req,
|
|
String headerName) {
|
|
List<String> result = new ArrayList<>();
|
|
Enumeration<String> headers = req.getHeaders(headerName);
|
|
while (headers.hasMoreElements()) {
|
|
String header = headers.nextElement();
|
|
String[] tokens = header.split(",");
|
|
for (String token : tokens) {
|
|
result.add(token.trim());
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
private static String getWebSocketAccept(String key) {
|
|
byte[] digest = ConcurrentMessageDigest.digestSHA1(
|
|
key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT);
|
|
return Base64.encodeBase64String(digest);
|
|
}
|
|
}
|