Skip to content

Commit

Permalink
Add AOT support for container element constraints
Browse files Browse the repository at this point in the history
This commit introduces support for bean validation container
element constraints, including transitive ones.

Transitive constraints in the parameterized types of a container
are not discoverable via the BeanDescriptor, so a complementary
type discovery is done on Spring side to cover the related use
case.

Closes spring-projectsgh-33842
  • Loading branch information
sdeleuze committed Nov 13, 2024
1 parent 525407e commit 30ac635
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import jakarta.validation.ConstraintValidator;
import jakarta.validation.NoProviderFoundException;
import jakarta.validation.Validation;
import jakarta.validation.Validator;
import jakarta.validation.ValidatorFactory;
import jakarta.validation.metadata.BeanDescriptor;
import jakarta.validation.metadata.ConstraintDescriptor;
import jakarta.validation.metadata.ConstructorDescriptor;
import jakarta.validation.metadata.MethodDescriptor;
import jakarta.validation.metadata.ContainerElementTypeDescriptor;
import jakarta.validation.metadata.ExecutableDescriptor;
import jakarta.validation.metadata.MethodType;
import jakarta.validation.metadata.ParameterDescriptor;
import jakarta.validation.metadata.PropertyDescriptor;
Expand All @@ -36,13 +40,17 @@

import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.ReflectionHints;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.aot.BeanRegistrationCode;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.KotlinDetector;
import org.springframework.core.ResolvableType;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

/**
* AOT {@code BeanRegistrationAotProcessor} that adds additional hints
Expand Down Expand Up @@ -80,8 +88,8 @@ private static class BeanValidationDelegate {

@Nullable
private static Validator getValidatorIfAvailable() {
try {
return Validation.buildDefaultValidatorFactory().getValidator();
try (ValidatorFactory validator = Validation.buildDefaultValidatorFactory()) {
return validator.getValidator();
}
catch (NoProviderFoundException ex) {
logger.info("No Bean Validation provider available - skipping validation constraint hint inference");
Expand All @@ -95,64 +103,134 @@ public static BeanRegistrationAotContribution processAheadOfTime(RegisteredBean
return null;
}

Class<?> beanClass = registeredBean.getBeanClass();
Set<Class<?>> validatedClasses = new HashSet<>();
Set<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses = new HashSet<>();

processAheadOfTime(beanClass, validatedClasses, constraintValidatorClasses);

if (!validatedClasses.isEmpty() || !constraintValidatorClasses.isEmpty()) {
return new AotContribution(validatedClasses, constraintValidatorClasses);
}
return null;
}

private static void processAheadOfTime(Class<?> clazz, Collection<Class<?>> validatedClasses,
Collection<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses) {

Assert.notNull(validator, "Validator can't be null");

BeanDescriptor descriptor;
try {
descriptor = validator.getConstraintsForClass(registeredBean.getBeanClass());
descriptor = validator.getConstraintsForClass(clazz);
}
catch (RuntimeException ex) {
if (KotlinDetector.isKotlinType(registeredBean.getBeanClass()) && ex instanceof ArrayIndexOutOfBoundsException) {
if (KotlinDetector.isKotlinType(clazz) && ex instanceof ArrayIndexOutOfBoundsException) {
// See https://hibernate.atlassian.net/browse/HV-1796 and https://youtrack.jetbrains.com/issue/KT-40857
logger.warn("Skipping validation constraint hint inference for bean " + registeredBean.getBeanName() +
logger.warn("Skipping validation constraint hint inference for class " + clazz +
" due to an ArrayIndexOutOfBoundsException at validator level");
}
else if (ex instanceof TypeNotPresentException) {
logger.debug("Skipping validation constraint hint inference for bean " +
registeredBean.getBeanName() + " due to a TypeNotPresentException at validator level: " + ex.getMessage());
logger.debug("Skipping validation constraint hint inference for class " +
clazz + " due to a TypeNotPresentException at validator level: " + ex.getMessage());
}
else {
logger.warn("Skipping validation constraint hint inference for bean " +
registeredBean.getBeanName(), ex);
logger.warn("Skipping validation constraint hint inference for class " + clazz, ex);
}
return null;
return;
}

Set<ConstraintDescriptor<?>> constraintDescriptors = new HashSet<>();
for (MethodDescriptor methodDescriptor : descriptor.getConstrainedMethods(MethodType.NON_GETTER, MethodType.GETTER)) {
for (ParameterDescriptor parameterDescriptor : methodDescriptor.getParameterDescriptors()) {
constraintDescriptors.addAll(parameterDescriptor.getConstraintDescriptors());
}
processExecutableDescriptor(descriptor.getConstrainedMethods(MethodType.NON_GETTER, MethodType.GETTER), constraintValidatorClasses);
processExecutableDescriptor(descriptor.getConstrainedConstructors(), constraintValidatorClasses);
processPropertyDescriptors(descriptor.getConstrainedProperties(), constraintValidatorClasses);
if (!constraintValidatorClasses.isEmpty() && shouldProcess(clazz)) {
validatedClasses.add(clazz);
}
for (ConstructorDescriptor constructorDescriptor : descriptor.getConstrainedConstructors()) {
for (ParameterDescriptor parameterDescriptor : constructorDescriptor.getParameterDescriptors()) {
constraintDescriptors.addAll(parameterDescriptor.getConstraintDescriptors());

ReflectionUtils.doWithFields(clazz, field -> {
Class<?> type = field.getType();
if (Iterable.class.isAssignableFrom(type) || List.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) {
ResolvableType resolvableType = ResolvableType.forField(field);
Class<?> genericType = resolvableType.getGeneric(0).toClass();
if (shouldProcess(genericType)) {
validatedClasses.add(clazz);
processAheadOfTime(genericType, validatedClasses, constraintValidatorClasses);
}
}
if (Map.class.isAssignableFrom(type)) {
ResolvableType resolvableType = ResolvableType.forField(field);
Class<?> keyGenericType = resolvableType.getGeneric(0).toClass();
Class<?> valueGenericType = resolvableType.getGeneric(1).toClass();
if (shouldProcess(keyGenericType)) {
validatedClasses.add(clazz);
processAheadOfTime(keyGenericType, validatedClasses, constraintValidatorClasses);
}
if (shouldProcess(valueGenericType)) {
validatedClasses.add(clazz);
processAheadOfTime(valueGenericType, validatedClasses, constraintValidatorClasses);
}
}
});
}

private static boolean shouldProcess(Class<?> clazz) {
return !clazz.getCanonicalName().startsWith("java.");
}

private static void processExecutableDescriptor(Set<? extends ExecutableDescriptor> executableDescriptors,
Collection<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses) {

for (ExecutableDescriptor executableDescriptor : executableDescriptors) {
for (ParameterDescriptor parameterDescriptor : executableDescriptor.getParameterDescriptors()) {
for (ConstraintDescriptor<?> constraintDescriptor : parameterDescriptor.getConstraintDescriptors()) {
constraintValidatorClasses.addAll(constraintDescriptor.getConstraintValidatorClasses());
}
for (ContainerElementTypeDescriptor typeDescriptor : parameterDescriptor.getConstrainedContainerElementTypes()) {
for (ConstraintDescriptor<?> constraintDescriptor : typeDescriptor.getConstraintDescriptors()) {
constraintValidatorClasses.addAll(constraintDescriptor.getConstraintValidatorClasses());
}
}
}
}
for (PropertyDescriptor propertyDescriptor : descriptor.getConstrainedProperties()) {
constraintDescriptors.addAll(propertyDescriptor.getConstraintDescriptors());
}
if (!constraintDescriptors.isEmpty()) {
return new AotContribution(constraintDescriptors);
}

private static void processPropertyDescriptors(Set<PropertyDescriptor> propertyDescriptors,
Collection<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses) {

for (PropertyDescriptor propertyDescriptor : propertyDescriptors) {
for (ConstraintDescriptor<?> constraintDescriptor : propertyDescriptor.getConstraintDescriptors()) {
constraintValidatorClasses.addAll(constraintDescriptor.getConstraintValidatorClasses());
}
for (ContainerElementTypeDescriptor typeDescriptor : propertyDescriptor.getConstrainedContainerElementTypes()) {
for (ConstraintDescriptor<?> constraintDescriptor : typeDescriptor.getConstraintDescriptors()) {
constraintValidatorClasses.addAll(constraintDescriptor.getConstraintValidatorClasses());
}
}
}
return null;
}
}


private static class AotContribution implements BeanRegistrationAotContribution {

private final Collection<ConstraintDescriptor<?>> constraintDescriptors;
private final Collection<Class<?>> validatedClasses;
private final Collection<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses;

public AotContribution(Collection<ConstraintDescriptor<?>> constraintDescriptors) {
this.constraintDescriptors = constraintDescriptors;
public AotContribution(Collection<Class<?>> validatedClasses,
Collection<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses) {

this.validatedClasses = validatedClasses;
this.constraintValidatorClasses = constraintValidatorClasses;
}

@Override
public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) {
for (ConstraintDescriptor<?> constraintDescriptor : this.constraintDescriptors) {
for (Class<?> constraintValidatorClass : constraintDescriptor.getConstraintValidatorClasses()) {
generationContext.getRuntimeHints().reflection().registerType(constraintValidatorClass,
MemberCategory.INVOKE_DECLARED_CONSTRUCTORS);
}
ReflectionHints hints = generationContext.getRuntimeHints().reflection();
for (Class<?> validatedClass : this.validatedClasses) {
hints.registerType(validatedClass, MemberCategory.DECLARED_FIELDS);
}
for (Class<? extends ConstraintValidator<?, ?>> constraintValidatorClass : this.constraintValidatorClasses) {
hints.registerType(constraintValidatorClass, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@
import java.lang.annotation.Repeatable;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import java.util.ArrayList;
import java.util.List;

import jakarta.validation.Constraint;
import jakarta.validation.ConstraintValidator;
import jakarta.validation.ConstraintValidatorContext;
import jakarta.validation.Payload;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Pattern;
import org.hibernate.validator.internal.constraintvalidators.bv.PatternValidator;
import org.junit.jupiter.api.Test;

import org.springframework.aot.generate.GenerationContext;
Expand Down Expand Up @@ -67,24 +72,55 @@ void shouldSkipNonAnnotatedType() {
@Test
void shouldProcessMethodParameterLevelConstraint() {
process(MethodParameterLevelConstraint.class);
assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(2);
assertThat(RuntimeHintsPredicates.reflection().onType(MethodParameterLevelConstraint.class)
.withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints());
assertThat(RuntimeHintsPredicates.reflection().onType(ExistsValidator.class)
.withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints());
}

@Test
void shouldProcessConstructorParameterLevelConstraint() {
process(ConstructorParameterLevelConstraint.class);
assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(2);
assertThat(RuntimeHintsPredicates.reflection().onType(ConstructorParameterLevelConstraint.class)
.withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints());
assertThat(RuntimeHintsPredicates.reflection().onType(ExistsValidator.class)
.withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints());
}

@Test
void shouldProcessPropertyLevelConstraint() {
process(PropertyLevelConstraint.class);
assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(2);
assertThat(RuntimeHintsPredicates.reflection().onType(PropertyLevelConstraint.class)
.withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints());
assertThat(RuntimeHintsPredicates.reflection().onType(ExistsValidator.class)
.withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints());
}

@Test
void shouldProcessGenericTypeLevelConstraint() {
process(GenericTypeLevelConstraint.class);
assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(2);
assertThat(RuntimeHintsPredicates.reflection().onType(GenericTypeLevelConstraint.class)
.withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints());
assertThat(RuntimeHintsPredicates.reflection().onType(PatternValidator.class)
.withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints());
}

@Test
void shouldProcessCascadedGenericTypeLevelConstraint() {
process(CascadedGenericTypeLevelConstraint.class);
assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(3);
assertThat(RuntimeHintsPredicates.reflection().onType(CascadedGenericTypeLevelConstraint.class)
.withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints());
assertThat(RuntimeHintsPredicates.reflection().onType(Exclude.class)
.withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints());
assertThat(RuntimeHintsPredicates.reflection().onType(PatternValidator.class)
.withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints());
}

private void process(Class<?> beanClass) {
BeanRegistrationAotContribution contribution = createContribution(beanClass);
if (contribution != null) {
Expand Down Expand Up @@ -168,4 +204,44 @@ public void setName(String name) {
}
}

static class Exclude {

@Valid
private List<@Pattern(regexp="^([1-5][x|X]{2}|[1-5][0-9]{2})\\$") String> httpStatus;

public List<String> getHttpStatus() {
return httpStatus;
}

public void setHttpStatus(List<String> httpStatus) {
this.httpStatus = httpStatus;
}
}

static class GenericTypeLevelConstraint {

private List<@Pattern(regexp="^([1-5][x|X]{2}|[1-5][0-9]{2})\\$") String> httpStatus;

public List<String> getHttpStatus() {
return httpStatus;
}

public void setHttpStatus(List<String> httpStatus) {
this.httpStatus = httpStatus;
}
}

static class CascadedGenericTypeLevelConstraint {

private List<Exclude> exclude = new ArrayList<>();

public List<Exclude> getExclude() {
return exclude;
}

public void setExclude(List<Exclude> exclude) {
this.exclude = exclude;
}
}

}

0 comments on commit 30ac635

Please sign in to comment.