@@ -3,18 +3,21 @@ package org.mockito
3
3
import java .lang .reflect .Method
4
4
import java .util .function
5
5
6
- import org .mockito .internal .handler .ScalaMockHandler .{ArgumentExtractor , Extractors }
6
+ import org .mockito .internal .handler .ScalaMockHandler .{ ArgumentExtractor , Extractors }
7
7
import org .mockito .invocation .InvocationOnMock
8
- import ru .vyarus .java .generics .resolver .GenericsResolver
9
8
import org .scalactic .TripleEquals ._
9
+ import ru .vyarus .java .generics .resolver .GenericsResolver
10
10
11
+ import scala .language .implicitConversions
11
12
import scala .reflect .internal .Symbols
12
13
13
14
private [mockito] object ReflectionUtils {
14
15
15
16
import scala .reflect .runtime .{ universe => ru }
16
17
import ru ._
17
18
19
+ implicit def symbolToMethodSymbol (sym : Symbol ): Symbols # MethodSymbol = sym.asInstanceOf [Symbols # MethodSymbol ]
20
+
18
21
private val mirror = runtimeMirror(getClass.getClassLoader)
19
22
private val customMirror = mirror.asInstanceOf [{
20
23
def methodToJava (sym : Symbols # MethodSymbol ): Method
@@ -23,60 +26,85 @@ private[mockito] object ReflectionUtils {
23
26
implicit class InvocationOnMockOps (invocation : InvocationOnMock ) {
24
27
def returnType : Class [_] = {
25
28
val method = invocation.getMethod
26
- val clazz = method.getDeclaringClass
27
- val javaReturnType = invocation.getMethod.getReturnType
29
+ val javaReturnType = method.getReturnType
28
30
29
31
if (javaReturnType == classOf [Object ])
30
- mirror
31
- .classSymbol(clazz)
32
- .info
33
- .decls
34
- .filter(d => d.isMethod && ! d.isConstructor)
35
- .find(d => customMirror.methodToJava(d.asInstanceOf [Symbols # MethodSymbol ]) === method)
36
- .map(_.asMethod)
37
- .filter(_.returnType.typeSymbol.isClass)
38
- .map(methodSymbol => mirror.runtimeClass(methodSymbol.returnType.typeSymbol.asClass))
39
- .orElse(resolveGenerics)
32
+ resolveWithScalaGenerics(method)
33
+ .orElse(resolveWithJavaGenerics(method))
40
34
.getOrElse(javaReturnType)
41
35
else javaReturnType
42
36
}
43
37
44
- private def resolveGenerics : Option [Class [_]] =
38
+ private def resolveWithScalaGenerics (method : Method ): Option [Class [_]] =
39
+ scala.util
40
+ .Try {
41
+ mirror
42
+ .classSymbol(method.getDeclaringClass)
43
+ .info
44
+ .decls
45
+ .filter(isNonConstructorMethod)
46
+ .find(d => customMirror.methodToJava(d) === method)
47
+ .map(_.asMethod)
48
+ .filter(_.returnType.typeSymbol.isClass)
49
+ .map(methodSymbol => mirror.runtimeClass(methodSymbol.returnType.typeSymbol.asClass))
50
+ }
51
+ .toOption
52
+ .flatten
53
+
54
+ private def resolveWithJavaGenerics (method : Method ): Option [Class [_]] =
45
55
scala.util.Try {
46
- GenericsResolver .resolve(invocation.getMock.getClass).`type`(clazz).method(invocation.getMethod ).resolveReturnClass()
56
+ GenericsResolver .resolve(invocation.getMock.getClass).`type`(clazz).method(method ).resolveReturnClass()
47
57
}.toOption
48
58
}
49
59
60
+ private def isNonConstructorMethod (d : ru.Symbol ): Boolean = d.isMethod && ! d.isConstructor
61
+
50
62
def interfaces [T ](implicit tag : WeakTypeTag [T ]): List [Class [_]] =
51
- tag.tpe match {
52
- case RefinedType (types, _) =>
53
- types.map(tag.mirror.runtimeClass).collect {
54
- case c : Class [_] if c.isInterface => c
63
+ scala.util
64
+ .Try {
65
+ tag.tpe match {
66
+ case RefinedType (types, _) =>
67
+ types.map(tag.mirror.runtimeClass).collect {
68
+ case c : Class [_] if c.isInterface => c
69
+ }
70
+ case _ => List .empty
55
71
}
56
- case _ => List .empty
57
- }
72
+ }
73
+ .toOption
74
+ .getOrElse(List .empty)
58
75
59
76
def markMethodsWithLazyArgs (clazz : Class [_]): Unit =
60
77
Extractors .computeIfAbsent(
61
78
clazz,
62
79
new function.Function [Class [_], ArgumentExtractor ] {
63
- override def apply (t : Class [_]): ArgumentExtractor = {
64
- val mirror = runtimeMirror(clazz.getClassLoader)
65
-
66
- val symbol = mirror.classSymbol(clazz)
67
-
68
- val methodsWithLazyArgs = symbol.info.decls
69
- .collect {
70
- case s if s.isMethod =>
71
- (s.name.toString, s.typeSignature.paramLists.flatten.zipWithIndex.collect {
72
- case (p, idx) if p.typeSignature.toString.startsWith(" =>" ) => idx
73
- }.toSet)
80
+ override def apply (t : Class [_]): ArgumentExtractor =
81
+ scala.util
82
+ .Try {
83
+ ArgumentExtractor {
84
+ mirror
85
+ .classSymbol(clazz)
86
+ .info
87
+ .decls
88
+ .collect {
89
+ case s if isNonConstructorMethod(s) =>
90
+ (customMirror.methodToJava(s), s.typeSignature.paramLists.flatten.zipWithIndex.collect {
91
+ case (p, idx) if p.typeSignature.toString.startsWith(" =>" ) => idx
92
+ }.toSet)
93
+ }
94
+ .toSeq
95
+ .filter(_._2.nonEmpty)
96
+ }
74
97
}
75
- .toMap
76
- .filter(_._2.nonEmpty)
77
-
78
- ArgumentExtractor (methodsWithLazyArgs)
79
- }
98
+ .toOption
99
+ .getOrElse(ArgumentExtractor .Empty )
80
100
}
81
101
)
102
+
103
+ def readDeclaredField [T ](o : AnyRef , field : String ): Option [T ] =
104
+ scala.util.Try {
105
+ val f = o.getClass.getDeclaredField(field)
106
+ f.setAccessible(true )
107
+ f.get(o).asInstanceOf [T ]
108
+ }.toOption
109
+
82
110
}
0 commit comments