17
17
package org .springframework .test .context .bean .override ;
18
18
19
19
import java .lang .annotation .Annotation ;
20
+ import java .lang .reflect .AnnotatedElement ;
20
21
import java .lang .reflect .Field ;
21
22
import java .lang .reflect .Modifier ;
23
+ import java .util .ArrayList ;
22
24
import java .util .Arrays ;
23
25
import java .util .Collections ;
24
26
import java .util .HashSet ;
25
- import java .util .LinkedList ;
26
27
import java .util .List ;
27
28
import java .util .Objects ;
28
29
import java .util .Set ;
29
30
import java .util .concurrent .atomic .AtomicBoolean ;
31
+ import java .util .function .BiConsumer ;
32
+ import java .util .function .Predicate ;
30
33
31
34
import org .springframework .beans .BeanUtils ;
32
35
import org .springframework .beans .factory .config .BeanDefinition ;
57
60
*
58
61
* <p>Concrete implementations of {@code BeanOverrideHandler} can store additional
59
62
* metadata to use during override {@linkplain #createOverrideInstance instance
60
- * creation} — for example, based on further processing of the annotation
61
- * or the annotated field .
63
+ * creation} — for example, based on further processing of the annotation,
64
+ * the annotated field, or the annotated class .
62
65
*
63
66
* <p><strong>NOTE</strong>: Only <em>singleton</em> beans can be overridden.
64
67
* Any attempt to override a non-singleton bean will result in an exception.
70
73
*/
71
74
public abstract class BeanOverrideHandler {
72
75
76
+ @ Nullable
73
77
private final Field field ;
74
78
75
79
private final Set <Annotation > fieldAnnotations ;
@@ -82,7 +86,7 @@ public abstract class BeanOverrideHandler {
82
86
private final BeanOverrideStrategy strategy ;
83
87
84
88
85
- protected BeanOverrideHandler (Field field , ResolvableType beanType , @ Nullable String beanName ,
89
+ protected BeanOverrideHandler (@ Nullable Field field , ResolvableType beanType , @ Nullable String beanName ,
86
90
BeanOverrideStrategy strategy ) {
87
91
88
92
this .field = field ;
@@ -96,57 +100,115 @@ protected BeanOverrideHandler(Field field, ResolvableType beanType, @Nullable St
96
100
* Process the given {@code testClass} and build the corresponding
97
101
* {@code BeanOverrideHandler} list derived from {@link BeanOverride @BeanOverride}
98
102
* fields in the test class and its type hierarchy.
99
- * <p>This method does not search the enclosing class hierarchy.
103
+ * <p>This method does not search the enclosing class hierarchy and does not
104
+ * search for {@code @BeanOverride} declarations on classes or interfaces.
100
105
* @param testClass the test class to process
101
106
* @return a list of bean override handlers
102
- * @see org.springframework.test.context.TestContextAnnotationUtils#searchEnclosingClass (Class)
107
+ * @see #forTestClass (Class, Predicate )
103
108
*/
104
109
public static List <BeanOverrideHandler > forTestClass (Class <?> testClass ) {
105
- List <BeanOverrideHandler > handlers = new LinkedList <>();
106
- findHandlers (testClass , testClass , handlers );
110
+ return forTestClass (testClass , false , clazz -> false );
111
+ }
112
+
113
+ /**
114
+ * Process the given {@code testClass} and build the corresponding
115
+ * {@code BeanOverrideHandler} list derived from {@link BeanOverride @BeanOverride}
116
+ * fields in the test class and in its type hierarchy as well as from
117
+ * {@code @BeanOverride} declarations on classes and interfaces.
118
+ * <p>This method additionally searches for {@code @BeanOverride} declarations
119
+ * in the enclosing class hierarchy if the supplied predicate evaluates to
120
+ * {@code true}.
121
+ * @param testClass the test class to process
122
+ * @param searchEnclosingClass a predicate which evaluates to {@code true}
123
+ * if a search should be performed on the enclosing class — for example,
124
+ * {@code TestContextAnnotationUtils::searchEnclosingClass}
125
+ * @return a list of bean override handlers
126
+ * @since 6.2.2
127
+ * @see org.springframework.test.context.TestContextAnnotationUtils#searchEnclosingClass(Class)
128
+ */
129
+ public static List <BeanOverrideHandler > forTestClass (Class <?> testClass , Predicate <Class <?>> searchEnclosingClass ) {
130
+ return forTestClass (testClass , true , searchEnclosingClass );
131
+ }
132
+
133
+ private static List <BeanOverrideHandler > forTestClass (Class <?> testClass , boolean searchOnTypes ,
134
+ Predicate <Class <?>> searchEnclosingClass ) {
135
+
136
+ List <BeanOverrideHandler > handlers = new ArrayList <>();
137
+ findHandlers (testClass , testClass , handlers , searchOnTypes , searchEnclosingClass );
107
138
return handlers ;
108
139
}
109
140
110
141
/**
111
- * Find handlers using tail recursion to ensure that "locally declared"
112
- * bean overrides take precedence over inherited bean overrides.
142
+ * Find handlers using tail recursion to ensure that "locally declared" bean overrides
143
+ * take precedence over inherited bean overrides.
144
+ * <p>Note: the search algorithm is effectively the inverse of the algorithm used in
145
+ * {@link org.springframework.test.context.TestContextAnnotationUtils#findAnnotationDescriptor(Class, Class)},
146
+ * but with tail recursion the semantics should be the same.
113
147
* @since 6.2.2
114
148
*/
115
- private static void findHandlers (Class <?> clazz , Class <?> testClass , List <BeanOverrideHandler > handlers ) {
116
- if (clazz == null || clazz == Object .class ) {
117
- return ;
149
+ private static void findHandlers (Class <?> clazz , Class <?> testClass , List <BeanOverrideHandler > handlers ,
150
+ boolean searchOnTypes , Predicate <Class <?>> searchEnclosingClass ) {
151
+
152
+ // 1) Search enclosing class hierarchy.
153
+ if (searchEnclosingClass .test (clazz )) {
154
+ findHandlers (clazz .getEnclosingClass (), testClass , handlers , searchOnTypes , searchEnclosingClass );
118
155
}
119
156
120
- // 1) Search type hierarchy.
121
- findHandlers (clazz .getSuperclass (), testClass , handlers );
157
+ // 2) Search class hierarchy.
158
+ Class <?> superclass = clazz .getSuperclass ();
159
+ if (superclass != null && superclass != Object .class ) {
160
+ findHandlers (superclass , testClass , handlers , searchOnTypes , searchEnclosingClass );
161
+ }
122
162
123
- // 2) Process fields in current class.
163
+ // 3) Search interfaces.
164
+ for (Class <?> ifc : clazz .getInterfaces ()) {
165
+ findHandlers (ifc , testClass , handlers , searchOnTypes , searchEnclosingClass );
166
+ }
167
+
168
+ // 4) Process current class.
169
+ if (searchOnTypes ) {
170
+ processClass (clazz , testClass , handlers );
171
+ }
172
+
173
+ // 5) Process fields in current class.
124
174
ReflectionUtils .doWithLocalFields (clazz , field -> processField (field , testClass , handlers ));
125
175
}
126
176
177
+ private static void processClass (Class <?> clazz , Class <?> testClass , List <BeanOverrideHandler > handlers ) {
178
+ processElement (clazz , testClass , (processor , composedAnnotation ) ->
179
+ processor .createHandlers (composedAnnotation , testClass ).forEach (handlers ::add ));
180
+ }
181
+
127
182
private static void processField (Field field , Class <?> testClass , List <BeanOverrideHandler > handlers ) {
128
183
AtomicBoolean overrideAnnotationFound = new AtomicBoolean ();
129
- MergedAnnotations . from (field , DIRECT ). stream ( BeanOverride . class ). forEach ( mergedAnnotation -> {
184
+ processElement (field , testClass , ( processor , composedAnnotation ) -> {
130
185
Assert .state (!Modifier .isStatic (field .getModifiers ()),
131
186
() -> "@BeanOverride field must not be static: " + field );
187
+ Assert .state (overrideAnnotationFound .compareAndSet (false , true ),
188
+ () -> "Multiple @BeanOverride annotations found on field: " + field );
189
+ handlers .add (processor .createHandler (composedAnnotation , testClass , field ));
190
+ });
191
+ }
192
+
193
+ private static void processElement (AnnotatedElement element , Class <?> testClass ,
194
+ BiConsumer <BeanOverrideProcessor , Annotation > consumer ) {
195
+
196
+ MergedAnnotations .from (element , DIRECT ).stream (BeanOverride .class ).forEach (mergedAnnotation -> {
132
197
MergedAnnotation <?> metaSource = mergedAnnotation .getMetaSource ();
133
198
Assert .state (metaSource != null , "@BeanOverride annotation must be meta-present" );
134
199
135
200
BeanOverride beanOverride = mergedAnnotation .synthesize ();
136
201
BeanOverrideProcessor processor = BeanUtils .instantiateClass (beanOverride .value ());
137
202
Annotation composedAnnotation = metaSource .synthesize ();
138
-
139
- Assert .state (overrideAnnotationFound .compareAndSet (false , true ),
140
- () -> "Multiple @BeanOverride annotations found on field: " + field );
141
- BeanOverrideHandler handler = processor .createHandler (composedAnnotation , testClass , field );
142
- handlers .add (handler );
203
+ consumer .accept (processor , composedAnnotation );
143
204
});
144
205
}
145
206
146
207
147
208
/**
148
209
* Get the annotated {@link Field}.
149
210
*/
211
+ @ Nullable
150
212
public final Field getField () {
151
213
return this .field ;
152
214
}
@@ -243,20 +305,23 @@ public boolean equals(Object other) {
243
305
!Objects .equals (this .strategy , that .strategy )) {
244
306
return false ;
245
307
}
308
+
309
+ // by-name lookup
246
310
if (this .beanName != null ) {
247
311
return true ;
248
312
}
249
313
250
314
// by-type lookup
251
- return (Objects .equals (this .field .getName (), that .field .getName ()) &&
315
+ return (this .field != null && that .field != null &&
316
+ Objects .equals (this .field .getName (), that .field .getName ()) &&
252
317
this .fieldAnnotations .equals (that .fieldAnnotations ));
253
318
}
254
319
255
320
@ Override
256
321
public int hashCode () {
257
322
int hash = Objects .hash (getClass (), this .beanType .getType (), this .beanName , this .strategy );
258
- return (this .beanName != null ? hash : hash +
259
- Objects .hash (this .field . getName (), this .fieldAnnotations ));
323
+ return (this .beanName != null ? hash :
324
+ hash + Objects .hash (( this .field != null ? this . field . getName () : null ), this .fieldAnnotations ));
260
325
}
261
326
262
327
@ Override
@@ -269,7 +334,10 @@ public String toString() {
269
334
.toString ();
270
335
}
271
336
272
- private static Set <Annotation > annotationSet (Field field ) {
337
+ private static Set <Annotation > annotationSet (@ Nullable Field field ) {
338
+ if (field == null ) {
339
+ return Collections .emptySet ();
340
+ }
273
341
Annotation [] annotations = field .getAnnotations ();
274
342
return (annotations .length != 0 ? new HashSet <>(Arrays .asList (annotations )) : Collections .emptySet ());
275
343
}
0 commit comments