This commit is contained in:
2024-11-30 19:03:49 +08:00
commit 1e6763c160
3806 changed files with 737676 additions and 0 deletions

View File

@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.IOException;
import java.net.URI;
import java.util.concurrent.atomic.AtomicInteger;
import javax.websocket.ClientEndpointConfig.Builder;
import javax.websocket.ContainerProvider;
import javax.websocket.DeploymentException;
import javax.websocket.WebSocketContainer;
import org.junit.Ignore;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.catalina.startup.TomcatBaseTest;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
@Ignore // Not for use in normal unit test runs
public class TestConnectionLimit extends TomcatBaseTest {
/*
* Simple test to see how many outgoing connections can be created on a
* single machine.
*/
@Test
public void testSingleMachine() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.getConnector().setAttribute("maxConnections", "-1");
tomcat.start();
URI uri = new URI("ws://localhost:" + getPort() +
TesterEchoServer.Config.PATH_ASYNC);
AtomicInteger counter = new AtomicInteger(0);
int threadCount = 50;
Thread[] threads = new ConnectionThread[threadCount];
for (int i = 0; i < threadCount; i++) {
threads[i] = new ConnectionThread(counter, uri);
threads[i].start();
}
// Wait for the threads to die
for (Thread thread : threads) {
thread.join();
}
System.out.println("Maximum connection count was " + counter.get());
}
private static class ConnectionThread extends Thread {
private final AtomicInteger counter;
private final URI uri;
private ConnectionThread(AtomicInteger counter, URI uri) {
this.counter = counter;
this.uri = uri;
}
@Override
public void run() {
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
int count = 0;
try {
while (true) {
wsContainer.connectToServer(TesterProgrammaticEndpoint.class,
Builder.create().build(), uri);
count = counter.incrementAndGet();
if (count % 100 == 0) {
System.out.println(count + " and counting...");
}
}
} catch (IOException | DeploymentException ioe) {
// Let thread die
}
}
}
}

View File

@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.websocket.Extension;
import javax.websocket.Extension.Parameter;
import org.junit.Test;
public class TestPerMessageDeflate {
/*
* https://bz.apache.org/bugzilla/show_bug.cgi?id=61491
*/
@Test
public void testSendEmptyMessagePartWithContextTakeover() throws IOException {
// Set up the extension using defaults
List<Parameter> parameters = Collections.emptyList();
List<List<Parameter>> preferences = new ArrayList<>();
preferences.add(parameters);
PerMessageDeflate perMessageDeflate = PerMessageDeflate.negotiate(preferences, true);
perMessageDeflate.setNext(new TesterTransformation());
ByteBuffer bb1 = ByteBuffer.wrap("A".getBytes(StandardCharsets.UTF_8));
MessagePart mp1 = new MessagePart(true, 0, Constants.OPCODE_TEXT, bb1, null, null, -1);
List<MessagePart> uncompressedParts1 = new ArrayList<>();
uncompressedParts1.add(mp1);
perMessageDeflate.sendMessagePart(uncompressedParts1);
ByteBuffer bb2 = ByteBuffer.wrap("".getBytes(StandardCharsets.UTF_8));
MessagePart mp2 = new MessagePart(true, 0, Constants.OPCODE_TEXT, bb2, null, null, -1);
List<MessagePart> uncompressedParts2 = new ArrayList<>();
uncompressedParts2.add(mp2);
perMessageDeflate.sendMessagePart(uncompressedParts2);
}
/*
* Minimal implementation to enable other transformations to be tested.
*/
private static class TesterTransformation implements Transformation {
@Override
public boolean validateRsvBits(int i) {
return false;
}
@Override
public boolean validateRsv(int rsv, byte opCode) {
return false;
}
@Override
public void setNext(Transformation t) {
}
@Override
public List<MessagePart> sendMessagePart(List<MessagePart> messageParts) {
return messageParts;
}
@Override
public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest)
throws IOException {
return null;
}
@Override
public Extension getExtensionResponse() {
return null;
}
@Override
public void close() {
}
}
}

View File

@@ -0,0 +1,469 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.util.ArrayList;
import java.util.List;
import javax.websocket.EncodeException;
import javax.websocket.Encoder;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import javax.websocket.Extension.Parameter;
import javax.websocket.MessageHandler;
import org.junit.Assert;
import org.junit.Test;
public class TestUtil {
// Used to init SecureRandom prior to running tests
public static void generateMask() {
Util.generateMask();
}
@Test
public void testGetMessageTypeSimple() {
Assert.assertEquals(
String.class, Util.getMessageType(new SimpleMessageHandler()));
}
@Test
public void testGetMessageTypeSubclass() {
Assert.assertEquals(String.class,
Util.getMessageType(new SubSimpleMessageHandler()));
}
@Test
public void testGetMessageTypeGenericSubclass() {
Assert.assertEquals(String.class,
Util.getMessageType(new GenericSubMessageHandler()));
}
@Test
public void testGetMessageTypeGenericMultipleSubclass() {
Assert.assertEquals(String.class,
Util.getMessageType(new GenericMultipleSubSubMessageHandler()));
}
@Test
public void testGetMessageTypeGenericMultipleSubclassSwap() {
Assert.assertEquals(String.class,
Util.getMessageType(new GenericMultipleSubSubSwapMessageHandler()));
}
@Test
public void testGetEncoderTypeSimple() {
Assert.assertEquals(
String.class, Util.getEncoderType(SimpleEncoder.class));
}
@Test
public void testGetEncoderTypeSubclass() {
Assert.assertEquals(String.class,
Util.getEncoderType(SubSimpleEncoder.class));
}
@Test
public void testGetEncoderTypeGenericSubclass() {
Assert.assertEquals(String.class,
Util.getEncoderType(GenericSubEncoder.class));
}
@Test
public void testGetEncoderTypeGenericMultipleSubclass() {
Assert.assertEquals(String.class,
Util.getEncoderType(GenericMultipleSubSubEncoder.class));
}
@Test
public void testGetEncoderTypeGenericMultipleSubclassSwap() {
Assert.assertEquals(String.class,
Util.getEncoderType(GenericMultipleSubSubSwapEncoder.class));
}
@Test
public void testGetEncoderTypeSimpleWithGenericType() {
Assert.assertEquals(List.class,
Util.getEncoderType(SimpleEncoderWithGenericType.class));
}
@Test
public void testGenericArrayEncoderString() {
Assert.assertEquals(String[].class,
Util.getEncoderType(GenericArrayEncoderString.class));
}
@Test
public void testGenericArraySubEncoderString() {
Assert.assertEquals(String[][].class,
Util.getEncoderType(GenericArraySubEncoderString.class));
}
private static class SimpleMessageHandler
implements MessageHandler.Whole<String> {
@Override
public void onMessage(String message) {
// NO-OP
}
}
private static class SubSimpleMessageHandler extends SimpleMessageHandler {
}
private abstract static class GenericMessageHandler<T> implements MessageHandler.Whole<T> {
}
private static class GenericSubMessageHandler extends GenericMessageHandler<String> {
@Override
public void onMessage(String message) {
// NO-OP
}
}
private static interface Foo<T> {
void doSomething(T thing);
}
private abstract static class GenericMultipleMessageHandler<A,B>
implements MessageHandler.Whole<A>, Foo<B> {
}
private abstract static class GenericMultipleSubMessageHandler<X,Y>
extends GenericMultipleMessageHandler<X,Y> {
}
private static class GenericMultipleSubSubMessageHandler
extends GenericMultipleSubMessageHandler<String,Boolean> {
@Override
public void onMessage(String message) {
// NO-OP
}
@Override
public void doSomething(Boolean thing) {
// NO-OP
}
}
private abstract static class GenericMultipleSubSwapMessageHandler<Y,X>
extends GenericMultipleMessageHandler<X,Y> {
}
private static class GenericMultipleSubSubSwapMessageHandler
extends GenericMultipleSubSwapMessageHandler<Boolean,String> {
@Override
public void onMessage(String message) {
// NO-OP
}
@Override
public void doSomething(Boolean thing) {
// NO-OP
}
}
private static class SimpleEncoder implements Encoder.Text<String> {
@Override
public void init(EndpointConfig endpointConfig) {
// NO-OP
}
@Override
public void destroy() {
// NO-OP
}
@Override
public String encode(String object) throws EncodeException {
return null;
}
}
private static class SubSimpleEncoder extends SimpleEncoder {
}
private abstract static class GenericEncoder<T> implements Encoder.Text<T> {
@Override
public void init(EndpointConfig endpointConfig) {
// NO-OP
}
@Override
public void destroy() {
// NO-OP
}
}
private static class GenericSubEncoder extends GenericEncoder<String> {
@Override
public String encode(String object) throws EncodeException {
return null;
}
}
private abstract static class GenericMultipleEncoder<A,B>
implements Encoder.Text<A>, Foo<B> {
@Override
public void init(EndpointConfig endpointConfig) {
// NO-OP
}
@Override
public void destroy() {
// NO-OP
}
}
private abstract static class GenericMultipleSubEncoder<X,Y>
extends GenericMultipleEncoder<X,Y> {
}
private static class GenericMultipleSubSubEncoder
extends GenericMultipleSubEncoder<String,Boolean> {
@Override
public String encode(String object) throws EncodeException {
return null;
}
@Override
public void doSomething(Boolean thing) {
// NO-OP
}
}
private abstract static class GenericMultipleSubSwapEncoder<Y,X>
extends GenericMultipleEncoder<X,Y> {
}
private static class GenericMultipleSubSubSwapEncoder
extends GenericMultipleSubSwapEncoder<Boolean,String> {
@Override
public String encode(String object) throws EncodeException {
return null;
}
@Override
public void doSomething(Boolean thing) {
// NO-OP
}
}
private static class SimpleEncoderWithGenericType
implements Encoder.Text<List<String>> {
@Override
public void init(EndpointConfig endpointConfig) {
// NO-OP
}
@Override
public void destroy() {
// NO-OP
}
@Override
public String encode(List<String> object) throws EncodeException {
return null;
}
}
private abstract static class GenericArrayEncoder<T> implements Encoder.Text<T[]> {
}
private static class GenericArrayEncoderString extends GenericArrayEncoder<String> {
@Override
public void init(EndpointConfig endpointConfig) {
// NO-OP
}
@Override
public void destroy() {
// NO-OP
}
@Override
public String encode(String[] object) throws EncodeException {
return null;
}
}
private abstract static class GenericArraySubEncoder<T> extends GenericArrayEncoder<T[]> {
}
private static class GenericArraySubEncoderString extends GenericArraySubEncoder<String> {
@Override
public void init(EndpointConfig endpointConfig) {
// NO-OP
}
@Override
public void destroy() {
// NO-OP
}
@Override
public String encode(String[][] object) throws EncodeException {
return null;
}
}
@Test
public void testParseExtensionHeaderSimple01() {
doTestParseExtensionHeaderSimple("ext;a=1;b=2");
}
@Test
public void testParseExtensionHeaderSimple02() {
doTestParseExtensionHeaderSimple("ext;a=\"1\";b=2");
}
@Test
public void testParseExtensionHeaderSimple03() {
doTestParseExtensionHeaderSimple("ext;a=1;b=\"2\"");
}
@Test
public void testParseExtensionHeaderSimple04() {
doTestParseExtensionHeaderSimple(" ext ; a = 1 ; b = 2 ");
}
private void doTestParseExtensionHeaderSimple(String header) {
// Simple test
List<Extension> result = new ArrayList<>();
Util.parseExtensionHeader(result, header);
Assert.assertEquals(1, result.size());
Extension ext = result.get(0);
Assert.assertEquals("ext", ext.getName());
List<Parameter> params = ext.getParameters();
Assert.assertEquals(2, params.size());
Parameter paramA = params.get(0);
Assert.assertEquals("a", paramA.getName());
Assert.assertEquals("1", paramA.getValue());
Parameter paramB = params.get(1);
Assert.assertEquals("b", paramB.getName());
Assert.assertEquals("2", paramB.getValue());
}
@Test
public void testParseExtensionHeaderMultiple01() {
doTestParseExtensionHeaderMultiple("ext;a=1;b=2,ext2;c;d=xyz,ext3");
}
@Test
public void testParseExtensionHeaderMultiple02() {
doTestParseExtensionHeaderMultiple(
" ext ; a = 1 ; b = 2 , ext2 ; c ; d = xyz , ext3 ");
}
private void doTestParseExtensionHeaderMultiple(String header) {
// Simple test
List<Extension> result = new ArrayList<>();
Util.parseExtensionHeader(result, header);
Assert.assertEquals(3, result.size());
Extension ext = result.get(0);
Assert.assertEquals("ext", ext.getName());
List<Parameter> params = ext.getParameters();
Assert.assertEquals(2, params.size());
Parameter paramA = params.get(0);
Assert.assertEquals("a", paramA.getName());
Assert.assertEquals("1", paramA.getValue());
Parameter paramB = params.get(1);
Assert.assertEquals("b", paramB.getName());
Assert.assertEquals("2", paramB.getValue());
Extension ext2 = result.get(1);
Assert.assertEquals("ext2", ext2.getName());
List<Parameter> params2 = ext2.getParameters();
Assert.assertEquals(2, params2.size());
Parameter paramC = params2.get(0);
Assert.assertEquals("c", paramC.getName());
Assert.assertNull(paramC.getValue());
Parameter paramD = params2.get(1);
Assert.assertEquals("d", paramD.getName());
Assert.assertEquals("xyz", paramD.getValue());
Extension ext3 = result.get(2);
Assert.assertEquals("ext3", ext3.getName());
List<Parameter> params3 = ext3.getParameters();
Assert.assertEquals(0, params3.size());
}
@Test(expected=IllegalArgumentException.class)
public void testParseExtensionHeaderInvalid01() {
Util.parseExtensionHeader(new ArrayList<Extension>(), "ext;a=\"1;b=2");
}
@Test(expected=IllegalArgumentException.class)
public void testParseExtensionHeaderInvalid02() {
Util.parseExtensionHeader(new ArrayList<Extension>(), "ext;a=1\";b=2");
}
}

View File

@@ -0,0 +1,228 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.net.URI;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ClientEndpointConfig.Configurator;
import javax.websocket.ContainerProvider;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.authenticator.AuthenticatorBase;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.util.descriptor.web.LoginConfig;
import org.apache.tomcat.util.descriptor.web.SecurityCollection;
import org.apache.tomcat.util.descriptor.web.SecurityConstraint;
import org.apache.tomcat.websocket.TesterMessageCountClient.BasicText;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
public class TestWebSocketFrameClient extends WebSocketBaseTest {
private static final String USER = "Aladdin";
private static final String PWD = "open sesame";
private static final String ROLE = "role";
private static final String URI_PROTECTED = "/foo";
@Test
public void testConnectToServerEndpoint() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(TesterFirehoseServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
// BZ 62596
final StringBuilder dummyValue = new StringBuilder(4000);
for (int i = 0; i < 4000; i++) {
dummyValue.append('A');
}
ClientEndpointConfig clientEndpointConfig =
ClientEndpointConfig.Builder.create().configurator(new Configurator() {
@Override
public void beforeRequest(Map<String, List<String>> headers) {
headers.put("Dummy", Collections.singletonList(dummyValue.toString()));
super.beforeRequest(headers);
}
}).build();
Session wsSession = wsContainer.connectToServer(
TesterProgrammaticEndpoint.class,
clientEndpointConfig,
new URI("ws://localhost:" + getPort() +
TesterFirehoseServer.Config.PATH));
CountDownLatch latch =
new CountDownLatch(TesterFirehoseServer.MESSAGE_COUNT);
BasicText handler = new BasicText(latch);
wsSession.addMessageHandler(handler);
wsSession.getBasicRemote().sendText("Hello");
System.out.println("Sent Hello message, waiting for data");
// Ignore the latch result as the message count test below will tell us
// if the right number of messages arrived
handler.getLatch().await(TesterFirehoseServer.WAIT_TIME_MILLIS,
TimeUnit.MILLISECONDS);
Queue<String> messages = handler.getMessages();
Assert.assertEquals(
TesterFirehoseServer.MESSAGE_COUNT, messages.size());
for (String message : messages) {
Assert.assertEquals(TesterFirehoseServer.MESSAGE, message);
}
}
@Test
public void testConnectToRootEndpoint() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
Context ctx2 = tomcat.addContext("/foo", null);
ctx2.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx2, "default", new DefaultServlet());
ctx2.addServletMappingDecoded("/", "default");
tomcat.start();
echoTester("",null);
echoTester("/",null);
echoTester("/foo",null);
echoTester("/foo/",null);
}
public void echoTester(String path, ClientEndpointConfig clientEndpointConfig)
throws Exception {
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
if (clientEndpointConfig == null) {
clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
}
Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class,
clientEndpointConfig, new URI("ws://localhost:" + getPort() + path));
CountDownLatch latch = new CountDownLatch(1);
BasicText handler = new BasicText(latch);
wsSession.addMessageHandler(handler);
wsSession.getBasicRemote().sendText("Hello");
boolean latchResult = handler.getLatch().await(10, TimeUnit.SECONDS);
Assert.assertTrue(latchResult);
Queue<String> messages = handler.getMessages();
Assert.assertEquals(1, messages.size());
for (String message : messages) {
Assert.assertEquals("Hello", message);
}
wsSession.close();
}
@Test
public void testConnectToBasicEndpoint() throws Exception {
Tomcat tomcat = getTomcatInstance();
Context ctx = tomcat.addContext(URI_PROTECTED, null);
ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
SecurityCollection collection = new SecurityCollection();
collection.addPatternDecoded("/");
String utf8User = "test";
String utf8Pass = "123\u00A3"; // pound sign
tomcat.addUser(utf8User, utf8Pass);
tomcat.addRole(utf8User, ROLE);
SecurityConstraint sc = new SecurityConstraint();
sc.addAuthRole(ROLE);
sc.addCollection(collection);
ctx.addConstraint(sc);
LoginConfig lc = new LoginConfig();
lc.setAuthMethod("BASIC");
ctx.setLoginConfig(lc);
AuthenticatorBase basicAuthenticator = new org.apache.catalina.authenticator.BasicAuthenticator();
ctx.getPipeline().addValve(basicAuthenticator);
tomcat.start();
ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
clientEndpointConfig.getUserProperties().put(Constants.WS_AUTHENTICATION_USER_NAME, utf8User);
clientEndpointConfig.getUserProperties().put(Constants.WS_AUTHENTICATION_PASSWORD, utf8Pass);
echoTester(URI_PROTECTED, clientEndpointConfig);
}
@Test
public void testConnectToDigestEndpoint() throws Exception {
Tomcat tomcat = getTomcatInstance();
Context ctx = tomcat.addContext(URI_PROTECTED, null);
ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
SecurityCollection collection = new SecurityCollection();
collection.addPatternDecoded("/*");
tomcat.addUser(USER, PWD);
tomcat.addRole(USER, ROLE);
SecurityConstraint sc = new SecurityConstraint();
sc.addAuthRole(ROLE);
sc.addCollection(collection);
ctx.addConstraint(sc);
LoginConfig lc = new LoginConfig();
lc.setAuthMethod("DIGEST");
ctx.setLoginConfig(lc);
AuthenticatorBase digestAuthenticator = new org.apache.catalina.authenticator.DigestAuthenticator();
ctx.getPipeline().addValve(digestAuthenticator);
tomcat.start();
ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
clientEndpointConfig.getUserProperties().put(Constants.WS_AUTHENTICATION_USER_NAME, USER);
clientEndpointConfig.getUserProperties().put(Constants.WS_AUTHENTICATION_PASSWORD,PWD);
echoTester(URI_PROTECTED, clientEndpointConfig);
}
}

View File

@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.net.URI;
import java.security.KeyStore;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ContainerProvider;
import javax.websocket.MessageHandler;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.util.net.TesterSupport;
import org.apache.tomcat.websocket.TesterMessageCountClient.BasicText;
import org.apache.tomcat.websocket.TesterMessageCountClient.SleepingText;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
public class TestWebSocketFrameClientSSL extends WebSocketBaseTest {
@Test
public void testConnectToServerEndpoint() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(TesterFirehoseServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
TesterSupport.initSsl(tomcat);
tomcat.start();
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
ClientEndpointConfig clientEndpointConfig =
ClientEndpointConfig.Builder.create().build();
clientEndpointConfig.getUserProperties().put(
Constants.SSL_CONTEXT_PROPERTY, createSSLContext());
Session wsSession = wsContainer.connectToServer(
TesterProgrammaticEndpoint.class,
clientEndpointConfig,
new URI("wss://localhost:" + getPort() +
TesterFirehoseServer.Config.PATH));
CountDownLatch latch =
new CountDownLatch(TesterFirehoseServer.MESSAGE_COUNT);
BasicText handler = new BasicText(latch);
wsSession.addMessageHandler(handler);
wsSession.getBasicRemote().sendText("Hello");
System.out.println("Sent Hello message, waiting for data");
// Ignore the latch result as the message count test below will tell us
// if the right number of messages arrived
handler.getLatch().await(TesterFirehoseServer.WAIT_TIME_MILLIS,
TimeUnit.MILLISECONDS);
Queue<String> messages = handler.getMessages();
Assert.assertEquals(
TesterFirehoseServer.MESSAGE_COUNT, messages.size());
for (String message : messages) {
Assert.assertEquals(TesterFirehoseServer.MESSAGE, message);
}
}
@Test
public void testBug56032() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(TesterFirehoseServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
TesterSupport.initSsl(tomcat);
tomcat.start();
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
ClientEndpointConfig clientEndpointConfig =
ClientEndpointConfig.Builder.create().build();
clientEndpointConfig.getUserProperties().put(
Constants.SSL_CONTEXT_PROPERTY, createSSLContext());
Session wsSession = wsContainer.connectToServer(
TesterProgrammaticEndpoint.class,
clientEndpointConfig,
new URI("wss://localhost:" + getPort() +
TesterFirehoseServer.Config.PATH));
// Process incoming messages very slowly
MessageHandler handler = new SleepingText(5000);
wsSession.addMessageHandler(handler);
wsSession.getBasicRemote().sendText("Hello");
// Wait long enough for the buffers to fill and the send to timeout
int count = 0;
int limit = TesterFirehoseServer.WAIT_TIME_MILLIS / 100;
System.out.println("Waiting for server to report an error");
while (TesterFirehoseServer.Endpoint.getErrorCount() == 0 && count < limit) {
Thread.sleep(100);
count ++;
}
if (TesterFirehoseServer.Endpoint.getErrorCount() == 0) {
Assert.fail("No error reported by Endpoint when timeout was expected");
}
// Wait up to another 10 seconds for the connection to be closed -
// should be a lot faster.
System.out.println("Waiting for connection to be closed");
count = 0;
limit = (TesterFirehoseServer.SEND_TIME_OUT_MILLIS * 2) / 100;
while (TesterFirehoseServer.Endpoint.getOpenConnectionCount() != 0 && count < limit) {
Thread.sleep(100);
count ++;
}
int openConnectionCount = TesterFirehoseServer.Endpoint.getOpenConnectionCount();
if (openConnectionCount != 0) {
Assert.fail("There are [" + openConnectionCount + "] connections still open");
}
// Close the client session.
wsSession.close();
}
private SSLContext createSSLContext() throws Exception {
// Create the SSL Context
// Java 7 doesn't default to TLSv1.2 but the tests do
SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
// Trust store
File keyStoreFile = new File(TesterSupport.CA_JKS);
KeyStore ks = KeyStore.getInstance("JKS");
try (InputStream is = new FileInputStream(keyStoreFile)) {
ks.load(is, Constants.SSL_TRUSTSTORE_PWD_DEFAULT.toCharArray());
}
TrustManagerFactory tmf = TrustManagerFactory.getInstance(
TrustManagerFactory.getDefaultAlgorithm());
tmf.init(ks);
sslContext.init(null, tmf.getTrustManagers(), null);
return sslContext;
}
}

View File

@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.IOException;
import org.junit.Assert;
import org.junit.Test;
public class TestWsFrame {
@Test
public void testByteArrayToLong() throws IOException {
Assert.assertEquals(0L, WsFrameBase.byteArrayToLong(new byte[] { 0 }, 0, 1));
Assert.assertEquals(1L, WsFrameBase.byteArrayToLong(new byte[] { 1 }, 0, 1));
Assert.assertEquals(0xFF, WsFrameBase.byteArrayToLong(new byte[] { -1 }, 0, 1));
Assert.assertEquals(0xFFFF,
WsFrameBase.byteArrayToLong(new byte[] { -1, -1 }, 0, 2));
Assert.assertEquals(0xFFFFFF,
WsFrameBase.byteArrayToLong(new byte[] { -1, -1, -1 }, 0, 3));
}
@Test
public void testByteArrayToLongOffset() throws IOException {
Assert.assertEquals(0L, WsFrameBase.byteArrayToLong(new byte[] { 20, 0 }, 1, 1));
Assert.assertEquals(1L, WsFrameBase.byteArrayToLong(new byte[] { 20, 1 }, 1, 1));
Assert.assertEquals(0xFF, WsFrameBase.byteArrayToLong(new byte[] { 20, -1 }, 1, 1));
Assert.assertEquals(0xFFFF,
WsFrameBase.byteArrayToLong(new byte[] { 20, -1, -1 }, 1, 2));
Assert.assertEquals(0xFFFFFF,
WsFrameBase.byteArrayToLong(new byte[] { 20, -1, -1, -1 }, 1, 3));
}
}

View File

@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ContainerProvider;
import javax.websocket.PongMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterEndpoint;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
public class TestWsPingPongMessages extends WebSocketBaseTest {
ByteBuffer applicationData = ByteBuffer.wrap(new String("mydata")
.getBytes());
@Test
public void testPingPongMessages() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
WebSocketContainer wsContainer = ContainerProvider
.getWebSocketContainer();
Session wsSession = wsContainer.connectToServer(
TesterProgrammaticEndpoint.class, ClientEndpointConfig.Builder
.create().build(), new URI("ws://localhost:"
+ getPort() + TesterEchoServer.Config.PATH_ASYNC));
CountDownLatch latch = new CountDownLatch(1);
TesterEndpoint tep = (TesterEndpoint) wsSession.getUserProperties()
.get("endpoint");
tep.setLatch(latch);
PongMessageHandler handler = new PongMessageHandler(latch);
wsSession.addMessageHandler(handler);
wsSession.getBasicRemote().sendPing(applicationData);
boolean latchResult = handler.getLatch().await(10, TimeUnit.SECONDS);
Assert.assertTrue(latchResult);
Assert.assertArrayEquals(applicationData.array(),
(handler.getMessages().peek()).getApplicationData().array());
}
public static class PongMessageHandler extends
TesterMessageCountClient.BasicHandler<PongMessage> {
public PongMessageHandler(CountDownLatch latch) {
super(latch);
}
@Override
public void onMessage(PongMessage message) {
getMessages().add(message);
if (getLatch() != null) {
getLatch().countDown();
}
}
}
}

View File

@@ -0,0 +1,247 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.OutputStream;
import java.io.Writer;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpointConfig.Builder;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.TesterMessageCountClient.AsyncBinary;
import org.apache.tomcat.websocket.TesterMessageCountClient.AsyncHandler;
import org.apache.tomcat.websocket.TesterMessageCountClient.AsyncText;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterAnnotatedEndpoint;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterEndpoint;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
public class TestWsRemoteEndpoint extends WebSocketBaseTest {
private static final String SEQUENCE = "ABCDE";
private static final int S_LEN = SEQUENCE.length();
private static final String TEST_MESSAGE_5K;
static {
StringBuilder sb = new StringBuilder(S_LEN * 1024);
for (int i = 0; i < 1024; i++) {
sb.append(SEQUENCE);
}
TEST_MESSAGE_5K = sb.toString();
}
@Test
public void testWriterAnnotation() throws Exception {
doTestWriter(TesterAnnotatedEndpoint.class, true, TEST_MESSAGE_5K);
}
@Test
public void testWriterProgrammatic() throws Exception {
doTestWriter(TesterProgrammaticEndpoint.class, true, TEST_MESSAGE_5K);
}
@Test
public void testWriterZeroLengthAnnotation() throws Exception {
doTestWriter(TesterAnnotatedEndpoint.class, true, "");
}
@Test
public void testWriterZeroLengthProgrammatic() throws Exception {
doTestWriter(TesterProgrammaticEndpoint.class, true, "");
}
@Test
public void testStreamAnnotation() throws Exception {
doTestWriter(TesterAnnotatedEndpoint.class, false, TEST_MESSAGE_5K);
}
@Test
public void testStreamProgrammatic() throws Exception {
doTestWriter(TesterProgrammaticEndpoint.class, false, TEST_MESSAGE_5K);
}
private void doTestWriter(Class<?> clazz, boolean useWriter, String testMessage) throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
Session wsSession;
URI uri = new URI("ws://localhost:" + getPort() +
TesterEchoServer.Config.PATH_ASYNC);
if (Endpoint.class.isAssignableFrom(clazz)) {
@SuppressWarnings("unchecked")
Class<? extends Endpoint> endpointClazz =
(Class<? extends Endpoint>) clazz;
wsSession = wsContainer.connectToServer(endpointClazz,
Builder.create().build(), uri);
} else {
wsSession = wsContainer.connectToServer(clazz, uri);
}
CountDownLatch latch = new CountDownLatch(1);
TesterEndpoint tep =
(TesterEndpoint) wsSession.getUserProperties().get("endpoint");
tep.setLatch(latch);
AsyncHandler<?> handler;
if (useWriter) {
handler = new AsyncText(latch);
} else {
handler = new AsyncBinary(latch);
}
wsSession.addMessageHandler(handler);
if (useWriter) {
Writer w = wsSession.getBasicRemote().getSendWriter();
for (int i = 0; i < 8; i++) {
w.write(testMessage);
}
w.close();
} else {
OutputStream s = wsSession.getBasicRemote().getSendStream();
for (int i = 0; i < 8; i++) {
s.write(testMessage.getBytes(StandardCharsets.UTF_8));
}
s.close();
}
boolean latchResult = handler.getLatch().await(10, TimeUnit.SECONDS);
Assert.assertTrue(latchResult);
List<String> results = new ArrayList<>();
if (useWriter) {
@SuppressWarnings("unchecked")
List<String> messages = (List<String>) handler.getMessages();
results.addAll(messages);
} else {
// Take advantage of the fact that the message uses characters that
// are represented as a single UTF-8 byte so won't be split across
// binary messages
@SuppressWarnings("unchecked")
List<ByteBuffer> messages = (List<ByteBuffer>) handler.getMessages();
for (ByteBuffer message : messages) {
byte[] bytes = new byte[message.limit()];
message.get(bytes);
results.add(new String(bytes, StandardCharsets.UTF_8));
}
}
int offset = 0;
int i = 0;
for (String result : results) {
if (testMessage.length() == 0) {
Assert.assertEquals(0, result.length());
} else {
// First may be a fragment
Assert.assertEquals(SEQUENCE.substring(offset, S_LEN),
result.substring(0, S_LEN - offset));
i = S_LEN - offset;
while (i + S_LEN < result.length()) {
if (!SEQUENCE.equals(result.substring(i, i + S_LEN))) {
Assert.fail();
}
i += S_LEN;
}
offset = result.length() - i;
if (!SEQUENCE.substring(0, offset).equals(result.substring(i))) {
Assert.fail();
}
}
}
}
@Test
public void testWriterErrorAnnotation() throws Exception {
doTestWriterError(TesterAnnotatedEndpoint.class);
}
@Test
public void testWriterErrorProgrammatic() throws Exception {
doTestWriterError(TesterProgrammaticEndpoint.class);
}
private void doTestWriterError(Class<?> clazz) throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
tomcat.start();
Session wsSession;
URI uri = new URI("ws://localhost:" + getPort() + TesterEchoServer.Config.PATH_WRITER_ERROR);
if (Endpoint.class.isAssignableFrom(clazz)) {
@SuppressWarnings("unchecked")
Class<? extends Endpoint> endpointClazz = (Class<? extends Endpoint>) clazz;
wsSession = wsContainer.connectToServer(endpointClazz, Builder.create().build(), uri);
} else {
wsSession = wsContainer.connectToServer(clazz, uri);
}
CountDownLatch latch = new CountDownLatch(1);
TesterEndpoint tep = (TesterEndpoint) wsSession.getUserProperties().get("endpoint");
tep.setLatch(latch);
AsyncHandler<?> handler;
handler = new AsyncText(latch);
wsSession.addMessageHandler(handler);
// This should trigger the error
wsSession.getBasicRemote().sendText("Start");
boolean latchResult = handler.getLatch().await(10, TimeUnit.SECONDS);
Assert.assertTrue(latchResult);
@SuppressWarnings("unchecked")
List<String> messages = (List<String>) handler.getMessages();
Assert.assertEquals(0, messages.size());
}
}

View File

@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.nio.ByteBuffer;
import org.junit.Assert;
import org.junit.Test;
public class TestWsSession {
@Test
public void testAppendCloseReasonWithTruncation01() {
doTestAppendCloseReasonWithTruncation(100);
}
@Test
public void testAppendCloseReasonWithTruncation02() {
doTestAppendCloseReasonWithTruncation(119);
}
@Test
public void testAppendCloseReasonWithTruncation03() {
doTestAppendCloseReasonWithTruncation(120);
}
@Test
public void testAppendCloseReasonWithTruncation04() {
doTestAppendCloseReasonWithTruncation(121);
}
@Test
public void testAppendCloseReasonWithTruncation05() {
doTestAppendCloseReasonWithTruncation(122);
}
@Test
public void testAppendCloseReasonWithTruncation06() {
doTestAppendCloseReasonWithTruncation(123);
}
@Test
public void testAppendCloseReasonWithTruncation07() {
doTestAppendCloseReasonWithTruncation(124);
}
@Test
public void testAppendCloseReasonWithTruncation08() {
doTestAppendCloseReasonWithTruncation(125);
}
@Test
public void testAppendCloseReasonWithTruncation09() {
doTestAppendCloseReasonWithTruncation(150);
}
private void doTestAppendCloseReasonWithTruncation(int reasonLength) {
StringBuilder reason = new StringBuilder(reasonLength);
for (int i = 0; i < reasonLength; i++) {
reason.append('a');
}
ByteBuffer buf = ByteBuffer.allocate(256);
WsSession.appendCloseReasonWithTruncation(buf, reason.toString());
// Check the position and contents
if (reasonLength <= 123) {
Assert.assertEquals(reasonLength, buf.position());
for (int i = 0; i < reasonLength; i++) {
Assert.assertEquals('a', buf.get(i));
}
} else {
// Must have been truncated
Assert.assertEquals(123, buf.position());
for (int i = 0; i < 120; i++) {
Assert.assertEquals('a', buf.get(i));
}
Assert.assertEquals(0xE2, buf.get(120) & 0xFF);
Assert.assertEquals(0x80, buf.get(121) & 0xFF);
Assert.assertEquals(0xA6, buf.get(122) & 0xFF);
}
}
}

View File

@@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.SendHandler;
import javax.websocket.SendResult;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpointConfig;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
import org.apache.tomcat.websocket.server.TesterEndpointConfig;
public class TestWsSessionSuspendResume extends WebSocketBaseTest {
@Test
public void test() throws Exception {
Tomcat tomcat = getTomcatInstance();
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
Session wsSession = wsContainer.connectToServer(
TesterProgrammaticEndpoint.class,
clientEndpointConfig,
new URI("ws://localhost:" + getPort() + Config.PATH));
final CountDownLatch latch = new CountDownLatch(2);
wsSession.addMessageHandler(new MessageHandler.Whole<String>() {
@Override
public void onMessage(String message) {
Assert.assertTrue("[echo, echo, echo]".equals(message));
latch.countDown();
}
});
for (int i = 0; i < 8; i++) {
wsSession.getAsyncRemote().sendText("echo");
}
boolean latchResult = latch.await(30, TimeUnit.SECONDS);
Assert.assertTrue(latchResult);
wsSession.close();
}
public static final class Config extends TesterEndpointConfig {
private static final String PATH = "/echo";
@Override
protected Class<?> getEndpointClass() {
return SuspendResumeEndpoint.class;
}
@Override
protected ServerEndpointConfig getServerEndpointConfig() {
return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build();
}
}
public static final class SuspendResumeEndpoint extends Endpoint {
@Override
public void onOpen(Session session, EndpointConfig epc) {
final MessageProcessor processor = new MessageProcessor(session, 3);
session.addMessageHandler(new MessageHandler.Whole<String>() {
@Override
public void onMessage(String message) {
processor.addMessage(message);
}
});
}
@Override
public void onClose(Session session, CloseReason closeReason) {
try {
session.close();
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void onError(Session session, Throwable t) {
t.printStackTrace();
}
}
private static final class MessageProcessor {
private final Session session;
private final int count;
private final List<String> messages = new ArrayList<>();
MessageProcessor(Session session, int count) {
this.session = session;
this.count = count;
}
void addMessage(String message) {
if (messages.size() == count) {
((WsSession) session).suspend();
session.getAsyncRemote().sendText(messages.toString(), new SendHandler() {
@Override
public void onResult(SendResult result) {
((WsSession) session).resume();
Assert.assertTrue(result.isOK());
}
});
messages.clear();
} else {
messages.add(message);
}
}
}
}

View File

@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ContainerProvider;
import javax.websocket.EndpointConfig;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
import org.apache.tomcat.websocket.server.TesterEndpointConfig;
public class TestWsSubprotocols extends WebSocketBaseTest {
@Test
public void testWsSubprotocols() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
WebSocketContainer wsContainer = ContainerProvider
.getWebSocketContainer();
tomcat.start();
Session wsSession = wsContainer.connectToServer(
TesterProgrammaticEndpoint.class, ClientEndpointConfig.Builder
.create().preferredSubprotocols(Arrays.asList("sp3"))
.build(), new URI("ws://localhost:" + getPort()
+ SubProtocolsEndpoint.PATH_BASIC));
Assert.assertTrue(wsSession.isOpen());
if (wsSession.getNegotiatedSubprotocol() != null) {
Assert.assertTrue(wsSession.getNegotiatedSubprotocol().isEmpty());
}
wsSession.close();
SubProtocolsEndpoint.recycle();
wsSession = wsContainer.connectToServer(
TesterProgrammaticEndpoint.class, ClientEndpointConfig.Builder
.create().preferredSubprotocols(Arrays.asList("sp2"))
.build(), new URI("ws://localhost:" + getPort()
+ SubProtocolsEndpoint.PATH_BASIC));
Assert.assertTrue(wsSession.isOpen());
Assert.assertEquals("sp2", wsSession.getNegotiatedSubprotocol());
// Client thread might move faster than server. Wait for upto 5s for the
// subProtocols to be set
int count = 0;
while (count < 50 && SubProtocolsEndpoint.subprotocols == null) {
count++;
Thread.sleep(100);
}
Assert.assertNotNull(SubProtocolsEndpoint.subprotocols);
Assert.assertArrayEquals(new String[]{"sp1","sp2"},
SubProtocolsEndpoint.subprotocols.toArray(new String[2]));
wsSession.close();
SubProtocolsEndpoint.recycle();
}
@ServerEndpoint(value = "/echo", subprotocols = {"sp1","sp2"})
public static class SubProtocolsEndpoint {
public static final String PATH_BASIC = "/echo";
public static volatile List<String> subprotocols;
@OnOpen
public void processOpen(@SuppressWarnings("unused") Session session,
EndpointConfig epc) {
subprotocols = ((ServerEndpointConfig)epc).getSubprotocols();
}
public static void recycle() {
subprotocols = null;
}
}
public static class Config extends TesterEndpointConfig {
@Override
protected Class<?> getEndpointClass() {
return SubProtocolsEndpoint.class;
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,388 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import javax.servlet.ServletContextEvent;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.server.Constants;
import org.apache.tomcat.websocket.server.WsContextListener;
/*
* This method is split out into a separate class to make it easier to track the
* various permutations and combinations of client and server endpoints.
*
* Each test uses 2 client endpoint and 2 server endpoints with each client
* connecting to each server for a total of four connections (note sometimes
* the two clients and/or the two servers will be the sam)e.
*/
public class TestWsWebSocketContainerGetOpenSessions extends WebSocketBaseTest {
@Test
public void testClientAClientAPojoAPojoA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/pojoA", "/pojoA", 2, 2, 4, 4);
}
@Test
public void testClientAClientBPojoAPojoA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/pojoA", "/pojoA", 2, 2, 4, 4);
}
@Test
public void testClientAClientAPojoAPojoB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/pojoA", "/pojoB", 2, 2, 2, 2);
}
@Test
public void testClientAClientBPojoAPojoB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/pojoA", "/pojoB", 2, 2, 2, 2);
}
@Test
public void testClientAClientAProgAProgA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/progA", "/progA", 2, 2, 4, 4);
}
@Test
public void testClientAClientBProgAProgA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/progA", "/progA", 2, 2, 4, 4);
}
@Test
public void testClientAClientAProgAProgB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/progA", "/progB", 2, 2, 2, 2);
}
@Test
public void testClientAClientBProgAProgB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/progA", "/progB", 2, 2, 2, 2);
}
@Test
public void testClientAClientAPojoAProgA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/pojoA", "/progA", 2, 2, 2, 2);
}
@Test
public void testClientAClientBPojoAProgA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/pojoA", "/progA", 2, 2, 2, 2);
}
@Test
public void testClientAClientAPojoAProgB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/pojoA", "/progB", 2, 2, 2, 2);
}
@Test
public void testClientAClientBPojoAProgB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/pojoA", "/progB", 2, 2, 2, 2);
}
private void doTest(Endpoint client1, Endpoint client2, String server1, String server2,
int client1Count, int client2Count, int server1Count, int server2Count) throws Exception {
Tracker.reset();
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
Session sClient1Server1 = createSession(wsContainer, client1, "client1", server1);
Session sClient1Server2 = createSession(wsContainer, client1, "client1", server2);
Session sClient2Server1 = createSession(wsContainer, client2, "client2", server1);
Session sClient2Server2 = createSession(wsContainer, client2, "client2", server2);
int delayCount = 0;
// Wait for up to 20s for this to complete. It should be a lot faster
// but some CI systems get be slow at times.
while (Tracker.getUpdateCount() < 8 && delayCount < 400) {
Thread.sleep(50);
delayCount++;
}
Assert.assertTrue(Tracker.checkRecord("client1", client1Count));
Assert.assertTrue(Tracker.checkRecord("client2", client2Count));
// Note: need to strip leading '/' from path
Assert.assertTrue(Tracker.checkRecord(server1.substring(1), server1Count));
Assert.assertTrue(Tracker.checkRecord(server2.substring(1), server2Count));
sClient1Server1.close();
sClient1Server2.close();
sClient2Server1.close();
sClient2Server2.close();
}
private Session createSession(WebSocketContainer wsContainer, Endpoint client,
String clientName, String server)
throws DeploymentException, IOException, URISyntaxException {
Session s = wsContainer.connectToServer(client,
ClientEndpointConfig.Builder.create().build(),
new URI("ws://localhost:" + getPort() + server));
Tracker.addRecord(clientName, s.getOpenSessions().size());
s.getBasicRemote().sendText("X");
return s;
}
public static class Config extends WsContextListener {
@Override
public void contextInitialized(ServletContextEvent sce) {
super.contextInitialized(sce);
ServerContainer sc =
(ServerContainer) sce.getServletContext().getAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
try {
sc.addEndpoint(PojoEndpointA.class);
sc.addEndpoint(PojoEndpointB.class);
sc.addEndpoint(ServerEndpointConfig.Builder.create(
ServerEndpointA.class, "/progA").build());
sc.addEndpoint(ServerEndpointConfig.Builder.create(
ServerEndpointB.class, "/progB").build());
} catch (DeploymentException e) {
throw new IllegalStateException(e);
}
}
}
public abstract static class PojoEndpointBase {
@OnMessage
public void onMessage(@SuppressWarnings("unused") String msg, Session session) {
Tracker.addRecord(getTrackingName(), session.getOpenSessions().size());
}
protected abstract String getTrackingName();
}
@ServerEndpoint("/pojoA")
public static class PojoEndpointA extends PojoEndpointBase {
@Override
protected String getTrackingName() {
return "pojoA";
}
}
@ServerEndpoint("/pojoB")
public static class PojoEndpointB extends PojoEndpointBase {
@Override
protected String getTrackingName() {
return "pojoB";
}
}
public abstract static class ServerEndpointBase extends Endpoint{
@Override
public void onOpen(Session session, EndpointConfig config) {
session.addMessageHandler(new TrackerMessageHandler(session, getTrackingName()));
}
protected abstract String getTrackingName();
}
public static final class ServerEndpointA extends ServerEndpointBase {
@Override
protected String getTrackingName() {
return "progA";
}
}
public static final class ServerEndpointB extends ServerEndpointBase {
@Override
protected String getTrackingName() {
return "progB";
}
}
public static final class TrackerMessageHandler implements MessageHandler.Whole<String> {
private final Session session;
private final String trackingName;
public TrackerMessageHandler(Session session, String trackingName) {
this.session = session;
this.trackingName = trackingName;
}
@Override
public void onMessage(String message) {
Tracker.addRecord(trackingName, session.getOpenSessions().size());
}
}
public abstract static class ClientEndpointBase extends Endpoint {
@Override
public void onOpen(Session session, EndpointConfig config) {
// NO-OP
}
@Override
public void onClose(Session session, CloseReason closeReason) {
// NO-OP
}
protected abstract String getTrackingName();
}
public static final class ClientEndpointA extends ClientEndpointBase {
@Override
protected String getTrackingName() {
return "clientA";
}
}
public static final class ClientEndpointB extends ClientEndpointBase {
@Override
protected String getTrackingName() {
return "clientB";
}
}
public static class Tracker {
private static final Map<String, Integer> records = new ConcurrentHashMap<>();
private static final AtomicInteger updateCount = new AtomicInteger(0);
public static void addRecord(String key, int count) {
records.put(key, Integer.valueOf(count));
updateCount.incrementAndGet();
}
public static boolean checkRecord(String key, int expectedCount) {
Integer actualCount = records.get(key);
if (actualCount == null) {
if (expectedCount == 0) {
return true;
} else {
return false;
}
} else {
return actualCount.intValue() == expectedCount;
}
}
public static int getUpdateCount() {
return updateCount.intValue();
}
public static void reset() {
records.clear();
updateCount.set(0);
}
}
}

View File

@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Ignore;
@Ignore // Additional infrastructure is required to run this test
public class TestWsWebSocketContainerWithProxy extends TestWsWebSocketContainer {
@BeforeClass
public static void init() {
// Set the system properties for an HTTP proxy on 192.168.0.100:80
// I used an httpd instance configured as an open forward proxy for this
// Update the IP/hostname as required
System.setProperty("http.proxyHost", "192.168.0.100");
System.setProperty("http.proxyPort", "80");
System.setProperty("http.nonProxyHosts", "");
}
@Before
public void setPort() {
// With httpd 2.2, AllowCONNECT requires fixed ports. From 2.4, a range
// can be used.
getTomcatInstance().getConnector().setPort(8080);
Assert.assertTrue(getTomcatInstance().getConnector().setProperty("address","0.0.0.0"));
}
@Override
protected String getHostName() {
// The IP/hostname where the tests are running. The proxy will connect
// back to this expecting to find the Tomcat instance created by the
// unit test.
return "192.168.0.200";
}
}

View File

@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.util.Collections;
import java.util.Set;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerApplicationConfig;
import javax.websocket.server.ServerEndpointConfig;
/**
* This configuration blocks any endpoints discovered by the SCI from being
* deployed. It is intended to prevent testing errors generated when the
* WebSocket SCI scans the test classes for endpoints as it will discover
* multiple endpoints mapped to the same path ('/'). The tests all explicitly
* configure their required endpoints so have no need for SCI based
* configuration.
*/
public class TesterBlockWebSocketSCI implements ServerApplicationConfig {
@Override
public Set<ServerEndpointConfig> getEndpointConfigs(
Set<Class<? extends Endpoint>> scanned) {
return Collections.emptySet();
}
@Override
public Set<Class<?>> getAnnotatedEndpointClasses(Set<Class<?>> scanned) {
return Collections.emptySet();
}
}

View File

@@ -0,0 +1,230 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.IOException;
import java.nio.ByteBuffer;
import javax.servlet.ServletContextEvent;
import javax.websocket.DeploymentException;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import org.apache.tomcat.websocket.server.Constants;
import org.apache.tomcat.websocket.server.WsContextListener;
public class TesterEchoServer {
public static class Config extends WsContextListener {
public static final String PATH_ASYNC = "/echoAsync";
public static final String PATH_BASIC = "/echoBasic";
public static final String PATH_BASIC_LIMIT_LOW = "/echoBasicLimitLow";
public static final String PATH_BASIC_LIMIT_HIGH = "/echoBasicLimitHigh";
public static final String PATH_WRITER_ERROR = "/echoWriterError";
@Override
public void contextInitialized(ServletContextEvent sce) {
super.contextInitialized(sce);
ServerContainer sc =
(ServerContainer) sce.getServletContext().getAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
try {
sc.addEndpoint(Async.class);
sc.addEndpoint(Basic.class);
sc.addEndpoint(BasicLimitLow.class);
sc.addEndpoint(BasicLimitHigh.class);
sc.addEndpoint(WriterError.class);
sc.addEndpoint(RootEcho.class);
} catch (DeploymentException e) {
throw new IllegalStateException(e);
}
}
}
@ServerEndpoint("/echoAsync")
public static class Async {
@OnMessage
public void echoTextMessage(Session session, String msg, boolean last) {
try {
session.getBasicRemote().sendText(msg, last);
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
@OnMessage
public void echoBinaryMessage(Session session, ByteBuffer msg,
boolean last) {
try {
session.getBasicRemote().sendBinary(msg, last);
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
}
@ServerEndpoint("/echoBasic")
public static class Basic {
@OnMessage
public void echoTextMessage(Session session, String msg) {
try {
session.getBasicRemote().sendText(msg);
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
@OnMessage
public void echoBinaryMessage(Session session, ByteBuffer msg) {
try {
session.getBasicRemote().sendBinary(msg);
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
}
@ServerEndpoint("/echoBasicLimitLow")
public static class BasicLimitLow {
public static final long MAX_SIZE = 10;
@OnMessage(maxMessageSize = MAX_SIZE)
public void echoTextMessage(Session session, String msg) {
try {
session.getBasicRemote().sendText(msg);
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
@OnMessage(maxMessageSize = MAX_SIZE)
public void echoBinaryMessage(Session session, ByteBuffer msg) {
try {
session.getBasicRemote().sendBinary(msg);
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
}
@ServerEndpoint("/echoBasicLimitHigh")
public static class BasicLimitHigh {
public static final long MAX_SIZE = 32 * 1024;
@OnMessage(maxMessageSize = MAX_SIZE)
public void echoTextMessage(Session session, String msg) {
try {
session.getBasicRemote().sendText(msg);
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
@OnMessage(maxMessageSize = MAX_SIZE)
public void echoBinaryMessage(Session session, ByteBuffer msg) {
try {
session.getBasicRemote().sendBinary(msg);
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
}
@ServerEndpoint("/echoWriterError")
public static class WriterError {
@OnMessage
public void echoTextMessage(Session session, @SuppressWarnings("unused") String msg) {
try {
session.getBasicRemote().getSendWriter();
// Simulate an error
throw new RuntimeException();
} catch (IOException e) {
// Should not happen
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
}
@ServerEndpoint("/")
public static class RootEcho {
@OnMessage
public void echoTextMessage(Session session, String msg) {
try {
session.getBasicRemote().sendText(msg);
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
}
}

View File

@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.RemoteEndpoint.Basic;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;
import org.apache.tomcat.websocket.server.TesterEndpointConfig;
/**
* Sends {@link #MESSAGE_COUNT} messages of size {@link #MESSAGE_SIZE} bytes as
* quickly as possible after the client sends its first message.
*/
public class TesterFirehoseServer {
public static final int MESSAGE_COUNT = 100000;
public static final String MESSAGE;
public static final int MESSAGE_SIZE = 1024;
public static final int WAIT_TIME_MILLIS = 300000;
public static final int SEND_TIME_OUT_MILLIS = 5000;
static {
StringBuilder sb = new StringBuilder(MESSAGE_SIZE);
for (int i = 0; i < MESSAGE_SIZE; i++) {
sb.append('x');
}
MESSAGE = sb.toString();
}
public static class Config extends TesterEndpointConfig {
public static final String PATH = "/firehose";
@Override
protected Class<?> getEndpointClass() {
return Endpoint.class;
}
}
@ServerEndpoint(Config.PATH)
public static class Endpoint {
private static AtomicInteger openConnectionCount = new AtomicInteger(0);
private static AtomicInteger errorCount = new AtomicInteger(0);
private volatile boolean started = false;
public static int getOpenConnectionCount() {
return openConnectionCount.intValue();
}
public static int getErrorCount() {
return errorCount.intValue();
}
@OnOpen
public void onOpen() {
openConnectionCount.incrementAndGet();
}
@OnMessage
public void onMessage(Session session, String msg) throws IOException {
if (started) {
return;
}
synchronized (this) {
if (started) {
return;
} else {
started = true;
}
}
System.out.println("Received " + msg + ", now sending data");
session.getUserProperties().put(
org.apache.tomcat.websocket.Constants.BLOCKING_SEND_TIMEOUT_PROPERTY,
Long.valueOf(SEND_TIME_OUT_MILLIS));
Basic remote = session.getBasicRemote();
remote.setBatchingAllowed(true);
for (int i = 0; i < MESSAGE_COUNT; i++) {
remote.sendText(MESSAGE);
if (i % (MESSAGE_COUNT * 0.4) == 0) {
remote.setBatchingAllowed(false);
remote.setBatchingAllowed(true);
}
}
// Flushing should happen automatically on session close
session.close();
}
@OnError
public void onError(@SuppressWarnings("unused") Throwable t) {
errorCount.incrementAndGet();
}
@OnClose
public void onClose() {
openConnectionCount.decrementAndGet();
}
}
}

View File

@@ -0,0 +1,244 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import javax.websocket.ClientEndpoint;
import javax.websocket.CloseReason;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnOpen;
import javax.websocket.Session;
public class TesterMessageCountClient {
public interface TesterEndpoint {
void setLatch(CountDownLatch latch);
}
public static class TesterProgrammaticEndpoint
extends Endpoint implements TesterEndpoint {
private CountDownLatch latch = null;
@Override
public void setLatch(CountDownLatch latch) {
this.latch = latch;
}
@Override
public void onClose(Session session, CloseReason closeReason) {
clearLatch();
}
@Override
public void onError(Session session, Throwable throwable) {
clearLatch();
}
private void clearLatch() {
if (latch != null) {
while (latch.getCount() > 0) {
latch.countDown();
}
}
}
@Override
public void onOpen(Session session, EndpointConfig config) {
session.getUserProperties().put("endpoint", this);
}
}
@ClientEndpoint
public static class TesterAnnotatedEndpoint implements TesterEndpoint {
private CountDownLatch latch = null;
@Override
public void setLatch(CountDownLatch latch) {
this.latch = latch;
}
@OnClose
public void onClose() {
clearLatch();
}
@OnError
public void onError(@SuppressWarnings("unused") Throwable throwable) {
clearLatch();
}
private void clearLatch() {
if (latch != null) {
while (latch.getCount() > 0) {
latch.countDown();
}
}
}
@OnOpen
public void onOpen(Session session) {
session.getUserProperties().put("endpoint", this);
}
}
public abstract static class BasicHandler<T>
implements MessageHandler.Whole<T> {
private final CountDownLatch latch;
private final Queue<T> messages = new LinkedBlockingQueue<>();
public BasicHandler(CountDownLatch latch) {
this.latch = latch;
}
public CountDownLatch getLatch() {
return latch;
}
public Queue<T> getMessages() {
return messages;
}
}
public static class BasicBinary extends BasicHandler<ByteBuffer> {
public BasicBinary(CountDownLatch latch) {
super(latch);
}
@Override
public void onMessage(ByteBuffer message) {
getMessages().add(message);
if (getLatch() != null) {
getLatch().countDown();
}
}
}
public static class BasicText extends BasicHandler<String> {
private final String expected;
public BasicText(CountDownLatch latch) {
this(latch, null);
}
public BasicText(CountDownLatch latch, String expected) {
super(latch);
this.expected = expected;
}
@Override
public void onMessage(String message) {
if (expected == null) {
getMessages().add(message);
} else {
if (!expected.equals(message)) {
throw new IllegalStateException(
"Expected: [" + expected + "]\r\n" +
"Was: [" + message + "]");
}
}
if (getLatch() != null) {
getLatch().countDown();
}
}
}
public static class SleepingText implements MessageHandler.Whole<String> {
private final int sleep;
public SleepingText(int sleep) {
this.sleep = sleep;
}
@Override
public void onMessage(String message) {
try {
Thread.sleep(sleep);
} catch (InterruptedException e) {
// Ignore
}
}
}
public abstract static class AsyncHandler<T>
implements MessageHandler.Partial<T> {
private final CountDownLatch latch;
private final List<T> messages = new CopyOnWriteArrayList<>();
public AsyncHandler(CountDownLatch latch) {
this.latch = latch;
}
public CountDownLatch getLatch() {
return latch;
}
public List<T> getMessages() {
return messages;
}
}
public static class AsyncBinary extends AsyncHandler<ByteBuffer> {
public AsyncBinary(CountDownLatch latch) {
super(latch);
}
@Override
public void onMessage(ByteBuffer message, boolean last) {
getMessages().add(message);
if (last && getLatch() != null) {
getLatch().countDown();
}
}
}
public static class AsyncText extends AsyncHandler<String> {
public AsyncText(CountDownLatch latch) {
super(latch);
}
@Override
public void onMessage(String message, boolean last) {
getMessages().add(message);
if (last && getLatch() != null) {
getLatch().countDown();
}
}
}
}

View File

@@ -0,0 +1,209 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import javax.websocket.ClientEndpoint;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import org.apache.tomcat.util.ExceptionUtils;
import org.apache.tomcat.websocket.pojo.PojoEndpointClient;
/**
* Runs the Autobahn test suite in client mode for testing the WebSocket client
* implementation.
*/
public class TesterWsClientAutobahn {
private static final String HOST = "localhost";
private static final int PORT = 9001;
private static final String USER_AGENT = "ApacheTomcat8WebSocketClient";
public static void main(String[] args) throws Exception {
WebSocketContainer wsc = ContainerProvider.getWebSocketContainer();
int testCaseCount = getTestCaseCount(wsc);
System.out.println("There are " + testCaseCount + " test cases");
for (int testCase = 1; testCase <= testCaseCount; testCase++) {
if (testCase % 50 == 0) {
System.out.println(testCase);
} else {
System.out.print('.');
}
try {
executeTestCase(wsc, testCase);
} catch (Throwable t) {
ExceptionUtils.handleThrowable(t);
t.printStackTrace();
}
}
System.out.println("Testing complete");
updateReports(wsc);
}
private static int getTestCaseCount(WebSocketContainer wsc)
throws Exception {
URI uri = new URI("ws://" + HOST + ":" + PORT + "/getCaseCount");
CaseCountClient caseCountClient = new CaseCountClient();
wsc.connectToServer(caseCountClient, uri);
return caseCountClient.getCaseCount();
}
private static void executeTestCase(WebSocketContainer wsc, int testCase)
throws Exception {
URI uri = new URI("ws://" + HOST + ":" + PORT + "/runCase?case=" +
testCase + "&agent=" + USER_AGENT);
TestCaseClient testCaseClient = new TestCaseClient();
Extension permessageDeflate = new WsExtension("permessage-deflate");
// Advertise support for client_max_window_bits
// Client only supports some values so there will be some failures here
// Note Autobahn returns a 400 response if you provide a value for
// client_max_window_bits
permessageDeflate.getParameters().add(
new WsExtensionParameter("client_max_window_bits", null));
List<Extension> extensions = new ArrayList<>(1);
extensions.add(permessageDeflate);
Endpoint ep = new PojoEndpointClient(testCaseClient, null);
ClientEndpointConfig.Builder builder = ClientEndpointConfig.Builder.create();
ClientEndpointConfig config = builder.extensions(extensions).build();
wsc.connectToServer(ep, config, uri);
testCaseClient.waitForClose();
}
private static void updateReports(WebSocketContainer wsc)
throws Exception {
URI uri = new URI("ws://" + HOST + ":" + PORT +
"/updateReports?agent=" + USER_AGENT);
UpdateReportsClient updateReportsClient = new UpdateReportsClient();
wsc.connectToServer(updateReportsClient, uri);
}
@ClientEndpoint
public static class CaseCountClient {
private final CountDownLatch latch = new CountDownLatch(1);
private volatile int caseCount = 0;
// Need to wait for message
public int getCaseCount() throws InterruptedException {
latch.await();
return caseCount;
}
@OnMessage
public void onMessage(String msg) {
latch.countDown();
caseCount = Integer.parseInt(msg);
}
@OnError
public void onError(Throwable t) {
latch.countDown();
t.printStackTrace();
}
}
@ClientEndpoint
public static class TestCaseClient {
private final CountDownLatch latch = new CountDownLatch(1);
public void waitForClose() throws InterruptedException {
latch.await();
}
@OnMessage
public void echoTextMessage(Session session, String msg, boolean last) {
try {
if (session.isOpen()) {
session.getBasicRemote().sendText(msg, last);
}
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
@OnMessage
public void echoBinaryMessage(Session session, ByteBuffer bb,
boolean last) {
try {
if (session.isOpen()) {
session.getBasicRemote().sendBinary(bb, last);
}
} catch (IOException e) {
try {
session.close();
} catch (IOException e1) {
// Ignore
}
}
}
@OnClose
public void releaseLatch() {
latch.countDown();
}
}
@ClientEndpoint
public static class UpdateReportsClient {
private final CountDownLatch latch = new CountDownLatch(1);
public void waitForClose() throws InterruptedException {
latch.await();
}
@OnClose
public void onClose() {
latch.countDown();
}
}
}

View File

@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import org.junit.After;
import org.junit.Assert;
import org.apache.catalina.Context;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.catalina.startup.TomcatBaseTest;
import org.apache.tomcat.websocket.server.WsContextListener;
public abstract class WebSocketBaseTest extends TomcatBaseTest {
protected Tomcat startServer(
final Class<? extends WsContextListener> configClass)
throws LifecycleException {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(configClass.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
return tomcat;
}
@After
public void checkBackgroundProcessHasStopped() throws Exception {
// Need to stop Tomcat to ensure background processed have been stopped.
getTomcatInstance().stop();
// Make sure the background process has stopped. In some test
// environments it will continue to run and break other tests that check
// it has stopped.
int count = 0;
// 5s should be plenty here but Gump can be a lot slower so allow 60s.
while (count < 600) {
if (BackgroundProcessManager.getInstance().getProcessCount() == 0) {
break;
}
Thread.sleep(100);
count++;
}
try {
Assert.assertEquals(0, BackgroundProcessManager.getInstance().getProcessCount());
} finally {
// Ensure the next test is not affected
BackgroundProcessManager.getInstance().shutdown();
}
}
}

View File

@@ -0,0 +1,789 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.pojo;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.servlet.ServletContextEvent;
import javax.websocket.ClientEndpoint;
import javax.websocket.ContainerProvider;
import javax.websocket.DecodeException;
import javax.websocket.Decoder;
import javax.websocket.DeploymentException;
import javax.websocket.EncodeException;
import javax.websocket.Encoder;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import javax.websocket.MessageHandler;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.catalina.startup.TomcatBaseTest;
import org.apache.tomcat.websocket.pojo.TesterUtil.ServerConfigListener;
import org.apache.tomcat.websocket.pojo.TesterUtil.SingletonConfigurator;
import org.apache.tomcat.websocket.server.WsContextListener;
public class TestEncodingDecoding extends TomcatBaseTest {
private static final String MESSAGE_ONE = "message-one";
private static final String MESSAGE_TWO = "message-two";
private static final String PATH_PROGRAMMATIC_EP = "/echoProgrammaticEP";
private static final String PATH_ANNOTATED_EP = "/echoAnnotatedEP";
private static final String PATH_GENERICS_EP = "/echoGenericsEP";
private static final String PATH_MESSAGES_EP = "/echoMessagesEP";
private static final String PATH_BATCHED_EP = "/echoBatchedEP";
@Test
public void testProgrammaticEndPoints() throws Exception{
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(ProgramaticServerEndpointConfig.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
tomcat.start();
Client client = new Client();
URI uri = new URI("ws://localhost:" + getPort() + PATH_PROGRAMMATIC_EP);
Session session = wsContainer.connectToServer(client, uri);
MsgString msg1 = new MsgString();
msg1.setData(MESSAGE_ONE);
session.getBasicRemote().sendObject(msg1);
// Should not take very long
int i = 0;
while (i < 20) {
if (MsgStringMessageHandler.received.size() > 0 &&
client.received.size() > 0) {
break;
}
i++;
Thread.sleep(100);
}
// Check messages were received
Assert.assertEquals(1, MsgStringMessageHandler.received.size());
Assert.assertEquals(1, client.received.size());
// Check correct messages were received
Assert.assertEquals(MESSAGE_ONE,
((MsgString) MsgStringMessageHandler.received.peek()).getData());
Assert.assertEquals(MESSAGE_ONE,
new String(((MsgByte) client.received.peek()).getData()));
session.close();
}
@Test
public void testAnnotatedEndPoints() throws Exception {
// Set up utility classes
Server server = new Server();
SingletonConfigurator.setInstance(server);
ServerConfigListener.setPojoClazz(Server.class);
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(ServerConfigListener.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
Client client = new Client();
URI uri = new URI("ws://localhost:" + getPort() + PATH_ANNOTATED_EP);
Session session = wsContainer.connectToServer(client, uri);
MsgString msg1 = new MsgString();
msg1.setData(MESSAGE_ONE);
session.getBasicRemote().sendObject(msg1);
// Should not take very long
int i = 0;
while (i < 20) {
if (server.received.size() > 0 && client.received.size() > 0) {
break;
}
i++;
Thread.sleep(100);
}
// Check messages were received
Assert.assertEquals(1, server.received.size());
Assert.assertEquals(1, client.received.size());
// Check correct messages were received
Assert.assertEquals(MESSAGE_ONE,
((MsgString) server.received.peek()).getData());
Assert.assertEquals(MESSAGE_ONE,
((MsgString) client.received.peek()).getData());
session.close();
// Should not take very long but some failures have been seen
i = testEvent(MsgStringEncoder.class.getName()+":init", 0);
i = testEvent(MsgStringDecoder.class.getName()+":init", i);
i = testEvent(MsgByteEncoder.class.getName()+":init", i);
i = testEvent(MsgByteDecoder.class.getName()+":init", i);
i = testEvent(MsgStringEncoder.class.getName()+":destroy", i);
i = testEvent(MsgStringDecoder.class.getName()+":destroy", i);
i = testEvent(MsgByteEncoder.class.getName()+":destroy", i);
i = testEvent(MsgByteDecoder.class.getName()+":destroy", i);
}
@Test
public void testGenericsCoders() throws Exception {
// Set up utility classes
GenericsServer server = new GenericsServer();
SingletonConfigurator.setInstance(server);
ServerConfigListener.setPojoClazz(GenericsServer.class);
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(ServerConfigListener.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
GenericsClient client = new GenericsClient();
URI uri = new URI("ws://localhost:" + getPort() + PATH_GENERICS_EP);
Session session = wsContainer.connectToServer(client, uri);
ArrayList<String> list = new ArrayList<>(2);
list.add("str1");
list.add("str2");
session.getBasicRemote().sendObject(list);
// Should not take very long
int i = 0;
while (i < 20) {
if (server.received.size() > 0 && client.received.size() > 0) {
break;
}
i++;
Thread.sleep(100);
}
// Check messages were received
Assert.assertEquals(1, server.received.size());
Assert.assertEquals(server.received.peek().toString(), "[str1, str2]");
Assert.assertEquals(1, client.received.size());
Assert.assertEquals(client.received.peek().toString(), "[str1, str2]");
session.close();
}
@Test
@Ignore // TODO Investigate why this test fails
public void testMessagesEndPoints() throws Exception {
// Set up utility classes
MessagesServer server = new MessagesServer();
SingletonConfigurator.setInstance(server);
ServerConfigListener.setPojoClazz(MessagesServer.class);
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(ServerConfigListener.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
StringClient client = new StringClient();
URI uri = new URI("ws://localhost:" + getPort() + PATH_MESSAGES_EP);
Session session = wsContainer.connectToServer(client, uri);
session.getBasicRemote().sendText(MESSAGE_ONE);
// Should not take very long
int i = 0;
while (i < 20) {
if (server.received.size() > 0 && client.received.size() > 0) {
break;
}
i++;
Thread.sleep(100);
}
// Check messages were received
Assert.assertEquals(1, server.received.size());
Assert.assertEquals(1, client.received.size());
// Check correct messages were received
Assert.assertEquals(MESSAGE_ONE, server.received.peek());
session.close();
Assert.assertNull(server.t);
}
@Test
@Ignore // TODO Investigate why this test fails
public void testBatchedEndPoints() throws Exception {
// Set up utility classes
BatchedServer server = new BatchedServer();
SingletonConfigurator.setInstance(server);
ServerConfigListener.setPojoClazz(BatchedServer.class);
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(ServerConfigListener.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
StringClient client = new StringClient();
URI uri = new URI("ws://localhost:" + getPort() + PATH_BATCHED_EP);
Session session = wsContainer.connectToServer(client, uri);
session.getBasicRemote().sendText(MESSAGE_ONE);
// Should not take very long
int i = 0;
while (i++ < 20) {
if (server.received.size() > 0 && client.received.size() > 0) {
break;
}
i++;
Thread.sleep(100);
}
// Check messages were received
Assert.assertEquals(1, server.received.size());
Assert.assertEquals(2, client.received.size());
// Check correct messages were received
Assert.assertEquals(MESSAGE_ONE, server.received.peek());
session.close();
Assert.assertNull(server.t);
}
private int testEvent(String name, int count) throws InterruptedException {
int i = count;
while (i < 50) {
if (Server.isLifeCycleEventCalled(name)) {
break;
}
i++;
Thread.sleep(100);
}
Assert.assertTrue(Server.isLifeCycleEventCalled(name));
return i;
}
@ClientEndpoint(decoders=ListStringDecoder.class, encoders=ListStringEncoder.class)
public static class GenericsClient {
private Queue<Object> received = new ConcurrentLinkedQueue<>();
@OnMessage
public void rx(List<String> in) {
received.add(in);
}
}
@ClientEndpoint(decoders={MsgStringDecoder.class, MsgByteDecoder.class},
encoders={MsgStringEncoder.class, MsgByteEncoder.class})
public static class Client {
private Queue<Object> received = new ConcurrentLinkedQueue<>();
@OnMessage
public void rx(MsgString in) {
received.add(in);
}
@OnMessage
public void rx(MsgByte in) {
received.add(in);
}
}
@ClientEndpoint
public static class StringClient {
private Queue<Object> received = new ConcurrentLinkedQueue<>();
@OnMessage
public void rx(String in) {
received.add(in);
}
}
@ServerEndpoint(value=PATH_GENERICS_EP,
decoders=ListStringDecoder.class,
encoders=ListStringEncoder.class,
configurator=SingletonConfigurator.class)
public static class GenericsServer {
private Queue<Object> received = new ConcurrentLinkedQueue<>();
@OnMessage
public List<String> rx(List<String> in) {
received.add(in);
// Echo the message back
return in;
}
}
@ServerEndpoint(value=PATH_MESSAGES_EP,
configurator=SingletonConfigurator.class)
public static class MessagesServer {
private final Queue<String> received = new ConcurrentLinkedQueue<>();
private volatile Throwable t = null;
@OnMessage
public String onMessage(String message, Session session) {
received.add(message);
session.getAsyncRemote().sendText(MESSAGE_ONE);
return message;
}
@OnError
public void onError(@SuppressWarnings("unused") Session session, Throwable t) {
t.printStackTrace();
this.t = t;
}
}
@ServerEndpoint(value=PATH_BATCHED_EP,
configurator=SingletonConfigurator.class)
public static class BatchedServer {
private final Queue<String> received = new ConcurrentLinkedQueue<>();
private volatile Throwable t = null;
@OnMessage
public String onMessage(String message, Session session) throws IOException {
received.add(message);
session.getAsyncRemote().setBatchingAllowed(true);
session.getAsyncRemote().sendText(MESSAGE_ONE);
return MESSAGE_TWO;
}
@OnError
public void onError(@SuppressWarnings("unused") Session session, Throwable t) {
t.printStackTrace();
this.t = t;
}
}
@ServerEndpoint(value=PATH_ANNOTATED_EP,
decoders={MsgStringDecoder.class, MsgByteDecoder.class},
encoders={MsgStringEncoder.class, MsgByteEncoder.class},
configurator=SingletonConfigurator.class)
public static class Server {
private Queue<Object> received = new ConcurrentLinkedQueue<>();
static Map<String, Boolean> lifeCyclesCalled = new ConcurrentHashMap<>(8);
@OnMessage
public MsgString rx(MsgString in) {
received.add(in);
// Echo the message back
return in;
}
@OnMessage
public MsgByte rx(MsgByte in) {
received.add(in);
// Echo the message back
return in;
}
public static void addLifeCycleEvent(String event){
lifeCyclesCalled.put(event, Boolean.TRUE);
}
public static boolean isLifeCycleEventCalled(String event){
Boolean called = lifeCyclesCalled.get(event);
return called == null ? false : called.booleanValue();
}
}
public static class MsgByteMessageHandler implements
MessageHandler.Whole<MsgByte> {
public static final Queue<Object> received = new ConcurrentLinkedQueue<>();
private final Session session;
public MsgByteMessageHandler(Session session) {
this.session = session;
}
@Override
public void onMessage(MsgByte in) {
System.out.println(getClass() + " received");
received.add(in);
try {
MsgByte msg = new MsgByte();
msg.setData("got it".getBytes());
session.getBasicRemote().sendObject(msg);
} catch (IOException | EncodeException e) {
throw new IllegalStateException(e);
}
}
}
public static class MsgStringMessageHandler implements MessageHandler.Whole<MsgString> {
public static final Queue<Object> received = new ConcurrentLinkedQueue<>();
private final Session session;
public MsgStringMessageHandler(Session session) {
this.session = session;
}
@Override
public void onMessage(MsgString in) {
received.add(in);
try {
MsgByte msg = new MsgByte();
msg.setData(MESSAGE_ONE.getBytes());
session.getBasicRemote().sendObject(msg);
} catch (IOException | EncodeException e) {
e.printStackTrace();
}
}
}
public static class ProgrammaticEndpoint extends Endpoint {
@Override
public void onOpen(Session session, EndpointConfig config) {
session.addMessageHandler(new MsgStringMessageHandler(session));
}
}
public static class MsgString {
private String data;
public String getData() { return data; }
public void setData(String data) { this.data = data; }
}
public static class MsgStringEncoder implements Encoder.Text<MsgString> {
@Override
public void init(EndpointConfig endpointConfig) {
Server.addLifeCycleEvent(getClass().getName() + ":init");
}
@Override
public void destroy() {
Server.addLifeCycleEvent(getClass().getName() + ":destroy");
}
@Override
public String encode(MsgString msg) throws EncodeException {
return "MsgString:" + msg.getData();
}
}
public static class MsgStringDecoder implements Decoder.Text<MsgString> {
@Override
public void init(EndpointConfig endpointConfig) {
Server.addLifeCycleEvent(getClass().getName() + ":init");
}
@Override
public void destroy() {
Server.addLifeCycleEvent(getClass().getName() + ":destroy");
}
@Override
public MsgString decode(String s) throws DecodeException {
MsgString result = new MsgString();
result.setData(s.substring(10));
return result;
}
@Override
public boolean willDecode(String s) {
return s.startsWith("MsgString:");
}
}
public static class MsgByte {
private byte[] data;
public byte[] getData() { return data; }
public void setData(byte[] data) { this.data = data; }
}
public static class MsgByteEncoder implements Encoder.Binary<MsgByte> {
@Override
public void init(EndpointConfig endpointConfig) {
Server.addLifeCycleEvent(getClass().getName() + ":init");
}
@Override
public void destroy() {
Server.addLifeCycleEvent(getClass().getName() + ":destroy");
}
@Override
public ByteBuffer encode(MsgByte msg) throws EncodeException {
byte[] data = msg.getData();
ByteBuffer reply = ByteBuffer.allocate(2 + data.length);
reply.put((byte) 0x12);
reply.put((byte) 0x34);
reply.put(data);
reply.flip();
return reply;
}
}
public static class MsgByteDecoder implements Decoder.Binary<MsgByte> {
@Override
public void init(EndpointConfig endpointConfig) {
Server.addLifeCycleEvent(getClass().getName() + ":init");
}
@Override
public void destroy() {
Server.addLifeCycleEvent(getClass().getName() + ":destroy");
}
@Override
public MsgByte decode(ByteBuffer bb) throws DecodeException {
MsgByte result = new MsgByte();
byte[] data = new byte[bb.limit() - bb.position()];
bb.get(data);
result.setData(data);
return result;
}
@Override
public boolean willDecode(ByteBuffer bb) {
bb.mark();
if (bb.get() == 0x12 && bb.get() == 0x34) {
return true;
}
bb.reset();
return false;
}
}
public static class ListStringEncoder implements Encoder.Text<List<String>> {
@Override
public void init(EndpointConfig endpointConfig) {
Server.addLifeCycleEvent(getClass().getName() + ":init");
}
@Override
public void destroy() {
Server.addLifeCycleEvent(getClass().getName() + ":destroy");
}
@Override
public String encode(List<String> str) throws EncodeException {
StringBuffer sbuf = new StringBuffer();
sbuf.append("[");
for (String s: str){
sbuf.append(s).append(",");
}
sbuf.deleteCharAt(sbuf.lastIndexOf(",")).append("]");
return sbuf.toString();
}
}
public static class ListStringDecoder implements Decoder.Text<List<String>> {
@Override
public void init(EndpointConfig endpointConfig) {
Server.addLifeCycleEvent(getClass().getName() + ":init");
}
@Override
public void destroy() {
Server.addLifeCycleEvent(getClass().getName() + ":destroy");
}
@Override
public List<String> decode(String str) throws DecodeException {
List<String> lst = new ArrayList<>(1);
str = str.substring(1,str.length()-1);
String[] strings = str.split(",");
for (String t : strings){
lst.add(t);
}
return lst;
}
@Override
public boolean willDecode(String str) {
return str.startsWith("[") && str.endsWith("]");
}
}
public static class ProgramaticServerEndpointConfig extends WsContextListener {
@Override
public void contextInitialized(ServletContextEvent sce) {
super.contextInitialized(sce);
ServerContainer sc =
(ServerContainer) sce.getServletContext().getAttribute(
org.apache.tomcat.websocket.server.Constants.
SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
try {
sc.addEndpoint(new ServerEndpointConfig() {
@Override
public Map<String, Object> getUserProperties() {
return Collections.emptyMap();
}
@Override
public List<Class<? extends Encoder>> getEncoders() {
List<Class<? extends Encoder>> encoders = new ArrayList<>(2);
encoders.add(MsgStringEncoder.class);
encoders.add(MsgByteEncoder.class);
return encoders;
}
@Override
public List<Class<? extends Decoder>> getDecoders() {
List<Class<? extends Decoder>> decoders = new ArrayList<>(2);
decoders.add(MsgStringDecoder.class);
decoders.add(MsgByteDecoder.class);
return decoders;
}
@Override
public List<String> getSubprotocols() {
return Collections.emptyList();
}
@Override
public String getPath() {
return PATH_PROGRAMMATIC_EP;
}
@Override
public List<Extension> getExtensions() {
return Collections.emptyList();
}
@Override
public Class<?> getEndpointClass() {
return ProgrammaticEndpoint.class;
}
@Override
public Configurator getConfigurator() {
return new ServerEndpointConfig.Configurator() {
};
}
});
} catch (DeploymentException e) {
throw new IllegalStateException(e);
}
}
}
@Test
public void testUnsupportedObject() throws Exception{
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(ProgramaticServerEndpointConfig.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
tomcat.start();
Client client = new Client();
URI uri = new URI("ws://localhost:" + getPort() + PATH_PROGRAMMATIC_EP);
Session session = wsContainer.connectToServer(client, uri);
// This should fail
Object msg1 = new Object();
try {
session.getBasicRemote().sendObject(msg1);
Assert.fail("No exception thrown ");
} catch (EncodeException e) {
// Expected
} catch (Throwable t) {
Assert.fail("Wrong exception type");
} finally {
session.close();
}
}
}

View File

@@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.pojo;
import java.net.URI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpoint;
import javax.websocket.ContainerProvider;
import javax.websocket.EndpointConfig;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpoint;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.catalina.startup.TomcatBaseTest;
import org.apache.tomcat.websocket.TestUtil;
import org.apache.tomcat.websocket.pojo.TesterUtil.ServerConfigListener;
import org.apache.tomcat.websocket.pojo.TesterUtil.SingletonConfigurator;
public class TestPojoEndpointBase extends TomcatBaseTest {
@Test
public void testBug54716() throws Exception {
TestUtil.generateMask();
// Set up utility classes
Bug54716 server = new Bug54716();
SingletonConfigurator.setInstance(server);
ServerConfigListener.setPojoClazz(Bug54716.class);
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(ServerConfigListener.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
Client client = new Client();
URI uri = new URI("ws://localhost:" + getPort() + "/");
wsContainer.connectToServer(client, uri);
// Server should close the connection after the exception on open.
boolean closed = client.waitForClose(5);
Assert.assertTrue("Server failed to close connection", closed);
}
@Test
public void testOnOpenPojoMethod() throws Exception {
// Set up utility classes
OnOpenServerEndpoint server = new OnOpenServerEndpoint();
SingletonConfigurator.setInstance(server);
ServerConfigListener.setPojoClazz(OnOpenServerEndpoint.class);
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(ServerConfigListener.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
Client client = new Client();
URI uri = new URI("ws://localhost:" + getPort() + "/");
Session session = wsContainer.connectToServer(client, uri);
client.waitForClose(5);
Assert.assertTrue(session.isOpen());
}
@ServerEndpoint("/")
public static class OnOpenServerEndpoint {
@OnOpen
public void onOpen(@SuppressWarnings("unused") Session session,
EndpointConfig config) {
if (config == null) {
throw new RuntimeException();
}
}
@OnError
public void onError(@SuppressWarnings("unused") Throwable t){
throw new RuntimeException();
}
}
@ServerEndpoint("/")
public static class Bug54716 {
@OnOpen
public void onOpen() {
throw new RuntimeException();
}
}
@ClientEndpoint
public static final class Client {
private final CountDownLatch closeLatch = new CountDownLatch(1);
@OnClose
public void onClose() {
closeLatch.countDown();
}
public boolean waitForClose(int seconds) throws InterruptedException {
return closeLatch.await(seconds, TimeUnit.SECONDS);
}
}
}

View File

@@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.pojo;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import javax.websocket.ContainerProvider;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.catalina.startup.TomcatBaseTest;
import org.apache.tomcat.websocket.pojo.TesterUtil.ServerConfigListener;
import org.apache.tomcat.websocket.pojo.TesterUtil.SimpleClient;
import org.apache.tomcat.websocket.pojo.TesterUtil.SingletonConfigurator;
public class TestPojoMethodMapping extends TomcatBaseTest {
private static final String PARAM_ONE = "abcde";
private static final String PARAM_TWO = "12345";
private static final String PARAM_THREE = "true";
@Test
public void test() throws Exception {
// Set up utility classes
Server server = new Server();
SingletonConfigurator.setInstance(server);
ServerConfigListener.setPojoClazz(Server.class);
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(ServerConfigListener.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
SimpleClient client = new SimpleClient();
URI uri = new URI("ws://localhost:" + getPort() + "/" + PARAM_ONE +
"/" + PARAM_TWO + "/" + PARAM_THREE);
Session session = wsContainer.connectToServer(client, uri);
session.getBasicRemote().sendText("NO-OP");
session.close();
// Give server 20s to close. 5s should be plenty but the Gump VM is slow
int count = 0;
while (count < 200) {
if (server.isClosed()) {
break;
}
count++;
Thread.sleep(100);
}
if (count == 50) {
Assert.fail("Server did not process an onClose event within 5 " +
"seconds of the client sending a close message");
}
// Check no errors
List<String> errors = server.getErrors();
for (String error : errors) {
System.err.println(error);
}
Assert.assertEquals("Found errors", 0, errors.size());
}
@ServerEndpoint(value="/{one}/{two}/{three}",
configurator=SingletonConfigurator.class)
public static final class Server {
private final List<String> errors = new ArrayList<>();
private volatile boolean closed;
@OnOpen
public void onOpen(@PathParam("one") String p1, @PathParam("two")int p2,
@PathParam("three")boolean p3) {
checkParams("onOpen", p1, p2, p3);
}
@OnMessage
public void onMessage(@SuppressWarnings("unused") String msg,
@PathParam("one") String p1, @PathParam("two")int p2,
@PathParam("three")boolean p3) {
checkParams("onMessage", p1, p2, p3);
}
@OnClose
public void onClose(@PathParam("one") String p1,
@PathParam("two")int p2, @PathParam("three")boolean p3) {
checkParams("onClose", p1, p2, p3);
closed = true;
}
public List<String> getErrors() {
return errors;
}
public boolean isClosed() {
return closed;
}
private void checkParams(String method, String p1, int p2, boolean p3) {
checkParam(method, PARAM_ONE, p1);
checkParam(method, PARAM_TWO, Integer.toString(p2));
checkParam(method, PARAM_THREE, Boolean.toString(p3));
}
private void checkParam(String method, String expected, String actual) {
if (!expected.equals(actual)) {
errors.add("Method [" + method + "]. Expected [" + expected +
"] was + [" + actual + "]");
}
}
}
}

View File

@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.pojo;
import javax.websocket.ClientEndpoint;
import javax.websocket.server.ServerEndpointConfig.Configurator;
import org.apache.tomcat.websocket.server.TesterEndpointConfig;
public class TesterUtil {
public static class ServerConfigListener extends TesterEndpointConfig {
private static Class<?> pojoClazz;
public static void setPojoClazz(Class<?> pojoClazz) {
ServerConfigListener.pojoClazz = pojoClazz;
}
@Override
protected Class<?> getEndpointClass() {
return pojoClazz;
}
}
public static class SingletonConfigurator extends Configurator {
private static Object instance;
public static void setInstance(Object instance) {
SingletonConfigurator.instance = instance;
}
@Override
public <T> T getEndpointInstance(Class<T> clazz)
throws InstantiationException {
@SuppressWarnings("unchecked")
T result = (T) instance;
return result;
}
}
@ClientEndpoint
public static final class SimpleClient {
}
}

View File

@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.server;
import java.io.IOException;
import java.net.URI;
import java.util.concurrent.atomic.AtomicInteger;
import javax.websocket.ClientEndpoint;
import javax.websocket.ContainerProvider;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpoint;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.loader.WebappClassLoaderBase;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.WebSocketBaseTest;
/**
* Tests endpoint methods are called with the correct class loader.
*/
public class TestClassLoader extends WebSocketBaseTest {
private static final String PASS = "PASS";
private static final String FAIL = "FAIL";
/*
* Checks class loader for the server endpoint during onOpen and onMessage
*/
@Test
public void testSimple() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
Client client = new Client();
Session wsSession = wsContainer.connectToServer(client,
new URI("ws://localhost:" + getPort() + "/test"));
Assert.assertTrue(wsSession.isOpen());
// Wait up to 5s for a message
int count = 0;
while (count < 50 && client.getMsgCount() < 1) {
Thread.sleep(100);
}
// Check it
Assert.assertEquals(1, client.getMsgCount());
Assert.assertFalse(client.hasFailed());
wsSession.getBasicRemote().sendText("Testing");
// Wait up to 5s for a message
count = 0;
while (count < 50 && client.getMsgCount() < 2) {
Thread.sleep(100);
}
Assert.assertEquals(2, client.getMsgCount());
Assert.assertFalse(client.hasFailed());
wsSession.close();
}
@ClientEndpoint
public static class Client {
private final AtomicInteger msgCount = new AtomicInteger(0);
private boolean failed = false;
public boolean hasFailed() {
return failed;
}
public int getMsgCount() {
return msgCount.get();
}
@OnMessage
public void onMessage(String msg) {
if (!failed && !PASS.equals(msg)) {
failed = true;
}
msgCount.incrementAndGet();
}
}
@ServerEndpoint("/test")
public static class ClassLoaderEndpoint {
@OnOpen
public void onOpen(Session session) throws IOException {
if (Thread.currentThread().getContextClassLoader() instanceof WebappClassLoaderBase) {
session.getBasicRemote().sendText(PASS);
} else {
session.getBasicRemote().sendText(FAIL);
}
}
@OnMessage
public String onMessage(@SuppressWarnings("unused") String msg) {
if (Thread.currentThread().getContextClassLoader() instanceof WebappClassLoaderBase) {
return PASS;
} else {
return FAIL;
}
}
}
public static class Config extends TesterEndpointConfig {
@Override
protected Class<?> getEndpointClass() {
return ClassLoaderEndpoint.class;
}
}
}

View File

@@ -0,0 +1,339 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.server;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCode;
import javax.websocket.CloseReason.CloseCodes;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpointConfig;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Test;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.websocket.WebSocketBaseTest;
/**
* Test the behavior of closing websockets under various conditions.
*/
public class TestClose extends WebSocketBaseTest {
private static Log log = LogFactory.getLog(TestClose.class);
// These are static because it is simpler than trying to inject them into
// the endpoint
private static volatile Events events;
public static class Events {
// Used to block in the @OnMessage
public final CountDownLatch onMessageWait = new CountDownLatch(1);
// Used to check which methods of a server endpoint were called
public final CountDownLatch onErrorCalled = new CountDownLatch(1);
public final CountDownLatch onMessageCalled = new CountDownLatch(1);
public final CountDownLatch onCloseCalled = new CountDownLatch(1);
// Parameter of an @OnClose call
public volatile CloseReason closeReason = null;
// Parameter of an @OnError call
public volatile Throwable onErrorThrowable = null;
//This is set to true for tests where the @OnMessage should send a message
public volatile boolean onMessageSends = false;
}
private static void awaitLatch(CountDownLatch latch, String failMessage) {
try {
if (!latch.await(5000, TimeUnit.MILLISECONDS)) {
Assert.fail(failMessage);
}
} catch (InterruptedException e) {
// Won't happen
throw new RuntimeException(e);
}
}
public static void awaitOnClose(CloseCode... codes) {
Set<CloseCode> set = new HashSet<>();
for (CloseCode code : codes) {
set.add(code);
}
awaitOnClose(set);
}
public static void awaitOnClose(Set<CloseCode> codes) {
awaitLatch(events.onCloseCalled, "onClose not called");
CloseCode received = events.closeReason.getCloseCode();
Assert.assertTrue("Rx: " + received, codes.contains(received));
}
public static void awaitOnError(Class<? extends Throwable> exceptionClazz) {
awaitLatch(events.onErrorCalled, "onError not called");
Assert.assertTrue(events.onErrorThrowable.getClass().getName(),
exceptionClazz.isAssignableFrom(events.onErrorThrowable.getClass()));
}
@Override
@Before
public void setUp() throws Exception {
super.setUp();
events = new Events();
}
@Test
public void testTcpClose() throws Exception {
// TODO
Assume.assumeFalse("This test currently fails for APR",
getTomcatInstance().getConnector().getProtocolHandlerClassName().contains("Apr"));
startServer(TestEndpointConfig.class);
TesterWsClient client = new TesterWsClient("localhost", getPort());
client.httpUpgrade(BaseEndpointConfig.PATH);
client.closeSocket();
awaitOnClose(CloseCodes.CLOSED_ABNORMALLY);
}
@Test
public void testTcpReset() throws Exception {
startServer(TestEndpointConfig.class);
TesterWsClient client = new TesterWsClient("localhost", getPort());
client.httpUpgrade(BaseEndpointConfig.PATH);
client.forceCloseSocket();
// TODO: I'm not entirely sure when onError should be called
awaitOnError(IOException.class);
awaitOnClose(CloseCodes.CLOSED_ABNORMALLY);
}
@Test
public void testWsCloseThenTcpClose() throws Exception {
startServer(TestEndpointConfig.class);
TesterWsClient client = new TesterWsClient("localhost", getPort());
client.httpUpgrade(BaseEndpointConfig.PATH);
client.sendCloseFrame(CloseCodes.GOING_AWAY);
client.closeSocket();
awaitOnClose(CloseCodes.GOING_AWAY);
}
@Test
public void testWsCloseThenTcpReset() throws Exception {
startServer(TestEndpointConfig.class);
TesterWsClient client = new TesterWsClient("localhost", getPort());
client.httpUpgrade(BaseEndpointConfig.PATH);
client.sendCloseFrame(CloseCodes.GOING_AWAY);
client.forceCloseSocket();
// WebSocket 1.1, section 2.1.5 requires this to be CLOSED_ABNORMALLY if
// the container initiates the close and the close code from the client
// if the client initiates it. When the client resets the TCP connection
// after sending the close, different operating systems react different
// ways. Some present the close message then drop the connection, some
// just drop the connection. Therefore, this test has to handle both
// close codes.
awaitOnClose(CloseCodes.CLOSED_ABNORMALLY, CloseCodes.GOING_AWAY);
}
@Test
public void testTcpCloseInOnMessage() throws Exception {
// TODO
Assume.assumeFalse("This test currently fails for APR",
getTomcatInstance().getConnector().getProtocolHandlerClassName().contains("Apr"));
startServer(TestEndpointConfig.class);
TesterWsClient client = new TesterWsClient("localhost", getPort());
client.httpUpgrade(BaseEndpointConfig.PATH);
client.sendTextMessage("Test");
awaitLatch(events.onMessageCalled, "onMessage not called");
client.closeSocket();
events.onMessageWait.countDown();
awaitOnClose(CloseCodes.CLOSED_ABNORMALLY);
}
@Test
public void testTcpResetInOnMessage() throws Exception {
startServer(TestEndpointConfig.class);
TesterWsClient client = new TesterWsClient("localhost", getPort());
client.httpUpgrade(BaseEndpointConfig.PATH);
client.sendTextMessage("Test");
awaitLatch(events.onMessageCalled, "onMessage not called");
client.forceCloseSocket();
events.onMessageWait.countDown();
awaitOnError(IOException.class);
awaitOnClose(CloseCodes.CLOSED_ABNORMALLY);
}
@Test
public void testTcpCloseWhenOnMessageSends() throws Exception {
events.onMessageSends = true;
testTcpCloseInOnMessage();
}
@Test
public void testTcpResetWhenOnMessageSends() throws Exception {
events.onMessageSends = true;
testTcpResetInOnMessage();
}
@Test
public void testWsCloseThenTcpCloseWhenOnMessageSends() throws Exception {
events.onMessageSends = true;
startServer(TestEndpointConfig.class);
TesterWsClient client = new TesterWsClient("localhost", getPort());
client.httpUpgrade(BaseEndpointConfig.PATH);
client.sendTextMessage("Test");
awaitLatch(events.onMessageCalled, "onMessage not called");
client.sendCloseFrame(CloseCodes.NORMAL_CLOSURE);
client.closeSocket();
events.onMessageWait.countDown();
awaitOnClose(CloseCodes.CLOSED_ABNORMALLY);
}
@Test
public void testWsCloseThenTcpResetWhenOnMessageSends() throws Exception {
events.onMessageSends = true;
startServer(TestEndpointConfig.class);
TesterWsClient client = new TesterWsClient("localhost", getPort());
client.httpUpgrade(BaseEndpointConfig.PATH);
client.sendTextMessage("Test");
awaitLatch(events.onMessageCalled, "onMessage not called");
client.sendCloseFrame(CloseCodes.NORMAL_CLOSURE);
client.forceCloseSocket();
events.onMessageWait.countDown();
awaitOnClose(CloseCodes.CLOSED_ABNORMALLY);
}
public static class TestEndpoint {
@OnOpen
public void onOpen() {
log.info("Session opened");
}
@OnMessage
public void onMessage(Session session, String message) {
log.info("Message received: " + message);
events.onMessageCalled.countDown();
awaitLatch(events.onMessageWait, "onMessageWait not triggered");
if (events.onMessageSends) {
try {
int count = 0;
// The latches above are meant to ensure the correct
// sequence of events but in some cases, particularly with
// APR, there is a short delay between the client closing /
// resetting the connection and the server recognising that
// fact. This loop tries to ensure that it lasts much longer
// than that delay so any close / reset from the client
// triggers an error here.
while (count < 10) {
count++;
session.getBasicRemote().sendText("Test reply");
Thread.sleep(500);
}
} catch (IOException | InterruptedException e) {
// Expected to fail
}
}
}
@OnError
public void onError(Throwable t) {
log.info("onError", t);
events.onErrorThrowable = t;
events.onErrorCalled.countDown();
}
@OnClose
public void onClose(CloseReason cr) {
log.info("onClose: " + cr);
events.closeReason = cr;
events.onCloseCalled.countDown();
}
}
public static class TestEndpointConfig extends BaseEndpointConfig {
@Override
protected Class<?> getEndpointClass() {
return TestEndpoint.class;
}
}
public abstract static class BaseEndpointConfig extends TesterEndpointConfig {
public static final String PATH = "/test";
@Override
protected ServerEndpointConfig getServerEndpointConfig() {
return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build();
}
}
}

View File

@@ -0,0 +1,168 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.server;
import java.net.URI;
import java.util.concurrent.atomic.AtomicInteger;
import javax.servlet.ServletContextEvent;
import javax.websocket.ClientEndpoint;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.DeploymentException;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpointConfig;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.WebSocketBaseTest;
public class TestCloseBug58624 extends WebSocketBaseTest {
@Test
public void testOnErrorNotCalledWhenClosingConnection() throws Throwable {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Bug58624ServerConfig.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
tomcat.start();
Bug58624ClientEndpoint client = new Bug58624ClientEndpoint();
URI uri = new URI("ws://localhost:" + getPort() + Bug58624ServerConfig.PATH);
Session session = wsContainer.connectToServer(client, uri);
// Wait for session to open on the server
int count = 0;
while (count < 50 && Bug58624ServerEndpoint.getOpenSessionCount() == 0) {
count++;
Thread.sleep(100);
}
Assert.assertNotEquals(0, Bug58624ServerEndpoint.getOpenSessionCount());
// Now close the session
session.close();
// Wait for session to close on the server
count = 0;
while (count < 50 && Bug58624ServerEndpoint.getOpenSessionCount() > 0) {
count++;
Thread.sleep(100);
}
Assert.assertEquals(0, Bug58624ServerEndpoint.getOpenSessionCount());
// Ensure no errors were reported on the server
Assert.assertEquals(0, Bug58624ServerEndpoint.getErrorCount());
if (client.getError() != null) {
throw client.getError();
}
}
@ClientEndpoint
public class Bug58624ClientEndpoint {
private volatile Throwable t;
@OnError
public void onError(Throwable t) {
this.t = t;
}
public Throwable getError() {
return this.t;
}
}
public static class Bug58624ServerConfig extends WsContextListener {
public static final String PATH = "/bug58624";
@Override
public void contextInitialized(ServletContextEvent sce) {
super.contextInitialized(sce);
ServerContainer sc = (ServerContainer) sce.getServletContext().getAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
ServerEndpointConfig sec = ServerEndpointConfig.Builder.create(
Bug58624ServerEndpoint.class, PATH).build();
try {
sc.addEndpoint(sec);
} catch (DeploymentException e) {
throw new RuntimeException(e);
}
}
}
public static class Bug58624ServerEndpoint {
private static AtomicInteger openSessionCount = new AtomicInteger(0);
private static AtomicInteger errorCount = new AtomicInteger(0);
public static int getOpenSessionCount() {
return openSessionCount.get();
}
public static int getErrorCount() {
return errorCount.get();
}
@OnOpen
public void onOpen() {
openSessionCount.incrementAndGet();
}
@OnMessage
public void onMessage(@SuppressWarnings("unused") Session session, String message) {
System.out.println("Received message " + message);
}
@OnError
public void onError(Throwable t) {
errorCount.incrementAndGet();
t.printStackTrace();
}
@OnClose
public void onClose(@SuppressWarnings("unused") CloseReason cr) {
openSessionCount.decrementAndGet();
}
}
}

View File

@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.server;
import java.io.IOException;
import java.net.URI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ContainerProvider;
import javax.websocket.EndpointConfig;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpoint;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.TesterMessageCountClient.BasicText;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
import org.apache.tomcat.websocket.WebSocketBaseTest;
/**
* Tests inspired by https://bz.apache.org/bugzilla/show_bug.cgi?id=58835 to
* check that WebSocket connections are closed gracefully on Tomcat shutdown.
*/
public class TestShutdown extends WebSocketBaseTest {
@Test
public void testShutdownBufferedMessages() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(EchoBufferedConfig.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
Session wsSession = wsContainer.connectToServer(
TesterProgrammaticEndpoint.class,
clientEndpointConfig,
new URI("ws://localhost:" + getPort() + "/test"));
CountDownLatch latch = new CountDownLatch(1);
BasicText handler = new BasicText(latch);
wsSession.addMessageHandler(handler);
wsSession.getBasicRemote().sendText("Hello");
int count = 0;
while (count < 10 && EchoBufferedEndpoint.messageCount.get() == 0) {
Thread.sleep(200);
count++;
}
Assert.assertNotEquals("Message not received by server",
EchoBufferedEndpoint.messageCount.get(), 0);
tomcat.stop();
Assert.assertTrue("Latch expired waiting for message", latch.await(10, TimeUnit.SECONDS));
}
public static class EchoBufferedConfig extends TesterEndpointConfig {
@Override
protected Class<?> getEndpointClass() {
return EchoBufferedEndpoint.class;
}
}
@ServerEndpoint("/test")
public static class EchoBufferedEndpoint {
private static AtomicLong messageCount = new AtomicLong(0);
@OnOpen
public void onOpen(Session session, @SuppressWarnings("unused") EndpointConfig epc)
throws IOException {
session.getAsyncRemote().setBatchingAllowed(true);
}
@OnMessage
public void onMessage(Session session, String msg) throws IOException {
messageCount.incrementAndGet();
session.getBasicRemote().sendText(msg);
}
}
}

View File

@@ -0,0 +1,213 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.server;
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
public class TestUriTemplate {
@Test
public void testBasic() throws Exception {
UriTemplate t = new UriTemplate("/{a}/{b}");
Map<String,String> result = t.match(new UriTemplate("/foo/bar"));
Assert.assertEquals(2, result.size());
Assert.assertTrue(result.containsKey("a"));
Assert.assertTrue(result.containsKey("b"));
Assert.assertEquals("foo", result.get("a"));
Assert.assertEquals("bar", result.get("b"));
}
@Test
public void testOneOfTwo() throws Exception {
UriTemplate t = new UriTemplate("/{a}/{b}");
Map<String,String> result = t.match(new UriTemplate("/foo"));
Assert.assertNull(result);
}
@Test(expected=java.lang.IllegalArgumentException.class)
public void testBasicPrefix() throws Exception {
@SuppressWarnings("unused")
UriTemplate t = new UriTemplate("/x{a}/y{b}");
}
@Test(expected=java.lang.IllegalArgumentException.class)
public void testPrefixOneOfTwo() throws Exception {
UriTemplate t = new UriTemplate("/x{a}/y{b}");
t.match(new UriTemplate("/xfoo"));
}
@Test(expected=java.lang.IllegalArgumentException.class)
public void testPrefixTwoOfTwo() throws Exception {
UriTemplate t = new UriTemplate("/x{a}/y{b}");
t.match(new UriTemplate("/ybar"));
}
@Test(expected=java.lang.IllegalArgumentException.class)
public void testQuote1() throws Exception {
UriTemplate t = new UriTemplate("/.{a}");
t.match(new UriTemplate("/yfoo"));
}
@Test(expected=java.lang.IllegalArgumentException.class)
public void testQuote2() throws Exception {
@SuppressWarnings("unused")
UriTemplate t = new UriTemplate("/.{a}");
}
@Test
public void testNoParams() throws Exception {
UriTemplate t = new UriTemplate("/foo/bar");
Map<String,String> result = t.match(new UriTemplate("/foo/bar"));
Assert.assertEquals(0, result.size());
}
@Test
public void testSpecExample1_01() throws Exception {
UriTemplate t = new UriTemplate("/a/b");
Map<String,String> result = t.match(new UriTemplate("/a/b"));
Assert.assertEquals(0, result.size());
}
@Test
public void testSpecExample1_02() throws Exception {
UriTemplate t = new UriTemplate("/a/b");
Map<String,String> result = t.match(new UriTemplate("/a"));
Assert.assertNull(result);
}
@Test
public void testSpecExample1_03() throws Exception {
UriTemplate t = new UriTemplate("/a/b");
Map<String,String> result = t.match(new UriTemplate("/a/bb"));
Assert.assertNull(result);
}
@Test
public void testSpecExample2_01() throws Exception {
UriTemplate t = new UriTemplate("/a/{var}");
Map<String,String> result = t.match(new UriTemplate("/a/b"));
Assert.assertEquals(1, result.size());
Assert.assertEquals("b", result.get("var"));
}
@Test
public void testSpecExample2_02() throws Exception {
UriTemplate t = new UriTemplate("/a/{var}");
Map<String,String> result = t.match(new UriTemplate("/a/apple"));
Assert.assertEquals(1, result.size());
Assert.assertEquals("apple", result.get("var"));
}
@Test
public void testSpecExample2_03() throws Exception {
UriTemplate t = new UriTemplate("/a/{var}");
Map<String,String> result = t.match(new UriTemplate("/a"));
Assert.assertNull(result);
}
@Test
public void testSpecExample2_04() throws Exception {
UriTemplate t = new UriTemplate("/a/{var}");
Map<String,String> result = t.match(new UriTemplate("/a/b/c"));
Assert.assertNull(result);
}
@Test(expected=java.lang.IllegalArgumentException.class)
public void testDuplicate01() throws Exception {
@SuppressWarnings("unused")
UriTemplate t = new UriTemplate("/{var}/{var}");
}
@Test
public void testDuplicate02() throws Exception {
UriTemplate t = new UriTemplate("/{a}/{b}");
Map<String,String> result = t.match(new UriTemplate("/x/x"));
Assert.assertEquals(2, result.size());
Assert.assertEquals("x", result.get("a"));
Assert.assertEquals("x", result.get("b"));
}
public void testEgMailingList01() throws Exception {
UriTemplate t = new UriTemplate("/a/{var}");
Map<String,String> result = t.match(new UriTemplate("/a/b/"));
Assert.assertNull(result);
}
public void testEgMailingList02() throws Exception {
UriTemplate t = new UriTemplate("/a/{var}");
Map<String,String> result = t.match(new UriTemplate("/a/"));
Assert.assertNull(result);
}
@Test
public void testEgMailingList03() throws Exception {
UriTemplate t = new UriTemplate("/a/{var}");
Map<String,String> result = t.match(new UriTemplate("/a"));
Assert.assertNull(result);
}
@Test(expected=java.lang.IllegalArgumentException.class)
public void testEgMailingList04() throws Exception {
UriTemplate t = new UriTemplate("/a/{var1}/{var2}");
@SuppressWarnings("unused")
Map<String,String> result = t.match(new UriTemplate("/a//c"));
}
@Test(expected=java.lang.IllegalArgumentException.class)
public void testEgMailingList05() throws Exception {
UriTemplate t = new UriTemplate("/a/{var}/");
@SuppressWarnings("unused")
Map<String,String> result = t.match(new UriTemplate("/a/b/"));
}
}

View File

@@ -0,0 +1,168 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.server;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.EncodeException;
import javax.websocket.Encoder;
import javax.websocket.EndpointConfig;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpointConfig;
import org.junit.Ignore;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.WebSocketBaseTest;
import org.apache.tomcat.websocket.pojo.TesterUtil.SimpleClient;
@Ignore // This test requires manual intervention to create breakpoints etc.
public class TestWsRemoteEndpointImplServer extends WebSocketBaseTest {
/*
* https://bz.apache.org/bugzilla/show_bug.cgi?id=58624
*
* This test requires three breakpoints to be set. Two in this file (marked
* A & B with comments) and one (C) at the start of
* WsRemoteEndpointImplServer.doWrite().
*
* With the breakpoints in place, run this test.
* Once breakpoints A & B are reached, progress the thread at breakpoint A
* one line to close the connection.
* Once breakpoint C is reached, allow the thread at breakpoint B to
* continue.
* Then allow the thread at breakpoint C to continue.
*
* In the failure mode, the thread at breakpoint B will not progress past
* the call to sendObject(). If the issue is fixed, the thread at breakpoint
* B will continue past sendObject() and terminate with a TimeoutException.
*/
@Test
public void testClientDropsConnection() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Bug58624Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
SimpleClient client = new SimpleClient();
URI uri = new URI("ws://localhost:" + getPort() + Bug58624Config.PATH);
Session session = wsContainer.connectToServer(client, uri);
// Break point A required on following line
session.close();
}
public static class Bug58624Config extends TesterEndpointConfig {
public static final String PATH = "/bug58624";
@Override
protected ServerEndpointConfig getServerEndpointConfig() {
List<Class<? extends Encoder>> encoders = new ArrayList<>();
encoders.add(Bug58624Encoder.class);
return ServerEndpointConfig.Builder.create(
Bug58624Endpoint.class, PATH).encoders(encoders).build();
}
}
public static class Bug58624Endpoint {
private static final ExecutorService ex = Executors.newFixedThreadPool(1);
@OnOpen
public void onOpen(Session session) {
// Disabling blocking timeouts for this test
session.getUserProperties().put(
org.apache.tomcat.websocket.Constants.BLOCKING_SEND_TIMEOUT_PROPERTY,
Long.valueOf(-1));
ex.submit(new Bug58624SendMessage(session));
}
@OnMessage
public void onMessage(String message) {
System.out.println("OnMessage: " + message);
}
@OnError
public void onError(Throwable t) {
System.err.println("OnError:");
t.printStackTrace();
}
@OnClose
public void onClose(@SuppressWarnings("unused") Session session, CloseReason cr) {
System.out.println("Closed " + cr);
}
}
public static class Bug58624SendMessage implements Runnable {
private Session session;
public Bug58624SendMessage(Session session) {
this.session = session;
}
@Override
public void run() {
try {
// Breakpoint B required on following line
session.getBasicRemote().sendObject("test");
} catch (IOException | EncodeException e) {
e.printStackTrace();
}
}
}
public static class Bug58624Encoder implements Encoder.Text<Object> {
@Override
public void destroy() {
}
@Override
public void init(EndpointConfig endpointConfig) {
}
@Override
public String encode(Object object) throws EncodeException {
return (String) object;
}
}
}

View File

@@ -0,0 +1,340 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.server;
import java.net.URI;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ContainerProvider;
import javax.websocket.DeploymentException;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.LifecycleState;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.unittest.TesterServletContext;
import org.apache.tomcat.websocket.TesterEchoServer;
import org.apache.tomcat.websocket.TesterMessageCountClient.BasicText;
import org.apache.tomcat.websocket.WebSocketBaseTest;
import org.apache.tomcat.websocket.pojo.TesterUtil.SimpleClient;
public class TestWsServerContainer extends WebSocketBaseTest {
@Test
public void testBug54807() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Bug54807Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
Assert.assertEquals(LifecycleState.STARTED, ctx.getState());
}
@Test
public void testBug58232() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Bug54807Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
Assert.assertEquals(LifecycleState.STARTED, ctx.getState());
SimpleClient client = new SimpleClient();
URI uri = new URI("ws://localhost:" + getPort() + "/echoBasic");
try (Session session = wsContainer.connectToServer(client, uri);) {
CountDownLatch latch = new CountDownLatch(1);
BasicText handler = new BasicText(latch);
session.addMessageHandler(handler);
session.getBasicRemote().sendText("echoBasic");
boolean latchResult = handler.getLatch().await(10, TimeUnit.SECONDS);
Assert.assertTrue(latchResult);
Queue<String> messages = handler.getMessages();
Assert.assertEquals(1, messages.size());
for (String message : messages) {
Assert.assertEquals("echoBasic", message);
}
}
}
public static class Bug54807Config extends TesterEndpointConfig {
@Override
protected ServerEndpointConfig getServerEndpointConfig() {
return ServerEndpointConfig.Builder.create(
TesterEchoServer.Basic.class, "/{param}").build();
}
}
@Test
public void testSpecExample3() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Object.class, "/a/{var}/c").build();
ServerEndpointConfig configB = ServerEndpointConfig.Builder.create(
Object.class, "/a/b/c").build();
ServerEndpointConfig configC = ServerEndpointConfig.Builder.create(
Object.class, "/a/{var1}/{var2}").build();
sc.addEndpoint(configA);
sc.addEndpoint(configB);
sc.addEndpoint(configC);
Assert.assertEquals(configB, sc.findMapping("/a/b/c").getConfig());
Assert.assertEquals(configA, sc.findMapping("/a/d/c").getConfig());
Assert.assertEquals(configC, sc.findMapping("/a/x/y").getConfig());
}
@Test
public void testSpecExample4() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Object.class, "/{var1}/d").build();
ServerEndpointConfig configB = ServerEndpointConfig.Builder.create(
Object.class, "/b/{var2}").build();
sc.addEndpoint(configA);
sc.addEndpoint(configB);
Assert.assertEquals(configB, sc.findMapping("/b/d").getConfig());
}
@Test(expected = DeploymentException.class)
public void testDuplicatePaths01() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Object.class, "/a/b/c").build();
ServerEndpointConfig configB = ServerEndpointConfig.Builder.create(
Object.class, "/a/b/c").build();
sc.addEndpoint(configA);
sc.addEndpoint(configB);
}
@Test(expected = DeploymentException.class)
public void testDuplicatePaths02() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Object.class, "/a/b/{var}").build();
ServerEndpointConfig configB = ServerEndpointConfig.Builder.create(
Object.class, "/a/b/{var}").build();
sc.addEndpoint(configA);
sc.addEndpoint(configB);
}
@Test(expected = DeploymentException.class)
public void testDuplicatePaths03() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Object.class, "/a/b/{var1}").build();
ServerEndpointConfig configB = ServerEndpointConfig.Builder.create(
Object.class, "/a/b/{var2}").build();
sc.addEndpoint(configA);
sc.addEndpoint(configB);
}
@Test
public void testDuplicatePaths04() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Object.class, "/a/{var1}/{var2}").build();
ServerEndpointConfig configB = ServerEndpointConfig.Builder.create(
Object.class, "/a/b/{var2}").build();
sc.addEndpoint(configA);
sc.addEndpoint(configB);
Assert.assertEquals(configA, sc.findMapping("/a/x/y").getConfig());
Assert.assertEquals(configB, sc.findMapping("/a/b/y").getConfig());
}
/*
* Simulates a class that gets picked up for extending Endpoint and for
* being annotated.
*/
@Test(expected = DeploymentException.class)
public void testDuplicatePaths11() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Pojo.class, "/foo").build();
sc.addEndpoint(configA, false);
sc.addEndpoint(Pojo.class, true);
}
/*
* POJO auto deployment followed by programmatic duplicate. Keep POJO.
*/
@Test
public void testDuplicatePaths12() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Pojo.class, "/foo").build();
sc.addEndpoint(Pojo.class, true);
sc.addEndpoint(configA);
Assert.assertNotEquals(configA, sc.findMapping("/foo").getConfig());
}
/*
* POJO programmatic followed by programmatic duplicate.
*/
@Test(expected = DeploymentException.class)
public void testDuplicatePaths13() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Pojo.class, "/foo").build();
sc.addEndpoint(Pojo.class);
sc.addEndpoint(configA);
}
/*
* POJO auto deployment followed by programmatic on same path.
*/
@Test(expected = DeploymentException.class)
public void testDuplicatePaths14() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Object.class, "/foo").build();
sc.addEndpoint(Pojo.class, true);
sc.addEndpoint(configA);
}
/*
* Simulates a class that gets picked up for extending Endpoint and for
* being annotated.
*/
@Test(expected = DeploymentException.class)
public void testDuplicatePaths21() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
PojoTemplate.class, "/foo/{a}").build();
sc.addEndpoint(configA, false);
sc.addEndpoint(PojoTemplate.class, true);
}
/*
* POJO auto deployment followed by programmatic duplicate. Keep POJO.
*/
@Test
public void testDuplicatePaths22() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
PojoTemplate.class, "/foo/{a}").build();
sc.addEndpoint(PojoTemplate.class, true);
sc.addEndpoint(configA);
Assert.assertNotEquals(configA, sc.findMapping("/foo/{a}").getConfig());
}
/*
* POJO programmatic followed by programmatic duplicate.
*/
@Test(expected = DeploymentException.class)
public void testDuplicatePaths23() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
PojoTemplate.class, "/foo/{a}").build();
sc.addEndpoint(PojoTemplate.class);
sc.addEndpoint(configA);
}
/*
* POJO auto deployment followed by programmatic on same path.
*/
@Test(expected = DeploymentException.class)
public void testDuplicatePaths24() throws Exception {
WsServerContainer sc = new WsServerContainer(new TesterServletContext());
ServerEndpointConfig configA = ServerEndpointConfig.Builder.create(
Object.class, "/foo/{a}").build();
sc.addEndpoint(PojoTemplate.class, true);
sc.addEndpoint(configA);
}
@ServerEndpoint("/foo")
public static class Pojo {
}
@ServerEndpoint("/foo/{a}")
public static class PojoTemplate {
}
}

View File

@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.server;
import javax.servlet.ServletContextEvent;
import javax.websocket.DeploymentException;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpointConfig;
public abstract class TesterEndpointConfig extends WsContextListener {
@Override
public void contextInitialized(ServletContextEvent sce) {
super.contextInitialized(sce);
ServerContainer sc = (ServerContainer) sce.getServletContext().getAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
try {
ServerEndpointConfig sec = getServerEndpointConfig();
if (sec == null) {
sc.addEndpoint(getEndpointClass());
} else {
sc.addEndpoint(sec);
}
} catch (DeploymentException e) {
throw new RuntimeException(e);
}
}
protected Class<?> getEndpointClass() {
return null;
}
protected ServerEndpointConfig getServerEndpointConfig() {
return null;
}
}

View File

@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket.server;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import javax.websocket.CloseReason.CloseCode;
/**
* A client for testing Websocket behavior that differs from standard client
* behavior.
*/
public class TesterWsClient {
private static final byte[] maskingKey = new byte[] { 0x12, 0x34, 0x56,
0x78 };
private final Socket socket;
public TesterWsClient(String host, int port) throws Exception {
this.socket = new Socket(host, port);
// Set read timeout in case of failure so test doesn't hang
socket.setSoTimeout(2000);
// Disable Nagle's algorithm to ensure packets sent immediately
// TODO: Hoping this causes writes to wait for a TCP ACK for TCP RST
// test cases but I'm not sure?
socket.setTcpNoDelay(true);
}
public void httpUpgrade(String path) throws IOException {
String req = createUpgradeRequest(path);
write(req.getBytes(StandardCharsets.UTF_8));
readUpgradeResponse();
}
public void sendTextMessage(String text) throws IOException {
sendTextMessage(text.getBytes(StandardCharsets.UTF_8));
}
public void sendTextMessage(byte[] utf8Bytes) throws IOException {
write(createFrame(true, 1, utf8Bytes));
}
public void sendCloseFrame(CloseCode closeCode) throws IOException {
int code = closeCode.getCode();
byte[] codeBytes = new byte[2];
codeBytes[0] = (byte) (code >> 8);
codeBytes[1] = (byte) code;
write(createFrame(true, 8, codeBytes));
}
private void readUpgradeResponse() throws IOException {
BufferedReader in = new BufferedReader(new InputStreamReader(
socket.getInputStream()));
String line = in.readLine();
while (line != null && !line.isEmpty()) {
line = in.readLine();
}
}
public void closeSocket() throws IOException {
// Enable SO_LINGER to ensure close() only returns when TCP closing
// handshake completes
socket.setSoLinger(true, 65535);
socket.close();
}
/*
* Send a TCP RST instead of a TCP closing handshake
*/
public void forceCloseSocket() throws IOException {
// SO_LINGER sends a TCP RST when timeout expires
socket.setSoLinger(true, 0);
socket.close();
}
private void write(byte[] bytes) throws IOException {
socket.getOutputStream().write(bytes);
socket.getOutputStream().flush();
}
private static String createUpgradeRequest(String path) {
String[] upgradeRequestLines = { "GET " + path + " HTTP/1.1",
"Connection: Upgrade", "Host: localhost:8080",
"Origin: localhost:8080",
"Sec-WebSocket-Key: OEvAoAKn5jsuqv2/YJ1Wfg==",
"Sec-WebSocket-Version: 13", "Upgrade: websocket" };
StringBuffer sb = new StringBuffer();
for (String line : upgradeRequestLines) {
sb.append(line);
sb.append("\r\n");
}
sb.append("\r\n");
return sb.toString();
}
private static byte[] createFrame(boolean fin, int opCode, byte[] payload) {
byte[] frame = new byte[6 + payload.length];
frame[0] = (byte) (opCode | (fin ? 1 << 7 : 0));
frame[1] = (byte) (0x80 | payload.length);
frame[2] = maskingKey[0];
frame[3] = maskingKey[1];
frame[4] = maskingKey[2];
frame[5] = maskingKey[3];
for (int i = 0; i < payload.length; i++) {
frame[i + 6] = (byte) (payload[i] ^ maskingKey[i % 4]);
}
return frame;
}
}