Skip to content

Commit 8a8c8fe

Browse files
committed
Detect target of factory method with AOT
Previously, if a factory method is defined on a parent, the generated code would blindly use the method's declaring class for both the target of the generated code, and the signature of the method. This commit improves the resolution by considering the factory metadata in the BeanDefinition. Closes gh-32609
1 parent f45e7b9 commit 8a8c8fe

File tree

12 files changed

+214
-63
lines changed

12 files changed

+214
-63
lines changed

spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java

+14-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
1717
package org.springframework.beans.factory.aot;
1818

1919
import java.lang.reflect.Constructor;
20-
import java.lang.reflect.Executable;
2120
import java.lang.reflect.Modifier;
2221
import java.util.List;
2322
import java.util.function.Predicate;
@@ -35,6 +34,7 @@
3534
import org.springframework.beans.factory.config.BeanDefinitionHolder;
3635
import org.springframework.beans.factory.support.InstanceSupplier;
3736
import org.springframework.beans.factory.support.RegisteredBean;
37+
import org.springframework.beans.factory.support.RegisteredBean.InstantiationDescriptor;
3838
import org.springframework.beans.factory.support.RootBeanDefinition;
3939
import org.springframework.core.ResolvableType;
4040
import org.springframework.javapoet.ClassName;
@@ -62,7 +62,7 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme
6262

6363
private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory;
6464

65-
private final Supplier<Executable> constructorOrFactoryMethod;
65+
private final Supplier<InstantiationDescriptor> instantiationDescriptor;
6666

6767

6868
DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode,
@@ -72,7 +72,7 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme
7272
this.beanRegistrationsCode = beanRegistrationsCode;
7373
this.registeredBean = registeredBean;
7474
this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory;
75-
this.constructorOrFactoryMethod = SingletonSupplier.of(registeredBean::resolveConstructorOrFactoryMethod);
75+
this.instantiationDescriptor = SingletonSupplier.of(registeredBean::resolveInstantiationDescriptor);
7676
}
7777

7878

@@ -82,7 +82,7 @@ public ClassName getTarget(RegisteredBean registeredBean) {
8282
throw new IllegalStateException("Default code generation is not supported for bean definitions "
8383
+ "declaring an instance supplier callback: " + registeredBean.getMergedBeanDefinition());
8484
}
85-
Class<?> target = extractDeclaringClass(registeredBean.getBeanType(), this.constructorOrFactoryMethod.get());
85+
Class<?> target = extractDeclaringClass(registeredBean, this.instantiationDescriptor.get());
8686
while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) {
8787
RegisteredBean parent = registeredBean.getParent();
8888
Assert.state(parent != null, "No parent available for inner bean");
@@ -91,14 +91,14 @@ public ClassName getTarget(RegisteredBean registeredBean) {
9191
return (target.isArray() ? ClassName.get(target.getComponentType()) : ClassName.get(target));
9292
}
9393

94-
private Class<?> extractDeclaringClass(ResolvableType beanType, Executable executable) {
95-
Class<?> declaringClass = ClassUtils.getUserClass(executable.getDeclaringClass());
96-
if (executable instanceof Constructor<?>
97-
&& AccessControl.forMember(executable).isPublic()
94+
private Class<?> extractDeclaringClass(RegisteredBean registeredBean, InstantiationDescriptor instantiationDescriptor) {
95+
Class<?> declaringClass = ClassUtils.getUserClass(instantiationDescriptor.targetClass());
96+
if (instantiationDescriptor.executable() instanceof Constructor<?>
97+
&& AccessControl.forMember(instantiationDescriptor.executable()).isPublic()
9898
&& FactoryBean.class.isAssignableFrom(declaringClass)) {
99-
return extractTargetClassFromFactoryBean(declaringClass, beanType);
99+
return extractTargetClassFromFactoryBean(declaringClass, registeredBean.getBeanType());
100100
}
101-
return executable.getDeclaringClass();
101+
return declaringClass;
102102
}
103103

104104
/**
@@ -238,9 +238,9 @@ public CodeBlock generateInstanceSupplierCode(GenerationContext generationContex
238238
throw new IllegalStateException("Default code generation is not supported for bean definitions declaring "
239239
+ "an instance supplier callback: " + this.registeredBean.getMergedBeanDefinition());
240240
}
241-
return new InstanceSupplierCodeGenerator(generationContext,
242-
beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut)
243-
.generateCode(this.registeredBean, this.constructorOrFactoryMethod.get());
241+
return new InstanceSupplierCodeGenerator(generationContext, beanRegistrationCode.getClassName(),
242+
beanRegistrationCode.getMethods(), allowDirectSupplierShortcut).generateCode(
243+
this.registeredBean, this.instantiationDescriptor.get());
244244
}
245245

246246
@Override

spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java

+39-23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -46,6 +46,7 @@
4646
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
4747
import org.springframework.beans.factory.support.InstanceSupplier;
4848
import org.springframework.beans.factory.support.RegisteredBean;
49+
import org.springframework.beans.factory.support.RegisteredBean.InstantiationDescriptor;
4950
import org.springframework.core.KotlinDetector;
5051
import org.springframework.core.MethodParameter;
5152
import org.springframework.core.ResolvableType;
@@ -120,14 +121,29 @@ public InstanceSupplierCodeGenerator(GenerationContext generationContext,
120121
* @param registeredBean the bean to handle
121122
* @param constructorOrFactoryMethod the executable to use to create the bean
122123
* @return the generated code
124+
* @deprecated in favor of {@link #generateCode(RegisteredBean, InstantiationDescriptor)}
123125
*/
126+
@Deprecated(since = "6.1.7")
124127
public CodeBlock generateCode(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) {
128+
return generateCode(registeredBean, new InstantiationDescriptor(
129+
constructorOrFactoryMethod, constructorOrFactoryMethod.getDeclaringClass()));
130+
}
131+
132+
/**
133+
* Generate the instance supplier code.
134+
* @param registeredBean the bean to handle
135+
* @param instantiationDescriptor the executable to use to create the bean
136+
* @return the generated code
137+
* @since 6.1.7
138+
*/
139+
public CodeBlock generateCode(RegisteredBean registeredBean, InstantiationDescriptor instantiationDescriptor) {
140+
Executable constructorOrFactoryMethod = instantiationDescriptor.executable();
125141
registerRuntimeHintsIfNecessary(registeredBean, constructorOrFactoryMethod);
126142
if (constructorOrFactoryMethod instanceof Constructor<?> constructor) {
127143
return generateCodeForConstructor(registeredBean, constructor);
128144
}
129145
if (constructorOrFactoryMethod instanceof Method method) {
130-
return generateCodeForFactoryMethod(registeredBean, method);
146+
return generateCodeForFactoryMethod(registeredBean, method, instantiationDescriptor.targetClass());
131147
}
132148
throw new IllegalStateException(
133149
"No suitable executor found for " + registeredBean.getBeanName());
@@ -253,21 +269,21 @@ private CodeBlock generateNewInstanceCodeForConstructor(boolean dependsOnBean,
253269
declaringClass.getSimpleName(), args);
254270
}
255271

256-
private CodeBlock generateCodeForFactoryMethod(RegisteredBean registeredBean, Method factoryMethod) {
272+
private CodeBlock generateCodeForFactoryMethod(RegisteredBean registeredBean, Method factoryMethod, Class<?> targetClass) {
257273
String beanName = registeredBean.getBeanName();
258-
Class<?> declaringClass = ClassUtils.getUserClass(factoryMethod.getDeclaringClass());
274+
Class<?> targetClassToUse = ClassUtils.getUserClass(targetClass);
259275
boolean dependsOnBean = !Modifier.isStatic(factoryMethod.getModifiers());
260276

261277
Visibility accessVisibility = getAccessVisibility(registeredBean, factoryMethod);
262278
if (accessVisibility != Visibility.PRIVATE) {
263279
return generateCodeForAccessibleFactoryMethod(
264-
beanName, factoryMethod, declaringClass, dependsOnBean);
280+
beanName, factoryMethod, targetClassToUse, dependsOnBean);
265281
}
266-
return generateCodeForInaccessibleFactoryMethod(beanName, factoryMethod, declaringClass);
282+
return generateCodeForInaccessibleFactoryMethod(beanName, factoryMethod, targetClassToUse);
267283
}
268284

269285
private CodeBlock generateCodeForAccessibleFactoryMethod(String beanName,
270-
Method factoryMethod, Class<?> declaringClass, boolean dependsOnBean) {
286+
Method factoryMethod, Class<?> targetClass, boolean dependsOnBean) {
271287

272288
this.generationContext.getRuntimeHints().reflection().registerMethod(
273289
factoryMethod, ExecutableMode.INTROSPECT);
@@ -276,20 +292,20 @@ private CodeBlock generateCodeForAccessibleFactoryMethod(String beanName,
276292
Class<?> suppliedType = ClassUtils.resolvePrimitiveIfNecessary(factoryMethod.getReturnType());
277293
CodeBlock.Builder code = CodeBlock.builder();
278294
code.add("$T.<$T>forFactoryMethod($T.class, $S)", BeanInstanceSupplier.class,
279-
suppliedType, declaringClass, factoryMethod.getName());
295+
suppliedType, targetClass, factoryMethod.getName());
280296
code.add(".withGenerator(($L) -> $T.$L())", REGISTERED_BEAN_PARAMETER_NAME,
281-
declaringClass, factoryMethod.getName());
297+
targetClass, factoryMethod.getName());
282298
return code.build();
283299
}
284300

285301
GeneratedMethod getInstanceMethod = generateGetInstanceSupplierMethod(method ->
286302
buildGetInstanceMethodForFactoryMethod(method, beanName, factoryMethod,
287-
declaringClass, dependsOnBean, PRIVATE_STATIC));
303+
targetClass, dependsOnBean, PRIVATE_STATIC));
288304
return generateReturnStatement(getInstanceMethod);
289305
}
290306

291307
private CodeBlock generateCodeForInaccessibleFactoryMethod(
292-
String beanName, Method factoryMethod, Class<?> declaringClass) {
308+
String beanName, Method factoryMethod, Class<?> targetClass) {
293309

294310
this.generationContext.getRuntimeHints().reflection().registerMethod(factoryMethod, ExecutableMode.INVOKE);
295311
GeneratedMethod getInstanceMethod = generateGetInstanceSupplierMethod(method -> {
@@ -298,19 +314,19 @@ private CodeBlock generateCodeForInaccessibleFactoryMethod(
298314
method.addModifiers(PRIVATE_STATIC);
299315
method.returns(ParameterizedTypeName.get(BeanInstanceSupplier.class, suppliedType));
300316
method.addStatement(generateInstanceSupplierForFactoryMethod(
301-
factoryMethod, suppliedType, declaringClass, factoryMethod.getName()));
317+
factoryMethod, suppliedType, targetClass, factoryMethod.getName()));
302318
});
303319
return generateReturnStatement(getInstanceMethod);
304320
}
305321

306322
private void buildGetInstanceMethodForFactoryMethod(MethodSpec.Builder method,
307-
String beanName, Method factoryMethod, Class<?> declaringClass,
323+
String beanName, Method factoryMethod, Class<?> targetClass,
308324
boolean dependsOnBean, javax.lang.model.element.Modifier... modifiers) {
309325

310326
String factoryMethodName = factoryMethod.getName();
311327
Class<?> suppliedType = ClassUtils.resolvePrimitiveIfNecessary(factoryMethod.getReturnType());
312328
CodeWarnings codeWarnings = new CodeWarnings();
313-
codeWarnings.detectDeprecation(declaringClass, factoryMethod, suppliedType)
329+
codeWarnings.detectDeprecation(targetClass, factoryMethod, suppliedType)
314330
.detectDeprecation(Arrays.stream(factoryMethod.getParameters()).map(Parameter::getType));
315331

316332
method.addJavadoc("Get the bean instance supplier for '$L'.", beanName);
@@ -320,41 +336,41 @@ private void buildGetInstanceMethodForFactoryMethod(MethodSpec.Builder method,
320336

321337
CodeBlock.Builder code = CodeBlock.builder();
322338
code.add(generateInstanceSupplierForFactoryMethod(
323-
factoryMethod, suppliedType, declaringClass, factoryMethodName));
339+
factoryMethod, suppliedType, targetClass, factoryMethodName));
324340

325341
boolean hasArguments = factoryMethod.getParameterCount() > 0;
326342
CodeBlock arguments = hasArguments ?
327-
new AutowiredArgumentsCodeGenerator(declaringClass, factoryMethod)
343+
new AutowiredArgumentsCodeGenerator(targetClass, factoryMethod)
328344
.generateCode(factoryMethod.getParameterTypes())
329345
: NO_ARGS;
330346

331347
CodeBlock newInstance = generateNewInstanceCodeForMethod(
332-
dependsOnBean, declaringClass, factoryMethodName, arguments);
348+
dependsOnBean, targetClass, factoryMethodName, arguments);
333349
code.add(generateWithGeneratorCode(hasArguments, newInstance));
334350
method.addStatement(code.build());
335351
}
336352

337353
private CodeBlock generateInstanceSupplierForFactoryMethod(Method factoryMethod,
338-
Class<?> suppliedType, Class<?> declaringClass, String factoryMethodName) {
354+
Class<?> suppliedType, Class<?> targetClass, String factoryMethodName) {
339355

340356
if (factoryMethod.getParameterCount() == 0) {
341357
return CodeBlock.of("return $T.<$T>forFactoryMethod($T.class, $S)",
342-
BeanInstanceSupplier.class, suppliedType, declaringClass, factoryMethodName);
358+
BeanInstanceSupplier.class, suppliedType, targetClass, factoryMethodName);
343359
}
344360

345361
CodeBlock parameterTypes = generateParameterTypesCode(factoryMethod.getParameterTypes(), 0);
346362
return CodeBlock.of("return $T.<$T>forFactoryMethod($T.class, $S, $L)",
347-
BeanInstanceSupplier.class, suppliedType, declaringClass, factoryMethodName, parameterTypes);
363+
BeanInstanceSupplier.class, suppliedType, targetClass, factoryMethodName, parameterTypes);
348364
}
349365

350366
private CodeBlock generateNewInstanceCodeForMethod(boolean dependsOnBean,
351-
Class<?> declaringClass, String factoryMethodName, CodeBlock args) {
367+
Class<?> targetClass, String factoryMethodName, CodeBlock args) {
352368

353369
if (!dependsOnBean) {
354-
return CodeBlock.of("$T.$L($L)", declaringClass, factoryMethodName, args);
370+
return CodeBlock.of("$T.$L($L)", targetClass, factoryMethodName, args);
355371
}
356372
return CodeBlock.of("$L.getBeanFactory().getBean($T.class).$L($L)",
357-
REGISTERED_BEAN_PARAMETER_NAME, declaringClass, factoryMethodName, args);
373+
REGISTERED_BEAN_PARAMETER_NAME, targetClass, factoryMethodName, args);
358374
}
359375

360376
private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) {

spring-beans/src/main/java/org/springframework/beans/factory/support/RegisteredBean.java

+38-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,8 @@
1717
package org.springframework.beans.factory.support;
1818

1919
import java.lang.reflect.Executable;
20+
import java.lang.reflect.Method;
21+
import java.lang.reflect.Modifier;
2022
import java.util.Set;
2123
import java.util.function.BiFunction;
2224
import java.util.function.Supplier;
@@ -206,12 +208,33 @@ public RegisteredBean getParent() {
206208
/**
207209
* Resolve the constructor or factory method to use for this bean.
208210
* @return the {@link java.lang.reflect.Constructor} or {@link java.lang.reflect.Method}
211+
* @deprecated in favor of {@link #resolveInstantiationDescriptor()}
209212
*/
213+
@Deprecated(since = "6.1.7")
210214
public Executable resolveConstructorOrFactoryMethod() {
211215
return new ConstructorResolver((AbstractAutowireCapableBeanFactory) getBeanFactory())
212216
.resolveConstructorOrFactoryMethod(getBeanName(), getMergedBeanDefinition());
213217
}
214218

219+
/**
220+
* Resolve the {@linkplain InstantiationDescriptor descriptor} to use to
221+
* instantiate this bean. It defines the {@link java.lang.reflect.Constructor}
222+
* or {@link java.lang.reflect.Method} to use as well as additional metadata.
223+
* @since 6.1.7
224+
*/
225+
public InstantiationDescriptor resolveInstantiationDescriptor() {
226+
Executable executable = resolveConstructorOrFactoryMethod();
227+
if (executable instanceof Method method && !Modifier.isStatic(method.getModifiers())) {
228+
String factoryBeanName = getMergedBeanDefinition().getFactoryBeanName();
229+
if (factoryBeanName != null && this.beanFactory.containsBean(factoryBeanName)) {
230+
Class<?> target = this.beanFactory.getMergedBeanDefinition(factoryBeanName)
231+
.getResolvableType().toClass();
232+
return new InstantiationDescriptor(executable, target);
233+
}
234+
}
235+
return new InstantiationDescriptor(executable, executable.getDeclaringClass());
236+
}
237+
215238
/**
216239
* Resolve an autowired argument.
217240
* @param descriptor the descriptor for the dependency (field/method/constructor)
@@ -237,6 +260,20 @@ public String toString() {
237260
.append("mergedBeanDefinition", getMergedBeanDefinition()).toString();
238261
}
239262

263+
/**
264+
* Describe how a bean should be instantiated. While the {@code targetClass}
265+
* is usually the declaring class of the {@code executable}, there are cases
266+
* where retaining the actual concrete type is necessary.
267+
* @param executable the {@link Executable} to invoke
268+
* @param targetClass the target {@link Class} of the executable
269+
* @since 6.1.7
270+
*/
271+
public record InstantiationDescriptor(Executable executable, Class<?> targetClass) {
272+
273+
public InstantiationDescriptor(Executable executable) {
274+
this(executable, executable.getDeclaringClass());
275+
}
276+
}
240277

241278
/**
242279
* Resolver used to obtain inner-bean details.

spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -161,7 +161,8 @@ void generateWithTargetTypeUsingGenericsSetsBothBeanClassAndTargetType() {
161161

162162
@Test
163163
void generateWithBeanClassAndFactoryMethodNameSetsTargetTypeAndBeanClass() {
164-
this.beanFactory.registerSingleton("factory", new SimpleBeanConfiguration());
164+
this.beanFactory.registerBeanDefinition("factory",
165+
new RootBeanDefinition(SimpleBeanConfiguration.class));
165166
RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleBean.class);
166167
beanDefinition.setFactoryBeanName("factory");
167168
beanDefinition.setFactoryMethodName("simpleBean");
@@ -182,7 +183,8 @@ void generateWithBeanClassAndFactoryMethodNameSetsTargetTypeAndBeanClass() {
182183

183184
@Test
184185
void generateWithTargetTypeAndFactoryMethodNameSetsOnlyBeanClass() {
185-
this.beanFactory.registerSingleton("factory", new SimpleBeanConfiguration());
186+
this.beanFactory.registerBeanDefinition("factory",
187+
new RootBeanDefinition(SimpleBeanConfiguration.class));
186188
RootBeanDefinition beanDefinition = new RootBeanDefinition();
187189
beanDefinition.setTargetType(SimpleBean.class);
188190
beanDefinition.setFactoryBeanName("factory");

0 commit comments

Comments
 (0)