1
+ package org .springframework .boot .test .mock .mockito ;
2
+
3
+ import org .springframework .beans .BeansException ;
4
+ import org .springframework .beans .factory .config .BeanDefinition ;
5
+ import org .springframework .beans .factory .config .BeanFactoryPostProcessor ;
6
+ import org .springframework .beans .factory .config .ConfigurableListableBeanFactory ;
7
+ import org .springframework .beans .factory .config .ConstructorArgumentValues ;
8
+ import org .springframework .beans .factory .support .BeanDefinitionRegistry ;
9
+ import org .springframework .beans .factory .support .RootBeanDefinition ;
10
+ import org .springframework .core .Ordered ;
11
+
12
+ import java .util .LinkedHashSet ;
13
+ import java .util .Set ;
14
+
15
+ public class MockitoScopedProxyPostProcessor implements BeanFactoryPostProcessor , Ordered {
16
+
17
+ private static final String BEAN_NAME = MockitoScopedProxyPostProcessor .class .getName ();
18
+ public static final String SCOPED_TARGET_PREFIX = "scopedTarget." ;
19
+
20
+ private final Set <Class > mockedTypes ;
21
+
22
+ public MockitoScopedProxyPostProcessor (Set <Class > mockedTypes ) {
23
+ this .mockedTypes = mockedTypes ;
24
+ }
25
+
26
+ public void postProcessBeanFactory (ConfigurableListableBeanFactory beanFactory ) throws BeansException {
27
+ BeanDefinitionRegistry bdr = (BeanDefinitionRegistry ) beanFactory ;
28
+
29
+ for (Class mockedType : mockedTypes ) {
30
+ String [] mockedBeans = beanFactory .getBeanNamesForType (mockedType );
31
+
32
+ for (String mockedBean : mockedBeans ) {
33
+ if (isScopedProxy (mockedBean )) {
34
+ bdr .removeBeanDefinition (mockedBean );
35
+ }
36
+ }
37
+ }
38
+ }
39
+
40
+ private static boolean isScopedProxy (String mockedBean ) {
41
+ return mockedBean .startsWith (SCOPED_TARGET_PREFIX );
42
+ }
43
+
44
+ public static void register (BeanDefinitionRegistry registry , Set <Class > mockedTypes ) {
45
+ BeanDefinition definition = getOrAddBeanDefinition (registry );
46
+
47
+ if (mockedTypes != null ) {
48
+ getConstructorArgs (definition ).addAll (mockedTypes );
49
+ }
50
+ }
51
+
52
+ @ SuppressWarnings ("unchecked" )
53
+ private static Set <Class > getConstructorArgs (BeanDefinition definition ) {
54
+ ConstructorArgumentValues .ValueHolder constructorArg = definition .getConstructorArgumentValues ()
55
+ .getIndexedArgumentValue (0 , Set .class );
56
+ return (Set <Class >) constructorArg .getValue ();
57
+ }
58
+
59
+ private static BeanDefinition getOrAddBeanDefinition (BeanDefinitionRegistry registry ) {
60
+ if (!registry .containsBeanDefinition (BEAN_NAME )) {
61
+ addBeanDefinition (registry );
62
+ }
63
+
64
+ return registry .getBeanDefinition (BEAN_NAME );
65
+ }
66
+
67
+ private static void addBeanDefinition (BeanDefinitionRegistry registry ) {
68
+ RootBeanDefinition def = new RootBeanDefinition (MockitoScopedProxyPostProcessor .class );
69
+ def .setRole (BeanDefinition .ROLE_INFRASTRUCTURE );
70
+ ConstructorArgumentValues constructorArguments = def .getConstructorArgumentValues ();
71
+ constructorArguments .addIndexedArgumentValue (0 , new LinkedHashSet <Class >());
72
+ registry .registerBeanDefinition (BEAN_NAME , def );
73
+ }
74
+
75
+ public int getOrder () {
76
+ return Ordered .HIGHEST_PRECEDENCE ;
77
+ }
78
+ }
0 commit comments