Skip to content

Commit

Permalink
Merge pull request #32405 from Ladicek/arc-mocking-intercepted-bean
Browse files Browse the repository at this point in the history
ArC: fix spying on intercepted beans
  • Loading branch information
gsmet authored Apr 5, 2023
2 parents 769878a + 503ad0d commit 398e845
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

import jakarta.enterprise.context.spi.Contextual;
Expand Down Expand Up @@ -371,8 +370,8 @@ private ResultHandle createInterceptor(ResultHandle interceptorBean, BytecodeCre

private ResultHandle createForwardingFunction(MethodCreator init, ClassInfo target, MethodInfo method) {
// Forwarding function
// Function<InvocationContext, Object> forward = ctx -> Foo.interceptMe_original((java.lang.String)ctx.getParameters()[0])
FunctionCreator func = init.createFunction(Function.class);
// BiFunction<Object, InvocationContext, Object> forward = (ignored, ctx) -> Foo.interceptMe_original((java.lang.String)ctx.getParameters()[0])
FunctionCreator func = init.createFunction(BiFunction.class);
BytecodeCreator funcBytecode = func.getBytecode();
List<Type> paramTypes = method.parameterTypes();
ResultHandle[] paramHandles;
Expand All @@ -382,7 +381,7 @@ private ResultHandle createForwardingFunction(MethodCreator init, ClassInfo targ
params = new String[0];
} else {
paramHandles = new ResultHandle[paramTypes.size()];
ResultHandle ctxHandle = funcBytecode.getMethodParam(0);
ResultHandle ctxHandle = funcBytecode.getMethodParam(1);
ResultHandle ctxParamsHandle = funcBytecode.invokeInterfaceMethod(
MethodDescriptor.ofMethod(InvocationContext.class, "getParameters", Object[].class),
ctxHandle);
Expand Down
45 changes: 13 additions & 32 deletions independent-projects/arc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,33 @@
</scm>

<properties>
<angus-activation.version>1.0.0</angus-activation.version>
<expressly.version>5.0.0</expressly.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.target>11</maven.compiler.target>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.release>11</maven.compiler.release>
<!-- Versions -->

<!-- specification versions -->
<version.cdi>4.0.1</version.cdi>
<version.jakarta-annotation>2.1.1</version.jakarta-annotation>
<version.jpa>3.0.0</version.jpa>
<version.jta>2.0.1</version.jta>
<!-- main versions -->
<version.gizmo>1.6.0.Final</version.gizmo>
<version.jandex>3.0.5</version.jandex>
<version.jboss-logging>3.5.0.Final</version.jboss-logging>
<version.mutiny>2.1.0</version.mutiny>
<!-- test versions -->
<version.assertj>3.24.2</version.assertj>
<version.junit5>5.9.2</version.junit5>
<version.kotlin>1.8.10</version.kotlin>
<version.kotlin-coroutines>1.6.4</version.kotlin-coroutines>
<version.maven>3.8.8</version.maven>
<version.assertj>3.24.2</version.assertj>
<version.jboss-logging>3.5.0.Final</version.jboss-logging>
<version.jakarta-annotation>2.1.1</version.jakarta-annotation>
<version.gizmo>1.6.0.Final</version.gizmo>
<version.jpa>3.0.0</version.jpa>
<version.mutiny>2.1.0</version.mutiny>
<version.mockito>5.2.0</version.mockito>
<!-- TCK versions -->
<version.arquillian>1.7.0.Alpha14</version.arquillian>
<version.atinject-tck>2.0.1</version.atinject-tck>
<version.cdi-tck>4.0.9</version.cdi-tck>
<version.junit4>4.13.2</version.junit4>

<!-- Maven plugin versions -->
<version.compiler.plugin>3.11.0</version.compiler.plugin>
<version.enforcer.plugin>3.2.1</version.enforcer.plugin>
<version.surefire.plugin>3.0.0</version.surefire.plugin>
Expand Down Expand Up @@ -140,27 +142,6 @@
</dependency>


<dependency>
<groupId>org.apache.maven</groupId>
<artifactId>maven-plugin-api</artifactId>
<version>${version.maven}</version>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.apache.maven.plugin-tools</groupId>
<artifactId>maven-plugin-annotations</artifactId>
<version>${version.maven}</version>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.apache.maven</groupId>
<artifactId>maven-core</artifactId>
<version>${version.maven}</version>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.jboss.logging</groupId>
<artifactId>jboss-logging</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.function.Supplier;

import jakarta.enterprise.context.spi.Context;
Expand Down Expand Up @@ -231,8 +231,7 @@ public final class MethodDescriptors {
String.class);

public static final MethodDescriptor INTERCEPTED_METHOD_METADATA_CONSTRUCTOR = MethodDescriptor.ofConstructor(
InterceptedMethodMetadata.class,
List.class, Method.class, Set.class, Function.class);
InterceptedMethodMetadata.class, List.class, Method.class, Set.class, BiFunction.class);

public static final MethodDescriptor CREATIONAL_CTX_HAS_DEPENDENT_INSTANCES = MethodDescriptor.ofMethod(
CreationalContextImpl.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -403,10 +404,11 @@ public String apply(List<BindingKey> keys) {
}

// Instantiate the forwarding function
// Function<InvocationContext, Object> forward = ctx -> super.foo((java.lang.String)ctx.getParameters()[0])
FunctionCreator func = initMetadataMethod.createFunction(Function.class);
// BiFunction<Object, InvocationContext, Object> forward = (target, ctx) -> target.foo$$superforward((java.lang.String)ctx.getParameters()[0])
FunctionCreator func = initMetadataMethod.createFunction(BiFunction.class);
BytecodeCreator funcBytecode = func.getBytecode();
ResultHandle ctxHandle = funcBytecode.getMethodParam(0);
ResultHandle targetHandle = funcBytecode.getMethodParam(0);
ResultHandle ctxHandle = funcBytecode.getMethodParam(1);
ResultHandle[] superParamHandles;
if (parameters.isEmpty()) {
superParamHandles = new ResultHandle[0];
Expand Down Expand Up @@ -438,7 +440,7 @@ public String apply(List<BindingKey> keys) {
.returnValue(funcBytecode.invokeVirtualMethod(virtualMethodDescriptor, funDecoratorInstance,
superParamHandles));
} else {
ResultHandle superResult = funcBytecode.invokeVirtualMethod(forwardDescriptor, initMetadataMethod.getThis(),
ResultHandle superResult = funcBytecode.invokeVirtualMethod(forwardDescriptor, targetHandle,
superParamHandles);
funcBytecode.returnValue(superResult != null ? superResult : funcBytecode.loadNull());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ private Object proceed(int currentPosition) throws Exception {
.invoke(new NextAroundInvokeInvocationContext(currentPosition + 1, parameters));
} else {
// Invoke the target method
return metadata.aroundInvokeForward.apply(this);
return metadata.aroundInvokeForward.apply(target, this);
}
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import java.lang.reflect.Method;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.BiFunction;

import jakarta.interceptor.InvocationContext;

Expand All @@ -16,10 +16,10 @@ public class InterceptedMethodMetadata {
public final List<InterceptorInvocation> chain;
public final Method method;
public final Set<Annotation> bindings;
public final Function<InvocationContext, Object> aroundInvokeForward;
public final BiFunction<Object, InvocationContext, Object> aroundInvokeForward;

public InterceptedMethodMetadata(List<InterceptorInvocation> chain, Method method, Set<Annotation> bindings,
Function<InvocationContext, Object> aroundInvokeForward) {
BiFunction<Object, InvocationContext, Object> aroundInvokeForward) {
this.chain = chain;
this.method = method;
this.bindings = bindings;
Expand Down
7 changes: 7 additions & 0 deletions independent-projects/arc/tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>${version.mockito}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>jakarta.persistence</groupId>
<artifactId>jakarta.persistence-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package io.quarkus.arc.test.mocking;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import jakarta.annotation.Priority;
import jakarta.inject.Singleton;
import jakarta.interceptor.AroundInvoke;
import jakarta.interceptor.Interceptor;
import jakarta.interceptor.InterceptorBinding;
import jakarta.interceptor.InvocationContext;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito;

import io.quarkus.arc.Arc;
import io.quarkus.arc.test.ArcTestContainer;

public class SpyOnInterceptedBeanTest {
@RegisterExtension
public ArcTestContainer container = new ArcTestContainer(MyBean.class, MyInterceptorBinding.class, MyInterceptor.class);

@Test
public void test() {
MyBean bean = Arc.container().instance(MyBean.class).get();

MyBean spy = Mockito.spy(bean);
Mockito.when(spy.getValue()).thenReturn("quux");

assertEquals("intercepted: intercepted: quux42", spy.doSomething(42));
Mockito.verify(spy).doSomething(Mockito.anyInt());
Mockito.verify(spy).doSomethingElse();
Mockito.verify(spy).getValue();
}

@Singleton
static class MyBean {
@MyInterceptorBinding
String doSomething(int param) {
return doSomethingElse() + param;
}

@MyInterceptorBinding
String doSomethingElse() {
return getValue();
}

String getValue() {
return "foobar";
}
}

@Target({ ElementType.TYPE, ElementType.METHOD })
@Retention(RetentionPolicy.RUNTIME)
@Documented
@InterceptorBinding
public @interface MyInterceptorBinding {
}

@MyInterceptorBinding
@Interceptor
@Priority(1)
public static class MyInterceptor {
@AroundInvoke
Object aroundInvoke(InvocationContext ctx) throws Exception {
return "intercepted: " + ctx.proceed();
}
}
}

0 comments on commit 398e845

Please sign in to comment.