Skip to content

Commit

Permalink
Initialize late-init extensions for every test class instance
Browse files Browse the repository at this point in the history
  • Loading branch information
marcphilipp committed Jul 23, 2024
1 parent f9ba7bf commit c0dfec1
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import org.junit.jupiter.api.extension.TestInstantiationException;
import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.engine.config.JupiterConfiguration;
import org.junit.jupiter.engine.descriptor.ExtensionUtils.ProgrammaticExtensionRegistration;
import org.junit.jupiter.engine.execution.AfterEachMethodAdapter;
import org.junit.jupiter.engine.execution.BeforeEachMethodAdapter;
import org.junit.jupiter.engine.execution.DefaultExecutableInvoker;
Expand Down Expand Up @@ -177,7 +176,7 @@ public JupiterEngineExecutionContext prepare(JupiterEngineExecutionContext conte
registerBeforeEachMethodAdapters(registry);
registerAfterEachMethodAdapters(registry);
this.afterAllMethods.forEach(method -> registerExtensionsFromExecutableParameters(registry, method));
ProgrammaticExtensionRegistration registration = registerExtensionsFromInstanceFields(registry, this.testClass);
registerExtensionsFromInstanceFields(registry, this.testClass);

ThrowableCollector throwableCollector = createThrowableCollector();
ExecutableInvoker executableInvoker = new DefaultExecutableInvoker(context);
Expand All @@ -187,7 +186,7 @@ public JupiterEngineExecutionContext prepare(JupiterEngineExecutionContext conte

// @formatter:off
return context.extend()
.withTestInstancesProvider(testInstancesProvider(context, extensionContext, registration))
.withTestInstancesProvider(testInstancesProvider(context, extensionContext))
.withExtensionRegistry(registry)
.withExtensionContext(extensionContext)
.withThrowableCollector(throwableCollector)
Expand Down Expand Up @@ -276,25 +275,25 @@ private TestInstanceFactory resolveTestInstanceFactory(ExtensionRegistry registr
}

private TestInstancesProvider testInstancesProvider(JupiterEngineExecutionContext parentExecutionContext,
ClassExtensionContext extensionContext, ProgrammaticExtensionRegistration registration) {
ClassExtensionContext extensionContext) {

return (registry, registrar, throwableCollector) -> extensionContext.getTestInstances().orElseGet(
() -> instantiateAndPostProcessTestInstance(parentExecutionContext, extensionContext, registry, registrar,
throwableCollector, registration));
throwableCollector));
}

private TestInstances instantiateAndPostProcessTestInstance(JupiterEngineExecutionContext parentExecutionContext,
ExtensionContext extensionContext, ExtensionRegistry registry, ExtensionRegistrar registrar,
ThrowableCollector throwableCollector, ProgrammaticExtensionRegistration registration) {
ThrowableCollector throwableCollector) {

TestInstances instances = instantiateTestClass(parentExecutionContext, registry, registrar, extensionContext,
throwableCollector);
throwableCollector.execute(() -> {
invokeTestInstancePostProcessors(instances.getInnermostInstance(), registry, extensionContext);
// In addition, we complete programmatic extension registration from instance fields here since the
// best time to do that is immediately following test class instantiation
// and post processing.
registration.complete(instances.getInnermostInstance());
// In addition, we initialize extension registered programmatically from instance fields here
// since the best time to do that is immediately following test class instantiation
// and post-processing.
registrar.initializeExtensions(this.testClass, instances.getInnermostInstance());
});
return instances;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Executable;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;

Expand All @@ -36,7 +34,6 @@
import org.junit.jupiter.api.extension.Extension;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.engine.extension.ExtensionRegistrar;
import org.junit.jupiter.engine.extension.ExtensionRegistrar.RegistrationToken;
import org.junit.jupiter.engine.extension.MutableExtensionRegistry;
import org.junit.platform.commons.PreconditionViolationException;
import org.junit.platform.commons.util.Preconditions;
Expand Down Expand Up @@ -120,9 +117,7 @@ static void registerExtensionsFromStaticFields(ExtensionRegistrar registrar, Cla
* @param clazz the class or interface in which to find the fields; never {@code null}
* @since 5.11
*/
static ProgrammaticExtensionRegistration registerExtensionsFromInstanceFields(ExtensionRegistrar registrar,
Class<?> clazz) {
ProgrammaticExtensionRegistration registration = new ProgrammaticExtensionRegistration();
static void registerExtensionsFromInstanceFields(ExtensionRegistrar registrar, Class<?> clazz) {
streamExtensionRegisteringFields(clazz, ReflectionUtils::isNotStatic) //
.forEach(field -> {
List<Class<? extends Extension>> extensionTypes = streamDeclarativeExtensionTypes(field).collect(
Expand All @@ -133,14 +128,10 @@ static ProgrammaticExtensionRegistration registerExtensionsFromInstanceFields(Ex
extensionTypes.forEach(registrar::registerExtension);
}
if (isAnnotated(field, RegisterExtension.class)) {
registration.proxies.add(new LateInitExtensionProxy( //
registrar.registerExtensionToken(field), //
instance -> readAndValidateExtensionFromField(field, instance, extensionTypes) //
));
registrar.registerUninitializedExtension(clazz, field,
instance -> readAndValidateExtensionFromField(field, instance, extensionTypes));
}
});

return registration;
}

/**
Expand Down Expand Up @@ -245,33 +236,4 @@ private static int getOrder(Field field) {
return findAnnotation(field, Order.class).map(Order::value).orElse(Order.DEFAULT);
}

/**
* @since 5.11
*/
static class ProgrammaticExtensionRegistration {
private final List<LateInitExtensionProxy> proxies = new ArrayList<>();

void complete(Object instance) {
proxies.forEach(proxy -> proxy.complete(instance));
}
}

/**
* @since 5.11
*/
private static class LateInitExtensionProxy {

private final RegistrationToken token;
private final Function<Object, Extension> initializer;

LateInitExtensionProxy(RegistrationToken token, Function<Object, Extension> initializer) {
this.token = token;
this.initializer = initializer;
}

void complete(Object instance) {
token.complete(initializer.apply(instance));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public boolean mayRegisterTests() {
// --- Node ----------------------------------------------------------------

@Override
public JupiterEngineExecutionContext prepare(JupiterEngineExecutionContext context) throws Exception {
public JupiterEngineExecutionContext prepare(JupiterEngineExecutionContext context) {
MutableExtensionRegistry registry = populateNewExtensionRegistryFromExtendWithAnnotation(
context.getExtensionRegistry(), getTestMethod());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static org.apiguardian.api.API.Status.INTERNAL;

import java.lang.reflect.Field;
import java.util.function.Function;

import org.apiguardian.api.API;
import org.junit.jupiter.api.extension.Extension;
Expand Down Expand Up @@ -70,10 +71,9 @@ public interface ExtensionRegistrar {
*/
void registerSyntheticExtension(Extension extension, Object source);

RegistrationToken registerExtensionToken(Field source);
void registerUninitializedExtension(Class<?> testClass, Field source,
Function<Object, ? extends Extension> initializer);

interface RegistrationToken {
void complete(Extension extension);
}
void initializeExtensions(Class<?> testClass, Object testInstance);

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Stream;

import org.apiguardian.api.API;
Expand Down Expand Up @@ -107,6 +110,7 @@ public static MutableExtensionRegistry createRegistryFrom(MutableExtensionRegist

private final Set<Class<? extends Extension>> registeredExtensionTypes;
private final List<Entry> registeredExtensions;
private final Map<Class<?>, List<LateInitEntry>> lateInitExtensions;

private MutableExtensionRegistry() {
this(emptySet(), emptyList());
Expand All @@ -119,7 +123,18 @@ private MutableExtensionRegistry(MutableExtensionRegistry parent) {
private MutableExtensionRegistry(Set<Class<? extends Extension>> registeredExtensionTypes,
List<Entry> registeredExtensions) {
this.registeredExtensionTypes = new LinkedHashSet<>(registeredExtensionTypes);
this.registeredExtensions = new ArrayList<>(registeredExtensions);
this.registeredExtensions = new ArrayList<>(registeredExtensions.size());
this.lateInitExtensions = new LinkedHashMap<>();
registeredExtensions.forEach(entry -> {
Entry newEntry = entry;
if (entry instanceof LateInitEntry) {
LateInitEntry lateInitEntry = ((LateInitEntry) entry).copy();
this.lateInitExtensions.computeIfAbsent(lateInitEntry.getTestClass(), __ -> new ArrayList<>()).add(
lateInitEntry);
newEntry = lateInitEntry;
}
this.registeredExtensions.add(newEntry);
});
}

@Override
Expand Down Expand Up @@ -157,12 +172,21 @@ public void registerSyntheticExtension(Extension extension, Object source) {
}

@Override
public RegistrationToken registerExtensionToken(Field source) {
logger.trace(() -> String.format("Registering local extension token for [%s]%s", source.getType().getName(),
buildSourceInfo(source)));
TokenBasedEntry token = new TokenBasedEntry();
this.registeredExtensions.add(token);
return token;
public void registerUninitializedExtension(Class<?> testClass, Field source,
Function<Object, ? extends Extension> initializer) {
logger.trace(() -> String.format("Registering local extension (late-init) for [%s]%s",
source.getType().getName(), buildSourceInfo(source)));
LateInitEntry entry = new LateInitEntry(testClass, initializer);
lateInitExtensions.computeIfAbsent(testClass, __ -> new ArrayList<>()).add(entry);
this.registeredExtensions.add(entry);
}

@Override
public void initializeExtensions(Class<?> testClass, Object testInstance) {
List<LateInitEntry> entries = lateInitExtensions.remove(testClass);
if (entries != null) {
entries.forEach(entry -> entry.initialize(testInstance));
}
}

private void registerDefaultExtension(Extension extension) {
Expand Down Expand Up @@ -214,22 +238,36 @@ static Entry of(Extension extension) {
Optional<Extension> getExtension();
}

private class TokenBasedEntry implements Entry, RegistrationToken {
private static class LateInitEntry implements Entry {

private final Class<?> testClass;
private final Function<Object, ? extends Extension> initializer;

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private Optional<Extension> extension = Optional.empty();

public LateInitEntry(Class<?> testClass, Function<Object, ? extends Extension> initializer) {
this.testClass = testClass;
this.initializer = initializer;
}

@Override
public Optional<Extension> getExtension() {
return extension;
}

@Override
public void complete(Extension value) {
if (!extension.isPresent()) {
extension = Optional.of(value);
registeredExtensionTypes.add(value.getClass());
}
public Class<?> getTestClass() {
return testClass;
}

void initialize(Object testInstance) {
extension = Optional.of(initializer.apply(testInstance));
}

LateInitEntry copy() {
LateInitEntry copy = new LateInitEntry(testClass, initializer);
copy.extension = this.extension;
return copy;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.Named;
import org.junit.jupiter.api.Nested;
Expand Down Expand Up @@ -200,7 +199,6 @@ void registersProgrammaticTestInstancePostProcessors() {
}

@Test
@Disabled("not yet implemented")
void createsExtensionPerInstance() {
var results = executeTests(r -> r //
.selectors(selectClass(InitializationPerInstanceTestCase.class)) //
Expand All @@ -216,7 +214,7 @@ private List<String> getRegisteredLocalExtensions(LogRecordListener listener) {
.map(message -> {
message = message.replaceAll(" from source .+", "");
int beginIndex = message.lastIndexOf('.') + 1;
if (message.contains("token")) {
if (message.contains("late-init")) {
return message.substring(beginIndex, message.indexOf("]"));
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,10 @@ void testWatcherSemanticsWhenRegisteredAtInstanceLevelWithTestInstanceLifecycleP
Class<?> testClass = TestInstancePerMethodInstanceLevelTestWatcherTestCase.class;
assertStatsForAbstractDisabledMethodsTestCase(testClass);

// We get "testDisabled" events for the @Test method and the @RepeatedTest container.
assertThat(TrackingTestWatcher.results.get("testDisabled")).containsExactly("test", "repeatedTest");
// Since the TestWatcher is registered at the instance level with test instance
// lifecycle per-method semantics, we get a "testDisabled" event only for the @Test
// method and NOT for the @RepeatedTest container.
assertThat(TrackingTestWatcher.results.get("testDisabled")).containsExactly("test");
}

@Test
Expand Down

0 comments on commit c0dfec1

Please sign in to comment.