Skip to content

Commit 3e48031

Browse files
LeMikaelFsbrannen
authored andcommitted
Reject null return value from MethodReplacer for primitive return type
This commit throws an exception instead of silently converting a null return value from a MethodReplacer to a primitive 0/false value. See gh-32412
1 parent f285971 commit 3e48031

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

Diff for: spring-beans/src/main/java/org/springframework/beans/factory/support/CglibSubclassingInstantiationStrategy.java

+13-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.lang.reflect.Constructor;
2020
import java.lang.reflect.Method;
21+
import java.util.Objects;
2122

2223
import org.apache.commons.logging.Log;
2324
import org.apache.commons.logging.LogFactory;
@@ -275,13 +276,24 @@ public ReplaceOverrideMethodInterceptor(RootBeanDefinition beanDefinition, BeanF
275276
this.owner = owner;
276277
}
277278

279+
@Nullable
278280
@Override
279281
public Object intercept(Object obj, Method method, Object[] args, MethodProxy mp) throws Throwable {
280282
ReplaceOverride ro = (ReplaceOverride) getBeanDefinition().getMethodOverrides().getOverride(method);
281283
Assert.state(ro != null, "ReplaceOverride not found");
282284
// TODO could cache if a singleton for minor performance optimization
283285
MethodReplacer mr = this.owner.getBean(ro.getMethodReplacerBeanName(), MethodReplacer.class);
284-
return mr.reimplement(obj, method, args);
286+
return processReturnType(method, mr.reimplement(obj, method, args));
287+
}
288+
289+
@Nullable
290+
private <T> T processReturnType(Method method, @Nullable T returnValue) {
291+
Class<?> returnType = method.getReturnType();
292+
if (returnType != void.class && returnType.isPrimitive()) {
293+
return Objects.requireNonNull(returnValue, () -> "Null return value from replacer does not match primitive return type for: " + method);
294+
}
295+
296+
return returnValue;
285297
}
286298
}
287299

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package org.springframework.beans.factory.support;
2+
3+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4+
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
5+
6+
import java.lang.reflect.Method;
7+
import java.util.Map;
8+
import java.util.stream.Stream;
9+
10+
import org.assertj.core.api.ThrowableAssert;
11+
import org.junit.jupiter.api.Test;
12+
import org.springframework.lang.Nullable;
13+
14+
class CglibSubclassingInstantiationStrategyTests {
15+
16+
private final CglibSubclassingInstantiationStrategy strategy = new CglibSubclassingInstantiationStrategy();
17+
18+
@Nullable
19+
public static Object valueToReturnFromReplacer;
20+
21+
@Test
22+
void methodOverride() {
23+
StaticListableBeanFactory beanFactory = new StaticListableBeanFactory(Map.of(
24+
"myBean", new MyBean(),
25+
"replacer", new MyReplacer()
26+
));
27+
28+
RootBeanDefinition bd = new RootBeanDefinition(MyBean.class);
29+
MethodOverrides methodOverrides = new MethodOverrides();
30+
Stream.of("getBoolean", "getShort", "getInt", "getLong", "getFloat", "getDouble", "getByte")
31+
.forEach(methodToOverride -> addOverride(methodOverrides, methodToOverride));
32+
bd.setMethodOverrides(methodOverrides);
33+
34+
MyBean bean = (MyBean) strategy.instantiate(bd, "myBean", beanFactory);
35+
36+
valueToReturnFromReplacer = null;
37+
assertCorrectExceptionThrownBy(bean::getBoolean);
38+
valueToReturnFromReplacer = true;
39+
assertThat(bean.getBoolean()).isTrue();
40+
41+
valueToReturnFromReplacer = null;
42+
assertCorrectExceptionThrownBy(bean::getShort);
43+
valueToReturnFromReplacer = 123;
44+
assertThat(bean.getShort()).isEqualTo((short) 123);
45+
46+
valueToReturnFromReplacer = null;
47+
assertCorrectExceptionThrownBy(bean::getInt);
48+
valueToReturnFromReplacer = 123;
49+
assertThat(bean.getInt()).isEqualTo(123);
50+
51+
valueToReturnFromReplacer = null;
52+
assertCorrectExceptionThrownBy(bean::getLong);
53+
valueToReturnFromReplacer = 123;
54+
assertThat(bean.getLong()).isEqualTo(123L);
55+
56+
valueToReturnFromReplacer = null;
57+
assertCorrectExceptionThrownBy(bean::getFloat);
58+
valueToReturnFromReplacer = 123;
59+
assertThat(bean.getFloat()).isEqualTo(123f);
60+
61+
valueToReturnFromReplacer = null;
62+
assertCorrectExceptionThrownBy(bean::getDouble);
63+
valueToReturnFromReplacer = 123;
64+
assertThat(bean.getDouble()).isEqualTo(123d);
65+
66+
valueToReturnFromReplacer = null;
67+
assertCorrectExceptionThrownBy(bean::getByte);
68+
valueToReturnFromReplacer = 123;
69+
assertThat(bean.getByte()).isEqualTo((byte) 123);
70+
}
71+
72+
private void assertCorrectExceptionThrownBy(ThrowableAssert.ThrowingCallable runnable) {
73+
assertThatThrownBy(runnable)
74+
.isInstanceOf(NullPointerException.class)
75+
.hasMessageMatching("Null return value from replacer does not match primitive return type for: "
76+
+ "\\w+ org\\.springframework\\.beans\\.factory\\.support\\.CglibSubclassingInstantiationStrategyTests\\$MyBean\\.\\w+\\(\\)");
77+
}
78+
79+
private void addOverride(MethodOverrides methodOverrides, String methodToOverride) {
80+
methodOverrides.addOverride(new ReplaceOverride(methodToOverride, "replacer"));
81+
}
82+
83+
static class MyBean {
84+
boolean getBoolean() {
85+
return true;
86+
}
87+
88+
short getShort() {
89+
return 123;
90+
}
91+
92+
int getInt() {
93+
return 123;
94+
}
95+
96+
long getLong() {
97+
return 123;
98+
}
99+
100+
float getFloat() {
101+
return 123;
102+
}
103+
104+
double getDouble() {
105+
return 123;
106+
}
107+
108+
byte getByte() {
109+
return 123;
110+
}
111+
}
112+
113+
static class MyReplacer implements MethodReplacer {
114+
115+
@Override
116+
public Object reimplement(Object obj, Method method, Object[] args) {
117+
return CglibSubclassingInstantiationStrategyTests.valueToReturnFromReplacer;
118+
}
119+
}
120+
}

0 commit comments

Comments
 (0)