Skip to content

Commit a7ebf63

Browse files
committed
Support @⁠MockitoBean at the type level
See spring-projectsgh-33925
1 parent 07455b1 commit a7ebf63

32 files changed

+1024
-116
lines changed

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessor.java

+39-31
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -104,11 +104,10 @@ private void registerBeanOverride(ConfigurableListableBeanFactory beanFactory, B
104104
Set<String> generatedBeanNames) {
105105

106106
String beanName = handler.getBeanName();
107-
Field field = handler.getField();
108-
Assert.state(!BeanFactoryUtils.isFactoryDereference(beanName),() -> """
109-
Unable to override bean '%s' for field '%s.%s': a FactoryBean cannot be overridden. \
110-
To override the bean created by the FactoryBean, remove the '&' prefix.""".formatted(
111-
beanName, field.getDeclaringClass().getSimpleName(), field.getName()));
107+
Assert.state(!BeanFactoryUtils.isFactoryDereference(beanName), () -> """
108+
Unable to override bean '%s'%s: a FactoryBean cannot be overridden. \
109+
To override the bean created by the FactoryBean, remove the '&' prefix."""
110+
.formatted(beanName, forField(handler.getField())));
112111

113112
switch (handler.getStrategy()) {
114113
case REPLACE -> replaceOrCreateBean(beanFactory, handler, generatedBeanNames, true);
@@ -134,7 +133,6 @@ private void replaceOrCreateBean(ConfigurableListableBeanFactory beanFactory, Be
134133
// 4) Create bean by-name, with a provided name
135134

136135
String beanName = handler.getBeanName();
137-
Field field = handler.getField();
138136
BeanDefinition existingBeanDefinition = null;
139137
if (beanName == null) {
140138
beanName = getBeanNameForType(beanFactory, handler, requireExistingBean);
@@ -169,11 +167,10 @@ private void replaceOrCreateBean(ConfigurableListableBeanFactory beanFactory, Be
169167
existingBeanDefinition = beanFactory.getBeanDefinition(beanName);
170168
}
171169
else if (requireExistingBean) {
172-
throw new IllegalStateException("""
173-
Unable to replace bean: there is no bean with name '%s' and type %s \
174-
(as required by field '%s.%s')."""
175-
.formatted(beanName, handler.getBeanType(),
176-
field.getDeclaringClass().getSimpleName(), field.getName()));
170+
Field field = handler.getField();
171+
throw new IllegalStateException(
172+
"Unable to replace bean: there is no bean with name '%s' and type %s%s."
173+
.formatted(beanName, handler.getBeanType(), requiredByField(field)));
177174
}
178175
// 4) We are creating a bean by-name with the provided beanName.
179176
}
@@ -264,13 +261,11 @@ private void wrapBean(ConfigurableListableBeanFactory beanFactory, BeanOverrideH
264261
else {
265262
String message = "Unable to select a bean to wrap: ";
266263
if (candidateCount == 0) {
267-
message += "there are no beans of type %s (as required by field '%s.%s')."
268-
.formatted(beanType, field.getDeclaringClass().getSimpleName(), field.getName());
264+
message += "there are no beans of type %s%s.".formatted(beanType, requiredByField(field));
269265
}
270266
else {
271-
message += "found %d beans of type %s (as required by field '%s.%s'): %s"
272-
.formatted(candidateCount, beanType, field.getDeclaringClass().getSimpleName(),
273-
field.getName(), candidateNames);
267+
message += "found %d beans of type %s%s: %s"
268+
.formatted(candidateCount, beanType, requiredByField(field), candidateNames);
274269
}
275270
throw new IllegalStateException(message);
276271
}
@@ -281,11 +276,9 @@ private void wrapBean(ConfigurableListableBeanFactory beanFactory, BeanOverrideH
281276
// We are wrapping an existing bean by-name.
282277
Set<String> candidates = getExistingBeanNamesByType(beanFactory, handler, false);
283278
if (!candidates.contains(beanName)) {
284-
throw new IllegalStateException("""
285-
Unable to wrap bean: there is no bean with name '%s' and type %s \
286-
(as required by field '%s.%s')."""
287-
.formatted(beanName, beanType, field.getDeclaringClass().getSimpleName(),
288-
field.getName()));
279+
throw new IllegalStateException(
280+
"Unable to wrap bean: there is no bean with name '%s' and type %s%s."
281+
.formatted(beanName, beanType, requiredByField(field)));
289282
}
290283
}
291284

@@ -308,8 +301,8 @@ private String getBeanNameForType(ConfigurableListableBeanFactory beanFactory, B
308301
else if (candidateCount == 0) {
309302
if (requireExistingBean) {
310303
throw new IllegalStateException(
311-
"Unable to override bean: there are no beans of type %s (as required by field '%s.%s')."
312-
.formatted(beanType, field.getDeclaringClass().getSimpleName(), field.getName()));
304+
"Unable to override bean: there are no beans of type %s%s."
305+
.formatted(beanType, requiredByField(field)));
313306
}
314307
return null;
315308
}
@@ -320,14 +313,14 @@ else if (candidateCount == 0) {
320313
}
321314

322315
throw new IllegalStateException(
323-
"Unable to select a bean to override: found %d beans of type %s (as required by field '%s.%s'): %s"
324-
.formatted(candidateCount, beanType, field.getDeclaringClass().getSimpleName(),
325-
field.getName(), candidateNames));
316+
"Unable to select a bean to override: found %d beans of type %s%s: %s"
317+
.formatted(candidateCount, beanType, requiredByField(field), candidateNames));
326318
}
327319

328320
private Set<String> getExistingBeanNamesByType(ConfigurableListableBeanFactory beanFactory, BeanOverrideHandler handler,
329321
boolean checkAutowiredCandidate) {
330322

323+
Field field = handler.getField();
331324
ResolvableType resolvableType = handler.getBeanType();
332325
Class<?> type = resolvableType.toClass();
333326

@@ -345,16 +338,16 @@ private Set<String> getExistingBeanNamesByType(ConfigurableListableBeanFactory b
345338
}
346339

347340
// Filter out non-matching autowire candidates.
348-
if (checkAutowiredCandidate) {
349-
DependencyDescriptor descriptor = new DependencyDescriptor(handler.getField(), true);
341+
if (field != null && checkAutowiredCandidate) {
342+
DependencyDescriptor descriptor = new DependencyDescriptor(field, true);
350343
beanNames.removeIf(beanName -> !beanFactory.isAutowireCandidate(beanName, descriptor));
351344
}
352345
// Filter out scoped proxy targets.
353346
beanNames.removeIf(ScopedProxyUtils::isScopedTarget);
354347

355348
// In case of multiple matches, fall back on the field's name as a last resort.
356-
if (beanNames.size() > 1) {
357-
String fieldName = handler.getField().getName();
349+
if (field != null && beanNames.size() > 1) {
350+
String fieldName = field.getName();
358351
if (beanNames.contains(fieldName)) {
359352
return Set.of(fieldName);
360353
}
@@ -452,4 +445,19 @@ private static void destroySingleton(ConfigurableListableBeanFactory beanFactory
452445
dlbf.destroySingleton(beanName);
453446
}
454447

448+
private static String forField(@Nullable Field field) {
449+
if (field == null) {
450+
return "";
451+
}
452+
return " for field '%s.%s'".formatted(field.getDeclaringClass().getSimpleName(), field.getName());
453+
}
454+
455+
private static String requiredByField(@Nullable Field field) {
456+
if (field == null) {
457+
return "";
458+
}
459+
return " (as required by field '%s.%s')".formatted(
460+
field.getDeclaringClass().getSimpleName(), field.getName());
461+
}
462+
455463
}

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideContextCustomizerFactory.java

+2-5
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,8 @@ public BeanOverrideContextCustomizer createContextCustomizer(Class<?> testClass,
5252
}
5353

5454
private void findBeanOverrideHandlers(Class<?> testClass, Set<BeanOverrideHandler> handlers) {
55-
if (TestContextAnnotationUtils.searchEnclosingClass(testClass)) {
56-
findBeanOverrideHandlers(testClass.getEnclosingClass(), handlers);
57-
}
58-
BeanOverrideHandler.forTestClass(testClass).forEach(handler ->
59-
Assert.state(handlers.add(handler), () ->
55+
BeanOverrideHandler.forTestClass(testClass, TestContextAnnotationUtils::searchEnclosingClass)
56+
.forEach(handler -> Assert.state(handlers.add(handler), () ->
6057
"Duplicate BeanOverrideHandler discovered in test class %s: %s"
6158
.formatted(testClass.getName(), handler)));
6259
}

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideHandler.java

+107-30
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,20 @@
1717
package org.springframework.test.context.bean.override;
1818

1919
import java.lang.annotation.Annotation;
20+
import java.lang.reflect.AnnotatedElement;
2021
import java.lang.reflect.Field;
2122
import java.lang.reflect.Modifier;
23+
import java.util.ArrayList;
2224
import java.util.Arrays;
2325
import java.util.Collections;
26+
import java.util.Comparator;
2427
import java.util.HashSet;
25-
import java.util.LinkedList;
2628
import java.util.List;
2729
import java.util.Objects;
2830
import java.util.Set;
2931
import java.util.concurrent.atomic.AtomicBoolean;
32+
import java.util.function.BiConsumer;
33+
import java.util.function.Predicate;
3034

3135
import org.springframework.beans.BeanUtils;
3236
import org.springframework.beans.factory.config.BeanDefinition;
@@ -57,8 +61,8 @@
5761
*
5862
* <p>Concrete implementations of {@code BeanOverrideHandler} can store additional
5963
* metadata to use during override {@linkplain #createOverrideInstance instance
60-
* creation} &mdash; for example, based on further processing of the annotation
61-
* or the annotated field.
64+
* creation} &mdash; for example, based on further processing of the annotation,
65+
* the annotated field, or the annotated class.
6266
*
6367
* <p><strong>NOTE</strong>: Only <em>singleton</em> beans can be overridden.
6468
* Any attempt to override a non-singleton bean will result in an exception.
@@ -70,6 +74,7 @@
7074
*/
7175
public abstract class BeanOverrideHandler {
7276

77+
@Nullable
7378
private final Field field;
7479

7580
private final Set<Annotation> fieldAnnotations;
@@ -82,7 +87,7 @@ public abstract class BeanOverrideHandler {
8287
private final BeanOverrideStrategy strategy;
8388

8489

85-
protected BeanOverrideHandler(Field field, ResolvableType beanType, @Nullable String beanName,
90+
protected BeanOverrideHandler(@Nullable Field field, ResolvableType beanType, @Nullable String beanName,
8691
BeanOverrideStrategy strategy) {
8792

8893
this.field = field;
@@ -96,57 +101,121 @@ protected BeanOverrideHandler(Field field, ResolvableType beanType, @Nullable St
96101
* Process the given {@code testClass} and build the corresponding
97102
* {@code BeanOverrideHandler} list derived from {@link BeanOverride @BeanOverride}
98103
* fields in the test class and its type hierarchy.
99-
* <p>This method does not search the enclosing class hierarchy.
104+
* <p>This method does not search the enclosing class hierarchy and does not
105+
* search for {@code @BeanOverride} declarations on classes or interfaces.
100106
* @param testClass the test class to process
101107
* @return a list of bean override handlers
102-
* @see org.springframework.test.context.TestContextAnnotationUtils#searchEnclosingClass(Class)
108+
* @see #forTestClass(Class, Predicate)
103109
*/
104110
public static List<BeanOverrideHandler> forTestClass(Class<?> testClass) {
105-
List<BeanOverrideHandler> handlers = new LinkedList<>();
106-
findHandlers(testClass, testClass, handlers);
111+
return findHandlers(testClass, false, clazz -> false);
112+
}
113+
114+
/**
115+
* Process the given {@code testClass} and build the corresponding
116+
* {@code BeanOverrideHandler} list derived from {@link BeanOverride @BeanOverride}
117+
* fields in the test class and in its type hierarchy as well as from
118+
* {@code @BeanOverride} declarations on classes and interfaces.
119+
* <p>This method additionally searches for {@code @BeanOverride} declarations
120+
* in the enclosing class hierarchy if the supplied predicate evaluates to
121+
* {@code true}.
122+
* @param testClass the test class to process
123+
* @param searchEnclosingClass a predicate which evaluates to {@code true}
124+
* if a search should be performed on the enclosing class &mdash; for example,
125+
* {@code TestContextAnnotationUtils::searchEnclosingClass}
126+
* @return a list of bean override handlers
127+
* @since 6.2.2
128+
* @see org.springframework.test.context.TestContextAnnotationUtils#searchEnclosingClass(Class)
129+
*/
130+
public static List<BeanOverrideHandler> forTestClass(Class<?> testClass, Predicate<Class<?>> searchEnclosingClass) {
131+
return findHandlers(testClass, true, searchEnclosingClass);
132+
}
133+
134+
private static List<BeanOverrideHandler> findHandlers(Class<?> testClass, boolean searchOnTypes,
135+
Predicate<Class<?>> searchEnclosingClass) {
136+
137+
List<BeanOverrideHandler> handlers = new ArrayList<>();
138+
findHandlers(testClass, testClass, handlers, searchOnTypes, searchEnclosingClass);
107139
return handlers;
108140
}
109141

110142
/**
111-
* Find handlers using tail recursion to ensure that "locally declared"
112-
* bean overrides take precedence over inherited bean overrides.
143+
* Find handlers using tail recursion to ensure that "locally declared" bean overrides
144+
* take precedence over inherited bean overrides.
145+
* <p>Note: the search algorithm is effectively the inverse of the algorithm used in
146+
* {@link org.springframework.test.context.TestContextAnnotationUtils#findAnnotationDescriptor(Class, Class)},
147+
* but with tail recursion the semantics should be the same.
113148
* @since 6.2.2
114149
*/
115-
private static void findHandlers(Class<?> clazz, Class<?> testClass, List<BeanOverrideHandler> handlers) {
116-
if (clazz == null || clazz == Object.class) {
117-
return;
150+
private static void findHandlers(Class<?> clazz, Class<?> testClass, List<BeanOverrideHandler> handlers,
151+
boolean searchOnTypes, Predicate<Class<?>> searchEnclosingClass) {
152+
153+
// 1) Search enclosing class hierarchy.
154+
if (searchEnclosingClass.test(clazz)) {
155+
findHandlers(clazz.getEnclosingClass(), testClass, handlers, searchOnTypes, searchEnclosingClass);
118156
}
119157

120-
// 1) Search type hierarchy.
121-
findHandlers(clazz.getSuperclass(), testClass, handlers);
158+
// 2) Search class hierarchy.
159+
Class<?> superclass = clazz.getSuperclass();
160+
if (superclass != null && superclass != Object.class) {
161+
findHandlers(superclass, testClass, handlers, searchOnTypes, searchEnclosingClass);
162+
}
163+
164+
if (searchOnTypes) {
165+
// 3) Search interfaces.
166+
for (Class<?> ifc : clazz.getInterfaces()) {
167+
findHandlers(ifc, testClass, handlers, searchOnTypes, searchEnclosingClass);
168+
}
169+
170+
// 4) Process current class.
171+
processClass(clazz, testClass, handlers);
172+
}
122173

123-
// 2) Process fields in current class.
174+
// 5) Process fields in current class.
124175
ReflectionUtils.doWithLocalFields(clazz, field -> processField(field, testClass, handlers));
125176
}
126177

178+
private static void processClass(Class<?> clazz, Class<?> testClass, List<BeanOverrideHandler> handlers) {
179+
processElement(clazz, testClass, (processor, composedAnnotation) ->
180+
processor.createHandlers(composedAnnotation, testClass).forEach(handlers::add));
181+
}
182+
127183
private static void processField(Field field, Class<?> testClass, List<BeanOverrideHandler> handlers) {
128184
AtomicBoolean overrideAnnotationFound = new AtomicBoolean();
129-
MergedAnnotations.from(field, DIRECT).stream(BeanOverride.class).forEach(mergedAnnotation -> {
185+
processElement(field, testClass, (processor, composedAnnotation) -> {
130186
Assert.state(!Modifier.isStatic(field.getModifiers()),
131187
() -> "@BeanOverride field must not be static: " + field);
132-
MergedAnnotation<?> metaSource = mergedAnnotation.getMetaSource();
133-
Assert.state(metaSource != null, "@BeanOverride annotation must be meta-present");
134-
135-
BeanOverride beanOverride = mergedAnnotation.synthesize();
136-
BeanOverrideProcessor processor = BeanUtils.instantiateClass(beanOverride.value());
137-
Annotation composedAnnotation = metaSource.synthesize();
138-
139188
Assert.state(overrideAnnotationFound.compareAndSet(false, true),
140189
() -> "Multiple @BeanOverride annotations found on field: " + field);
141-
BeanOverrideHandler handler = processor.createHandler(composedAnnotation, testClass, field);
142-
handlers.add(handler);
190+
handlers.add(processor.createHandler(composedAnnotation, testClass, field));
143191
});
144192
}
145193

194+
private static void processElement(AnnotatedElement element, Class<?> testClass,
195+
BiConsumer<BeanOverrideProcessor, Annotation> consumer) {
196+
197+
MergedAnnotations.from(element, DIRECT)
198+
.stream(BeanOverride.class)
199+
.sorted(reversedMetaDistance)
200+
.forEach(mergedAnnotation -> {
201+
MergedAnnotation<?> metaSource = mergedAnnotation.getMetaSource();
202+
Assert.state(metaSource != null, "@BeanOverride annotation must be meta-present");
203+
204+
BeanOverride beanOverride = mergedAnnotation.synthesize();
205+
BeanOverrideProcessor processor = BeanUtils.instantiateClass(beanOverride.value());
206+
Annotation composedAnnotation = metaSource.synthesize();
207+
consumer.accept(processor, composedAnnotation);
208+
});
209+
}
210+
211+
private static final Comparator<MergedAnnotation<? extends Annotation>> reversedMetaDistance =
212+
Comparator.<MergedAnnotation<? extends Annotation>> comparingInt(MergedAnnotation::getDistance).reversed();
213+
146214

147215
/**
148216
* Get the annotated {@link Field}.
149217
*/
218+
@Nullable
150219
public final Field getField() {
151220
return this.field;
152221
}
@@ -243,20 +312,25 @@ public boolean equals(Object other) {
243312
!Objects.equals(this.strategy, that.strategy)) {
244313
return false;
245314
}
315+
316+
// by-name lookup
246317
if (this.beanName != null) {
247318
return true;
248319
}
249320

321+
// TODO Validate equals() logic for null fields.
322+
250323
// by-type lookup
251-
return (Objects.equals(this.field.getName(), that.field.getName()) &&
324+
return (this.field != null && that.field != null &&
325+
Objects.equals(this.field.getName(), that.field.getName()) &&
252326
this.fieldAnnotations.equals(that.fieldAnnotations));
253327
}
254328

255329
@Override
256330
public int hashCode() {
257331
int hash = Objects.hash(getClass(), this.beanType.getType(), this.beanName, this.strategy);
258-
return (this.beanName != null ? hash : hash +
259-
Objects.hash(this.field.getName(), this.fieldAnnotations));
332+
return (this.beanName != null ? hash :
333+
hash + Objects.hash((this.field != null ? this.field.getName() : null), this.fieldAnnotations));
260334
}
261335

262336
@Override
@@ -269,7 +343,10 @@ public String toString() {
269343
.toString();
270344
}
271345

272-
private static Set<Annotation> annotationSet(Field field) {
346+
private static Set<Annotation> annotationSet(@Nullable Field field) {
347+
if (field == null) {
348+
return Collections.emptySet();
349+
}
273350
Annotation[] annotations = field.getAnnotations();
274351
return (annotations.length != 0 ? new HashSet<>(Arrays.asList(annotations)) : Collections.emptySet());
275352
}

0 commit comments

Comments
 (0)