Skip to content

Commit 80aaf18

Browse files
committed
Avoid masking user exception with ??? for Nothing typed expressions
Code like: val x = if (cond) throw new A else throw new B Was being transformed to: val ifRes = ??? if (cond) ifRes = throw new A else ifRes = throw new B val x = ifRes by way of the use of `gen.mkZero` which throws `???` if the requested type is `Nothing` This commit special cases `Nothing` typed expressions in a similar manner to `Unit` type expressions. The example above is now translated to: if (cond) throw new A else throw new B val x = throw new IllegalStateException() Fixes #120
1 parent 017928c commit 80aaf18

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

Diff for: src/main/scala/scala/async/internal/AnfTransform.scala

+7-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ private[async] trait AnfTransform {
4848
val stats :+ expr = anf.transformToList(tree)
4949
def statsExprUnit =
5050
stats :+ expr :+ api.typecheck(atPos(expr.pos)(Literal(Constant(()))))
51+
def statsExprThrow =
52+
stats :+ expr :+ api.typecheck(atPos(expr.pos)(Throw(Apply(Select(New(gen.mkAttributedRef(defn.IllegalStateExceptionClass)), nme.CONSTRUCTOR), Nil))))
5153
expr match {
5254
case Apply(fun, args) if isAwait(fun) =>
5355
val valDef = defineVal(name.await, expr, tree.pos)
@@ -68,6 +70,8 @@ private[async] trait AnfTransform {
6870
// but add Unit value to bring it into form expected by async transform
6971
if (expr.tpe =:= definitions.UnitTpe) {
7072
statsExprUnit
73+
} else if (expr.tpe =:= definitions.NothingTpe) {
74+
statsExprThrow
7175
} else {
7276
val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
7377
def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) {
@@ -88,8 +92,9 @@ private[async] trait AnfTransform {
8892
// but add Unit value to bring it into form expected by async transform
8993
if (expr.tpe =:= definitions.UnitTpe) {
9094
statsExprUnit
91-
}
92-
else {
95+
} else if (expr.tpe =:= definitions.NothingTpe) {
96+
statsExprThrow
97+
} else {
9398
val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
9499
def typedAssign(lhs: Tree) =
95100
api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol)))))

Diff for: src/main/scala/scala/async/internal/TransformUtils.scala

+1
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ private[async] trait TransformUtils {
151151

152152
val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
153153
val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol)
154+
val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException")
154155
}
155156

156157
// `while(await(x))` ... or `do { await(x); ... } while(...)` contain an `If` that loops;

Diff for: src/test/scala/scala/async/run/anf/AnfTransformSpec.scala

+42
Original file line numberDiff line numberDiff line change
@@ -405,4 +405,46 @@ class AnfTransformSpec {
405405
val applyImplicitView = tree.collect { case x if x.getClass.getName.endsWith("ApplyImplicitView") => x }
406406
applyImplicitView.map(_.toString) mustStartWith List("view(a$macro$")
407407
}
408+
409+
@Test
410+
def nothingTypedIf(): Unit = {
411+
import scala.async.internal.AsyncId.{async, await}
412+
val result = util.Try(async {
413+
if (true) {
414+
val n = await(1)
415+
if (n < 2) {
416+
throw new RuntimeException("case a")
417+
}
418+
else {
419+
throw new RuntimeException("case b")
420+
}
421+
}
422+
else {
423+
"case c"
424+
}
425+
})
426+
427+
assert(result.asInstanceOf[util.Failure[_]].exception.getMessage == "case a")
428+
}
429+
430+
@Test
431+
def nothingTypedMatch(): Unit = {
432+
import scala.async.internal.AsyncId.{async, await}
433+
val result = util.Try(async {
434+
0 match {
435+
case _ if "".isEmpty =>
436+
val n = await(1)
437+
n match {
438+
case _ if n < 2 =>
439+
throw new RuntimeException("case a")
440+
case _ =>
441+
throw new RuntimeException("case b")
442+
}
443+
case _ =>
444+
"case c"
445+
}
446+
})
447+
448+
assert(result.asInstanceOf[util.Failure[_]].exception.getMessage == "case a")
449+
}
408450
}

0 commit comments

Comments
 (0)