diff --git a/jetty-websocket/jetty-websocket-tests/src/test/java/org/eclipse/jetty/websocket/tests/WebSocketOverHTTP2Test.java b/jetty-websocket/jetty-websocket-tests/src/test/java/org/eclipse/jetty/websocket/tests/WebSocketOverHTTP2Test.java index f028132af0a8..16351aa533f7 100644 --- a/jetty-websocket/jetty-websocket-tests/src/test/java/org/eclipse/jetty/websocket/tests/WebSocketOverHTTP2Test.java +++ b/jetty-websocket/jetty-websocket-tests/src/test/java/org/eclipse/jetty/websocket/tests/WebSocketOverHTTP2Test.java @@ -22,6 +22,7 @@ import java.io.InterruptedIOException; import java.net.ConnectException; import java.net.URI; +import java.nio.channels.ClosedChannelException; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.function.Function; @@ -45,6 +46,7 @@ import org.eclipse.jetty.server.HttpChannel; import org.eclipse.jetty.server.HttpConfiguration; import org.eclipse.jetty.server.HttpConnectionFactory; +import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.SecureRequestCustomizer; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; @@ -61,6 +63,7 @@ import org.eclipse.jetty.websocket.server.JettyWebSocketServlet; import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory; import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.eclipse.jetty.websocket.servlet.internal.UpgradeHttpServletRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -234,7 +237,8 @@ protected void service(HttpServletRequest request, HttpServletResponse response) assertTrue(wsEndPoint.closeLatch.await(5, TimeUnit.SECONDS)); } - @Test void testWebSocketConnectPortDoesNotExist() throws Exception + @Test + public void testWebSocketConnectPortDoesNotExist() throws Exception { startServer(); startClient(clientConnector -> new ClientConnectionFactoryOverHTTP2.H2(new HTTP2Client(clientConnector))); @@ -250,7 +254,8 @@ protected void service(HttpServletRequest request, HttpServletResponse response) assertThat(cause.getMessage(), containsStringIgnoringCase("Connection refused")); } - @Test void testWebSocketNotFound() throws Exception + @Test + public void testWebSocketNotFound() throws Exception { startServer(); startClient(clientConnector -> new ClientConnectionFactoryOverHTTP2.H2(new HTTP2Client(clientConnector))); @@ -266,7 +271,8 @@ protected void service(HttpServletRequest request, HttpServletResponse response) assertThat(cause.getMessage(), containsStringIgnoringCase("Unexpected HTTP Response Status Code: 501")); } - @Test void testNotNegotiated() throws Exception + @Test + public void testNotNegotiated() throws Exception { startServer(); startClient(clientConnector -> new ClientConnectionFactoryOverHTTP2.H2(new HTTP2Client(clientConnector))); @@ -282,7 +288,8 @@ protected void service(HttpServletRequest request, HttpServletResponse response) assertThat(cause.getMessage(), containsStringIgnoringCase("Unexpected HTTP Response Status Code: 503")); } - @Test void testThrowFromCreator() throws Exception + @Test + public void testThrowFromCreator() throws Exception { startServer(); startClient(clientConnector -> new ClientConnectionFactoryOverHTTP2.H2(new HTTP2Client(clientConnector))); @@ -302,6 +309,22 @@ protected void service(HttpServletRequest request, HttpServletResponse response) assertThat(cause.getMessage(), containsStringIgnoringCase("Unexpected HTTP Response Status Code: 500")); } + @Test + public void testServerConnectionClose() throws Exception + { + startServer(); + startClient(clientConnector -> new ClientConnectionFactoryOverHTTP2.H2(new HTTP2Client(clientConnector))); + + EventSocket wsEndPoint = new EventSocket(); + URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/ws/connectionClose"); + + ExecutionException failure = Assertions.assertThrows(ExecutionException.class, () -> + wsClient.connect(wsEndPoint, uri).get(5, TimeUnit.SECONDS)); + + Throwable cause = failure.getCause(); + assertThat(cause, instanceOf(ClosedChannelException.class)); + } + private static class TestJettyWebSocketServlet extends JettyWebSocketServlet { @Override @@ -313,6 +336,13 @@ protected void configure(JettyWebSocketServletFactory factory) { throw new RuntimeException("throwing from creator"); }); + factory.addMapping("/ws/connectionClose", (request, response) -> + { + UpgradeHttpServletRequest servletRequest = (UpgradeHttpServletRequest)request.getHttpServletRequest(); + Request baseRequest = servletRequest.getBaseRequest(); + baseRequest.getHttpChannel().getEndPoint().close(); + return new EchoSocket(); + }); } } } diff --git a/jetty-websocket/websocket-servlet/src/main/java/org/eclipse/jetty/websocket/servlet/internal/UpgradeHttpServletRequest.java b/jetty-websocket/websocket-servlet/src/main/java/org/eclipse/jetty/websocket/servlet/internal/UpgradeHttpServletRequest.java index 85fa5b968e6b..5d13c47b74e9 100644 --- a/jetty-websocket/websocket-servlet/src/main/java/org/eclipse/jetty/websocket/servlet/internal/UpgradeHttpServletRequest.java +++ b/jetty-websocket/websocket-servlet/src/main/java/org/eclipse/jetty/websocket/servlet/internal/UpgradeHttpServletRequest.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.TreeMap; import javax.servlet.AsyncContext; import javax.servlet.DispatcherType; @@ -55,6 +56,7 @@ public class UpgradeHttpServletRequest implements HttpServletRequest { private static final String UNSUPPORTED_WITH_WEBSOCKET_UPGRADE = "Feature unsupported with a Upgraded to WebSocket HttpServletRequest"; + private final Request baseRequest; private final ServletContext context; private final DispatcherType dispatcher; private final String method; @@ -110,8 +112,9 @@ public UpgradeHttpServletRequest(HttpServletRequest httpRequest) remoteUser = httpRequest.getRemoteUser(); principal = httpRequest.getUserPrincipal(); - authentication = Request.getBaseRequest(httpRequest).getAuthentication(); - scope = Request.getBaseRequest(httpRequest).getUserIdentityScope(); + baseRequest = Objects.requireNonNull(Request.getBaseRequest(httpRequest)); + authentication = baseRequest.getAuthentication(); + scope = baseRequest.getUserIdentityScope(); Enumeration headerNames = httpRequest.getHeaderNames(); while (headerNames.hasMoreElements()) @@ -278,6 +281,11 @@ public HttpSession getSession() return session; } + public Request getBaseRequest() + { + return baseRequest; + } + @Override public String getRequestedSessionId() {