Skip to content

Fix #14773: Reuse the param slots for the tailrec local mutable vars. #14865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package jvm

import scala.language.unsafeNulls

import scala.annotation.tailrec

import scala.collection.{ mutable, immutable }

import scala.tools.asm
Expand Down Expand Up @@ -484,6 +486,14 @@ trait BCodeSkelBuilder extends BCodeHelpers {
slots.getOrElse(locSym, makeLocal(locSym))
}

def reuseLocal(sym: Symbol, loc: Local): Unit =
val existing = slots.put(sym, loc)
if (existing.isDefined)
report.error("attempt to create duplicate local var.", ctx.source.atSpan(sym.span))

def reuseThisSlot(sym: Symbol): Unit =
reuseLocal(sym, Local(symInfoTK(sym), sym.javaSimpleName, 0, sym.is(Synthetic)))

private def makeLocal(sym: Symbol, tk: BType): Local = {
assert(nxtIdx != -1, "not a valid start index")
val loc = Local(tk, sym.javaSimpleName, nxtIdx, sym.is(Synthetic))
Expand Down Expand Up @@ -753,18 +763,47 @@ trait BCodeSkelBuilder extends BCodeHelpers {
.addFlagIf(isNative, asm.Opcodes.ACC_NATIVE) // native methods of objects are generated in mirror classes

// TODO needed? for(ann <- m.symbol.annotations) { ann.symbol.initialize }
initJMethod(flags, params.map(_.symbol))
val paramSyms = params.map(_.symbol)
initJMethod(flags, paramSyms)


if (!isAbstractMethod && !isNative) {
// #14773 Reuse locals slots for tailrec-generated mutable vars
val trimmedRhs: Tree =
@tailrec def loop(stats: List[Tree]): List[Tree] =
stats match
case (tree @ ValDef(TailLocalName(_, _), _, _)) :: rest if tree.symbol.isAllOf(Mutable | Synthetic) =>
tree.rhs match
case This(_) =>
locals.reuseThisSlot(tree.symbol)
loop(rest)
case rhs: Ident if paramSyms.contains(rhs.symbol) =>
locals.reuseLocal(tree.symbol, locals(rhs.symbol))
loop(rest)
case _ =>
stats
case _ =>
stats
end loop

rhs match
case Block(stats, expr) =>
val trimmedStats = loop(stats)
if trimmedStats eq stats then
rhs
else
Block(trimmedStats, expr)
case _ =>
rhs
end trimmedRhs

def emitNormalMethodBody(): Unit = {
val veryFirstProgramPoint = currProgramPoint()
genLoad(rhs, returnType)
genLoad(trimmedRhs, returnType)

rhs match {
trimmedRhs match {
case (_: Return) | Block(_, (_: Return)) => ()
case (_: Apply) | Block(_, (_: Apply)) if rhs.symbol eq defn.throwMethod => ()
case (_: Apply) | Block(_, (_: Apply)) if trimmedRhs.symbol eq defn.throwMethod => ()
case tpd.EmptyTree =>
report.error("Concrete method has no definition: " + dd + (
if (ctx.settings.Ydebug.value) "(found: " + methSymbol.owner.info.decls.toList.mkString(", ") + ")"
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/TailRec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class TailRec extends MiniPhase {
val tpe =
if (enclosingClass.is(Module)) enclosingClass.thisType
else enclosingClass.classInfo.selfType
val sym = newSymbol(method, nme.SELF, Synthetic | Mutable, tpe)
val sym = newSymbol(method, TailLocalName.fresh(nme.SELF), Synthetic | Mutable, tpe)
varForRewrittenThis = Some(sym)
sym
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ trait DottyBytecodeTest {
def assertSameCode(method: MethodNode, expected: List[Instruction]): Unit =
assertSameCode(instructionsFromMethod(method).dropNonOp, expected)
def assertSameCode(actual: List[Instruction], expected: List[Instruction]): Unit = {
assert(actual === expected, s"\nExpected: $expected\nActual : $actual")
assert(actual === expected, "\n" + diffInstructions(actual, expected))
}

def assertInvoke(m: MethodNode, receiver: String, method: String): Unit =
Expand Down Expand Up @@ -296,4 +296,3 @@ trait DottyBytecodeTest {
object DottyBytecodeTest {
extension [T](l: List[T]) def stringLines = l.mkString("\n")
}

97 changes: 97 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,103 @@ class TestBCode extends DottyBytecodeTest {
}
}

@Test def i14773TailRecReuseParamSlots(): Unit = {
val source =
s"""class Foo {
| @scala.annotation.tailrec // explicit @tailrec here
| final def fact(n: Int, acc: Int): Int =
| if n == 0 then acc
| else fact(n - 1, acc * n)
|}
|
|class IntList(head: Int, tail: IntList | Null) {
| // implicit @tailrec
| final def sum(acc: Int): Int =
| val t = tail
| if t == null then acc + head
| else t.sum(acc + head)
|}
""".stripMargin

checkBCode(source) { dir =>
// The mutable local vars for n and acc reuse the slots of the params n and acc

val fooClass = loadClassNode(dir.lookupName("Foo.class", directory = false).input)
val factMeth = getMethod(fooClass, "fact")

assertSameCode(factMeth, List(
Label(0),
VarOp(ILOAD, 1),
Op(ICONST_0),
Jump(IF_ICMPNE, Label(7)),
VarOp(ILOAD, 2),
Jump(GOTO, Label(26)),
Label(7),
VarOp(ALOAD, 0),
VarOp(ASTORE, 3),
VarOp(ILOAD, 1),
Op(ICONST_1),
Op(ISUB),
VarOp(ISTORE, 4),
VarOp(ILOAD, 2),
VarOp(ILOAD, 1),
Op(IMUL),
VarOp(ISTORE, 5),
VarOp(ALOAD, 3),
VarOp(ASTORE, 0),
VarOp(ILOAD, 4),
VarOp(ISTORE, 1),
VarOp(ILOAD, 5),
VarOp(ISTORE, 2),
Jump(GOTO, Label(29)),
Label(26),
Op(IRETURN),
Label(29),
Jump(GOTO, Label(0)),
Op(NOP),
Op(ATHROW),
))

// The mutable local vars for this and acc reuse the slots of `this` and of the param acc

val intListClass = loadClassNode(dir.lookupName("IntList.class", directory = false).input)
val sumMeth = getMethod(intListClass, "sum")

assertSameCode(sumMeth, List(
Label(0),
VarOp(ALOAD, 0),
Field(GETFIELD, "IntList", "tail", "LIntList;"),
VarOp(ASTORE, 2),
VarOp(ALOAD, 2),
Jump(IFNONNULL, Label(12)),
VarOp(ILOAD, 1),
VarOp(ALOAD, 0),
Field(GETFIELD, "IntList", "head", "I"),
Op(IADD),
Jump(GOTO, Label(26)),
Label(12),
VarOp(ALOAD, 2),
VarOp(ASTORE, 3),
VarOp(ILOAD, 1),
VarOp(ALOAD, 0),
Field(GETFIELD, "IntList", "head", "I"),
Op(IADD),
VarOp(ISTORE, 4),
VarOp(ALOAD, 3),
VarOp(ASTORE, 0),
VarOp(ILOAD, 4),
VarOp(ISTORE, 1),
Jump(GOTO, Label(29)),
Label(26),
Op(IRETURN),
Label(29),
Jump(GOTO, Label(0)),
Op(NOP),
Op(ATHROW),
))
}
}

@Test
def getClazz: Unit = {
val source = """
Expand Down