Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #4502 - allow changing of close response from jetty and javax websocket onClose events #4523

Merged
merged 3 commits into from
Jan 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@ public void block(long timeout, TimeUnit unit) throws IOException
{
Throwable cause = e.getCause();
if (cause instanceof RuntimeException)
throw (RuntimeException)cause;
else if (cause instanceof IOException)
throw (IOException)cause;
throw new RuntimeException(cause);
lachlan-roberts marked this conversation as resolved.
Show resolved Hide resolved
else
throw new IOException(cause);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
Expand Down Expand Up @@ -60,6 +62,8 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
private final Logger logger;
private final JavaxWebSocketContainer container;
private final Object endpointInstance;
private final AtomicBoolean closeNotified = new AtomicBoolean();

/**
* List of configured named variables in the uri-template.
* <p>
Expand Down Expand Up @@ -278,24 +282,41 @@ public void onFrame(Frame frame, Callback callback)
dataType = OpCode.UNDEFINED;
}

public void onClose(Frame frame, Callback callback)
lachlan-roberts marked this conversation as resolved.
Show resolved Hide resolved
{
notifyOnClose(CloseStatus.getCloseStatus(frame), callback);
}

@Override
public void onClosed(CloseStatus closeStatus, Callback callback)
{
notifyOnClose(closeStatus, callback);
container.notifySessionListeners((listener) -> listener.onJavaxWebSocketSessionClosed(session));
sbordet marked this conversation as resolved.
Show resolved Hide resolved
}

private void notifyOnClose(CloseStatus closeStatus, Callback callback)
{
// Make sure onClose is only notified once.
if (!closeNotified.compareAndSet(false, true))
lachlan-roberts marked this conversation as resolved.
Show resolved Hide resolved
{
callback.failed(new ClosedChannelException());
return;
}

try
{
if (closeHandle != null)
{
CloseReason closeReason = new CloseReason(CloseReason.CloseCodes.getCloseCode(closeStatus.getCode()), closeStatus.getReason());
closeHandle.invoke(closeReason);
}

callback.succeeded();
}
catch (Throwable cause)
{
callback.failed(new WebSocketException(endpointInstance.getClass().getSimpleName() + " CLOSE method error: " + cause.getMessage(), cause));
}

container.notifySessionListeners((listener) -> listener.onJavaxWebSocketSessionClosed(session));
}

@Override
Expand Down Expand Up @@ -572,11 +593,6 @@ private void acceptMessage(Frame frame, Callback callback)
activeMessageSink = null;
}

public void onClose(Frame frame, Callback callback)
{
callback.succeeded();
}

public void onPing(Frame frame, Callback callback)
{
ByteBuffer payload = BufferUtil.copy(frame.getPayload());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class EventSocket

public CountDownLatch openLatch = new CountDownLatch(1);
public CountDownLatch closeLatch = new CountDownLatch(1);
public CountDownLatch errorLatch = new CountDownLatch(1);

@OnOpen
public void onOpen(Session session, EndpointConfig endpointConfig)
Expand Down Expand Up @@ -85,5 +86,6 @@ public void onError(Throwable cause)
if (LOG.isDebugEnabled())
LOG.debug("{} onError(): {}", toString(), cause);
error = cause;
errorLatch.countDown();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//

package org.eclipse.jetty.websocket.javax.tests;

import java.io.IOException;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import javax.websocket.ClientEndpoint;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import javax.websocket.EndpointConfig;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;

import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.websocket.javax.client.JavaxWebSocketClientContainer;
import org.eclipse.jetty.websocket.javax.server.config.JavaxWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class JavaxOnCloseTest
{
private static BlockingArrayQueue<OnCloseEndpoint> serverEndpoints = new BlockingArrayQueue<>();

private Server server;
private ServerConnector connector;
private JavaxWebSocketClientContainer client = new JavaxWebSocketClientContainer();

@ServerEndpoint("/")
public static class OnCloseEndpoint extends EventSocket
{
private Consumer<Session> onClose;

public void setOnClose(Consumer<Session> onClose)
{
this.onClose = onClose;
}

@Override
public void onOpen(Session session, EndpointConfig endpointConfig)
{
super.onOpen(session, endpointConfig);
serverEndpoints.add(this);
}

@Override
public void onClose(CloseReason reason)
{
super.onClose(reason);
onClose.accept(session);
}
}

@ClientEndpoint
public static class BlockingClientEndpoint extends EventSocket
{
private CountDownLatch blockInClose = new CountDownLatch(1);

public void unBlockClose()
{
blockInClose.countDown();
}

@Override
public void onClose(CloseReason reason)
{
try
{
blockInClose.await();
super.onClose(reason);
}
catch (InterruptedException e)
{
throw new RuntimeException(e);
}
}
}

@BeforeEach
public void start() throws Exception
{
server = new Server();
connector = new ServerConnector(server);
connector.setPort(0);
server.addConnector(connector);

ServletContextHandler contextHandler = new ServletContextHandler(ServletContextHandler.SESSIONS);
contextHandler.setContextPath("/");
server.setHandler(contextHandler);

JavaxWebSocketServletContainerInitializer.configure(contextHandler, ((servletContext, container) ->
container.addEndpoint(OnCloseEndpoint.class)));

client.start();
server.start();
}

@AfterEach
public void stop() throws Exception
{
client.stop();
server.stop();
}

@Test
public void changeStatusCodeInOnClose() throws Exception
{
EventSocket clientEndpoint = new EventSocket();
URI uri = new URI("ws://localhost:" + connector.getLocalPort() + "/");
client.connectToServer(clientEndpoint, uri);

OnCloseEndpoint serverEndpoint = Objects.requireNonNull(serverEndpoints.poll(5, TimeUnit.SECONDS));
serverEndpoint.setOnClose((session) -> assertDoesNotThrow(() ->
session.close(new CloseReason(CloseCodes.SERVICE_RESTART, "custom close reason"))));
assertTrue(serverEndpoint.openLatch.await(5, TimeUnit.SECONDS));

clientEndpoint.session.close();
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeReason.getCloseCode(), is(CloseCodes.SERVICE_RESTART));
assertThat(clientEndpoint.closeReason.getReasonPhrase(), is("custom close reason"));
}

@Test
public void secondCloseFromOnCloseFails() throws Exception
{
EventSocket clientEndpoint = new EventSocket();
URI uri = new URI("ws://localhost:" + connector.getLocalPort() + "/");
client.connectToServer(clientEndpoint, uri);

OnCloseEndpoint serverEndpoint = Objects.requireNonNull(serverEndpoints.poll(5, TimeUnit.SECONDS));
assertTrue(serverEndpoint.openLatch.await(5, TimeUnit.SECONDS));
serverEndpoint.setOnClose((session) -> assertThrows(ClosedChannelException.class, session::close));

serverEndpoint.session.close(new CloseReason(CloseCodes.NORMAL_CLOSURE, "first close"));
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeReason.getCloseCode(), is(CloseCodes.NORMAL_CLOSURE));
assertThat(clientEndpoint.closeReason.getReasonPhrase(), is("first close"));
}

@Test
public void abnormalStatusDoesNotChange() throws Exception
{
BlockingClientEndpoint clientEndpoint = new BlockingClientEndpoint();
URI uri = new URI("ws://localhost:" + connector.getLocalPort() + "/");
client.connectToServer(clientEndpoint, uri);

OnCloseEndpoint serverEndpoint = Objects.requireNonNull(serverEndpoints.poll(5, TimeUnit.SECONDS));
assertTrue(serverEndpoint.openLatch.await(5, TimeUnit.SECONDS));
serverEndpoint.setOnClose((session) ->
{
IOException error = assertThrows(IOException.class,
() -> session.close(new CloseReason(CloseCodes.UNEXPECTED_CONDITION, "abnormal close 2")));
assertThat(error.getCause(), instanceOf(ClosedChannelException.class));
clientEndpoint.unBlockClose();
});

serverEndpoint.session.close(new CloseReason(CloseCodes.PROTOCOL_ERROR, "abnormal close 1"));
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeReason.getCloseCode(), is(CloseCodes.PROTOCOL_ERROR));
assertThat(clientEndpoint.closeReason.getReasonPhrase(), is("abnormal close 1"));
}

@Test
public void onErrorOccurringAfterOnClose() throws Exception
{
EventSocket clientEndpoint = new EventSocket();
URI uri = new URI("ws://localhost:" + connector.getLocalPort() + "/");
client.connectToServer(clientEndpoint, uri);

OnCloseEndpoint serverEndpoint = Objects.requireNonNull(serverEndpoints.poll(5, TimeUnit.SECONDS));
assertTrue(serverEndpoint.openLatch.await(5, TimeUnit.SECONDS));
serverEndpoint.setOnClose((session) ->
{
throw new RuntimeException("trigger onError from onClose");
});

try
{
clientEndpoint.session.close();
}
catch (IOException e)
{
// Ignore. This only occurs in the rare case where the
// close response is received while we are still sending the message.
}

assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeReason.getCloseCode(), is(CloseCodes.UNEXPECTED_CONDITION));
assertThat(clientEndpoint.closeReason.getReasonPhrase(), containsString("trigger onError from onClose"));

assertTrue(serverEndpoint.errorLatch.await(5, TimeUnit.SECONDS));
assertThat(serverEndpoint.error, instanceOf(RuntimeException.class));
assertThat(serverEndpoint.error.getMessage(), containsString("trigger onError from onClose"));
}
}
Loading