Skip to content

Reflection robustness #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 66 additions & 38 deletions common/src/main/scala/org/mockito/ReflectionUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@ package org.mockito
import java.lang.reflect.Method
import java.util.function

import org.mockito.internal.handler.ScalaMockHandler.{ArgumentExtractor, Extractors}
import org.mockito.internal.handler.ScalaMockHandler.{ ArgumentExtractor, Extractors }
import org.mockito.invocation.InvocationOnMock
import ru.vyarus.java.generics.resolver.GenericsResolver
import org.scalactic.TripleEquals._
import ru.vyarus.java.generics.resolver.GenericsResolver

import scala.language.implicitConversions
import scala.reflect.internal.Symbols

private[mockito] object ReflectionUtils {

import scala.reflect.runtime.{ universe => ru }
import ru._

implicit def symbolToMethodSymbol(sym: Symbol): Symbols#MethodSymbol = sym.asInstanceOf[Symbols#MethodSymbol]

private val mirror = runtimeMirror(getClass.getClassLoader)
private val customMirror = mirror.asInstanceOf[{
def methodToJava(sym: Symbols#MethodSymbol): Method
Expand All @@ -23,60 +26,85 @@ private[mockito] object ReflectionUtils {
implicit class InvocationOnMockOps(invocation: InvocationOnMock) {
def returnType: Class[_] = {
val method = invocation.getMethod
val clazz = method.getDeclaringClass
val javaReturnType = invocation.getMethod.getReturnType
val javaReturnType = method.getReturnType

if (javaReturnType == classOf[Object])
mirror
.classSymbol(clazz)
.info
.decls
.filter(d => d.isMethod && !d.isConstructor)
.find(d => customMirror.methodToJava(d.asInstanceOf[Symbols#MethodSymbol]) === method)
.map(_.asMethod)
.filter(_.returnType.typeSymbol.isClass)
.map(methodSymbol => mirror.runtimeClass(methodSymbol.returnType.typeSymbol.asClass))
.orElse(resolveGenerics)
resolveWithScalaGenerics(method)
.orElse(resolveWithJavaGenerics(method))
.getOrElse(javaReturnType)
else javaReturnType
}

private def resolveGenerics: Option[Class[_]] =
private def resolveWithScalaGenerics(method: Method): Option[Class[_]] =
scala.util
.Try {
mirror
.classSymbol(method.getDeclaringClass)
.info
.decls
.filter(isNonConstructorMethod)
.find(d => customMirror.methodToJava(d) === method)
.map(_.asMethod)
.filter(_.returnType.typeSymbol.isClass)
.map(methodSymbol => mirror.runtimeClass(methodSymbol.returnType.typeSymbol.asClass))
}
.toOption
.flatten

private def resolveWithJavaGenerics(method: Method): Option[Class[_]] =
scala.util.Try {
GenericsResolver.resolve(invocation.getMock.getClass).`type`(clazz).method(invocation.getMethod).resolveReturnClass()
GenericsResolver.resolve(invocation.getMock.getClass).`type`(clazz).method(method).resolveReturnClass()
}.toOption
}

private def isNonConstructorMethod(d: ru.Symbol): Boolean = d.isMethod && !d.isConstructor

def interfaces[T](implicit tag: WeakTypeTag[T]): List[Class[_]] =
tag.tpe match {
case RefinedType(types, _) =>
types.map(tag.mirror.runtimeClass).collect {
case c: Class[_] if c.isInterface => c
scala.util
.Try {
tag.tpe match {
case RefinedType(types, _) =>
types.map(tag.mirror.runtimeClass).collect {
case c: Class[_] if c.isInterface => c
}
case _ => List.empty
}
case _ => List.empty
}
}
.toOption
.getOrElse(List.empty)

def markMethodsWithLazyArgs(clazz: Class[_]): Unit =
Extractors.computeIfAbsent(
clazz,
new function.Function[Class[_], ArgumentExtractor] {
override def apply(t: Class[_]): ArgumentExtractor = {
val mirror = runtimeMirror(clazz.getClassLoader)

val symbol = mirror.classSymbol(clazz)

val methodsWithLazyArgs = symbol.info.decls
.collect {
case s if s.isMethod =>
(s.name.toString, s.typeSignature.paramLists.flatten.zipWithIndex.collect {
case (p, idx) if p.typeSignature.toString.startsWith("=>") => idx
}.toSet)
override def apply(t: Class[_]): ArgumentExtractor =
scala.util
.Try {
ArgumentExtractor {
mirror
.classSymbol(clazz)
.info
.decls
.collect {
case s if isNonConstructorMethod(s) =>
(customMirror.methodToJava(s), s.typeSignature.paramLists.flatten.zipWithIndex.collect {
case (p, idx) if p.typeSignature.toString.startsWith("=>") => idx
}.toSet)
}
.toSeq
.filter(_._2.nonEmpty)
}
}
.toMap
.filter(_._2.nonEmpty)

ArgumentExtractor(methodsWithLazyArgs)
}
.toOption
.getOrElse(ArgumentExtractor.Empty)
}
)

def readDeclaredField[T](o: AnyRef, field: String): Option[T] =
scala.util.Try {
val f = o.getClass.getDeclaredField(field)
f.setAccessible(true)
f.get(o).asInstanceOf[T]
}.toOption

}
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package org.mockito.internal.handler
package org.mockito
package internal.handler

import java.lang.reflect.Method
import java.lang.reflect.Modifier.isAbstract
import java.util.concurrent.ConcurrentHashMap

import org.mockito.ReflectionUtils.readDeclaredField
import org.mockito.internal.handler.ScalaMockHandler._
import org.mockito.internal.invocation.{InterceptedInvocation, MockitoMethod}
import org.mockito.invocation.{Invocation, MockHandler}
import org.mockito.internal.invocation.mockref.MockReference
import org.mockito.internal.invocation.{ InterceptedInvocation, MockitoMethod, RealMethod }
import org.mockito.invocation.{ Invocation, MockHandler }
import org.mockito.mock.MockCreationSettings
import org.scalactic.TripleEquals._

class ScalaMockHandler[T](mockSettings: MockCreationSettings[T]) extends MockHandlerImpl[T](mockSettings) {
override def handle(invocation: Invocation): AnyRef =
Expand All @@ -15,15 +20,13 @@ class ScalaMockHandler[T](mockSettings: MockCreationSettings[T]) extends MockHan
else {
val scalaInvocation = invocation match {
case i: InterceptedInvocation =>
val mockitoMethod: MockitoMethod = readField(i, "mockitoMethod")
new InterceptedInvocation(
readField(i, "mockRef"),
mockitoMethod,
unwrapByNameArgs(mockitoMethod, i.getRawArguments.asInstanceOf[Array[Any]]).asInstanceOf[Array[Object]],
readField(i, "realMethod"),
i.getLocation,
i.getSequenceNumber
)
val byNameAwareInvocation = for {
mockitoMethod <- i.mockitoMethod
mockRef <- i.mockRef
realMethod <- i.realMethod
byNameArgs = unwrapByNameArgs(mockitoMethod, i.getRawArguments)
} yield new InterceptedInvocation(mockRef, mockitoMethod, byNameArgs, realMethod, i.getLocation, i.getSequenceNumber)
byNameAwareInvocation.getOrElse(invocation)
case other => other
}
super.handle(scalaInvocation)
Expand All @@ -34,25 +37,25 @@ object ScalaMockHandler {
def apply[T](mockSettings: MockCreationSettings[T]): MockHandler[T] =
new InvocationNotifierHandler[T](new ScalaNullResultGuardian[T](new ScalaMockHandler(mockSettings)), mockSettings)

private def readField[T](invocation: InterceptedInvocation, field: String): T = {
val f = classOf[InterceptedInvocation].getDeclaredField(field)
f.setAccessible(true)
f.get(invocation).asInstanceOf[T]
implicit class InterceptedInvocationOps(i: InterceptedInvocation) {
def mockitoMethod: Option[MockitoMethod] = readDeclaredField(i, "mockitoMethod")
def mockRef: Option[MockReference[Object]] = readDeclaredField(i, "mockRef")
def realMethod: Option[RealMethod] = readDeclaredField(i, "realMethod")
}

private def unwrapByNameArgs(method: MockitoMethod, args: Array[Any]): Array[Any] = {
val declaringClass = method.getJavaMethod.getDeclaringClass
if (Extractors.containsKey(declaringClass)) Extractors.get(declaringClass).transformArgs(method.getName, args)
else args
}
private def unwrapByNameArgs(method: MockitoMethod, args: Array[AnyRef]): Array[Object] =
Extractors
.getOrDefault(method.getJavaMethod.getDeclaringClass, ArgumentExtractor.Empty)
.transformArgs(method.getJavaMethod, args.asInstanceOf[Array[Any]])
.asInstanceOf[Array[Object]]

val Extractors = new ConcurrentHashMap[Class[_], ArgumentExtractor]

case class ArgumentExtractor(toTransform: Map[String, Set[Int]]) {

def transformArgs(methodName: String, args: Array[Any]): Array[Any] =
case class ArgumentExtractor(toTransform: Seq[(Method, Set[Int])]) {
def transformArgs(method: Method, args: Array[Any]): Array[Any] =
toTransform
.get(methodName)
.find(_._1 === method)
.map(_._2)
.map { transformIndices =>
args.zipWithIndex.map {
case (arg: Function0[_], idx) if transformIndices.contains(idx) => arg()
Expand All @@ -61,4 +64,8 @@ object ScalaMockHandler {
}
.getOrElse(args)
}

object ArgumentExtractor {
val Empty = ArgumentExtractor(Seq.empty)
}
}