Skip to content

Commit

Permalink
Merge pull request #39642 from mkouba/issue-39224
Browse files Browse the repository at this point in the history
WebSocket Next: endpoint callback arguments injection
  • Loading branch information
mkouba authored Mar 26, 2024
2 parents 8ab1f7d + b5593f8 commit 316f8b7
Show file tree
Hide file tree
Showing 23 changed files with 1,049 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package io.quarkus.websockets.next.deployment;

import java.util.Set;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.MethodParameterInfo;
import org.jboss.jandex.Type;

import io.quarkus.gizmo.BytecodeCreator;
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketServerException;

/**
* Provides arguments for method parameters of a callback method declared on a WebSocket endpoint.
*/
interface CallbackArgument {

/**
*
* @param context
* @return {@code true} if this provider matches the given parameter context, {@code false} otherwise
* @throws WebSocketServerException If an invalid parameter is detected
*/
boolean matches(ParameterContext context);

/**
* This method is only used if {@link #matches(ParameterContext)} previously returned {@code true} for the same parameter
* context.
*
* @param context
* @return the result handle to be passed as an argument to a callback method
*/
ResultHandle get(InvocationBytecodeContext context);

/**
*
* @return the priority
*/
default int priotity() {
return DEFAULT_PRIORITY;
}

static final int DEFAULT_PRIORITY = 1;

interface ParameterContext {

/**
*
* @return the endpoint path
*/
String endpointPath();

/**
*
* @return the callback marker annotation
*/
AnnotationInstance callbackAnnotation();

/**
*
* @return the Java method parameter
*/
MethodParameterInfo parameter();

/**
*
* @return the set of parameter annotations, potentially transformed
*/
Set<AnnotationInstance> parameterAnnotations();

default boolean acceptsMessage() {
return WebSocketDotNames.ON_BINARY_MESSAGE.equals(callbackAnnotation().name())
|| WebSocketDotNames.ON_TEXT_MESSAGE.equals(callbackAnnotation().name())
|| WebSocketDotNames.ON_PONG_MESSAGE.equals(callbackAnnotation().name());
}

}

interface InvocationBytecodeContext extends ParameterContext {

/**
*
* @return the bytecode
*/
BytecodeCreator bytecode();

/**
* Obtains the message directly in the bytecode.
*
* @return the message object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks
*/
ResultHandle getMessage();

/**
* Attempts to obtain the decoded message directly in the bytecode.
*
* @param parameterType
* @return the decoded message object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks
*/
ResultHandle getDecodedMessage(Type parameterType);

/**
* Obtains the current connection directly in the bytecode.
*
* @return the current {@link WebSocketConnection}, never {@code null}
*/
ResultHandle getConnection();

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.quarkus.websockets.next.deployment;

import io.quarkus.builder.item.MultiBuildItem;

final class CallbackArgumentBuildItem extends MultiBuildItem {

private final CallbackArgument provider;

CallbackArgumentBuildItem(CallbackArgument provider) {
this.provider = provider;
}

CallbackArgument getProvider() {
return provider;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package io.quarkus.websockets.next.deployment;

import java.util.ArrayList;
import java.util.List;

import io.quarkus.builder.item.SimpleBuildItem;
import io.quarkus.websockets.next.deployment.CallbackArgument.ParameterContext;

final class CallbackArgumentsBuildItem extends SimpleBuildItem {

final List<CallbackArgument> sortedArguments;

CallbackArgumentsBuildItem(List<CallbackArgument> providers) {
this.sortedArguments = providers;
}

/**
*
* @param context
* @return all matching providers, never {@code null}
*/
List<CallbackArgument> findMatching(ParameterContext context) {
List<CallbackArgument> matching = new ArrayList<>();
for (CallbackArgument argument : sortedArguments) {
if (argument.matches(context)) {
matching.add(argument);
}
}
return matching;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.quarkus.websockets.next.deployment;

import io.quarkus.gizmo.ResultHandle;

class ConnectionCallbackArgument implements CallbackArgument {

@Override
public boolean matches(ParameterContext context) {
return context.parameter().type().name().equals(WebSocketDotNames.WEB_SOCKET_CONNECTION);
}

@Override
public ResultHandle get(InvocationBytecodeContext context) {
return context.getConnection();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package io.quarkus.websockets.next.deployment;

import io.quarkus.gizmo.MethodDescriptor;
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.websockets.next.WebSocketConnection;

class HandshakeRequestCallbackArgument implements CallbackArgument {

@Override
public boolean matches(ParameterContext context) {
return context.parameter().type().name().equals(WebSocketDotNames.HANDSHAKE_REQUEST);
}

@Override
public ResultHandle get(InvocationBytecodeContext context) {
ResultHandle connection = context.getConnection();
return context.bytecode().invokeInterfaceMethod(MethodDescriptor.ofMethod(WebSocketConnection.class, "handshakeRequest",
WebSocketConnection.HandshakeRequest.class), connection);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.quarkus.websockets.next.deployment;

import io.quarkus.gizmo.ResultHandle;

class MessageCallbackArgument implements CallbackArgument {

@Override
public boolean matches(ParameterContext context) {
return context.acceptsMessage() && context.parameterAnnotations().isEmpty();
}

@Override
public ResultHandle get(InvocationBytecodeContext context) {
return context.getDecodedMessage(context.parameter().type());
}

@Override
public int priotity() {
return DEFAULT_PRIORITY - 1;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package io.quarkus.websockets.next.deployment;

import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationValue;

import io.quarkus.arc.processor.Annotations;
import io.quarkus.gizmo.MethodDescriptor;
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketServerException;

class PathParamCallbackArgument implements CallbackArgument {

@Override
public boolean matches(ParameterContext context) {
String paramName = getParamName(context);
if (paramName != null) {
if (!context.parameter().type().name().equals(WebSocketDotNames.STRING)) {
throw new WebSocketServerException("Method parameter annotated with @PathParam must be java.lang.String: "
+ WebSocketServerProcessor.callbackToString(context.parameter().method()));
}
List<String> pathParams = getPathParamNames(context.endpointPath());
if (!pathParams.contains(paramName)) {
throw new WebSocketServerException(
String.format(
"@PathParam name [%s] must be used in the endpoint path [%s]: %s", paramName,
context.endpointPath(),
WebSocketServerProcessor.callbackToString(context.parameter().method())));
}
return true;
}
return false;
}

@Override
public ResultHandle get(InvocationBytecodeContext context) {
ResultHandle connection = context.getConnection();
String paramName = getParamName(context);
return context.bytecode().invokeInterfaceMethod(
MethodDescriptor.ofMethod(WebSocketConnection.class, "pathParam", String.class, String.class), connection,
context.bytecode().load(paramName));
}

private String getParamName(ParameterContext context) {
AnnotationInstance pathParamAnnotation = Annotations.find(context.parameterAnnotations(), WebSocketDotNames.PATH_PARAM);
if (pathParamAnnotation != null) {
String paramName;
AnnotationValue nameVal = pathParamAnnotation.value();
if (nameVal != null) {
paramName = nameVal.asString();
} else {
// Try to use the element name
paramName = context.parameter().name();
}
if (paramName == null) {
throw new WebSocketServerException(String.format(
"Unable to extract the path parameter name - method parameter names not recorded for %s: compile the class with -parameters",
context.parameter().method().declaringClass().name()));
}
return paramName;
}
return null;
}

static List<String> getPathParamNames(String path) {
List<String> names = new ArrayList<>();
Matcher m = WebSocketServerProcessor.TRANSLATED_PATH_PARAM_PATTERN.matcher(path);
while (m.find()) {
names.add(m.group().substring(1));
}
return names;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnPongMessage;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.PathParam;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketConnection;
import io.smallrye.common.annotation.Blocking;
Expand Down Expand Up @@ -35,4 +36,6 @@ final class WebSocketDotNames {
static final DotName JSON_OBJECT = DotName.createSimple(JsonObject.class);
static final DotName JSON_ARRAY = DotName.createSimple(JsonArray.class);
static final DotName VOID = DotName.createSimple(Void.class);
static final DotName PATH_PARAM = DotName.createSimple(PathParam.class);
static final DotName HANDSHAKE_REQUEST = DotName.createSimple(WebSocketConnection.HandshakeRequest.class);
}
Loading

0 comments on commit 316f8b7

Please sign in to comment.