Skip to content

Commit

Permalink
Issue #3428 - Initial refactor to support javax websocket decoderLists
Browse files Browse the repository at this point in the history
Signed-off-by: Lachlan Roberts <lachlan@webtide.com>
  • Loading branch information
lachlan-roberts committed May 19, 2020
1 parent aa52d67 commit 4b19c19
Show file tree
Hide file tree
Showing 26 changed files with 954 additions and 667 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,10 @@ public EndpointConfig newDefaultEndpointConfig(Class<?> endpointClass)
public JavaxWebSocketFrameHandlerMetadata getMetadata(Class<?> endpointClass, EndpointConfig endpointConfig)
{
if (javax.websocket.Endpoint.class.isAssignableFrom(endpointClass))
{
return createEndpointMetadata((Class<? extends Endpoint>)endpointClass, endpointConfig);
}

if (endpointClass.getAnnotation(ClientEndpoint.class) == null)
{
return null;
}

JavaxWebSocketFrameHandlerMetadata metadata = new JavaxWebSocketFrameHandlerMetadata(endpointConfig);
return discoverJavaxFrameHandlerMetadata(endpointClass, metadata);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under
// the terms of the Eclipse Public License 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0
//
// This Source Code may also be made available under the following
// Secondary Licenses when the conditions for such availability set
// forth in the Eclipse Public License, v. 2.0 are satisfied:
// the Apache License v2.0 which is available at
// https://www.apache.org/licenses/LICENSE-2.0
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//

package org.eclipse.jetty.websocket.javax.common;

import java.io.InputStream;
import java.io.Reader;
import java.nio.ByteBuffer;
import javax.websocket.PongMessage;
import javax.websocket.Session;

import org.eclipse.jetty.websocket.util.InvokerUtils;

// The different kind of @OnMessage method parameter signatures expected.
public class JavaxWebSocketCallingArgs
{
static final InvokerUtils.Arg[] textCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(String.class).required()
};

static final InvokerUtils.Arg[] textPartialCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(String.class).required(),
new InvokerUtils.Arg(boolean.class).required()
};

static final InvokerUtils.Arg[] binaryBufferCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(ByteBuffer.class).required()
};

static final InvokerUtils.Arg[] binaryPartialBufferCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(ByteBuffer.class).required(),
new InvokerUtils.Arg(boolean.class).required()
};

static final InvokerUtils.Arg[] binaryArrayCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(byte[].class).required()
};

static final InvokerUtils.Arg[] binaryPartialArrayCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(byte[].class).required(),
new InvokerUtils.Arg(boolean.class).required()
};

static final InvokerUtils.Arg[] inputStreamCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(InputStream.class).required()
};

static final InvokerUtils.Arg[] readerCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(Reader.class).required()
};

static final InvokerUtils.Arg[] pongCallingArgs = new InvokerUtils.Arg[]{
new InvokerUtils.Arg(Session.class),
new InvokerUtils.Arg(PongMessage.class).required()
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand All @@ -46,6 +47,7 @@
import org.eclipse.jetty.websocket.core.exception.ProtocolException;
import org.eclipse.jetty.websocket.core.exception.WebSocketException;
import org.eclipse.jetty.websocket.javax.common.decoders.AvailableDecoders;
import org.eclipse.jetty.websocket.javax.common.decoders.RegisteredDecoder;
import org.eclipse.jetty.websocket.javax.common.messages.DecodedBinaryMessageSink;
import org.eclipse.jetty.websocket.javax.common.messages.DecodedBinaryStreamMessageSink;
import org.eclipse.jetty.websocket.javax.common.messages.DecodedTextMessageSink;
Expand Down Expand Up @@ -95,9 +97,9 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
private MethodHandle openHandle;
private MethodHandle closeHandle;
private MethodHandle errorHandle;
private JavaxWebSocketFrameHandlerMetadata.MessageMetadata textMetadata;
private JavaxWebSocketFrameHandlerMetadata.MessageMetadata binaryMetadata;
private MethodHandle pongHandle;
private JavaxWebSocketMessageMetadata textMetadata;
private JavaxWebSocketMessageMetadata binaryMetadata;

private UpgradeRequest upgradeRequest;

Expand All @@ -114,8 +116,8 @@ public class JavaxWebSocketFrameHandler implements FrameHandler
public JavaxWebSocketFrameHandler(JavaxWebSocketContainer container,
Object endpointInstance,
MethodHandle openHandle, MethodHandle closeHandle, MethodHandle errorHandle,
JavaxWebSocketFrameHandlerMetadata.MessageMetadata textMetadata,
JavaxWebSocketFrameHandlerMetadata.MessageMetadata binaryMetadata,
JavaxWebSocketMessageMetadata textMetadata,
JavaxWebSocketMessageMetadata binaryMetadata,
MethodHandle pongHandle,
EndpointConfig endpointConfig)
{
Expand Down Expand Up @@ -170,26 +172,32 @@ public void onOpen(CoreSession coreSession, Callback callback)
errorHandle = InvokerUtils.bindTo(errorHandle, session);
pongHandle = InvokerUtils.bindTo(pongHandle, session);

JavaxWebSocketFrameHandlerMetadata.MessageMetadata actualTextMetadata = JavaxWebSocketFrameHandlerMetadata.MessageMetadata.copyOf(textMetadata);
JavaxWebSocketMessageMetadata actualTextMetadata = JavaxWebSocketMessageMetadata.copyOf(textMetadata);
if (actualTextMetadata != null)
{
if (actualTextMetadata.isMaxMessageSizeSet())
session.setMaxTextMessageBufferSize(actualTextMetadata.maxMessageSize);
session.setMaxTextMessageBufferSize(actualTextMetadata.getMaxMessageSize());

MethodHandle methodHandle = actualTextMetadata.getMethodHandle();
methodHandle = InvokerUtils.bindTo(methodHandle, endpointInstance, endpointConfig, session);
methodHandle = JavaxWebSocketFrameHandlerFactory.wrapNonVoidReturnType(methodHandle, session);
actualTextMetadata.setMethodHandle(methodHandle);

actualTextMetadata.handle = InvokerUtils.bindTo(actualTextMetadata.handle, endpointInstance, endpointConfig, session);
actualTextMetadata.handle = JavaxWebSocketFrameHandlerFactory.wrapNonVoidReturnType(actualTextMetadata.handle, session);
textSink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, actualTextMetadata);
textMetadata = actualTextMetadata;
}

JavaxWebSocketFrameHandlerMetadata.MessageMetadata actualBinaryMetadata = JavaxWebSocketFrameHandlerMetadata.MessageMetadata.copyOf(binaryMetadata);
JavaxWebSocketMessageMetadata actualBinaryMetadata = JavaxWebSocketMessageMetadata.copyOf(binaryMetadata);
if (actualBinaryMetadata != null)
{
if (actualBinaryMetadata.isMaxMessageSizeSet())
session.setMaxBinaryMessageBufferSize(actualBinaryMetadata.maxMessageSize);
session.setMaxBinaryMessageBufferSize(actualBinaryMetadata.getMaxMessageSize());

MethodHandle methodHandle = actualBinaryMetadata.getMethodHandle();
methodHandle = InvokerUtils.bindTo(methodHandle, endpointInstance, endpointConfig, session);
methodHandle = JavaxWebSocketFrameHandlerFactory.wrapNonVoidReturnType(methodHandle, session);
actualBinaryMetadata.setMethodHandle(methodHandle);

actualBinaryMetadata.handle = InvokerUtils.bindTo(actualBinaryMetadata.handle, endpointInstance, endpointConfig, session);
actualBinaryMetadata.handle = JavaxWebSocketFrameHandlerFactory.wrapNonVoidReturnType(actualBinaryMetadata.handle, session);
binarySink = JavaxWebSocketFrameHandlerFactory.createMessageSink(session, actualBinaryMetadata);
binaryMetadata = actualBinaryMetadata;
}
Expand Down Expand Up @@ -350,12 +358,12 @@ public Map<Byte, RegisteredMessageHandler> getMessageHandlerMap()
return messageHandlerMap;
}

public JavaxWebSocketFrameHandlerMetadata.MessageMetadata getBinaryMetadata()
public JavaxWebSocketMessageMetadata getBinaryMetadata()
{
return binaryMetadata;
}

public JavaxWebSocketFrameHandlerMetadata.MessageMetadata getTextMetadata()
public JavaxWebSocketMessageMetadata getTextMetadata()
{
return textMetadata;
}
Expand All @@ -369,7 +377,7 @@ private void assertBasicTypeNotRegistered(byte basicWebSocketType, Object messag
}
}

public <T> void addMessageHandler(JavaxWebSocketSession session, Class<T> clazz, MessageHandler.Partial<T> handler)
public <T> void addMessageHandler(Class<T> clazz, MessageHandler.Partial<T> handler)
{
try
{
Expand All @@ -384,29 +392,29 @@ public <T> void addMessageHandler(JavaxWebSocketSession session, Class<T> clazz,
assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName());
MessageSink messageSink = new PartialByteArrayMessageSink(coreSession, partialMessageHandler);
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
JavaxWebSocketFrameHandlerMetadata.MessageMetadata metadata = new JavaxWebSocketFrameHandlerMetadata.MessageMetadata();
metadata.handle = partialMessageHandler;
metadata.sinkClass = PartialByteArrayMessageSink.class;
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(partialMessageHandler);
metadata.setSinkClass(PartialByteArrayMessageSink.class);
this.binaryMetadata = metadata;
}
else if (ByteBuffer.class.isAssignableFrom(clazz))
{
assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName());
MessageSink messageSink = new PartialByteBufferMessageSink(coreSession, partialMessageHandler);
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
JavaxWebSocketFrameHandlerMetadata.MessageMetadata metadata = new JavaxWebSocketFrameHandlerMetadata.MessageMetadata();
metadata.handle = partialMessageHandler;
metadata.sinkClass = PartialByteBufferMessageSink.class;
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(partialMessageHandler);
metadata.setSinkClass(PartialByteBufferMessageSink.class);
this.binaryMetadata = metadata;
}
else if (String.class.isAssignableFrom(clazz))
{
assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName());
MessageSink messageSink = new PartialStringMessageSink(coreSession, partialMessageHandler);
this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink);
JavaxWebSocketFrameHandlerMetadata.MessageMetadata metadata = new JavaxWebSocketFrameHandlerMetadata.MessageMetadata();
metadata.handle = partialMessageHandler;
metadata.sinkClass = PartialStringMessageSink.class;
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(partialMessageHandler);
metadata.setSinkClass(PartialStringMessageSink.class);
this.textMetadata = metadata;
}
else
Expand All @@ -426,67 +434,67 @@ else if (String.class.isAssignableFrom(clazz))
}
}

public <T> void addMessageHandler(JavaxWebSocketSession session, Class<T> clazz, MessageHandler.Whole<T> handler)
public <T> void addMessageHandler(Class<T> clazz, MessageHandler.Whole<T> handler)
{
try
{
MethodHandles.Lookup lookup = JavaxWebSocketFrameHandlerFactory.getServerMethodHandleLookup();
MethodHandle wholeMsgMethodHandle = lookup.findVirtual(MessageHandler.Whole.class, "onMessage", MethodType.methodType(void.class, Object.class));
wholeMsgMethodHandle = wholeMsgMethodHandle.bindTo(handler);
MethodHandle methodHandle = JavaxWebSocketFrameHandlerFactory.getServerMethodHandleLookup()
.findVirtual(MessageHandler.Whole.class, "onMessage", MethodType.methodType(void.class, Object.class))
.bindTo(handler);

if (PongMessage.class.isAssignableFrom(clazz))
{
assertBasicTypeNotRegistered(OpCode.PONG, this.pongHandle, handler.getClass().getName());
this.pongHandle = wholeMsgMethodHandle;
this.pongHandle = methodHandle;
registerMessageHandler(OpCode.PONG, clazz, handler, null);
}
else
{
AvailableDecoders availableDecoders = session.getDecoders();

AvailableDecoders.RegisteredDecoder registeredDecoder = availableDecoders.getRegisteredDecoderFor(clazz);
RegisteredDecoder registeredDecoder = availableDecoders.getFirstRegisteredDecoder(clazz);
if (registeredDecoder == null)
{
throw new IllegalStateException("Unable to find Decoder for type: " + clazz);
}

JavaxWebSocketFrameHandlerMetadata.MessageMetadata metadata = new JavaxWebSocketFrameHandlerMetadata.MessageMetadata();
metadata.handle = wholeMsgMethodHandle;
metadata.registeredDecoder = registeredDecoder;
JavaxWebSocketMessageMetadata metadata = new JavaxWebSocketMessageMetadata();
metadata.setMethodHandle(methodHandle);
metadata.setRegisteredDecoder(registeredDecoder);

if (registeredDecoder.implementsInterface(Decoder.Binary.class))
{
assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName());
Decoder.Binary<T> decoder = availableDecoders.getInstanceOf(registeredDecoder);
MessageSink messageSink = new DecodedBinaryMessageSink(coreSession, decoder, wholeMsgMethodHandle);
metadata.sinkClass = messageSink.getClass();
List<RegisteredDecoder> binaryDecoders = availableDecoders.getBinaryDecoders(clazz);
MessageSink messageSink = new DecodedBinaryMessageSink<T>(coreSession, methodHandle, binaryDecoders);
metadata.setSinkClass(messageSink.getClass());
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
this.binaryMetadata = metadata;
}
else if (registeredDecoder.implementsInterface(Decoder.BinaryStream.class))
{
assertBasicTypeNotRegistered(OpCode.BINARY, this.binaryMetadata, handler.getClass().getName());
Decoder.BinaryStream<T> decoder = availableDecoders.getInstanceOf(registeredDecoder);
MessageSink messageSink = new DecodedBinaryStreamMessageSink(coreSession, decoder, wholeMsgMethodHandle);
metadata.sinkClass = messageSink.getClass();
List<RegisteredDecoder> binaryStreamDecoders = availableDecoders.getBinaryStreamDecoders(clazz);
MessageSink messageSink = new DecodedBinaryStreamMessageSink<T>(coreSession, methodHandle, binaryStreamDecoders);
metadata.setSinkClass(messageSink.getClass());
this.binarySink = registerMessageHandler(OpCode.BINARY, clazz, handler, messageSink);
this.binaryMetadata = metadata;
}
else if (registeredDecoder.implementsInterface(Decoder.Text.class))
{
assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName());
Decoder.Text<T> decoder = availableDecoders.getInstanceOf(registeredDecoder);
MessageSink messageSink = new DecodedTextMessageSink(coreSession, decoder, wholeMsgMethodHandle);
metadata.sinkClass = messageSink.getClass();
List<RegisteredDecoder> textDecoders = availableDecoders.getTextDecoders(clazz);
MessageSink messageSink = new DecodedTextMessageSink<T>(coreSession, methodHandle, textDecoders);
metadata.setSinkClass(messageSink.getClass());
this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink);
this.textMetadata = metadata;
}
else if (registeredDecoder.implementsInterface(Decoder.TextStream.class))
{
assertBasicTypeNotRegistered(OpCode.TEXT, this.textMetadata, handler.getClass().getName());
Decoder.TextStream<T> decoder = availableDecoders.getInstanceOf(registeredDecoder);
MessageSink messageSink = new DecodedTextStreamMessageSink(coreSession, decoder, wholeMsgMethodHandle);
metadata.sinkClass = messageSink.getClass();
List<RegisteredDecoder> textStreamDecoders = availableDecoders.getTextStreamDecoders(clazz);
MessageSink messageSink = new DecodedTextStreamMessageSink<T>(coreSession, methodHandle, textStreamDecoders);
metadata.setSinkClass(messageSink.getClass());
this.textSink = registerMessageHandler(OpCode.TEXT, clazz, handler, messageSink);
this.textMetadata = metadata;
}
Expand Down
Loading

0 comments on commit 4b19c19

Please sign in to comment.