Skip to content

Commit ebd3b93

Browse files
committed
Avoid infinite recursion in AOT processing with recursive generics
1 parent 5e08a88 commit ebd3b93

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java

+11-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import java.util.Collection;
2020
import java.util.HashSet;
21-
import java.util.List;
2221
import java.util.Map;
2322
import java.util.Optional;
2423
import java.util.Set;
@@ -104,20 +103,24 @@ public static BeanRegistrationAotContribution processAheadOfTime(RegisteredBean
104103
}
105104

106105
Class<?> beanClass = registeredBean.getBeanClass();
106+
Set<Class<?>> visitedClasses = new HashSet<>();
107107
Set<Class<?>> validatedClasses = new HashSet<>();
108108
Set<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses = new HashSet<>();
109109

110-
processAheadOfTime(beanClass, validatedClasses, constraintValidatorClasses);
110+
processAheadOfTime(beanClass, visitedClasses, validatedClasses, constraintValidatorClasses);
111111

112112
if (!validatedClasses.isEmpty() || !constraintValidatorClasses.isEmpty()) {
113113
return new AotContribution(validatedClasses, constraintValidatorClasses);
114114
}
115115
return null;
116116
}
117117

118-
private static void processAheadOfTime(Class<?> clazz, Collection<Class<?>> validatedClasses,
119-
Collection<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses) {
118+
private static void processAheadOfTime(Class<?> clazz, Set<Class<?>> visitedClasses, Set<Class<?>> validatedClasses,
119+
Set<Class<? extends ConstraintValidator<?, ?>>> constraintValidatorClasses) {
120120

121+
if (!visitedClasses.add(clazz)) {
122+
return;
123+
}
121124
Assert.notNull(validator, "Validator can't be null");
122125

123126
BeanDescriptor descriptor;
@@ -149,12 +152,12 @@ else if (ex instanceof TypeNotPresentException) {
149152

150153
ReflectionUtils.doWithFields(clazz, field -> {
151154
Class<?> type = field.getType();
152-
if (Iterable.class.isAssignableFrom(type) || List.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) {
155+
if (Iterable.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) {
153156
ResolvableType resolvableType = ResolvableType.forField(field);
154157
Class<?> genericType = resolvableType.getGeneric(0).toClass();
155158
if (shouldProcess(genericType)) {
156159
validatedClasses.add(clazz);
157-
processAheadOfTime(genericType, validatedClasses, constraintValidatorClasses);
160+
processAheadOfTime(genericType, visitedClasses, validatedClasses, constraintValidatorClasses);
158161
}
159162
}
160163
if (Map.class.isAssignableFrom(type)) {
@@ -163,11 +166,11 @@ else if (ex instanceof TypeNotPresentException) {
163166
Class<?> valueGenericType = resolvableType.getGeneric(1).toClass();
164167
if (shouldProcess(keyGenericType)) {
165168
validatedClasses.add(clazz);
166-
processAheadOfTime(keyGenericType, validatedClasses, constraintValidatorClasses);
169+
processAheadOfTime(keyGenericType, visitedClasses, validatedClasses, constraintValidatorClasses);
167170
}
168171
if (shouldProcess(valueGenericType)) {
169172
validatedClasses.add(clazz);
170-
processAheadOfTime(valueGenericType, validatedClasses, constraintValidatorClasses);
173+
processAheadOfTime(valueGenericType, visitedClasses, validatedClasses, constraintValidatorClasses);
171174
}
172175
}
173176
});

spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java

+26
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
import java.lang.annotation.Target;
2323
import java.util.ArrayList;
2424
import java.util.List;
25+
import java.util.Map;
26+
import java.util.Optional;
27+
import java.util.Set;
2528

2629
import jakarta.validation.Constraint;
2730
import jakarta.validation.ConstraintValidator;
@@ -31,6 +34,8 @@
3134
import jakarta.validation.constraints.Pattern;
3235
import org.hibernate.validator.internal.constraintvalidators.bv.PatternValidator;
3336
import org.junit.jupiter.api.Test;
37+
import org.junit.jupiter.params.ParameterizedTest;
38+
import org.junit.jupiter.params.provider.ValueSource;
3439

3540
import org.springframework.aot.generate.GenerationContext;
3641
import org.springframework.aot.hint.MemberCategory;
@@ -121,6 +126,15 @@ void shouldProcessTransitiveGenericTypeLevelConstraint() {
121126
.withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints());
122127
}
123128

129+
@ParameterizedTest // gh-33936
130+
@ValueSource(classes = {BeanWithIterable.class, BeanWithMap.class, BeanWithOptional.class})
131+
void shouldProcessRecursiveGenericsWithoutInfiniteRecursion(Class<?> beanClass) {
132+
process(beanClass);
133+
assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(1);
134+
assertThat(RuntimeHintsPredicates.reflection().onType(beanClass)
135+
.withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints());
136+
}
137+
124138
private void process(Class<?> beanClass) {
125139
BeanRegistrationAotContribution contribution = createContribution(beanClass);
126140
if (contribution != null) {
@@ -244,4 +258,16 @@ public void setExclude(List<Exclude> exclude) {
244258
}
245259
}
246260

261+
static class BeanWithIterable {
262+
private final Iterable<BeanWithIterable> beans = Set.of();
263+
}
264+
265+
static class BeanWithMap {
266+
private final Map<String, BeanWithMap> beans = Map.of();
267+
}
268+
269+
static class BeanWithOptional {
270+
private final Optional<BeanWithOptional> beans = Optional.empty();
271+
}
272+
247273
}

0 commit comments

Comments
 (0)