From 80aaf18d5111322baee73dad30eb0a81cdd62314 Mon Sep 17 00:00:00 2001
From: Jason Zaugg <jzaugg@gmail.com>
Date: Mon, 27 Jul 2015 13:15:43 +1000
Subject: [PATCH] Avoid masking user exception with ??? for Nothing typed
 expressions

Code like:

    val x = if (cond) throw new A else throw new B

Was being transformed to:

    val ifRes = ???
    if (cond) ifRes = throw new A else ifRes = throw new B
    val x = ifRes

by way of the use of `gen.mkZero` which throws `???` if the requested type is `Nothing`

This commit special cases `Nothing` typed expressions in a similar manner to `Unit` type expressions.
The example above is now translated to:

    if (cond) throw new A else throw new B
    val x = throw new IllegalStateException()

Fixes #120
---
 .../scala/async/internal/AnfTransform.scala   |  9 +++-
 .../scala/async/internal/TransformUtils.scala |  1 +
 .../async/run/anf/AnfTransformSpec.scala      | 42 +++++++++++++++++++
 3 files changed, 50 insertions(+), 2 deletions(-)

diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala
index 585f3882..55e164bf 100644
--- a/src/main/scala/scala/async/internal/AnfTransform.scala
+++ b/src/main/scala/scala/async/internal/AnfTransform.scala
@@ -48,6 +48,8 @@ private[async] trait AnfTransform {
           val stats :+ expr = anf.transformToList(tree)
           def statsExprUnit =
             stats :+ expr :+ api.typecheck(atPos(expr.pos)(Literal(Constant(()))))
+          def statsExprThrow =
+            stats :+ expr :+ api.typecheck(atPos(expr.pos)(Throw(Apply(Select(New(gen.mkAttributedRef(defn.IllegalStateExceptionClass)), nme.CONSTRUCTOR), Nil))))
           expr match {
             case Apply(fun, args) if isAwait(fun) =>
               val valDef = defineVal(name.await, expr, tree.pos)
@@ -68,6 +70,8 @@ private[async] trait AnfTransform {
               // but add Unit value to bring it into form expected by async transform
               if (expr.tpe =:= definitions.UnitTpe) {
                 statsExprUnit
+              } else if (expr.tpe =:= definitions.NothingTpe) {
+                statsExprThrow
               } else {
                 val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
                 def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) {
@@ -88,8 +92,9 @@ private[async] trait AnfTransform {
               // but add Unit value to bring it into form expected by async transform
               if (expr.tpe =:= definitions.UnitTpe) {
                 statsExprUnit
-              }
-              else {
+              } else if (expr.tpe =:= definitions.NothingTpe) {
+                statsExprThrow
+              } else {
                 val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
                 def typedAssign(lhs: Tree) =
                   api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol)))))
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala
index df958b87..547f9807 100644
--- a/src/main/scala/scala/async/internal/TransformUtils.scala
+++ b/src/main/scala/scala/async/internal/TransformUtils.scala
@@ -151,6 +151,7 @@ private[async] trait TransformUtils {
 
     val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
     val Async_await   = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol)
+    val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException")
   }
 
   // `while(await(x))` ... or `do { await(x); ... } while(...)` contain an `If` that loops;
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
index 2cce7e88..13cc351b 100644
--- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -405,4 +405,46 @@ class AnfTransformSpec {
     val applyImplicitView = tree.collect { case x if x.getClass.getName.endsWith("ApplyImplicitView") => x }
     applyImplicitView.map(_.toString) mustStartWith List("view(a$macro$")
   }
+
+  @Test
+  def nothingTypedIf(): Unit = {
+    import scala.async.internal.AsyncId.{async, await}
+    val result = util.Try(async {
+      if (true) {
+        val n = await(1)
+        if (n < 2) {
+          throw new RuntimeException("case a")
+        }
+        else {
+          throw new RuntimeException("case b")
+        }
+      }
+      else {
+        "case c"
+      }
+    })
+
+    assert(result.asInstanceOf[util.Failure[_]].exception.getMessage == "case a")
+  }
+
+  @Test
+  def nothingTypedMatch(): Unit = {
+    import scala.async.internal.AsyncId.{async, await}
+    val result = util.Try(async {
+      0 match {
+        case _ if "".isEmpty =>
+          val n = await(1)
+          n match {
+            case _ if n < 2 =>
+              throw new RuntimeException("case a")
+            case _ =>
+              throw new RuntimeException("case b")
+          }
+        case _ =>
+          "case c"
+      }
+    })
+
+    assert(result.asInstanceOf[util.Failure[_]].exception.getMessage == "case a")
+  }
 }