From 5e7b3a3bedb846f01ab01e056e58c1df4d2ab3d7 Mon Sep 17 00:00:00 2001 From: Stefano Cordio Date: Sat, 23 Nov 2024 16:47:20 +0100 Subject: [PATCH] Avoid infinite recursion in BeanValidationBeanRegistrationAotProcessor Prior to this commit, AOT processing for bean validation failed with a StackOverflowError for constraints with fields having recursive generic types. With this commit, the algorithm tracks visited classes and aborts preemptively when a cycle is detected. Closes gh-33950 Co-authored-by: Sam Brannen --- ...alidationBeanRegistrationAotProcessor.java | 19 ++++++++------ ...tionBeanRegistrationAotProcessorTests.java | 26 +++++++++++++++++++ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java b/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java index 0647386c21cb..dfc55d7c0b4d 100644 --- a/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java +++ b/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java @@ -18,7 +18,6 @@ import java.util.Collection; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -104,10 +103,11 @@ public static BeanRegistrationAotContribution processAheadOfTime(RegisteredBean } Class beanClass = registeredBean.getBeanClass(); + Set> visitedClasses = new HashSet<>(); Set> validatedClasses = new HashSet<>(); Set>> constraintValidatorClasses = new HashSet<>(); - processAheadOfTime(beanClass, validatedClasses, constraintValidatorClasses); + processAheadOfTime(beanClass, visitedClasses, validatedClasses, constraintValidatorClasses); if (!validatedClasses.isEmpty() || !constraintValidatorClasses.isEmpty()) { return new AotContribution(validatedClasses, constraintValidatorClasses); @@ -115,9 +115,12 @@ public static BeanRegistrationAotContribution processAheadOfTime(RegisteredBean return null; } - private static void processAheadOfTime(Class clazz, Collection> validatedClasses, - Collection>> constraintValidatorClasses) { + private static void processAheadOfTime(Class clazz, Set> visitedClasses, Set> validatedClasses, + Set>> constraintValidatorClasses) { + if (!visitedClasses.add(clazz)) { + return; + } Assert.notNull(validator, "Validator can't be null"); BeanDescriptor descriptor; @@ -149,12 +152,12 @@ else if (ex instanceof TypeNotPresentException) { ReflectionUtils.doWithFields(clazz, field -> { Class type = field.getType(); - if (Iterable.class.isAssignableFrom(type) || List.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) { + if (Iterable.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) { ResolvableType resolvableType = ResolvableType.forField(field); Class genericType = resolvableType.getGeneric(0).toClass(); if (shouldProcess(genericType)) { validatedClasses.add(clazz); - processAheadOfTime(genericType, validatedClasses, constraintValidatorClasses); + processAheadOfTime(genericType, visitedClasses, validatedClasses, constraintValidatorClasses); } } if (Map.class.isAssignableFrom(type)) { @@ -163,11 +166,11 @@ else if (ex instanceof TypeNotPresentException) { Class valueGenericType = resolvableType.getGeneric(1).toClass(); if (shouldProcess(keyGenericType)) { validatedClasses.add(clazz); - processAheadOfTime(keyGenericType, validatedClasses, constraintValidatorClasses); + processAheadOfTime(keyGenericType, visitedClasses, validatedClasses, constraintValidatorClasses); } if (shouldProcess(valueGenericType)) { validatedClasses.add(clazz); - processAheadOfTime(valueGenericType, validatedClasses, constraintValidatorClasses); + processAheadOfTime(valueGenericType, visitedClasses, validatedClasses, constraintValidatorClasses); } } }); diff --git a/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java index d43d8033317d..bbcdf3e4b707 100644 --- a/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java @@ -22,6 +22,9 @@ import java.lang.annotation.Target; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; import jakarta.validation.Constraint; import jakarta.validation.ConstraintValidator; @@ -31,6 +34,8 @@ import jakarta.validation.constraints.Pattern; import org.hibernate.validator.internal.constraintvalidators.bv.PatternValidator; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.hint.MemberCategory; @@ -121,6 +126,15 @@ void shouldProcessTransitiveGenericTypeLevelConstraint() { .withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints()); } + @ParameterizedTest // gh-33936 + @ValueSource(classes = {BeanWithIterable.class, BeanWithMap.class, BeanWithOptional.class}) + void shouldProcessRecursiveGenericsWithoutInfiniteRecursion(Class beanClass) { + process(beanClass); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(1); + assertThat(RuntimeHintsPredicates.reflection().onType(beanClass) + .withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints()); + } + private void process(Class beanClass) { BeanRegistrationAotContribution contribution = createContribution(beanClass); if (contribution != null) { @@ -244,4 +258,16 @@ public void setExclude(List exclude) { } } + static class BeanWithIterable { + private final Iterable beans = Set.of(); + } + + static class BeanWithMap { + private final Map beans = Map.of(); + } + + static class BeanWithOptional { + private final Optional beans = Optional.empty(); + } + }