Skip to content

Commit 0c7bc23

Browse files
committed
Redesign BeanOverrideRegistry internals
1 parent 470bf3b commit 0c7bc23

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistry.java

+14-6
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727

2828
import org.springframework.beans.factory.BeanCreationException;
2929
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
30+
import org.springframework.lang.Nullable;
3031
import org.springframework.util.Assert;
3132
import org.springframework.util.ReflectionUtils;
32-
import org.springframework.util.StringUtils;
3333

3434
/**
3535
* An internal class used to track {@link BeanOverrideHandler}-related state after
@@ -110,14 +110,13 @@ Object wrapBeanIfNecessary(Object bean, String beanName) {
110110
void inject(Object target, BeanOverrideHandler handler) {
111111
Field field = handler.getField();
112112
Assert.notNull(field, () -> "BeanOverrideHandler must have a non-null field: " + handler);
113-
String beanName = this.handlerToBeanNameMap.get(handler);
114-
Assert.state(StringUtils.hasLength(beanName), () -> "No bean found for BeanOverrideHandler: " + handler);
115-
inject(field, target, beanName);
113+
Object bean = getBeanForHandler(handler, field.getType());
114+
Assert.state(bean != null, () -> "No bean found for BeanOverrideHandler: " + handler);
115+
inject(field, target, bean);
116116
}
117117

118-
private void inject(Field field, Object target, String beanName) {
118+
private void inject(Field field, Object target, Object bean) {
119119
try {
120-
Object bean = this.beanFactory.getBean(beanName, field.getType());
121120
ReflectionUtils.makeAccessible(field);
122121
ReflectionUtils.setField(field, target, bean);
123122
}
@@ -126,4 +125,13 @@ private void inject(Field field, Object target, String beanName) {
126125
}
127126
}
128127

128+
@Nullable
129+
private Object getBeanForHandler(BeanOverrideHandler handler, Class<?> requiredType) {
130+
String beanName = this.handlerToBeanNameMap.get(handler);
131+
if (beanName != null) {
132+
return this.beanFactory.getBean(beanName, requiredType);
133+
}
134+
return null;
135+
}
136+
129137
}

0 commit comments

Comments
 (0)