Skip to content

Commit 4395afe

Browse files
committed
Merge pull request #151 from retronym/topic/late-expansion-fixes
Late expansion fixes
2 parents 656748c + 549a656 commit 4395afe

File tree

9 files changed

+495
-82
lines changed

9 files changed

+495
-82
lines changed

Diff for: src/main/scala/scala/async/internal/AnfTransform.scala

+105-40
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,27 @@ private[async] trait AnfTransform {
2727
val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner)
2828

2929
var mode: AnfMode = Anf
30+
31+
object trace {
32+
private var indent = -1
33+
34+
private def indentString = " " * indent
35+
36+
def apply[T](args: Any)(t: => T): T = {
37+
def prefix = mode.toString.toLowerCase
38+
indent += 1
39+
def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127)
40+
try {
41+
AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
42+
val result = t
43+
AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
44+
result
45+
} finally {
46+
indent -= 1
47+
}
48+
}
49+
}
50+
3051
typingTransform(tree1, owner)((tree, api) => {
3152
def blockToList(tree: Tree): List[Tree] = tree match {
3253
case Block(stats, expr) => stats :+ expr
@@ -97,8 +118,11 @@ private[async] trait AnfTransform {
97118
val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)).setType(definitions.UnitTpe)
98119
stats :+ varDef :+ ifWithAssign :+ atPos(tree.pos)(gen.mkAttributedStableRef(varDef.symbol)).setType(tree.tpe)
99120
}
100-
case LabelDef(name, params, rhs) =>
101-
statsExprUnit
121+
case ld @ LabelDef(name, params, rhs) =>
122+
if (ld.symbol.info.resultType.typeSymbol == definitions.UnitClass)
123+
statsExprUnit
124+
else
125+
stats :+ expr
102126

103127
case Match(scrut, cases) =>
104128
// if type of match is Unit don't introduce assignment,
@@ -134,26 +158,6 @@ private[async] trait AnfTransform {
134158
}
135159
}
136160

137-
object trace {
138-
private var indent = -1
139-
140-
private def indentString = " " * indent
141-
142-
def apply[T](args: Any)(t: => T): T = {
143-
def prefix = mode.toString.toLowerCase
144-
indent += 1
145-
def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127)
146-
try {
147-
AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
148-
val result = t
149-
AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
150-
result
151-
} finally {
152-
indent -= 1
153-
}
154-
}
155-
}
156-
157161
def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
158162
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe))
159163
internal.valDef(sym, internal.changeOwner(lhs, api.currentOwner, sym)).setType(NoType).setPos(pos)
@@ -219,8 +223,29 @@ private[async] trait AnfTransform {
219223
funStats ++ argStatss.flatten.flatten :+ typedNewApply
220224

221225
case Block(stats, expr) =>
222-
val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr)
223-
eliminateMatchEndLabelParameter(trees)
226+
val stats1 = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit)
227+
val exprs1 = linearize.transformToList(expr)
228+
val trees = stats1 ::: exprs1
229+
def isMatchEndLabel(t: Tree): Boolean = t match {
230+
case ValDef(_, _, _, t) if isMatchEndLabel(t) => true
231+
case ld: LabelDef if ld.name.toString.startsWith("matchEnd") => true
232+
case _ => false
233+
}
234+
def groupsEndingWith[T](ts: List[T])(f: T => Boolean): List[List[T]] = if (ts.isEmpty) Nil else {
235+
ts.indexWhere(f) match {
236+
case -1 => List(ts)
237+
case i =>
238+
val (ts1, ts2) = ts.splitAt(i + 1)
239+
ts1 :: groupsEndingWith(ts2)(f)
240+
}
241+
}
242+
val matchGroups = groupsEndingWith(trees)(isMatchEndLabel)
243+
val trees1 = matchGroups.flatMap(eliminateMatchEndLabelParameter)
244+
val result = trees1 flatMap {
245+
case Block(stats, expr) => stats :+ expr
246+
case t => t :: Nil
247+
}
248+
result
224249

225250
case ValDef(mods, name, tpt, rhs) =>
226251
if (containsAwait(rhs)) {
@@ -260,7 +285,10 @@ private[async] trait AnfTransform {
260285
scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs)
261286

262287
case LabelDef(name, params, rhs) =>
263-
List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
288+
if (tree.symbol.info.typeSymbol == definitions.UnitClass)
289+
List(treeCopy.LabelDef(tree, name, params, api.typecheck(newBlock(linearize.transformToList(rhs), Literal(Constant(()))))).setSymbol(tree.symbol))
290+
else
291+
List(treeCopy.LabelDef(tree, name, params, api.typecheck(listToBlock(linearize.transformToList(rhs)))).setSymbol(tree.symbol))
264292

265293
case TypeApply(fun, targs) =>
266294
val funStats :+ simpleFun = linearize.transformToList(fun)
@@ -274,7 +302,7 @@ private[async] trait AnfTransform {
274302

275303
// Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable
276304
//
277-
// CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts
305+
// CaseDefs are translated to labels without parameters. A terminal label, `matchEnd`, accepts
278306
// a parameter which is the result of the match (this is regular, so even Unit-typed matches have this).
279307
//
280308
// For our purposes, it is easier to:
@@ -286,34 +314,71 @@ private[async] trait AnfTransform {
286314
val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]()
287315

288316
val matchResults = collection.mutable.Buffer[Tree]()
289-
val statsExpr0 = statsExpr.reverseMap {
290-
case ld @ LabelDef(_, param :: Nil, body) =>
317+
def modifyLabelDef(ld: LabelDef): (Tree, Tree) = {
318+
val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
319+
val param = ld.params.head
320+
val ld2 = if (ld.params.head.tpe.typeSymbol == definitions.UnitClass) {
321+
// Unit typed match: eliminate the label def parameter, but don't create a matchres temp variable to
322+
// store the result for cleaner generated code.
323+
caseDefToMatchResult(ld.symbol) = NoSymbol
324+
val rhs2 = substituteTrees(ld.rhs, param.symbol :: Nil, api.typecheck(literalUnit) :: Nil)
325+
(treeCopy.LabelDef(ld, ld.name, Nil, api.typecheck(literalUnit)), rhs2)
326+
} else {
327+
// Otherwise, create the matchres var. We'll callers of the label def below.
328+
// Remember: we're iterating through the statement sequence in reverse, so we'll get
329+
// to the LabelDef and mutate `matchResults` before we'll get to its callers.
291330
val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
292331
matchResults += matchResult
293332
caseDefToMatchResult(ld.symbol) = matchResult.symbol
294-
val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil))
295-
setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType))
296-
ld2
333+
val rhs2 = ld.rhs.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil)
334+
(treeCopy.LabelDef(ld, ld.name, Nil, api.typecheck(literalUnit)), rhs2)
335+
}
336+
setInfo(ld.symbol, methodType(Nil, definitions.UnitTpe))
337+
ld2
338+
}
339+
val statsExpr0 = statsExpr.reverse.flatMap {
340+
case ld @ LabelDef(_, param :: Nil, _) =>
341+
val (ld1, after) = modifyLabelDef(ld)
342+
List(after, ld1)
343+
case a @ ValDef(mods, name, tpt, ld @ LabelDef(_, param :: Nil, _)) =>
344+
val (ld1, after) = modifyLabelDef(ld)
345+
List(treeCopy.ValDef(a, mods, name, tpt, after), ld1)
297346
case t =>
298-
if (caseDefToMatchResult.isEmpty) t
299-
else typingTransform(t)((tree, api) =>
347+
if (caseDefToMatchResult.isEmpty) t :: Nil
348+
else typingTransform(t)((tree, api) => {
349+
def typedPos(pos: Position)(t: Tree): Tree =
350+
api.typecheck(atPos(pos)(t))
300351
tree match {
301352
case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) =>
302-
api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil))))
303-
case Block(stats, expr) =>
353+
val temp = caseDefToMatchResult(fun.symbol)
354+
if (temp == NoSymbol)
355+
typedPos(tree.pos)(newBlock(api.recur(arg) :: Nil, treeCopy.Apply(tree, fun, Nil)))
356+
else
357+
// setType needed for LateExpansion.shadowingRefinedType test case. There seems to be an inconsistency
358+
// in the trees after pattern matcher.
359+
// TODO miminize the problem in patmat and fix in scalac.
360+
typedPos(tree.pos)(newBlock(Assign(Ident(temp), api.recur(internal.setType(arg, fun.tpe.paramLists.head.head.info))) :: Nil, treeCopy.Apply(tree, fun, Nil)))
361+
case Block(stats, expr: Apply) if isLabel(expr.symbol) =>
304362
api.default(tree) match {
305-
case Block(stats, Block(stats1, expr)) =>
306-
treeCopy.Block(tree, stats ::: stats1, expr)
363+
case Block(stats0, Block(stats1, expr1)) =>
364+
// flatten the block returned by `case Apply` above into the enclosing block for
365+
// cleaner generated code.
366+
treeCopy.Block(tree, stats0 ::: stats1, expr1)
307367
case t => t
308368
}
309369
case _ =>
310370
api.default(tree)
311371
}
312-
)
372+
}) :: Nil
313373
}
314374
matchResults.toList match {
315-
case Nil => statsExpr
316-
case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
375+
case _ if caseDefToMatchResult.isEmpty =>
376+
statsExpr // return the original trees if nothing changed
377+
case Nil =>
378+
statsExpr0.reverse :+ literalUnit // must have been a unit-typed match, no matchRes variable to definne or refer to
379+
case r1 :: Nil =>
380+
// { var matchRes = _; ....; matchRes }
381+
(r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
317382
case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr
318383
}
319384
}

Diff for: src/main/scala/scala/async/internal/AsyncBase.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ abstract class AsyncBase {
5555

5656
protected[async] def asyncMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = {
5757
import u._
58-
asyncMacroSymbol.owner.typeSignature.member(newTermName("async"))
58+
if (asyncMacroSymbol == null) NoSymbol
59+
else asyncMacroSymbol.owner.typeSignature.member(newTermName("async"))
5960
}
6061

6162
protected[async] def awaitMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = {
6263
import u._
63-
asyncMacroSymbol.owner.typeSignature.member(newTermName("await"))
64+
if (asyncMacroSymbol == null) NoSymbol
65+
else asyncMacroSymbol.owner.typeSignature.member(newTermName("await"))
6466
}
6567

6668
protected[async] def nullOut(u: Universe)(name: u.Expr[String], v: u.Expr[Any]): u.Expr[Unit] =

Diff for: src/main/scala/scala/async/internal/AsyncTransform.scala

+10-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,16 @@ trait AsyncTransform {
4949
List(emptyConstructor, stateVar) ++ resultAndAccessors ++ List(execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
5050
}
5151

52-
val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit])
53-
val template = Template(List(tryToUnit, typeOf[() => Unit]).map(TypeTree(_)), emptyValDef, body)
52+
val customParents = futureSystemOps.stateMachineClassParents
53+
val tycon = if (customParents.exists(!_.typeSymbol.asClass.isTrait)) {
54+
// prefer extending a class to reduce the class file size of the state machine.
55+
symbolOf[scala.runtime.AbstractFunction1[Any, Any]]
56+
} else {
57+
// ... unless a custom future system already extends some class
58+
symbolOf[scala.Function1[Any, Any]]
59+
}
60+
val tryToUnit = appliedType(tycon, futureSystemOps.tryType[Any], typeOf[Unit])
61+
val template = Template((futureSystemOps.stateMachineClassParents ::: List(tryToUnit, typeOf[() => Unit])).map(TypeTree(_)), emptyValDef, body)
5462

5563
val t = ClassDef(NoMods, name.stateMachineT, Nil, template)
5664
typecheckClassDef(t)

Diff for: src/main/scala/scala/async/internal/ExprBuilder.scala

+38-18
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
*/
44
package scala.async.internal
55

6-
import scala.reflect.macros.Context
76
import scala.collection.mutable.ListBuffer
87
import collection.mutable
98
import language.existentials
@@ -34,18 +33,17 @@ trait ExprBuilder {
3433

3534
var stats: List[Tree]
3635

37-
def statsAnd(trees: List[Tree]): List[Tree] = {
38-
val body = stats match {
36+
def treesThenStats(trees: List[Tree]): List[Tree] = {
37+
(stats match {
3938
case init :+ last if tpeOf(last) =:= definitions.NothingTpe =>
40-
adaptToUnit(init :+ Typed(last, TypeTree(definitions.AnyTpe)))
39+
adaptToUnit((trees ::: init) :+ Typed(last, TypeTree(definitions.AnyTpe)))
4140
case _ =>
42-
adaptToUnit(stats)
43-
}
44-
Try(body, Nil, adaptToUnit(trees)) :: Nil
41+
adaptToUnit(trees ::: stats)
42+
}) :: Nil
4543
}
4644

4745
final def allStats: List[Tree] = this match {
48-
case a: AsyncStateWithAwait => statsAnd(a.awaitable.resultValDef :: Nil)
46+
case a: AsyncStateWithAwait => treesThenStats(a.awaitable.resultValDef :: Nil)
4947
case _ => stats
5048
}
5149

@@ -63,7 +61,7 @@ trait ExprBuilder {
6361
List(nextState)
6462

6563
def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = {
66-
mkHandlerCase(state, statsAnd(mkStateTree(nextState, symLookup) :: Nil))
64+
mkHandlerCase(state, treesThenStats(mkStateTree(nextState, symLookup) :: Nil))
6765
}
6866

6967
override val toString: String =
@@ -99,10 +97,10 @@ trait ExprBuilder {
9997
if (futureSystemOps.continueCompletedFutureOnSameThread)
10098
If(futureSystemOps.isCompleted(c.Expr[futureSystem.Fut[_]](awaitable.expr)).tree,
10199
adaptToUnit(ifIsFailureTree[T](futureSystemOps.getCompleted[Any](c.Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) :: Nil),
102-
Block(toList(callOnComplete), Return(literalUnit)))
100+
Block(toList(callOnComplete), Return(literalUnit))) :: Nil
103101
else
104-
Block(toList(callOnComplete), Return(literalUnit))
105-
mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), tryGetOrCallOnComplete))
102+
toList(callOnComplete) ::: Return(literalUnit) :: Nil
103+
mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup)) ++ tryGetOrCallOnComplete)
106104
}
107105

108106
private def tryGetTree(tryReference: => Tree) =
@@ -251,12 +249,17 @@ trait ExprBuilder {
251249
case LabelDef(name, _, _) => name.toString.startsWith("case")
252250
case _ => false
253251
}
254-
val (before, _ :: after) = (stats :+ expr).span(_ ne t)
255-
before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef)
252+
val span = (stats :+ expr).filterNot(isLiteralUnit).span(_ ne t)
253+
span match {
254+
case (before, _ :: after) =>
255+
before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef)
256+
case _ =>
257+
stats :+ expr
258+
}
256259
}
257260

258261
// populate asyncStates
259-
for (stat <- (stats :+ expr)) stat match {
262+
def add(stat: Tree): Unit = stat match {
260263
// the val name = await(..) pattern
261264
case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
262265
val onCompleteState = nextState()
@@ -315,10 +318,13 @@ trait ExprBuilder {
315318
asyncStates ++= builder.asyncStates
316319
currState = afterLabelState
317320
stateBuilder = new AsyncStateBuilder(currState, symLookup)
321+
case b @ Block(stats, expr) =>
322+
(stats :+ expr) foreach (add)
318323
case _ =>
319324
checkForUnsupportedAwait(stat)
320325
stateBuilder += stat
321326
}
327+
for (stat <- (stats :+ expr)) add(stat)
322328
val lastState = stateBuilder.resultSimple(endState)
323329
asyncStates += lastState
324330
}
@@ -357,8 +363,8 @@ trait ExprBuilder {
357363
val caseForLastState: CaseDef = {
358364
val lastState = asyncStates.last
359365
val lastStateBody = c.Expr[T](lastState.body)
360-
val rhs = futureSystemOps.completeProm(
361-
c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryySuccess[T](lastStateBody))
366+
val rhs = futureSystemOps.completeWithSuccess(
367+
c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), lastStateBody)
362368
mkHandlerCase(lastState.state, Block(rhs.tree, Return(literalUnit)))
363369
}
364370
asyncStates.toList match {
@@ -392,7 +398,10 @@ trait ExprBuilder {
392398
* }
393399
*/
394400
private def resumeFunTree[T: WeakTypeTag]: Tree = {
395-
val body = Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]))
401+
val stateMemberSymbol = symLookup.stateMachineMember(name.state)
402+
val stateMemberRef = symLookup.memberRef(name.state)
403+
val body = Match(stateMemberRef, mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ++ List(CaseDef(Ident(nme.WILDCARD), EmptyTree, Throw(Apply(Select(New(Ident(defn.IllegalStateExceptionClass)), termNames.CONSTRUCTOR), List())))))
404+
396405
Try(
397406
body,
398407
List(
@@ -462,13 +471,24 @@ trait ExprBuilder {
462471
private def tpeOf(t: Tree): Type = t match {
463472
case _ if t.tpe != null => t.tpe
464473
case Try(body, Nil, _) => tpeOf(body)
474+
case Block(_, expr) => tpeOf(expr)
475+
case Literal(Constant(value)) if value == () => definitions.UnitTpe
476+
case Return(_) => definitions.NothingTpe
465477
case _ => NoType
466478
}
467479

468480
private def adaptToUnit(rhs: List[Tree]): c.universe.Block = {
469481
rhs match {
482+
case (rhs: Block) :: Nil if tpeOf(rhs) <:< definitions.UnitTpe =>
483+
rhs
470484
case init :+ last if tpeOf(last) <:< definitions.UnitTpe =>
471485
Block(init, last)
486+
case init :+ (last @ Literal(Constant(()))) =>
487+
Block(init, last)
488+
case init :+ (last @ Block(_, Return(_) | Literal(Constant(())))) =>
489+
Block(init, last)
490+
case init :+ (Block(stats, expr)) =>
491+
Block(init, Block(stats :+ expr, literalUnit))
472492
case _ =>
473493
Block(rhs, literalUnit)
474494
}

Diff for: src/main/scala/scala/async/internal/FutureSystem.scala

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ trait FutureSystem {
3333
def promType[A: WeakTypeTag]: Type
3434
def tryType[A: WeakTypeTag]: Type
3535
def execContextType: Type
36+
def stateMachineClassParents: List[Type] = Nil
3637

3738
/** Create an empty promise */
3839
def createProm[A: WeakTypeTag]: Expr[Prom[A]]
@@ -55,6 +56,7 @@ trait FutureSystem {
5556

5657
/** Complete a promise with a value */
5758
def completeProm[A](prom: Expr[Prom[A]], value: Expr[Tryy[A]]): Expr[Unit]
59+
def completeWithSuccess[A: WeakTypeTag](prom: Expr[Prom[A]], value: Expr[A]): Expr[Unit] = completeProm(prom, tryySuccess(value))
5860

5961
def spawn(tree: Tree, execContext: Tree): Tree =
6062
future(c.Expr[Unit](tree))(c.Expr[ExecContext](execContext)).tree

0 commit comments

Comments
 (0)