Skip to content

Commit 8873eb0

Browse files
authored
Merge pull request #14865 from dotty-staging/better-bytecode-for-tailrec-methods
Fix #14773: Reuse the param slots for the tailrec local mutable vars.
2 parents 8ab71d6 + 730883f commit 8873eb0

File tree

4 files changed

+142
-7
lines changed

4 files changed

+142
-7
lines changed

compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala

+43-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ package jvm
44

55
import scala.language.unsafeNulls
66

7+
import scala.annotation.tailrec
8+
79
import scala.collection.{ mutable, immutable }
810

911
import scala.tools.asm
@@ -484,6 +486,14 @@ trait BCodeSkelBuilder extends BCodeHelpers {
484486
slots.getOrElse(locSym, makeLocal(locSym))
485487
}
486488

489+
def reuseLocal(sym: Symbol, loc: Local): Unit =
490+
val existing = slots.put(sym, loc)
491+
if (existing.isDefined)
492+
report.error("attempt to create duplicate local var.", ctx.source.atSpan(sym.span))
493+
494+
def reuseThisSlot(sym: Symbol): Unit =
495+
reuseLocal(sym, Local(symInfoTK(sym), sym.javaSimpleName, 0, sym.is(Synthetic)))
496+
487497
private def makeLocal(sym: Symbol, tk: BType): Local = {
488498
assert(nxtIdx != -1, "not a valid start index")
489499
val loc = Local(tk, sym.javaSimpleName, nxtIdx, sym.is(Synthetic))
@@ -753,18 +763,47 @@ trait BCodeSkelBuilder extends BCodeHelpers {
753763
.addFlagIf(isNative, asm.Opcodes.ACC_NATIVE) // native methods of objects are generated in mirror classes
754764

755765
// TODO needed? for(ann <- m.symbol.annotations) { ann.symbol.initialize }
756-
initJMethod(flags, params.map(_.symbol))
766+
val paramSyms = params.map(_.symbol)
767+
initJMethod(flags, paramSyms)
757768

758769

759770
if (!isAbstractMethod && !isNative) {
771+
// #14773 Reuse locals slots for tailrec-generated mutable vars
772+
val trimmedRhs: Tree =
773+
@tailrec def loop(stats: List[Tree]): List[Tree] =
774+
stats match
775+
case (tree @ ValDef(TailLocalName(_, _), _, _)) :: rest if tree.symbol.isAllOf(Mutable | Synthetic) =>
776+
tree.rhs match
777+
case This(_) =>
778+
locals.reuseThisSlot(tree.symbol)
779+
loop(rest)
780+
case rhs: Ident if paramSyms.contains(rhs.symbol) =>
781+
locals.reuseLocal(tree.symbol, locals(rhs.symbol))
782+
loop(rest)
783+
case _ =>
784+
stats
785+
case _ =>
786+
stats
787+
end loop
788+
789+
rhs match
790+
case Block(stats, expr) =>
791+
val trimmedStats = loop(stats)
792+
if trimmedStats eq stats then
793+
rhs
794+
else
795+
Block(trimmedStats, expr)
796+
case _ =>
797+
rhs
798+
end trimmedRhs
760799

761800
def emitNormalMethodBody(): Unit = {
762801
val veryFirstProgramPoint = currProgramPoint()
763-
genLoad(rhs, returnType)
802+
genLoad(trimmedRhs, returnType)
764803

765-
rhs match {
804+
trimmedRhs match {
766805
case (_: Return) | Block(_, (_: Return)) => ()
767-
case (_: Apply) | Block(_, (_: Apply)) if rhs.symbol eq defn.throwMethod => ()
806+
case (_: Apply) | Block(_, (_: Apply)) if trimmedRhs.symbol eq defn.throwMethod => ()
768807
case tpd.EmptyTree =>
769808
report.error("Concrete method has no definition: " + dd + (
770809
if (ctx.settings.Ydebug.value) "(found: " + methSymbol.owner.info.decls.toList.mkString(", ") + ")"

compiler/src/dotty/tools/dotc/transform/TailRec.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ class TailRec extends MiniPhase {
253253
val tpe =
254254
if (enclosingClass.is(Module)) enclosingClass.thisType
255255
else enclosingClass.classInfo.selfType
256-
val sym = newSymbol(method, nme.SELF, Synthetic | Mutable, tpe)
256+
val sym = newSymbol(method, TailLocalName.fresh(nme.SELF), Synthetic | Mutable, tpe)
257257
varForRewrittenThis = Some(sym)
258258
sym
259259
}

compiler/test/dotty/tools/backend/jvm/DottyBytecodeTest.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ trait DottyBytecodeTest {
122122
def assertSameCode(method: MethodNode, expected: List[Instruction]): Unit =
123123
assertSameCode(instructionsFromMethod(method).dropNonOp, expected)
124124
def assertSameCode(actual: List[Instruction], expected: List[Instruction]): Unit = {
125-
assert(actual === expected, s"\nExpected: $expected\nActual : $actual")
125+
assert(actual === expected, "\n" + diffInstructions(actual, expected))
126126
}
127127

128128
def assertInvoke(m: MethodNode, receiver: String, method: String): Unit =
@@ -296,4 +296,3 @@ trait DottyBytecodeTest {
296296
object DottyBytecodeTest {
297297
extension [T](l: List[T]) def stringLines = l.mkString("\n")
298298
}
299-

compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala

+97
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,103 @@ class TestBCode extends DottyBytecodeTest {
946946
}
947947
}
948948

949+
@Test def i14773TailRecReuseParamSlots(): Unit = {
950+
val source =
951+
s"""class Foo {
952+
| @scala.annotation.tailrec // explicit @tailrec here
953+
| final def fact(n: Int, acc: Int): Int =
954+
| if n == 0 then acc
955+
| else fact(n - 1, acc * n)
956+
|}
957+
|
958+
|class IntList(head: Int, tail: IntList | Null) {
959+
| // implicit @tailrec
960+
| final def sum(acc: Int): Int =
961+
| val t = tail
962+
| if t == null then acc + head
963+
| else t.sum(acc + head)
964+
|}
965+
""".stripMargin
966+
967+
checkBCode(source) { dir =>
968+
// The mutable local vars for n and acc reuse the slots of the params n and acc
969+
970+
val fooClass = loadClassNode(dir.lookupName("Foo.class", directory = false).input)
971+
val factMeth = getMethod(fooClass, "fact")
972+
973+
assertSameCode(factMeth, List(
974+
Label(0),
975+
VarOp(ILOAD, 1),
976+
Op(ICONST_0),
977+
Jump(IF_ICMPNE, Label(7)),
978+
VarOp(ILOAD, 2),
979+
Jump(GOTO, Label(26)),
980+
Label(7),
981+
VarOp(ALOAD, 0),
982+
VarOp(ASTORE, 3),
983+
VarOp(ILOAD, 1),
984+
Op(ICONST_1),
985+
Op(ISUB),
986+
VarOp(ISTORE, 4),
987+
VarOp(ILOAD, 2),
988+
VarOp(ILOAD, 1),
989+
Op(IMUL),
990+
VarOp(ISTORE, 5),
991+
VarOp(ALOAD, 3),
992+
VarOp(ASTORE, 0),
993+
VarOp(ILOAD, 4),
994+
VarOp(ISTORE, 1),
995+
VarOp(ILOAD, 5),
996+
VarOp(ISTORE, 2),
997+
Jump(GOTO, Label(29)),
998+
Label(26),
999+
Op(IRETURN),
1000+
Label(29),
1001+
Jump(GOTO, Label(0)),
1002+
Op(NOP),
1003+
Op(ATHROW),
1004+
))
1005+
1006+
// The mutable local vars for this and acc reuse the slots of `this` and of the param acc
1007+
1008+
val intListClass = loadClassNode(dir.lookupName("IntList.class", directory = false).input)
1009+
val sumMeth = getMethod(intListClass, "sum")
1010+
1011+
assertSameCode(sumMeth, List(
1012+
Label(0),
1013+
VarOp(ALOAD, 0),
1014+
Field(GETFIELD, "IntList", "tail", "LIntList;"),
1015+
VarOp(ASTORE, 2),
1016+
VarOp(ALOAD, 2),
1017+
Jump(IFNONNULL, Label(12)),
1018+
VarOp(ILOAD, 1),
1019+
VarOp(ALOAD, 0),
1020+
Field(GETFIELD, "IntList", "head", "I"),
1021+
Op(IADD),
1022+
Jump(GOTO, Label(26)),
1023+
Label(12),
1024+
VarOp(ALOAD, 2),
1025+
VarOp(ASTORE, 3),
1026+
VarOp(ILOAD, 1),
1027+
VarOp(ALOAD, 0),
1028+
Field(GETFIELD, "IntList", "head", "I"),
1029+
Op(IADD),
1030+
VarOp(ISTORE, 4),
1031+
VarOp(ALOAD, 3),
1032+
VarOp(ASTORE, 0),
1033+
VarOp(ILOAD, 4),
1034+
VarOp(ISTORE, 1),
1035+
Jump(GOTO, Label(29)),
1036+
Label(26),
1037+
Op(IRETURN),
1038+
Label(29),
1039+
Jump(GOTO, Label(0)),
1040+
Op(NOP),
1041+
Op(ATHROW),
1042+
))
1043+
}
1044+
}
1045+
9491046
@Test
9501047
def getClazz: Unit = {
9511048
val source = """

0 commit comments

Comments
 (0)