Skip to content

Commit 9d20bf1

Browse files
authored
Merge pull request #82 from mockito/reflection-robustness
Reflection robustness
2 parents 99fa938 + 0d926fd commit 9d20bf1

File tree

2 files changed

+98
-63
lines changed

2 files changed

+98
-63
lines changed

common/src/main/scala/org/mockito/ReflectionUtils.scala

+66-38
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,21 @@ package org.mockito
33
import java.lang.reflect.Method
44
import java.util.function
55

6-
import org.mockito.internal.handler.ScalaMockHandler.{ArgumentExtractor, Extractors}
6+
import org.mockito.internal.handler.ScalaMockHandler.{ ArgumentExtractor, Extractors }
77
import org.mockito.invocation.InvocationOnMock
8-
import ru.vyarus.java.generics.resolver.GenericsResolver
98
import org.scalactic.TripleEquals._
9+
import ru.vyarus.java.generics.resolver.GenericsResolver
1010

11+
import scala.language.implicitConversions
1112
import scala.reflect.internal.Symbols
1213

1314
private[mockito] object ReflectionUtils {
1415

1516
import scala.reflect.runtime.{ universe => ru }
1617
import ru._
1718

19+
implicit def symbolToMethodSymbol(sym: Symbol): Symbols#MethodSymbol = sym.asInstanceOf[Symbols#MethodSymbol]
20+
1821
private val mirror = runtimeMirror(getClass.getClassLoader)
1922
private val customMirror = mirror.asInstanceOf[{
2023
def methodToJava(sym: Symbols#MethodSymbol): Method
@@ -23,60 +26,85 @@ private[mockito] object ReflectionUtils {
2326
implicit class InvocationOnMockOps(invocation: InvocationOnMock) {
2427
def returnType: Class[_] = {
2528
val method = invocation.getMethod
26-
val clazz = method.getDeclaringClass
27-
val javaReturnType = invocation.getMethod.getReturnType
29+
val javaReturnType = method.getReturnType
2830

2931
if (javaReturnType == classOf[Object])
30-
mirror
31-
.classSymbol(clazz)
32-
.info
33-
.decls
34-
.filter(d => d.isMethod && !d.isConstructor)
35-
.find(d => customMirror.methodToJava(d.asInstanceOf[Symbols#MethodSymbol]) === method)
36-
.map(_.asMethod)
37-
.filter(_.returnType.typeSymbol.isClass)
38-
.map(methodSymbol => mirror.runtimeClass(methodSymbol.returnType.typeSymbol.asClass))
39-
.orElse(resolveGenerics)
32+
resolveWithScalaGenerics(method)
33+
.orElse(resolveWithJavaGenerics(method))
4034
.getOrElse(javaReturnType)
4135
else javaReturnType
4236
}
4337

44-
private def resolveGenerics: Option[Class[_]] =
38+
private def resolveWithScalaGenerics(method: Method): Option[Class[_]] =
39+
scala.util
40+
.Try {
41+
mirror
42+
.classSymbol(method.getDeclaringClass)
43+
.info
44+
.decls
45+
.filter(isNonConstructorMethod)
46+
.find(d => customMirror.methodToJava(d) === method)
47+
.map(_.asMethod)
48+
.filter(_.returnType.typeSymbol.isClass)
49+
.map(methodSymbol => mirror.runtimeClass(methodSymbol.returnType.typeSymbol.asClass))
50+
}
51+
.toOption
52+
.flatten
53+
54+
private def resolveWithJavaGenerics(method: Method): Option[Class[_]] =
4555
scala.util.Try {
46-
GenericsResolver.resolve(invocation.getMock.getClass).`type`(clazz).method(invocation.getMethod).resolveReturnClass()
56+
GenericsResolver.resolve(invocation.getMock.getClass).`type`(clazz).method(method).resolveReturnClass()
4757
}.toOption
4858
}
4959

60+
private def isNonConstructorMethod(d: ru.Symbol): Boolean = d.isMethod && !d.isConstructor
61+
5062
def interfaces[T](implicit tag: WeakTypeTag[T]): List[Class[_]] =
51-
tag.tpe match {
52-
case RefinedType(types, _) =>
53-
types.map(tag.mirror.runtimeClass).collect {
54-
case c: Class[_] if c.isInterface => c
63+
scala.util
64+
.Try {
65+
tag.tpe match {
66+
case RefinedType(types, _) =>
67+
types.map(tag.mirror.runtimeClass).collect {
68+
case c: Class[_] if c.isInterface => c
69+
}
70+
case _ => List.empty
5571
}
56-
case _ => List.empty
57-
}
72+
}
73+
.toOption
74+
.getOrElse(List.empty)
5875

5976
def markMethodsWithLazyArgs(clazz: Class[_]): Unit =
6077
Extractors.computeIfAbsent(
6178
clazz,
6279
new function.Function[Class[_], ArgumentExtractor] {
63-
override def apply(t: Class[_]): ArgumentExtractor = {
64-
val mirror = runtimeMirror(clazz.getClassLoader)
65-
66-
val symbol = mirror.classSymbol(clazz)
67-
68-
val methodsWithLazyArgs = symbol.info.decls
69-
.collect {
70-
case s if s.isMethod =>
71-
(s.name.toString, s.typeSignature.paramLists.flatten.zipWithIndex.collect {
72-
case (p, idx) if p.typeSignature.toString.startsWith("=>") => idx
73-
}.toSet)
80+
override def apply(t: Class[_]): ArgumentExtractor =
81+
scala.util
82+
.Try {
83+
ArgumentExtractor {
84+
mirror
85+
.classSymbol(clazz)
86+
.info
87+
.decls
88+
.collect {
89+
case s if isNonConstructorMethod(s) =>
90+
(customMirror.methodToJava(s), s.typeSignature.paramLists.flatten.zipWithIndex.collect {
91+
case (p, idx) if p.typeSignature.toString.startsWith("=>") => idx
92+
}.toSet)
93+
}
94+
.toSeq
95+
.filter(_._2.nonEmpty)
96+
}
7497
}
75-
.toMap
76-
.filter(_._2.nonEmpty)
77-
78-
ArgumentExtractor(methodsWithLazyArgs)
79-
}
98+
.toOption
99+
.getOrElse(ArgumentExtractor.Empty)
80100
}
81101
)
102+
103+
def readDeclaredField[T](o: AnyRef, field: String): Option[T] =
104+
scala.util.Try {
105+
val f = o.getClass.getDeclaredField(field)
106+
f.setAccessible(true)
107+
f.get(o).asInstanceOf[T]
108+
}.toOption
109+
82110
}
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
package org.mockito.internal.handler
1+
package org.mockito
2+
package internal.handler
23

4+
import java.lang.reflect.Method
35
import java.lang.reflect.Modifier.isAbstract
46
import java.util.concurrent.ConcurrentHashMap
57

8+
import org.mockito.ReflectionUtils.readDeclaredField
69
import org.mockito.internal.handler.ScalaMockHandler._
7-
import org.mockito.internal.invocation.{InterceptedInvocation, MockitoMethod}
8-
import org.mockito.invocation.{Invocation, MockHandler}
10+
import org.mockito.internal.invocation.mockref.MockReference
11+
import org.mockito.internal.invocation.{ InterceptedInvocation, MockitoMethod, RealMethod }
12+
import org.mockito.invocation.{ Invocation, MockHandler }
913
import org.mockito.mock.MockCreationSettings
14+
import org.scalactic.TripleEquals._
1015

1116
class ScalaMockHandler[T](mockSettings: MockCreationSettings[T]) extends MockHandlerImpl[T](mockSettings) {
1217
override def handle(invocation: Invocation): AnyRef =
@@ -15,15 +20,13 @@ class ScalaMockHandler[T](mockSettings: MockCreationSettings[T]) extends MockHan
1520
else {
1621
val scalaInvocation = invocation match {
1722
case i: InterceptedInvocation =>
18-
val mockitoMethod: MockitoMethod = readField(i, "mockitoMethod")
19-
new InterceptedInvocation(
20-
readField(i, "mockRef"),
21-
mockitoMethod,
22-
unwrapByNameArgs(mockitoMethod, i.getRawArguments.asInstanceOf[Array[Any]]).asInstanceOf[Array[Object]],
23-
readField(i, "realMethod"),
24-
i.getLocation,
25-
i.getSequenceNumber
26-
)
23+
val byNameAwareInvocation = for {
24+
mockitoMethod <- i.mockitoMethod
25+
mockRef <- i.mockRef
26+
realMethod <- i.realMethod
27+
byNameArgs = unwrapByNameArgs(mockitoMethod, i.getRawArguments)
28+
} yield new InterceptedInvocation(mockRef, mockitoMethod, byNameArgs, realMethod, i.getLocation, i.getSequenceNumber)
29+
byNameAwareInvocation.getOrElse(invocation)
2730
case other => other
2831
}
2932
super.handle(scalaInvocation)
@@ -34,25 +37,25 @@ object ScalaMockHandler {
3437
def apply[T](mockSettings: MockCreationSettings[T]): MockHandler[T] =
3538
new InvocationNotifierHandler[T](new ScalaNullResultGuardian[T](new ScalaMockHandler(mockSettings)), mockSettings)
3639

37-
private def readField[T](invocation: InterceptedInvocation, field: String): T = {
38-
val f = classOf[InterceptedInvocation].getDeclaredField(field)
39-
f.setAccessible(true)
40-
f.get(invocation).asInstanceOf[T]
40+
implicit class InterceptedInvocationOps(i: InterceptedInvocation) {
41+
def mockitoMethod: Option[MockitoMethod] = readDeclaredField(i, "mockitoMethod")
42+
def mockRef: Option[MockReference[Object]] = readDeclaredField(i, "mockRef")
43+
def realMethod: Option[RealMethod] = readDeclaredField(i, "realMethod")
4144
}
4245

43-
private def unwrapByNameArgs(method: MockitoMethod, args: Array[Any]): Array[Any] = {
44-
val declaringClass = method.getJavaMethod.getDeclaringClass
45-
if (Extractors.containsKey(declaringClass)) Extractors.get(declaringClass).transformArgs(method.getName, args)
46-
else args
47-
}
46+
private def unwrapByNameArgs(method: MockitoMethod, args: Array[AnyRef]): Array[Object] =
47+
Extractors
48+
.getOrDefault(method.getJavaMethod.getDeclaringClass, ArgumentExtractor.Empty)
49+
.transformArgs(method.getJavaMethod, args.asInstanceOf[Array[Any]])
50+
.asInstanceOf[Array[Object]]
4851

4952
val Extractors = new ConcurrentHashMap[Class[_], ArgumentExtractor]
5053

51-
case class ArgumentExtractor(toTransform: Map[String, Set[Int]]) {
52-
53-
def transformArgs(methodName: String, args: Array[Any]): Array[Any] =
54+
case class ArgumentExtractor(toTransform: Seq[(Method, Set[Int])]) {
55+
def transformArgs(method: Method, args: Array[Any]): Array[Any] =
5456
toTransform
55-
.get(methodName)
57+
.find(_._1 === method)
58+
.map(_._2)
5659
.map { transformIndices =>
5760
args.zipWithIndex.map {
5861
case (arg: Function0[_], idx) if transformIndices.contains(idx) => arg()
@@ -61,4 +64,8 @@ object ScalaMockHandler {
6164
}
6265
.getOrElse(args)
6366
}
67+
68+
object ArgumentExtractor {
69+
val Empty = ArgumentExtractor(Seq.empty)
70+
}
6471
}

0 commit comments

Comments
 (0)