Skip to content

Commit dd3a5a1

Browse files
committed
Support @⁠MockitoBean at the type level
Proof of Concept See spring-projectsgh-33925
1 parent 58a836d commit dd3a5a1

13 files changed

+282
-66
lines changed

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessor.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ else if (candidateCount == 0) {
328328
private Set<String> getExistingBeanNamesByType(ConfigurableListableBeanFactory beanFactory, BeanOverrideHandler handler,
329329
boolean checkAutowiredCandidate) {
330330

331+
Field field = handler.getField();
331332
ResolvableType resolvableType = handler.getBeanType();
332333
Class<?> type = resolvableType.toClass();
333334

@@ -345,16 +346,16 @@ private Set<String> getExistingBeanNamesByType(ConfigurableListableBeanFactory b
345346
}
346347

347348
// Filter out non-matching autowire candidates.
348-
if (checkAutowiredCandidate) {
349-
DependencyDescriptor descriptor = new DependencyDescriptor(handler.getField(), true);
349+
if (field != null && checkAutowiredCandidate) {
350+
DependencyDescriptor descriptor = new DependencyDescriptor(field, true);
350351
beanNames.removeIf(beanName -> !beanFactory.isAutowireCandidate(beanName, descriptor));
351352
}
352353
// Filter out scoped proxy targets.
353354
beanNames.removeIf(ScopedProxyUtils::isScopedTarget);
354355

355356
// In case of multiple matches, fall back on the field's name as a last resort.
356-
if (beanNames.size() > 1) {
357-
String fieldName = handler.getField().getName();
357+
if (field != null && beanNames.size() > 1) {
358+
String fieldName = field.getName();
358359
if (beanNames.contains(fieldName)) {
359360
return Set.of(fieldName);
360361
}

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideContextCustomizerFactory.java

+2-5
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,8 @@ public BeanOverrideContextCustomizer createContextCustomizer(Class<?> testClass,
5252
}
5353

5454
private void findBeanOverrideHandlers(Class<?> testClass, Set<BeanOverrideHandler> handlers) {
55-
if (TestContextAnnotationUtils.searchEnclosingClass(testClass)) {
56-
findBeanOverrideHandlers(testClass.getEnclosingClass(), handlers);
57-
}
58-
BeanOverrideHandler.forTestClass(testClass).forEach(handler ->
59-
Assert.state(handlers.add(handler), () ->
55+
BeanOverrideHandler.forTestClass(testClass, TestContextAnnotationUtils::searchEnclosingClass)
56+
.forEach(handler -> Assert.state(handlers.add(handler), () ->
6057
"Duplicate BeanOverrideHandler discovered in test class %s: %s"
6158
.formatted(testClass.getName(), handler)));
6259
}

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideHandler.java

+64-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.test.context.bean.override;
1818

1919
import java.lang.annotation.Annotation;
20+
import java.lang.reflect.AnnotatedElement;
2021
import java.lang.reflect.Field;
2122
import java.lang.reflect.Modifier;
2223
import java.util.Arrays;
@@ -27,6 +28,8 @@
2728
import java.util.Objects;
2829
import java.util.Set;
2930
import java.util.concurrent.atomic.AtomicBoolean;
31+
import java.util.function.BiConsumer;
32+
import java.util.function.Predicate;
3033

3134
import org.springframework.beans.BeanUtils;
3235
import org.springframework.beans.factory.config.BeanDefinition;
@@ -70,6 +73,7 @@
7073
*/
7174
public abstract class BeanOverrideHandler {
7275

76+
@Nullable
7377
private final Field field;
7478

7579
private final Set<Annotation> fieldAnnotations;
@@ -82,7 +86,7 @@ public abstract class BeanOverrideHandler {
8286
private final BeanOverrideStrategy strategy;
8387

8488

85-
protected BeanOverrideHandler(Field field, ResolvableType beanType, @Nullable String beanName,
89+
protected BeanOverrideHandler(@Nullable Field field, ResolvableType beanType, @Nullable String beanName,
8690
BeanOverrideStrategy strategy) {
8791

8892
this.field = field;
@@ -102,34 +106,78 @@ protected BeanOverrideHandler(Field field, ResolvableType beanType, @Nullable St
102106
* @see org.springframework.test.context.TestContextAnnotationUtils#searchEnclosingClass(Class)
103107
*/
104108
public static List<BeanOverrideHandler> forTestClass(Class<?> testClass) {
109+
return forTestClass(testClass, clazz -> false);
110+
}
111+
112+
public static List<BeanOverrideHandler> forTestClass(Class<?> testClass, Predicate<Class<?>> searchEnclosingClass) {
105113
List<BeanOverrideHandler> handlers = new LinkedList<>();
106-
ReflectionUtils.doWithFields(testClass, field -> processField(field, testClass, handlers));
114+
findHandlers(testClass, testClass, handlers, searchEnclosingClass);
107115
return handlers;
108116
}
109117

118+
private static void findHandlers(Class<?> clazz, Class<?> testClass, List<BeanOverrideHandler> handlers,
119+
Predicate<Class<?>> searchEnclosingClass) {
120+
121+
if (clazz == null || Object.class == clazz) {
122+
return;
123+
}
124+
125+
// 1) Search enclosing class hierarchy.
126+
if (searchEnclosingClass.test(clazz)) {
127+
findHandlers(clazz.getEnclosingClass(), testClass, handlers, searchEnclosingClass);
128+
}
129+
130+
// 2) Search type hierarchy.
131+
findHandlers(clazz.getSuperclass(), testClass, handlers, searchEnclosingClass);
132+
133+
// 3) Search interfaces.
134+
for (Class<?> ifc : clazz.getInterfaces()) {
135+
findHandlers(ifc, testClass, handlers, searchEnclosingClass);
136+
}
137+
138+
// 4) Process current class.
139+
processClass(clazz, testClass, handlers);
140+
141+
// 5) Process fields in current class.
142+
ReflectionUtils.doWithLocalFields(clazz, field -> processField(field, testClass, handlers));
143+
}
144+
145+
private static void processClass(Class<?> clazz, Class<?> testClass, List<BeanOverrideHandler> handlers) {
146+
processElement(clazz, testClass, (processor, composedAnnotation) -> {
147+
processor.createHandlers(composedAnnotation, testClass).forEach(handlers::add);
148+
});
149+
}
150+
110151
private static void processField(Field field, Class<?> testClass, List<BeanOverrideHandler> handlers) {
111152
AtomicBoolean overrideAnnotationFound = new AtomicBoolean();
112-
MergedAnnotations.from(field, DIRECT).stream(BeanOverride.class).forEach(mergedAnnotation -> {
153+
processElement(field, testClass, (processor, composedAnnotation) -> {
113154
Assert.state(!Modifier.isStatic(field.getModifiers()),
114155
() -> "@BeanOverride field must not be static: " + field);
156+
Assert.state(overrideAnnotationFound.compareAndSet(false, true),
157+
() -> "Multiple @BeanOverride annotations found on field: " + field);
158+
handlers.add(processor.createHandler(composedAnnotation, testClass, field));
159+
});
160+
}
161+
162+
private static void processElement(AnnotatedElement element, Class<?> testClass,
163+
BiConsumer<BeanOverrideProcessor, Annotation> consumer) {
164+
165+
MergedAnnotations.from(element, DIRECT).stream(BeanOverride.class).forEach(mergedAnnotation -> {
115166
MergedAnnotation<?> metaSource = mergedAnnotation.getMetaSource();
116167
Assert.state(metaSource != null, "@BeanOverride annotation must be meta-present");
117168

118169
BeanOverride beanOverride = mergedAnnotation.synthesize();
119170
BeanOverrideProcessor processor = BeanUtils.instantiateClass(beanOverride.value());
120171
Annotation composedAnnotation = metaSource.synthesize();
121-
122-
Assert.state(overrideAnnotationFound.compareAndSet(false, true),
123-
() -> "Multiple @BeanOverride annotations found on field: " + field);
124-
BeanOverrideHandler handler = processor.createHandler(composedAnnotation, testClass, field);
125-
handlers.add(handler);
172+
consumer.accept(processor, composedAnnotation);
126173
});
127174
}
128175

129176

130177
/**
131178
* Get the annotated {@link Field}.
132179
*/
180+
@Nullable
133181
public final Field getField() {
134182
return this.field;
135183
}
@@ -231,15 +279,16 @@ public boolean equals(Object other) {
231279
}
232280

233281
// by-type lookup
234-
return (Objects.equals(this.field.getName(), that.field.getName()) &&
282+
return (this.field != null && that.field != null &&
283+
Objects.equals(this.field.getName(), that.field.getName()) &&
235284
this.fieldAnnotations.equals(that.fieldAnnotations));
236285
}
237286

238287
@Override
239288
public int hashCode() {
240289
int hash = Objects.hash(getClass(), this.beanType.getType(), this.beanName, this.strategy);
241-
return (this.beanName != null ? hash : hash +
242-
Objects.hash(this.field.getName(), this.fieldAnnotations));
290+
return (this.beanName != null ? hash :
291+
hash + Objects.hash((this.field != null ? this.field.getName() : ""), this.fieldAnnotations));
243292
}
244293

245294
@Override
@@ -252,7 +301,10 @@ public String toString() {
252301
.toString();
253302
}
254303

255-
private static Set<Annotation> annotationSet(Field field) {
304+
private static Set<Annotation> annotationSet(@Nullable Field field) {
305+
if (field == null) {
306+
return Collections.emptySet();
307+
}
256308
Annotation[] annotations = field.getAnnotations();
257309
return (annotations.length != 0 ? new HashSet<>(Arrays.asList(annotations)) : Collections.emptySet());
258310
}

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideProcessor.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.lang.annotation.Annotation;
2020
import java.lang.reflect.Field;
21+
import java.util.List;
2122

2223
/**
2324
* Strategy interface for Bean Override processing, which creates a
@@ -36,7 +37,6 @@
3637
* @author Sam Brannen
3738
* @since 6.2
3839
*/
39-
@FunctionalInterface
4040
public interface BeanOverrideProcessor {
4141

4242
/**
@@ -49,4 +49,8 @@ public interface BeanOverrideProcessor {
4949
*/
5050
BeanOverrideHandler createHandler(Annotation overrideAnnotation, Class<?> testClass, Field field);
5151

52+
default List<BeanOverrideHandler> createHandlers(Annotation overrideAnnotation, Class<?> testClass) {
53+
return List.of();
54+
}
55+
5256
}

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideTestExecutionListener.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ private static void injectFields(TestContext testContext) {
8888
.getBean(BeanOverrideContextCustomizer.REGISTRY_BEAN_NAME, BeanOverrideRegistry.class);
8989

9090
for (BeanOverrideHandler handler : handlers) {
91-
beanOverrideRegistry.inject(testInstance, handler);
91+
if (handler.getField() != null) {
92+
beanOverrideRegistry.inject(testInstance, handler);
93+
}
9294
}
9395
}
9496
}

spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/AbstractMockitoBeanOverrideHandler.java

+10-10
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ abstract class AbstractMockitoBeanOverrideHandler extends BeanOverrideHandler {
3838
private final MockReset reset;
3939

4040

41-
protected AbstractMockitoBeanOverrideHandler(Field field, ResolvableType beanType,
41+
protected AbstractMockitoBeanOverrideHandler(@Nullable Field field, ResolvableType beanType,
4242
@Nullable String beanName, BeanOverrideStrategy strategy, @Nullable MockReset reset) {
4343

4444
super(field, beanType, beanName, strategy);
@@ -56,20 +56,20 @@ MockReset getReset() {
5656

5757
@Override
5858
protected void trackOverrideInstance(Object mock, SingletonBeanRegistry trackingBeanRegistry) {
59-
getMockitoBeans(trackingBeanRegistry).add(mock);
59+
getMockBeans(trackingBeanRegistry).add(mock);
6060
}
6161

62-
private static MockitoBeans getMockitoBeans(SingletonBeanRegistry trackingBeanRegistry) {
63-
String beanName = MockitoBeans.class.getName();
64-
MockitoBeans mockitoBeans = null;
62+
private static MockBeans getMockBeans(SingletonBeanRegistry trackingBeanRegistry) {
63+
String beanName = MockBeans.class.getName();
64+
MockBeans mockBeans = null;
6565
if (trackingBeanRegistry.containsSingleton(beanName)) {
66-
mockitoBeans = (MockitoBeans) trackingBeanRegistry.getSingleton(beanName);
66+
mockBeans = (MockBeans) trackingBeanRegistry.getSingleton(beanName);
6767
}
68-
if (mockitoBeans == null) {
69-
mockitoBeans = new MockitoBeans();
70-
trackingBeanRegistry.registerSingleton(beanName, mockitoBeans);
68+
if (mockBeans == null) {
69+
mockBeans = new MockBeans();
70+
trackingBeanRegistry.registerSingleton(beanName, mockBeans);
7171
}
72-
return mockitoBeans;
72+
return mockBeans;
7373
}
7474

7575
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.bean.override.mockito;
18+
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
22+
import org.mockito.Mockito;
23+
24+
/**
25+
* Beans created using Mockito.
26+
*
27+
* @author Andy Wilkinson
28+
* @author Sam Brannen
29+
* @since 6.2
30+
*/
31+
class MockBeans {
32+
33+
private final List<Object> beans = new ArrayList<>();
34+
35+
36+
void add(Object bean) {
37+
this.beans.add(bean);
38+
}
39+
40+
/**
41+
* Reset all Mockito beans configured with the supplied {@link MockReset} strategy.
42+
* <p>No mocks will be reset if the supplied strategy is {@link MockReset#NONE}.
43+
*/
44+
void resetAll(MockReset reset) {
45+
if (reset != MockReset.NONE) {
46+
for (Object bean : this.beans) {
47+
if (reset == MockReset.get(bean)) {
48+
Mockito.reset(bean);
49+
}
50+
}
51+
}
52+
}
53+
54+
}

spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBean.java

+15-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.lang.annotation.Documented;
2020
import java.lang.annotation.ElementType;
21+
import java.lang.annotation.Repeatable;
2122
import java.lang.annotation.Retention;
2223
import java.lang.annotation.RetentionPolicy;
2324
import java.lang.annotation.Target;
@@ -69,9 +70,10 @@
6970
* @see org.springframework.test.context.bean.override.mockito.MockitoSpyBean @MockitoSpyBean
7071
* @see org.springframework.test.context.bean.override.convention.TestBean @TestBean
7172
*/
72-
@Target(ElementType.FIELD)
73+
@Target({ElementType.FIELD, ElementType.TYPE})
7374
@Retention(RetentionPolicy.RUNTIME)
7475
@Documented
76+
@Repeatable(MockitoBeans.class)
7577
@BeanOverride(MockitoBeanOverrideProcessor.class)
7678
public @interface MockitoBean {
7779

@@ -94,6 +96,18 @@
9496
@AliasFor("value")
9597
String name() default "";
9698

99+
/**
100+
* The types to mock.
101+
* <p>Defaults to none.
102+
* <p>Each type specified here will result in a mock being created and
103+
* registered with the {@code ApplicationContext}.
104+
* <p>Types must be omitted when the annotation is used on a field.
105+
* <p>When {@code @MockitoBean} also defines a {@link #name} this attribute
106+
* can only contain a single value.
107+
* @return the types to mock
108+
*/
109+
Class<?>[] types() default {};
110+
97111
/**
98112
* Extra interfaces that should also be declared by the mock.
99113
* <p>Defaults to none.

spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideHandler.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class MockitoBeanOverrideHandler extends AbstractMockitoBeanOverrideHandler {
5757
private final boolean serializable;
5858

5959

60-
MockitoBeanOverrideHandler(Field field, ResolvableType typeToMock, MockitoBean mockitoBean) {
60+
MockitoBeanOverrideHandler(@Nullable Field field, ResolvableType typeToMock, MockitoBean mockitoBean) {
6161
this(field, typeToMock, (!mockitoBean.name().isBlank() ? mockitoBean.name() : null),
6262
(mockitoBean.enforceOverride() ? REPLACE : REPLACE_OR_CREATE),
6363
mockitoBean.reset(), mockitoBean.extraInterfaces(), mockitoBean.answers(), mockitoBean.serializable());

0 commit comments

Comments
 (0)