Files
apache-tomcat-8.5.51/test/org/apache/tomcat/websocket/TestWsWebSocketContainerGetOpenSessions.java
2024-11-30 19:03:49 +08:00

389 lines
12 KiB
Java

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import javax.servlet.ServletContextEvent;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import org.junit.Assert;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.server.Constants;
import org.apache.tomcat.websocket.server.WsContextListener;
/*
* This method is split out into a separate class to make it easier to track the
* various permutations and combinations of client and server endpoints.
*
* Each test uses 2 client endpoint and 2 server endpoints with each client
* connecting to each server for a total of four connections (note sometimes
* the two clients and/or the two servers will be the sam)e.
*/
public class TestWsWebSocketContainerGetOpenSessions extends WebSocketBaseTest {
@Test
public void testClientAClientAPojoAPojoA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/pojoA", "/pojoA", 2, 2, 4, 4);
}
@Test
public void testClientAClientBPojoAPojoA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/pojoA", "/pojoA", 2, 2, 4, 4);
}
@Test
public void testClientAClientAPojoAPojoB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/pojoA", "/pojoB", 2, 2, 2, 2);
}
@Test
public void testClientAClientBPojoAPojoB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/pojoA", "/pojoB", 2, 2, 2, 2);
}
@Test
public void testClientAClientAProgAProgA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/progA", "/progA", 2, 2, 4, 4);
}
@Test
public void testClientAClientBProgAProgA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/progA", "/progA", 2, 2, 4, 4);
}
@Test
public void testClientAClientAProgAProgB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/progA", "/progB", 2, 2, 2, 2);
}
@Test
public void testClientAClientBProgAProgB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/progA", "/progB", 2, 2, 2, 2);
}
@Test
public void testClientAClientAPojoAProgA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/pojoA", "/progA", 2, 2, 2, 2);
}
@Test
public void testClientAClientBPojoAProgA() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/pojoA", "/progA", 2, 2, 2, 2);
}
@Test
public void testClientAClientAPojoAProgB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointA();
doTest(client1, client2, "/pojoA", "/progB", 2, 2, 2, 2);
}
@Test
public void testClientAClientBPojoAProgB() throws Exception {
Endpoint client1 = new ClientEndpointA();
Endpoint client2 = new ClientEndpointB();
doTest(client1, client2, "/pojoA", "/progB", 2, 2, 2, 2);
}
private void doTest(Endpoint client1, Endpoint client2, String server1, String server2,
int client1Count, int client2Count, int server1Count, int server2Count) throws Exception {
Tracker.reset();
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
Session sClient1Server1 = createSession(wsContainer, client1, "client1", server1);
Session sClient1Server2 = createSession(wsContainer, client1, "client1", server2);
Session sClient2Server1 = createSession(wsContainer, client2, "client2", server1);
Session sClient2Server2 = createSession(wsContainer, client2, "client2", server2);
int delayCount = 0;
// Wait for up to 20s for this to complete. It should be a lot faster
// but some CI systems get be slow at times.
while (Tracker.getUpdateCount() < 8 && delayCount < 400) {
Thread.sleep(50);
delayCount++;
}
Assert.assertTrue(Tracker.checkRecord("client1", client1Count));
Assert.assertTrue(Tracker.checkRecord("client2", client2Count));
// Note: need to strip leading '/' from path
Assert.assertTrue(Tracker.checkRecord(server1.substring(1), server1Count));
Assert.assertTrue(Tracker.checkRecord(server2.substring(1), server2Count));
sClient1Server1.close();
sClient1Server2.close();
sClient2Server1.close();
sClient2Server2.close();
}
private Session createSession(WebSocketContainer wsContainer, Endpoint client,
String clientName, String server)
throws DeploymentException, IOException, URISyntaxException {
Session s = wsContainer.connectToServer(client,
ClientEndpointConfig.Builder.create().build(),
new URI("ws://localhost:" + getPort() + server));
Tracker.addRecord(clientName, s.getOpenSessions().size());
s.getBasicRemote().sendText("X");
return s;
}
public static class Config extends WsContextListener {
@Override
public void contextInitialized(ServletContextEvent sce) {
super.contextInitialized(sce);
ServerContainer sc =
(ServerContainer) sce.getServletContext().getAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
try {
sc.addEndpoint(PojoEndpointA.class);
sc.addEndpoint(PojoEndpointB.class);
sc.addEndpoint(ServerEndpointConfig.Builder.create(
ServerEndpointA.class, "/progA").build());
sc.addEndpoint(ServerEndpointConfig.Builder.create(
ServerEndpointB.class, "/progB").build());
} catch (DeploymentException e) {
throw new IllegalStateException(e);
}
}
}
public abstract static class PojoEndpointBase {
@OnMessage
public void onMessage(@SuppressWarnings("unused") String msg, Session session) {
Tracker.addRecord(getTrackingName(), session.getOpenSessions().size());
}
protected abstract String getTrackingName();
}
@ServerEndpoint("/pojoA")
public static class PojoEndpointA extends PojoEndpointBase {
@Override
protected String getTrackingName() {
return "pojoA";
}
}
@ServerEndpoint("/pojoB")
public static class PojoEndpointB extends PojoEndpointBase {
@Override
protected String getTrackingName() {
return "pojoB";
}
}
public abstract static class ServerEndpointBase extends Endpoint{
@Override
public void onOpen(Session session, EndpointConfig config) {
session.addMessageHandler(new TrackerMessageHandler(session, getTrackingName()));
}
protected abstract String getTrackingName();
}
public static final class ServerEndpointA extends ServerEndpointBase {
@Override
protected String getTrackingName() {
return "progA";
}
}
public static final class ServerEndpointB extends ServerEndpointBase {
@Override
protected String getTrackingName() {
return "progB";
}
}
public static final class TrackerMessageHandler implements MessageHandler.Whole<String> {
private final Session session;
private final String trackingName;
public TrackerMessageHandler(Session session, String trackingName) {
this.session = session;
this.trackingName = trackingName;
}
@Override
public void onMessage(String message) {
Tracker.addRecord(trackingName, session.getOpenSessions().size());
}
}
public abstract static class ClientEndpointBase extends Endpoint {
@Override
public void onOpen(Session session, EndpointConfig config) {
// NO-OP
}
@Override
public void onClose(Session session, CloseReason closeReason) {
// NO-OP
}
protected abstract String getTrackingName();
}
public static final class ClientEndpointA extends ClientEndpointBase {
@Override
protected String getTrackingName() {
return "clientA";
}
}
public static final class ClientEndpointB extends ClientEndpointBase {
@Override
protected String getTrackingName() {
return "clientB";
}
}
public static class Tracker {
private static final Map<String, Integer> records = new ConcurrentHashMap<>();
private static final AtomicInteger updateCount = new AtomicInteger(0);
public static void addRecord(String key, int count) {
records.put(key, Integer.valueOf(count));
updateCount.incrementAndGet();
}
public static boolean checkRecord(String key, int expectedCount) {
Integer actualCount = records.get(key);
if (actualCount == null) {
if (expectedCount == 0) {
return true;
} else {
return false;
}
} else {
return actualCount.intValue() == expectedCount;
}
}
public static int getUpdateCount() {
return updateCount.intValue();
}
public static void reset() {
records.clear();
updateCount.set(0);
}
}
}