diff --git a/core/src/test/scala/user/org/mockito/IdiomaticMockitoTest.scala b/core/src/test/scala/user/org/mockito/IdiomaticMockitoTest.scala index 6ad59d21..4f509bc3 100644 --- a/core/src/test/scala/user/org/mockito/IdiomaticMockitoTest.scala +++ b/core/src/test/scala/user/org/mockito/IdiomaticMockitoTest.scala @@ -3,22 +3,22 @@ package user.org.mockito import org.mockito.captor.ArgCaptor import org.mockito.exceptions.verification._ import org.mockito.invocation.InvocationOnMock -import org.mockito.{ ArgumentMatchersSugar, IdiomaticMockito } +import org.mockito.{ArgumentMatchersSugar, IdiomaticMockito, MockitoSugar} import org.scalatest.prop.TableDrivenPropertyChecks -import org.scalatest.{ Matchers, WordSpec } -import user.org.mockito.matchers.{ ValueCaseClass, ValueClass } +import org.scalatest.{Matchers, WordSpec} +import user.org.mockito.matchers.{ValueCaseClass, ValueClass} case class Bread(name: String) extends AnyVal case class Cheese(name: String) class IdiomaticMockitoTest extends WordSpec with Matchers with IdiomaticMockito with ArgumentMatchersSugar with TableDrivenPropertyChecks { val scenarios = Table( - ("testDouble", "orgDouble"), - ("mock", () => mock[Org]), - ("spy", () => spy(new Org)) + ("testDouble", "orgDouble", "foo"), + ("mock", () => mock[Org], () => mock[Foo]), + ("spy", () => spy(new Org), () => spy(new Foo)) ) - forAll(scenarios) { (testDouble, orgDouble) => + forAll(scenarios) { (testDouble, orgDouble, foo) => testDouble should { "stub a return value" in { val org = orgDouble() @@ -563,6 +563,74 @@ class IdiomaticMockitoTest extends WordSpec with Matchers with IdiomaticMockito org.valueCaseClass(2, ValueCaseClass(100)) shouldBe "mocked!" org.valueCaseClass(2, any[ValueCaseClass]) was called } + + + "default answer should deal with default arguments" in { + val aMock = foo() + + aMock.iHaveSomeDefaultArguments("I'm not gonna pass the second argument") + aMock.iHaveSomeDefaultArguments("I'm gonna pass the second argument", "second argument") + + aMock.iHaveSomeDefaultArguments("I'm not gonna pass the second argument", "default value") was called + aMock.iHaveSomeDefaultArguments("I'm gonna pass the second argument", "second argument") was called + } + + "work with by-name arguments (argument order doesn't matter when not using matchers)" in { + val aMock = foo() + + aMock.iStartWithByNameArgs("arg1", "arg2") shouldReturn "mocked!" + + aMock.iStartWithByNameArgs("arg1", "arg2") shouldBe "mocked!" + aMock.iStartWithByNameArgs("arg111", "arg2") should not be "mocked!" + + aMock.iStartWithByNameArgs("arg1", "arg2") was called + aMock.iStartWithByNameArgs("arg111", "arg2") was called + } + + "work with primitive by-name arguments" in { + val aMock = foo() + + aMock.iHavePrimitiveByNameArgs(1, "arg2") shouldReturn "mocked!" + + aMock.iHavePrimitiveByNameArgs(1, "arg2") shouldBe "mocked!" + aMock.iHavePrimitiveByNameArgs(2, "arg2") should not be "mocked!" + + aMock.iHavePrimitiveByNameArgs(1, "arg2") was called + aMock.iHavePrimitiveByNameArgs(2, "arg2") was called + } + + "work with Function0 arguments" in { + val aMock = foo() + + aMock.iHaveFunction0Args(eqTo("arg1"), function0("arg2")) shouldReturn "mocked!" + + aMock.iHaveFunction0Args("arg1", () => "arg2") shouldBe "mocked!" + aMock.iHaveFunction0Args("arg1", () => "arg3") should not be "mocked!" + + aMock.iHaveFunction0Args(eqTo("arg1"), function0("arg2")) was called + aMock.iHaveFunction0Args(eqTo("arg1"), function0("arg3")) was called + } + + "reset" in { + val aMock = foo() + + aMock.bar shouldReturn "mocked!" + aMock.iHavePrimitiveByNameArgs(1, "arg2") shouldReturn "mocked!" + + aMock.bar shouldBe "mocked!" + aMock.iHavePrimitiveByNameArgs(1, "arg2") shouldBe "mocked!" + + MockitoSugar.reset(aMock) + + aMock.bar should not be "mocked!" + aMock.iHavePrimitiveByNameArgs(1, "arg2") should not be "mocked!" + + //to verify the reset mock handler still handles by-name params + aMock.iHavePrimitiveByNameArgs(1, "arg2") shouldReturn "mocked!" + + aMock.iHavePrimitiveByNameArgs(1, "arg2") shouldBe "mocked!" + } + } } diff --git a/macro/src/main/scala/org/mockito/DoSomethingMacro.scala b/macro/src/main/scala/org/mockito/DoSomethingMacro.scala index eadcf686..d8a16698 100644 --- a/macro/src/main/scala/org/mockito/DoSomethingMacro.scala +++ b/macro/src/main/scala/org/mockito/DoSomethingMacro.scala @@ -14,8 +14,11 @@ object DoSomethingMacro { val r = c.Expr[T] { c.macroApplication match { case q"$_.DoSomethingOps[$r]($v).willBe($_.returned).by[$_]($obj.$method[..$targs](...$args))" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"_root_.org.mockito.MockitoSugar.doReturn[$r]($v).when($obj).$method[..$targs](...$newArgs)" + } else + q"_root_.org.mockito.MockitoSugar.doReturn[$r]($v).when($obj).$method[..$targs](...$args)" case q"$_.DoSomethingOps[$r]($v).willBe($_.returned).by[$_]($obj.$method[..$targs])" => q"_root_.org.mockito.MockitoSugar.doReturn[$r]($v).when($obj).$method[..$targs]" @@ -33,8 +36,11 @@ object DoSomethingMacro { val r = c.Expr[T] { c.macroApplication match { case q"$_.DoSomethingOps[$r]($v).willBe($_.answered).by[$_]($obj.$method[..$targs](...$args))" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"_root_.org.mockito.MockitoSugar.doAnswer($v).when($obj).$method[..$targs](...$newArgs)" + } else + q"_root_.org.mockito.MockitoSugar.doAnswer($v).when($obj).$method[..$targs](...$args)" case q"$_.DoSomethingOps[$r]($v).willBe($_.answered).by[$_]($obj.$method[..$targs])" => q"_root_.org.mockito.MockitoSugar.doAnswer($v).when($obj).$method[..$targs]" @@ -52,8 +58,11 @@ object DoSomethingMacro { val r = c.Expr[T] { c.macroApplication match { case q"$_.ThrowSomethingOps[$_]($v).willBe($_.thrown).by[$_]($obj.$method[..$targs](...$args))" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"_root_.org.mockito.MockitoSugar.doThrow($v).when($obj).$method[..$targs](...$newArgs)" + } else + q"_root_.org.mockito.MockitoSugar.doThrow($v).when($obj).$method[..$targs](...$args)" case q"$_.ThrowSomethingOps[$_]($v).willBe($_.thrown).by[$_]($obj.$method[..$targs])" => q"_root_.org.mockito.MockitoSugar.doThrow($v).when($obj).$method[..$targs]" @@ -71,8 +80,11 @@ object DoSomethingMacro { val r = c.Expr[T] { c.macroApplication match { case q"$_.theRealMethod.willBe($_.called).by[$_]($obj.$method[..$targs](...$args))" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"_root_.org.mockito.MockitoSugar.doCallRealMethod.when($obj).$method[..$targs](...$newArgs)" + } else + q"_root_.org.mockito.MockitoSugar.doCallRealMethod.when($obj).$method[..$targs](...$args)" case q"$_.theRealMethod.willBe($_.called).by[$_]($obj.$method[..$targs])" => q"_root_.org.mockito.MockitoSugar.doCallRealMethod.when($obj).$method[..$targs]" diff --git a/macro/src/main/scala/org/mockito/Utils.scala b/macro/src/main/scala/org/mockito/Utils.scala index 2f9ac44d..065c64ed 100644 --- a/macro/src/main/scala/org/mockito/Utils.scala +++ b/macro/src/main/scala/org/mockito/Utils.scala @@ -2,6 +2,8 @@ package org.mockito import scala.reflect.macros.blackbox object Utils { + private[mockito] def hasMatchers(c: blackbox.Context)(args: List[c.Tree]): Boolean = + args.exists(arg => isMatcher(c)(arg)) private[mockito] def isMatcher(c: blackbox.Context)(arg: c.Tree): Boolean = { import c.universe._ diff --git a/macro/src/main/scala/org/mockito/VerifyMacro.scala b/macro/src/main/scala/org/mockito/VerifyMacro.scala index ca8a0e0b..92466232 100644 --- a/macro/src/main/scala/org/mockito/VerifyMacro.scala +++ b/macro/src/main/scala/org/mockito/VerifyMacro.scala @@ -19,8 +19,11 @@ object VerifyMacro { val r = c.Expr[Unit] { c.macroApplication match { case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).was($_.called)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"$order.verify($obj).$method[..$targs](...$newArgs)" + } else + q"$order.verify($obj).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).was($_.called)($order)" => q"$order.verify($obj).$method[..$targs]" @@ -41,8 +44,11 @@ object VerifyMacro { q"_root_.org.mockito.MockitoSugar.verifyZeroInteractions($obj)" case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).wasNever($_.called)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.never).$method[..$targs](...$newArgs)" + } else + q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.never).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).wasNever($_.called)($order)" => q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.never).$method[..$targs]" @@ -65,8 +71,11 @@ object VerifyMacro { val r = c.Expr[Unit] { c.macroApplication match { case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).wasCalled($times)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.times($times.times)).$method[..$targs](...$newArgs)" + } else + q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.times($times.times)).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).wasCalled($times)($order)" => q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.times($times.times)).$method[..$targs]" @@ -86,8 +95,11 @@ object VerifyMacro { val r = c.Expr[Unit] { c.macroApplication match { case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).wasCalled($times)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.atLeast($times.times)).$method[..$targs](...$newArgs)" + } else + q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.atLeast($times.times)).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).wasCalled($times)($order)" => q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.atLeast($times.times)).$method[..$targs]" @@ -107,8 +119,11 @@ object VerifyMacro { val r = c.Expr[Unit] { c.macroApplication match { case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).wasCalled($times)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.atMost($times.times)).$method[..$targs](...$newArgs)" + } else + q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.atMost($times.times)).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).wasCalled($times)($order)" => q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.atMost($times.times)).$method[..$targs]" @@ -128,8 +143,11 @@ object VerifyMacro { val r = c.Expr[Unit] { c.macroApplication match { case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).wasCalled($_)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.only).$method[..$targs](...$newArgs)" + } else + q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.only).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).wasCalled($_)($order)" => q"$order.verifyWithMode($obj, _root_.org.mockito.Mockito.only).$method[..$targs]" diff --git a/macro/src/main/scala/org/mockito/WhenMacro.scala b/macro/src/main/scala/org/mockito/WhenMacro.scala index 11829b07..dffad149 100644 --- a/macro/src/main/scala/org/mockito/WhenMacro.scala +++ b/macro/src/main/scala/org/mockito/WhenMacro.scala @@ -19,8 +19,11 @@ object WhenMacro { val r = c.Expr[ReturnActions[T]] { c.macroApplication match { case q"$_.StubbingOps[$t]($obj.$method[..$targs](...$args)).shouldReturn" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"new _root_.org.mockito.WhenMacro.ReturnActions(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$newArgs)))" + } else + q"new _root_.org.mockito.WhenMacro.ReturnActions(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$args)))" case q"$_.StubbingOps[$t]($obj.$method[..$targs]).shouldReturn" => q"new _root_.org.mockito.WhenMacro.ReturnActions(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs]))" @@ -38,8 +41,11 @@ object WhenMacro { val r = c.Expr[Unit] { c.macroApplication match { case q"$_.StubbingOps[$t]($obj.$method[..$targs](...$args)).isLenient()" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"new _root_.org.mockito.stubbing.ScalaFirstStubbing(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$newArgs))).isLenient()" + } else + q"new _root_.org.mockito.stubbing.ScalaFirstStubbing(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$args))).isLenient()" case q"$_.StubbingOps[$t]($obj.$method[..$targs]).isLenient()" => q"new _root_.org.mockito.stubbing.ScalaFirstStubbing(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs])).isLenient()" @@ -59,8 +65,11 @@ object WhenMacro { val r = c.Expr[ScalaOngoingStubbing[T]] { c.macroApplication match { case q"$_.StubbingOps[$t]($obj.$method[..$targs](...$args)).shouldCall($_.realMethod)" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"new _root_.org.mockito.stubbing.ScalaOngoingStubbing(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$newArgs)).thenCallRealMethod())" + } else + q"new _root_.org.mockito.stubbing.ScalaOngoingStubbing(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$args)).thenCallRealMethod())" case q"$_.StubbingOps[$t]($obj.$method[..$targs]).shouldCall($_.realMethod)" => q"new _root_.org.mockito.stubbing.ScalaOngoingStubbing(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs]).thenCallRealMethod())" @@ -82,8 +91,11 @@ object WhenMacro { val r = c.Expr[ThrowActions[T]] { c.macroApplication match { case q"$_.StubbingOps[$t]($obj.$method[..$targs](...$args)).shouldThrow" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"new _root_.org.mockito.WhenMacro.ThrowActions(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$newArgs)))" + } else + q"new _root_.org.mockito.WhenMacro.ThrowActions(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$args)))" case q"$_.StubbingOps[$t]($obj.$method[..$targs]).shouldThrow" => q"new _root_.org.mockito.WhenMacro.ThrowActions(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs]))" @@ -127,8 +139,11 @@ object WhenMacro { val r = c.Expr[AnswerActions[T]] { c.macroApplication match { case q"$_.StubbingOps[$t]($obj.$method[..$targs](...$args)).shouldAnswer" => + if (args.exists(a => hasMatchers(c)(a))) { val newArgs = args.map(a => transformArgs(c)(a)) q"new _root_.org.mockito.WhenMacro.AnswerActions(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$newArgs)))" + } else + q"new _root_.org.mockito.WhenMacro.AnswerActions(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$args)))" case q"$_.StubbingOps[$t]($obj.$method[..$targs]).shouldAnswer" => q"new _root_.org.mockito.WhenMacro.AnswerActions(_root_.org.mockito.Mockito.when[$t]($obj.$method[..$targs]))"