Skip to content

Commit e1c450d

Browse files
committed
Support Bean Overrides with AOT and native image
This set of commits introduces AOT and native image support for the new Bean Override feature in the Spring TestContext Framework -- for example, for using @⁠TestBean and @⁠MockitoBean in AOT mode on the JVM as well as in a GraalVM native image. Note, however, that @⁠MockitoBean has currently only been tested in a native image when mocking interfaces and using Mockito's ProxyMockMaker, along with a custom runtime hints. Closes gh-32933
2 parents f590511 + 65d2191 commit e1c450d

13 files changed

+400
-133
lines changed

Diff for: spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverride.java

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import java.lang.annotation.RetentionPolicy;
2323
import java.lang.annotation.Target;
2424

25+
import org.springframework.aot.hint.annotation.Reflective;
26+
2527
/**
2628
* Mark a composed annotation as eligible for Bean Override processing.
2729
*
@@ -37,11 +39,13 @@
3739
* {@link org.springframework.test.context.bean.override.mockito.MockitoSpyBean @MockitoSpyBean}.
3840
*
3941
* @author Simon Baslé
42+
* @author Sam Brannen
4043
* @since 6.2
4144
*/
4245
@Retention(RetentionPolicy.RUNTIME)
4346
@Target(ElementType.ANNOTATION_TYPE)
4447
@Documented
48+
@Reflective(BeanOverrideReflectiveProcessor.class)
4549
public @interface BeanOverride {
4650

4751
/**

Diff for: spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessor.java

+24
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.beans.factory.support.DefaultBeanNameGenerator;
3636
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
3737
import org.springframework.beans.factory.support.RootBeanDefinition;
38+
import org.springframework.context.aot.AbstractAotProcessor;
3839
import org.springframework.core.Ordered;
3940
import org.springframework.core.ResolvableType;
4041
import org.springframework.lang.Nullable;
@@ -105,6 +106,12 @@ private void registerBeanOverride(ConfigurableListableBeanFactory beanFactory, O
105106
private void replaceDefinition(ConfigurableListableBeanFactory beanFactory, OverrideMetadata overrideMetadata,
106107
boolean enforceExistingDefinition) {
107108

109+
// NOTE: This method supports 3 distinct scenarios which must be accounted for.
110+
//
111+
// 1) JVM runtime
112+
// 2) AOT processing
113+
// 3) AOT runtime
114+
108115
if (!(beanFactory instanceof BeanDefinitionRegistry registry)) {
109116
throw new IllegalStateException("Cannot process bean override with a BeanFactory " +
110117
"that doesn't implement BeanDefinitionRegistry: " + beanFactory.getClass().getName());
@@ -147,12 +154,24 @@ else if (enforceExistingDefinition) {
147154

148155
if (existingBeanDefinition != null) {
149156
// Validate the existing bean definition.
157+
//
158+
// Applies during "JVM runtime", "AOT processing", and "AOT runtime".
150159
validateBeanDefinition(beanFactory, beanName);
151160
}
161+
else if (Boolean.getBoolean(AbstractAotProcessor.AOT_PROCESSING)) {
162+
// There was no existing bean definition, but during "AOT processing" we
163+
// do not register the "pseudo" bean definition since our AOT support
164+
// cannot automatically convert that to a functional bean definition for
165+
// use at "AOT runtime". Furthermore, by not registering a bean definition
166+
// for a nonexistent bean, we allow the "JVM runtime" and "AOT runtime"
167+
// to operate the same in the following else-block.
168+
}
152169
else {
153170
// There was no existing bean definition, so we register the "pseudo" bean
154171
// definition to ensure that a suitable bean definition exists for the given
155172
// bean name for proper autowiring candidate resolution.
173+
//
174+
// Applies during "JVM runtime" and "AOT runtime".
156175
registry.registerBeanDefinition(beanName, pseudoBeanDefinition);
157176
}
158177

@@ -163,6 +182,11 @@ else if (enforceExistingDefinition) {
163182
// Now we have an instance (the override) that we can register. At this stage, we don't
164183
// expect a singleton instance to be present. If for some reason a singleton instance
165184
// already exists, the following will throw an exception.
185+
//
186+
// As a bonus, by manually registering a singleton during "AOT processing", we allow
187+
// GenericApplicationContext's preDetermineBeanType() method to transparently register
188+
// runtime hints for a proxy generated by the above createOverride() invocation --
189+
// for example, when @MockitoBean creates a mock based on a JDK dynamic proxy.
166190
beanFactory.registerSingleton(beanName, override);
167191
}
168192

Diff for: spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideContextCustomizer.java

+15-37
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,8 @@
1717
package org.springframework.test.context.bean.override;
1818

1919
import java.util.Set;
20-
import java.util.function.Consumer;
2120

22-
import org.springframework.beans.factory.config.BeanDefinition;
23-
import org.springframework.beans.factory.config.ConstructorArgumentValues;
24-
import org.springframework.beans.factory.config.RuntimeBeanReference;
25-
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
26-
import org.springframework.beans.factory.support.RootBeanDefinition;
21+
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
2722
import org.springframework.context.ConfigurableApplicationContext;
2823
import org.springframework.test.context.ContextCustomizer;
2924
import org.springframework.test.context.MergedContextConfiguration;
@@ -34,6 +29,7 @@
3429
*
3530
* @author Simon Baslé
3631
* @author Stephane Nicoll
32+
* @author Sam Brannen
3733
* @since 6.2
3834
*/
3935
class BeanOverrideContextCustomizer implements ContextCustomizer {
@@ -56,43 +52,25 @@ class BeanOverrideContextCustomizer implements ContextCustomizer {
5652

5753
@Override
5854
public void customizeContext(ConfigurableApplicationContext context, MergedContextConfiguration mergedConfig) {
59-
if (!(context instanceof BeanDefinitionRegistry registry)) {
60-
throw new IllegalStateException("Cannot process bean overrides with an ApplicationContext " +
61-
"that doesn't implement BeanDefinitionRegistry: " + context.getClass().getName());
62-
}
63-
registerInfrastructure(registry);
55+
ConfigurableBeanFactory beanFactory = context.getBeanFactory();
56+
// Since all three Bean Override infrastructure beans are never injected as
57+
// dependencies into other beans within the ApplicationContext, it is sufficient
58+
// to register them as manual singleton instances. In addition, registration of
59+
// the BeanOverrideBeanFactoryPostProcessor as a singleton is a requirement for
60+
// AOT processing, since a bean definition cannot be generated for the
61+
// Set<OverrideMetadata> argument that it accepts in its constructor.
62+
BeanOverrideRegistrar beanOverrideRegistrar = new BeanOverrideRegistrar(beanFactory);
63+
beanFactory.registerSingleton(REGISTRAR_BEAN_NAME, beanOverrideRegistrar);
64+
beanFactory.registerSingleton(INFRASTRUCTURE_BEAN_NAME,
65+
new BeanOverrideBeanFactoryPostProcessor(this.metadata, beanOverrideRegistrar));
66+
beanFactory.registerSingleton(EARLY_INFRASTRUCTURE_BEAN_NAME,
67+
new WrapEarlyBeanPostProcessor(beanOverrideRegistrar));
6468
}
6569

6670
Set<OverrideMetadata> getMetadata() {
6771
return this.metadata;
6872
}
6973

70-
private void registerInfrastructure(BeanDefinitionRegistry registry) {
71-
addInfrastructureBeanDefinition(registry, BeanOverrideRegistrar.class, REGISTRAR_BEAN_NAME,
72-
constructorArgs -> {});
73-
74-
RuntimeBeanReference registrarReference = new RuntimeBeanReference(REGISTRAR_BEAN_NAME);
75-
addInfrastructureBeanDefinition(registry, WrapEarlyBeanPostProcessor.class, EARLY_INFRASTRUCTURE_BEAN_NAME,
76-
constructorArgs -> constructorArgs.addIndexedArgumentValue(0, registrarReference));
77-
addInfrastructureBeanDefinition(registry, BeanOverrideBeanFactoryPostProcessor.class, INFRASTRUCTURE_BEAN_NAME,
78-
constructorArgs -> {
79-
constructorArgs.addIndexedArgumentValue(0, this.metadata);
80-
constructorArgs.addIndexedArgumentValue(1, registrarReference);
81-
});
82-
}
83-
84-
private void addInfrastructureBeanDefinition(BeanDefinitionRegistry registry,
85-
Class<?> clazz, String beanName, Consumer<ConstructorArgumentValues> constructorArgumentsConsumer) {
86-
87-
if (!registry.containsBeanDefinition(beanName)) {
88-
RootBeanDefinition definition = new RootBeanDefinition(clazz);
89-
definition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
90-
ConstructorArgumentValues constructorArguments = definition.getConstructorArgumentValues();
91-
constructorArgumentsConsumer.accept(constructorArguments);
92-
registry.registerBeanDefinition(beanName, definition);
93-
}
94-
}
95-
9674
@Override
9775
public boolean equals(Object other) {
9876
if (other == this) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.bean.override;
18+
19+
import java.lang.reflect.AnnotatedElement;
20+
21+
import org.springframework.aot.hint.ReflectionHints;
22+
import org.springframework.aot.hint.annotation.ReflectiveProcessor;
23+
import org.springframework.core.annotation.MergedAnnotation;
24+
import org.springframework.core.annotation.MergedAnnotations;
25+
26+
import static org.springframework.aot.hint.MemberCategory.INVOKE_DECLARED_CONSTRUCTORS;
27+
28+
/**
29+
* {@link ReflectiveProcessor} that processes {@link BeanOverride @BeanOverride}
30+
* annotations.
31+
*
32+
* @author Sam Brannen
33+
* @since 6.2
34+
*/
35+
class BeanOverrideReflectiveProcessor implements ReflectiveProcessor {
36+
37+
@Override
38+
public void registerReflectionHints(ReflectionHints hints, AnnotatedElement element) {
39+
MergedAnnotations.from(element)
40+
.get(BeanOverride.class)
41+
.synthesize(MergedAnnotation::isPresent)
42+
.map(BeanOverride::value)
43+
.ifPresent(clazz -> hints.registerType(clazz, INVOKE_DECLARED_CONSTRUCTORS));
44+
}
45+
46+
}

Diff for: spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistrar.java

+8-17
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222

2323
import org.springframework.beans.BeansException;
2424
import org.springframework.beans.factory.BeanCreationException;
25-
import org.springframework.beans.factory.BeanFactory;
26-
import org.springframework.beans.factory.BeanFactoryAware;
2725
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
28-
import org.springframework.lang.Nullable;
2926
import org.springframework.util.Assert;
3027
import org.springframework.util.ReflectionUtils;
3128
import org.springframework.util.StringUtils;
@@ -36,36 +33,31 @@
3633
* for test execution listeners.
3734
*
3835
* @author Simon Baslé
36+
* @author Sam Brannen
3937
* @since 6.2
4038
*/
41-
class BeanOverrideRegistrar implements BeanFactoryAware {
39+
class BeanOverrideRegistrar {
4240

4341
private final Map<OverrideMetadata, String> beanNameRegistry = new HashMap<>();
4442

4543
private final Map<String, OverrideMetadata> earlyOverrideMetadata = new HashMap<>();
4644

47-
@Nullable
48-
private ConfigurableBeanFactory beanFactory;
45+
private final ConfigurableBeanFactory beanFactory;
4946

5047

51-
@Override
52-
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
53-
if (!(beanFactory instanceof ConfigurableBeanFactory cbf)) {
54-
throw new IllegalStateException("Cannot process bean override with a BeanFactory " +
55-
"that doesn't implement ConfigurableBeanFactory: " + beanFactory.getClass().getName());
56-
}
57-
this.beanFactory = cbf;
48+
BeanOverrideRegistrar(ConfigurableBeanFactory beanFactory) {
49+
Assert.notNull(beanFactory, "ConfigurableBeanFactory must not be null");
50+
this.beanFactory = beanFactory;
5851
}
5952

6053
/**
61-
* Check {@link #markWrapEarly(OverrideMetadata, String) early override}
54+
* Check {@linkplain #markWrapEarly(OverrideMetadata, String) early override}
6255
* records and use the {@link OverrideMetadata} to create an override
63-
* instance from the provided bean, if relevant.
56+
* instance based on the provided bean, if relevant.
6457
*/
6558
Object wrapIfNecessary(Object bean, String beanName) throws BeansException {
6659
OverrideMetadata metadata = this.earlyOverrideMetadata.get(beanName);
6760
if (metadata != null && metadata.getStrategy() == BeanOverrideStrategy.WRAP_BEAN) {
68-
Assert.state(this.beanFactory != null, "ConfigurableBeanFactory must not be null");
6961
bean = metadata.createOverride(beanName, null, bean);
7062
metadata.track(bean, this.beanFactory);
7163
}
@@ -99,7 +91,6 @@ private void inject(Field field, Object target, String beanName) {
9991
try {
10092
ReflectionUtils.makeAccessible(field);
10193
Object existingValue = ReflectionUtils.getField(field, target);
102-
Assert.state(this.beanFactory != null, "ConfigurableBeanFactory must not be null");
10394
Object bean = this.beanFactory.getBean(beanName, field.getType());
10495
if (existingValue == bean) {
10596
return;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.bean.override.mockito;
18+
19+
import java.lang.annotation.Annotation;
20+
import java.lang.reflect.AnnotatedElement;
21+
import java.util.Arrays;
22+
import java.util.concurrent.atomic.AtomicBoolean;
23+
import java.util.function.Predicate;
24+
25+
import org.springframework.util.ReflectionUtils;
26+
27+
/**
28+
* Utility class that detects {@code org.mockito} annotations as well as the
29+
* annotations in this package (like {@link MockitoBeanSettings @MockitoBeanSettings}).
30+
*
31+
* @author Simon Baslé
32+
* @author Sam Brannen
33+
*/
34+
abstract class MockitoAnnotationDetector {
35+
36+
private static final String MOCKITO_BEAN_PACKAGE = MockitoBeanSettings.class.getPackageName();
37+
38+
private static final String ORG_MOCKITO_PACKAGE = "org.mockito";
39+
40+
private static final Predicate<Annotation> isMockitoAnnotation = annotation -> {
41+
String packageName = annotation.annotationType().getPackageName();
42+
return (packageName.startsWith(MOCKITO_BEAN_PACKAGE) ||
43+
packageName.startsWith(ORG_MOCKITO_PACKAGE));
44+
};
45+
46+
static boolean hasMockitoAnnotations(Class<?> testClass) {
47+
if (isAnnotated(testClass)) {
48+
return true;
49+
}
50+
// TODO Ideally we should short-circuit the search once we've found a Mockito annotation,
51+
// since there's no need to continue searching additional fields or further up the class
52+
// hierarchy; however, that is not possible with ReflectionUtils#doWithFields. Plus, the
53+
// previous invocation of isAnnotated(testClass) only finds annotations declared directly
54+
// on the test class. So, we'll likely need a completely different approach that combines
55+
// the "test class/interface is annotated?" and "field is annotated?" checks in a single
56+
// search algorithm.
57+
AtomicBoolean found = new AtomicBoolean();
58+
ReflectionUtils.doWithFields(testClass, field -> found.set(true), MockitoAnnotationDetector::isAnnotated);
59+
return found.get();
60+
}
61+
62+
private static boolean isAnnotated(AnnotatedElement annotatedElement) {
63+
return Arrays.stream(annotatedElement.getAnnotations()).anyMatch(isMockitoAnnotation);
64+
}
65+
66+
}

Diff for: spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoResetTestExecutionListener.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
3030
import org.springframework.context.ApplicationContext;
3131
import org.springframework.context.ConfigurableApplicationContext;
32-
import org.springframework.core.NativeDetector;
3332
import org.springframework.core.Ordered;
3433
import org.springframework.lang.Nullable;
3534
import org.springframework.test.context.TestContext;
@@ -58,14 +57,16 @@ public int getOrder() {
5857

5958
@Override
6059
public void beforeTestMethod(TestContext testContext) throws Exception {
61-
if (MockitoTestExecutionListener.mockitoPresent && !NativeDetector.inNativeImage()) {
60+
Class<?> testClass = testContext.getTestClass();
61+
if (MockitoTestExecutionListener.mockitoPresent && MockitoAnnotationDetector.hasMockitoAnnotations(testClass)) {
6262
resetMocks(testContext.getApplicationContext(), MockReset.BEFORE);
6363
}
6464
}
6565

6666
@Override
6767
public void afterTestMethod(TestContext testContext) throws Exception {
68-
if (MockitoTestExecutionListener.mockitoPresent && !NativeDetector.inNativeImage()) {
68+
Class<?> testClass = testContext.getTestClass();
69+
if (MockitoTestExecutionListener.mockitoPresent && MockitoAnnotationDetector.hasMockitoAnnotations(testClass)) {
6970
resetMocks(testContext.getApplicationContext(), MockReset.AFTER);
7071
}
7172
}

0 commit comments

Comments
 (0)