Skip to content

Commit 1cfb2ac

Browse files
committed
Detect target of factory method with AOT
Previously, if a factory method is defined on a parent, the generated code would blindly use the method's declaring class for both the target of the generated code, and the signature of the method. This commit improves the resolution by considering the factory metadata in the BeanDefinition. Closes spring-projectsgh-32609
1 parent 0268180 commit 1cfb2ac

File tree

12 files changed

+211
-64
lines changed

12 files changed

+211
-64
lines changed

spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java

+14-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
1717
package org.springframework.beans.factory.aot;
1818

1919
import java.lang.reflect.Constructor;
20-
import java.lang.reflect.Executable;
2120
import java.lang.reflect.Modifier;
2221
import java.util.List;
2322
import java.util.function.Predicate;
@@ -35,6 +34,7 @@
3534
import org.springframework.beans.factory.config.BeanDefinitionHolder;
3635
import org.springframework.beans.factory.support.InstanceSupplier;
3736
import org.springframework.beans.factory.support.RegisteredBean;
37+
import org.springframework.beans.factory.support.RegisteredBean.InstantiationDescriptor;
3838
import org.springframework.beans.factory.support.RootBeanDefinition;
3939
import org.springframework.core.ResolvableType;
4040
import org.springframework.javapoet.ClassName;
@@ -62,7 +62,7 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme
6262

6363
private final BeanDefinitionMethodGeneratorFactory beanDefinitionMethodGeneratorFactory;
6464

65-
private final Supplier<Executable> constructorOrFactoryMethod;
65+
private final Supplier<InstantiationDescriptor> instantiationDescriptor;
6666

6767

6868
DefaultBeanRegistrationCodeFragments(BeanRegistrationsCode beanRegistrationsCode,
@@ -72,7 +72,7 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme
7272
this.beanRegistrationsCode = beanRegistrationsCode;
7373
this.registeredBean = registeredBean;
7474
this.beanDefinitionMethodGeneratorFactory = beanDefinitionMethodGeneratorFactory;
75-
this.constructorOrFactoryMethod = SingletonSupplier.of(registeredBean::resolveConstructorOrFactoryMethod);
75+
this.instantiationDescriptor = SingletonSupplier.of(registeredBean::resolveInstantiationDescriptor);
7676
}
7777

7878

@@ -82,7 +82,7 @@ public ClassName getTarget(RegisteredBean registeredBean) {
8282
throw new IllegalStateException("Default code generation is not supported for bean definitions "
8383
+ "declaring an instance supplier callback: " + registeredBean.getMergedBeanDefinition());
8484
}
85-
Class<?> target = extractDeclaringClass(registeredBean.getBeanType(), this.constructorOrFactoryMethod.get());
85+
Class<?> target = extractDeclaringClass(registeredBean, this.instantiationDescriptor.get());
8686
while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) {
8787
RegisteredBean parent = registeredBean.getParent();
8888
Assert.state(parent != null, "No parent available for inner bean");
@@ -91,14 +91,14 @@ public ClassName getTarget(RegisteredBean registeredBean) {
9191
return (target.isArray() ? ClassName.get(target.getComponentType()) : ClassName.get(target));
9292
}
9393

94-
private Class<?> extractDeclaringClass(ResolvableType beanType, Executable executable) {
95-
Class<?> declaringClass = ClassUtils.getUserClass(executable.getDeclaringClass());
96-
if (executable instanceof Constructor<?>
97-
&& AccessControl.forMember(executable).isPublic()
94+
private Class<?> extractDeclaringClass(RegisteredBean registeredBean, InstantiationDescriptor instantiationDescriptor) {
95+
Class<?> declaringClass = ClassUtils.getUserClass(instantiationDescriptor.target());
96+
if (instantiationDescriptor.executable() instanceof Constructor<?>
97+
&& AccessControl.forMember(instantiationDescriptor.executable()).isPublic()
9898
&& FactoryBean.class.isAssignableFrom(declaringClass)) {
99-
return extractTargetClassFromFactoryBean(declaringClass, beanType);
99+
return extractTargetClassFromFactoryBean(declaringClass, registeredBean.getBeanType());
100100
}
101-
return executable.getDeclaringClass();
101+
return declaringClass;
102102
}
103103

104104
/**
@@ -238,9 +238,9 @@ public CodeBlock generateInstanceSupplierCode(GenerationContext generationContex
238238
throw new IllegalStateException("Default code generation is not supported for bean definitions declaring "
239239
+ "an instance supplier callback: " + this.registeredBean.getMergedBeanDefinition());
240240
}
241-
return new InstanceSupplierCodeGenerator(generationContext,
242-
beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut)
243-
.generateCode(this.registeredBean, this.constructorOrFactoryMethod.get());
241+
return new InstanceSupplierCodeGenerator(generationContext, beanRegistrationCode.getClassName(),
242+
beanRegistrationCode.getMethods(), allowDirectSupplierShortcut).generateCode(
243+
this.registeredBean, this.instantiationDescriptor.get());
244244
}
245245

246246
@Override

spring-beans/src/main/java/org/springframework/beans/factory/aot/InstanceSupplierCodeGenerator.java

+37-23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -46,6 +46,7 @@
4646
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
4747
import org.springframework.beans.factory.support.InstanceSupplier;
4848
import org.springframework.beans.factory.support.RegisteredBean;
49+
import org.springframework.beans.factory.support.RegisteredBean.InstantiationDescriptor;
4950
import org.springframework.core.KotlinDetector;
5051
import org.springframework.core.MethodParameter;
5152
import org.springframework.core.ResolvableType;
@@ -121,13 +122,26 @@ public InstanceSupplierCodeGenerator(GenerationContext generationContext,
121122
* @param constructorOrFactoryMethod the executable to use to create the bean
122123
* @return the generated code
123124
*/
125+
@Deprecated(since = "6.1.7")
124126
public CodeBlock generateCode(RegisteredBean registeredBean, Executable constructorOrFactoryMethod) {
127+
return generateCode(registeredBean, new InstantiationDescriptor(
128+
constructorOrFactoryMethod, constructorOrFactoryMethod.getDeclaringClass()));
129+
}
130+
131+
/**
132+
* Generate the instance supplier code.
133+
* @param registeredBean the bean to handle
134+
* @param instantiationDescriptor the executable to use to create the bean
135+
* @return the generated code
136+
*/
137+
public CodeBlock generateCode(RegisteredBean registeredBean, InstantiationDescriptor instantiationDescriptor) {
138+
Executable constructorOrFactoryMethod = instantiationDescriptor.executable();
125139
registerRuntimeHintsIfNecessary(registeredBean, constructorOrFactoryMethod);
126140
if (constructorOrFactoryMethod instanceof Constructor<?> constructor) {
127141
return generateCodeForConstructor(registeredBean, constructor);
128142
}
129143
if (constructorOrFactoryMethod instanceof Method method) {
130-
return generateCodeForFactoryMethod(registeredBean, method);
144+
return generateCodeForFactoryMethod(registeredBean, method, instantiationDescriptor.target());
131145
}
132146
throw new IllegalStateException(
133147
"No suitable executor found for " + registeredBean.getBeanName());
@@ -253,21 +267,21 @@ private CodeBlock generateNewInstanceCodeForConstructor(boolean dependsOnBean,
253267
declaringClass.getSimpleName(), args);
254268
}
255269

256-
private CodeBlock generateCodeForFactoryMethod(RegisteredBean registeredBean, Method factoryMethod) {
270+
private CodeBlock generateCodeForFactoryMethod(RegisteredBean registeredBean, Method factoryMethod, Class<?> target) {
257271
String beanName = registeredBean.getBeanName();
258-
Class<?> declaringClass = ClassUtils.getUserClass(factoryMethod.getDeclaringClass());
272+
Class<?> targetToUse = ClassUtils.getUserClass(target);
259273
boolean dependsOnBean = !Modifier.isStatic(factoryMethod.getModifiers());
260274

261275
Visibility accessVisibility = getAccessVisibility(registeredBean, factoryMethod);
262276
if (accessVisibility != Visibility.PRIVATE) {
263277
return generateCodeForAccessibleFactoryMethod(
264-
beanName, factoryMethod, declaringClass, dependsOnBean);
278+
beanName, factoryMethod, targetToUse, dependsOnBean);
265279
}
266-
return generateCodeForInaccessibleFactoryMethod(beanName, factoryMethod, declaringClass);
280+
return generateCodeForInaccessibleFactoryMethod(beanName, factoryMethod, targetToUse);
267281
}
268282

269283
private CodeBlock generateCodeForAccessibleFactoryMethod(String beanName,
270-
Method factoryMethod, Class<?> declaringClass, boolean dependsOnBean) {
284+
Method factoryMethod, Class<?> target, boolean dependsOnBean) {
271285

272286
this.generationContext.getRuntimeHints().reflection().registerMethod(
273287
factoryMethod, ExecutableMode.INTROSPECT);
@@ -276,20 +290,20 @@ private CodeBlock generateCodeForAccessibleFactoryMethod(String beanName,
276290
Class<?> suppliedType = ClassUtils.resolvePrimitiveIfNecessary(factoryMethod.getReturnType());
277291
CodeBlock.Builder code = CodeBlock.builder();
278292
code.add("$T.<$T>forFactoryMethod($T.class, $S)", BeanInstanceSupplier.class,
279-
suppliedType, declaringClass, factoryMethod.getName());
293+
suppliedType, target, factoryMethod.getName());
280294
code.add(".withGenerator(($L) -> $T.$L())", REGISTERED_BEAN_PARAMETER_NAME,
281-
declaringClass, factoryMethod.getName());
295+
target, factoryMethod.getName());
282296
return code.build();
283297
}
284298

285299
GeneratedMethod getInstanceMethod = generateGetInstanceSupplierMethod(method ->
286300
buildGetInstanceMethodForFactoryMethod(method, beanName, factoryMethod,
287-
declaringClass, dependsOnBean, PRIVATE_STATIC));
301+
target, dependsOnBean, PRIVATE_STATIC));
288302
return generateReturnStatement(getInstanceMethod);
289303
}
290304

291305
private CodeBlock generateCodeForInaccessibleFactoryMethod(
292-
String beanName, Method factoryMethod, Class<?> declaringClass) {
306+
String beanName, Method factoryMethod, Class<?> target) {
293307

294308
this.generationContext.getRuntimeHints().reflection().registerMethod(factoryMethod, ExecutableMode.INVOKE);
295309
GeneratedMethod getInstanceMethod = generateGetInstanceSupplierMethod(method -> {
@@ -298,19 +312,19 @@ private CodeBlock generateCodeForInaccessibleFactoryMethod(
298312
method.addModifiers(PRIVATE_STATIC);
299313
method.returns(ParameterizedTypeName.get(BeanInstanceSupplier.class, suppliedType));
300314
method.addStatement(generateInstanceSupplierForFactoryMethod(
301-
factoryMethod, suppliedType, declaringClass, factoryMethod.getName()));
315+
factoryMethod, suppliedType, target, factoryMethod.getName()));
302316
});
303317
return generateReturnStatement(getInstanceMethod);
304318
}
305319

306320
private void buildGetInstanceMethodForFactoryMethod(MethodSpec.Builder method,
307-
String beanName, Method factoryMethod, Class<?> declaringClass,
321+
String beanName, Method factoryMethod, Class<?> target,
308322
boolean dependsOnBean, javax.lang.model.element.Modifier... modifiers) {
309323

310324
String factoryMethodName = factoryMethod.getName();
311325
Class<?> suppliedType = ClassUtils.resolvePrimitiveIfNecessary(factoryMethod.getReturnType());
312326
CodeWarnings codeWarnings = new CodeWarnings();
313-
codeWarnings.detectDeprecation(declaringClass, factoryMethod, suppliedType)
327+
codeWarnings.detectDeprecation(target, factoryMethod, suppliedType)
314328
.detectDeprecation(Arrays.stream(factoryMethod.getParameters()).map(Parameter::getType));
315329

316330
method.addJavadoc("Get the bean instance supplier for '$L'.", beanName);
@@ -320,41 +334,41 @@ private void buildGetInstanceMethodForFactoryMethod(MethodSpec.Builder method,
320334

321335
CodeBlock.Builder code = CodeBlock.builder();
322336
code.add(generateInstanceSupplierForFactoryMethod(
323-
factoryMethod, suppliedType, declaringClass, factoryMethodName));
337+
factoryMethod, suppliedType, target, factoryMethodName));
324338

325339
boolean hasArguments = factoryMethod.getParameterCount() > 0;
326340
CodeBlock arguments = hasArguments ?
327-
new AutowiredArgumentsCodeGenerator(declaringClass, factoryMethod)
341+
new AutowiredArgumentsCodeGenerator(target, factoryMethod)
328342
.generateCode(factoryMethod.getParameterTypes())
329343
: NO_ARGS;
330344

331345
CodeBlock newInstance = generateNewInstanceCodeForMethod(
332-
dependsOnBean, declaringClass, factoryMethodName, arguments);
346+
dependsOnBean, target, factoryMethodName, arguments);
333347
code.add(generateWithGeneratorCode(hasArguments, newInstance));
334348
method.addStatement(code.build());
335349
}
336350

337351
private CodeBlock generateInstanceSupplierForFactoryMethod(Method factoryMethod,
338-
Class<?> suppliedType, Class<?> declaringClass, String factoryMethodName) {
352+
Class<?> suppliedType, Class<?> target, String factoryMethodName) {
339353

340354
if (factoryMethod.getParameterCount() == 0) {
341355
return CodeBlock.of("return $T.<$T>forFactoryMethod($T.class, $S)",
342-
BeanInstanceSupplier.class, suppliedType, declaringClass, factoryMethodName);
356+
BeanInstanceSupplier.class, suppliedType, target, factoryMethodName);
343357
}
344358

345359
CodeBlock parameterTypes = generateParameterTypesCode(factoryMethod.getParameterTypes(), 0);
346360
return CodeBlock.of("return $T.<$T>forFactoryMethod($T.class, $S, $L)",
347-
BeanInstanceSupplier.class, suppliedType, declaringClass, factoryMethodName, parameterTypes);
361+
BeanInstanceSupplier.class, suppliedType, target, factoryMethodName, parameterTypes);
348362
}
349363

350364
private CodeBlock generateNewInstanceCodeForMethod(boolean dependsOnBean,
351-
Class<?> declaringClass, String factoryMethodName, CodeBlock args) {
365+
Class<?> target, String factoryMethodName, CodeBlock args) {
352366

353367
if (!dependsOnBean) {
354-
return CodeBlock.of("$T.$L($L)", declaringClass, factoryMethodName, args);
368+
return CodeBlock.of("$T.$L($L)", target, factoryMethodName, args);
355369
}
356370
return CodeBlock.of("$L.getBeanFactory().getBean($T.class).$L($L)",
357-
REGISTERED_BEAN_PARAMETER_NAME, declaringClass, factoryMethodName, args);
371+
REGISTERED_BEAN_PARAMETER_NAME, target, factoryMethodName, args);
358372
}
359373

360374
private CodeBlock generateReturnStatement(GeneratedMethod generatedMethod) {

spring-beans/src/main/java/org/springframework/beans/factory/support/RegisteredBean.java

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,8 @@
1717
package org.springframework.beans.factory.support;
1818

1919
import java.lang.reflect.Executable;
20+
import java.lang.reflect.Method;
21+
import java.lang.reflect.Modifier;
2022
import java.util.Set;
2123
import java.util.function.BiFunction;
2224
import java.util.function.Supplier;
@@ -206,12 +208,33 @@ public RegisteredBean getParent() {
206208
/**
207209
* Resolve the constructor or factory method to use for this bean.
208210
* @return the {@link java.lang.reflect.Constructor} or {@link java.lang.reflect.Method}
211+
* @deprecated in favor of {@link #resolveInstantiationDescriptor()}
209212
*/
213+
@Deprecated(since = "6.1.7")
210214
public Executable resolveConstructorOrFactoryMethod() {
211215
return new ConstructorResolver((AbstractAutowireCapableBeanFactory) getBeanFactory())
212216
.resolveConstructorOrFactoryMethod(getBeanName(), getMergedBeanDefinition());
213217
}
214218

219+
/**
220+
* Resolve the {@linkplain InstantiationDescriptor descriptor} to use to
221+
* instantiate this bean. It defines the {@link java.lang.reflect.Constructor}
222+
* or {@link java.lang.reflect.Method} to use as well as additional metadata.
223+
* @since 6.1.7
224+
*/
225+
public InstantiationDescriptor resolveInstantiationDescriptor() {
226+
Executable executable = resolveConstructorOrFactoryMethod();
227+
if (executable instanceof Method method && !Modifier.isStatic(method.getModifiers())) {
228+
String factoryBeanName = getMergedBeanDefinition().getFactoryBeanName();
229+
if (factoryBeanName != null && this.beanFactory.containsBean(factoryBeanName)) {
230+
Class<?> target = this.beanFactory.getMergedBeanDefinition(factoryBeanName)
231+
.getResolvableType().toClass();
232+
return new InstantiationDescriptor(executable, target);
233+
}
234+
}
235+
return new InstantiationDescriptor(executable, executable.getDeclaringClass());
236+
}
237+
215238
/**
216239
* Resolve an autowired argument.
217240
* @param descriptor the descriptor for the dependency (field/method/constructor)
@@ -237,6 +260,18 @@ public String toString() {
237260
.append("mergedBeanDefinition", getMergedBeanDefinition()).toString();
238261
}
239262

263+
/**
264+
* Describe how a bean should be instantiated.
265+
* @param executable the {@link Executable} to invoke
266+
* @param target the target of the executable, usually {@link Executable#getDeclaringClass()}
267+
* @since 6.1.7
268+
*/
269+
public record InstantiationDescriptor(Executable executable, Class<?> target) {
270+
271+
public InstantiationDescriptor(Executable executable) {
272+
this(executable, executable.getDeclaringClass());
273+
}
274+
}
240275

241276
/**
242277
* Resolver used to obtain inner-bean details.

spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -161,7 +161,8 @@ void generateWithTargetTypeUsingGenericsSetsBothBeanClassAndTargetType() {
161161

162162
@Test
163163
void generateWithBeanClassAndFactoryMethodNameSetsTargetTypeAndBeanClass() {
164-
this.beanFactory.registerSingleton("factory", new SimpleBeanConfiguration());
164+
this.beanFactory.registerBeanDefinition("factory",
165+
new RootBeanDefinition(SimpleBeanConfiguration.class));
165166
RootBeanDefinition beanDefinition = new RootBeanDefinition(SimpleBean.class);
166167
beanDefinition.setFactoryBeanName("factory");
167168
beanDefinition.setFactoryMethodName("simpleBean");
@@ -182,7 +183,8 @@ void generateWithBeanClassAndFactoryMethodNameSetsTargetTypeAndBeanClass() {
182183

183184
@Test
184185
void generateWithTargetTypeAndFactoryMethodNameSetsOnlyBeanClass() {
185-
this.beanFactory.registerSingleton("factory", new SimpleBeanConfiguration());
186+
this.beanFactory.registerBeanDefinition("factory",
187+
new RootBeanDefinition(SimpleBeanConfiguration.class));
186188
RootBeanDefinition beanDefinition = new RootBeanDefinition();
187189
beanDefinition.setTargetType(SimpleBean.class);
188190
beanDefinition.setFactoryBeanName("factory");

0 commit comments

Comments
 (0)