Skip to content

Commit ff6bebe

Browse files
Backport "List(...) optimization to avoid intermediate array" to LTS (#20779)
Backports #17166 to the LTS branch. PR submitted by the release tooling. [skip ci]
2 parents 12af044 + 063cddd commit ff6bebe

File tree

4 files changed

+309
-28
lines changed

4 files changed

+309
-28
lines changed

Diff for: compiler/src/dotty/tools/dotc/core/Definitions.scala

+15-11
Original file line numberDiff line numberDiff line change
@@ -513,14 +513,16 @@ class Definitions {
513513
methodNames.map(getWrapVarargsArrayModule.requiredMethod(_))
514514
})
515515

516-
@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
517-
def ListType: TypeRef = ListClass.typeRef
518-
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
519-
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
520-
def NilType: TermRef = NilModule.termRef
521-
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
522-
def ConsType: TypeRef = ConsClass.typeRef
523-
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")
516+
@tu lazy val ListClass: Symbol = requiredClass("scala.collection.immutable.List")
517+
def ListType: TypeRef = ListClass.typeRef
518+
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
519+
@tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply)
520+
def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List)
521+
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
522+
def NilType: TermRef = NilModule.termRef
523+
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
524+
def ConsType: TypeRef = ConsClass.typeRef
525+
@tu lazy val SeqFactoryClass: Symbol = requiredClass("scala.collection.SeqFactory")
524526

525527
@tu lazy val SingletonClass: ClassSymbol =
526528
// needed as a synthetic class because Scala 2.x refers to it in classfiles
@@ -530,16 +532,18 @@ class Definitions {
530532
List(AnyType), EmptyScope)
531533
@tu lazy val SingletonType: TypeRef = SingletonClass.typeRef
532534

533-
@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
534-
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
535+
@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
536+
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
537+
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
538+
@tu lazy val SeqModule_apply: Symbol = SeqModule.requiredMethod(nme.apply)
539+
def SeqModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.Seq)
535540
def SeqClass(using Context): ClassSymbol = SeqType.symbol.asClass
536541
@tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply)
537542
@tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head)
538543
@tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop)
539544
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
540545
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
541546
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
542-
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
543547

544548

545549
@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")

Diff for: compiler/src/dotty/tools/dotc/transform/ArrayApply.scala

+55-16
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
package dotty.tools.dotc
1+
package dotty.tools
2+
package dotc
23
package transform
34

4-
import core.*
5+
import ast.tpd
6+
import core.*, Contexts.*, Decorators.*, Symbols.*, Flags.*, StdNames.*
7+
import reporting.trace
8+
import util.Property
59
import MegaPhase.*
6-
import Contexts.*
7-
import Symbols.*
8-
import Flags.*
9-
import StdNames.*
10-
import dotty.tools.dotc.ast.tpd
11-
12-
1310

1411
/** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode.
1512
*
@@ -22,27 +19,69 @@ class ArrayApply extends MiniPhase {
2219

2320
override def description: String = ArrayApply.description
2421

25-
override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
22+
private val TransformListApplyBudgetKey = new Property.Key[Int]
23+
private def transformListApplyBudget(using Context) =
24+
ctx.property(TransformListApplyBudgetKey).getOrElse(8) // default is 8, as originally implemented in nsc
25+
26+
override def prepareForApply(tree: Apply)(using Context): Context = tree match
27+
case SeqApplyArgs(elems) =>
28+
ctx.fresh.setProperty(TransformListApplyBudgetKey, transformListApplyBudget - elems.length)
29+
case _ => ctx
30+
31+
override def transformApply(tree: Apply)(using Context): Tree =
2632
if isArrayModuleApply(tree.symbol) then
27-
tree.args match {
28-
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
33+
tree.args match
34+
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: ct :: Nil
2935
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
3036
seqLit
3137

32-
case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil
38+
case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: Nil
3339
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) =>
34-
tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)
40+
JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)
3541

3642
case _ =>
3743
tree
38-
}
3944

40-
else tree
45+
else tree match
46+
case SeqApplyArgs(elems) if transformListApplyBudget > 0 || elems.isEmpty =>
47+
val consed = elems.foldRight(ref(defn.NilModule)): (elem, acc) =>
48+
New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc))
49+
consed.cast(tree.tpe)
50+
case _ => tree
4151

4252
private def isArrayModuleApply(sym: Symbol)(using Context): Boolean =
4353
sym.name == nme.apply
4454
&& (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension)))
4555

56+
private def isListApply(tree: Tree)(using Context): Boolean =
57+
(tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match
58+
case Select(qual, _) =>
59+
val sym = qual.symbol
60+
sym == defn.ListModule
61+
|| sym == defn.ListModuleAlias
62+
case _ => false
63+
64+
private def isSeqApply(tree: Tree)(using Context): Boolean =
65+
isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match
66+
case Select(qual, _) =>
67+
val sym = qual.symbol
68+
sym == defn.SeqModule
69+
|| sym == defn.SeqModuleAlias
70+
|| sym == defn.CollectionSeqType.symbol.companionModule
71+
case _ => false
72+
73+
private object SeqApplyArgs:
74+
def unapply(tree: Apply)(using Context): Option[List[Tree]] =
75+
if isSeqApply(tree) then
76+
tree.args match
77+
// <List or Seq>(a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types
78+
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil
79+
if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) =>
80+
Some(rest.elems)
81+
case _ => None
82+
else None
83+
84+
4685
/** Only optimize when classtag if it is one of
4786
* - `ClassTag.apply(classOf[XYZ])`
4887
* - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ``

Diff for: compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala

+151-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
package dotty.tools.backend.jvm
1+
package dotty.tools
2+
package backend.jvm
23

34
import org.junit.Test
45
import org.junit.Assert._
@@ -160,4 +161,153 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
160161
}
161162
}
162163

164+
@Test def testListApplyAvoidsIntermediateArray = {
165+
checkApplyAvoidsIntermediateArray("List"):
166+
"""import scala.collection.immutable.{ ::, Nil }
167+
|class Foo {
168+
| def meth1: List[String] = List("1", "2", "3")
169+
| def meth2: List[String] = new ::("1", new ::("2", new ::("3", Nil)))
170+
|}
171+
""".stripMargin
172+
}
173+
174+
@Test def testSeqApplyAvoidsIntermediateArray = {
175+
checkApplyAvoidsIntermediateArray("Seq"):
176+
"""import scala.collection.immutable.{ ::, Nil }
177+
|class Foo {
178+
| def meth1: Seq[String] = Seq("1", "2", "3")
179+
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
180+
|}
181+
""".stripMargin
182+
}
183+
184+
@Test def testSeqApplyAvoidsIntermediateArray2 = {
185+
checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"):
186+
"""import scala.collection.immutable.{ ::, Seq, Nil }
187+
|class Foo {
188+
| def meth1: Seq[String] = Seq("1", "2", "3")
189+
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
190+
|}
191+
""".stripMargin
192+
}
193+
194+
@Test def testSeqApplyAvoidsIntermediateArray3 = {
195+
checkApplyAvoidsIntermediateArray("scala.collection.Seq"):
196+
"""import scala.collection.immutable.{ ::, Nil }, scala.collection.Seq
197+
|class Foo {
198+
| def meth1: Seq[String] = Seq("1", "2", "3")
199+
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
200+
|}
201+
""".stripMargin
202+
}
203+
204+
@Test def testListApplyAvoidsIntermediateArray_max1 = {
205+
checkApplyAvoidsIntermediateArray_examples("max1"):
206+
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", "7")
207+
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::("7", Nil)))))))
208+
""".stripMargin
209+
}
210+
211+
@Test def testListApplyAvoidsIntermediateArray_max2 = {
212+
checkApplyAvoidsIntermediateArray_examples("max2"):
213+
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", List[Object]())
214+
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::(Nil, Nil)))))))
215+
""".stripMargin
216+
}
217+
218+
@Test def testListApplyAvoidsIntermediateArray_max3 = {
219+
checkApplyAvoidsIntermediateArray_examples("max3"):
220+
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", List[Object]("6"))
221+
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::(new ::("6", Nil), Nil))))))
222+
""".stripMargin
223+
}
224+
225+
@Test def testListApplyAvoidsIntermediateArray_max4 = {
226+
checkApplyAvoidsIntermediateArray_examples("max4"):
227+
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", List[Object]("5", "6"))
228+
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::(new ::("5", new ::("6", Nil)), Nil)))))
229+
""".stripMargin
230+
}
231+
232+
@Test def testListApplyAvoidsIntermediateArray_over1 = {
233+
checkApplyAvoidsIntermediateArray_examples("over1"):
234+
""" def meth1: List[Object] = List("1", "2", "3", "4", "5", "6", "7", "8")
235+
| def meth2: List[Object] = List(wrapRefArray(Array("1", "2", "3", "4", "5", "6", "7", "8"))*)
236+
""".stripMargin
237+
}
238+
239+
@Test def testListApplyAvoidsIntermediateArray_over2 = {
240+
checkApplyAvoidsIntermediateArray_examples("over2"):
241+
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]())
242+
| def meth2: List[Object] = List(wrapRefArray(Array[Object]("1", "2", "3", "4", "5", "6", "7", Nil))*)
243+
""".stripMargin
244+
}
245+
246+
@Test def testListApplyAvoidsIntermediateArray_over3 = {
247+
checkApplyAvoidsIntermediateArray_examples("over3"):
248+
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7"))
249+
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::("6", new ::(List(wrapRefArray(Array[Object]("7"))*), Nil)))))))
250+
""".stripMargin
251+
}
252+
253+
@Test def testListApplyAvoidsIntermediateArray_over4 = {
254+
checkApplyAvoidsIntermediateArray_examples("over4"):
255+
""" def meth1: List[Object] = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7"))
256+
| def meth2: List[Object] = new ::("1", new ::("2", new ::("3", new ::("4", new ::("5", new ::(List(wrapRefArray(Array[Object]("6", "7"))*), Nil))))))
257+
""".stripMargin
258+
}
259+
260+
@Test def testListApplyAvoidsIntermediateArray_max5 = {
261+
checkApplyAvoidsIntermediateArray_examples("max5"):
262+
""" def meth1: List[Object] = List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object]())))))))
263+
| def meth2: List[Object] = new ::(new ::(new ::(new ::(new ::(new ::(new ::(Nil, Nil), Nil), Nil), Nil), Nil), Nil), Nil)
264+
""".stripMargin
265+
}
266+
267+
@Test def testListApplyAvoidsIntermediateArray_over5 = {
268+
checkApplyAvoidsIntermediateArray_examples("over5"):
269+
""" def meth1: List[Object] = List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object](List[Object]()))))))))
270+
| def meth2: List[Object] = new ::(new ::(new ::(new ::(new ::(new ::(new ::(List[Object](wrapRefArray(Array[Object](Nil))*), Nil), Nil), Nil), Nil), Nil), Nil), Nil)
271+
""".stripMargin
272+
}
273+
274+
@Test def testListApplyAvoidsIntermediateArray_max6 = {
275+
checkApplyAvoidsIntermediateArray_examples("max6"):
276+
""" def meth1: List[Object] = List[Object]("1", "2", List[Object]("3", "4", List[Object](List[Object]())))
277+
| def meth2: List[Object] = new ::("1", new ::("2", new ::(new ::("3", new ::("4", new ::(new ::(Nil, Nil), Nil))), Nil)))
278+
""".stripMargin
279+
}
280+
281+
@Test def testListApplyAvoidsIntermediateArray_over6 = {
282+
checkApplyAvoidsIntermediateArray_examples("over6"):
283+
""" def meth1: List[Object] = List[Object]("1", "2", List[Object]("3", "4", List[Object]("5")))
284+
| def meth2: List[Object] = new ::("1", new ::("2", new ::(new ::("3", new ::("4", new ::(new ::("5", Nil), Nil))), Nil)))
285+
""".stripMargin
286+
}
287+
288+
def checkApplyAvoidsIntermediateArray_examples(name: String)(body: String): Unit = {
289+
checkApplyAvoidsIntermediateArray(s"List_$name"):
290+
s"""import scala.collection.immutable.{ ::, Nil }, scala.runtime.ScalaRunTime.wrapRefArray
291+
|class Foo {
292+
|$body
293+
|}
294+
""".stripMargin
295+
}
296+
297+
def checkApplyAvoidsIntermediateArray(name: String)(source: String): Unit = {
298+
checkBCode(source) { dir =>
299+
val clsIn = dir.lookupName("Foo.class", directory = false).input
300+
val clsNode = loadClassNode(clsIn)
301+
val meth1 = getMethod(clsNode, "meth1")
302+
val meth2 = getMethod(clsNode, "meth2")
303+
304+
val instructions1 = instructionsFromMethod(meth1).filter { case TypeOp(CHECKCAST, _) => false case _ => true }
305+
val instructions2 = instructionsFromMethod(meth2).filter { case TypeOp(CHECKCAST, _) => false case _ => true }
306+
307+
assert(instructions1 == instructions2,
308+
s"the $name.apply method\n" +
309+
diffInstructions(instructions1, instructions2))
310+
}
311+
}
312+
163313
}

Diff for: tests/run/list-apply-eval.scala

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
object Test:
2+
3+
var counter = 0
4+
5+
def next =
6+
counter += 1
7+
counter.toString
8+
9+
def main(args: Array[String]): Unit =
10+
//List.apply is subject to an optimisation in cleanup
11+
//ensure that the arguments are evaluated in the currect order
12+
// Rewritten to:
13+
// val myList: List = new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), new collection.immutable.::(Test.this.next(), scala.collection.immutable.Nil)));
14+
val myList = List(next, next, next)
15+
assert(myList == List("1", "2", "3"), myList)
16+
17+
val mySeq = Seq(next, next, next)
18+
assert(mySeq == Seq("4", "5", "6"), mySeq)
19+
20+
val emptyList = List[Int]()
21+
assert(emptyList == Nil)
22+
23+
// just assert it doesn't throw CCE to List
24+
val queue = scala.collection.mutable.Queue[String]()
25+
26+
// test for the cast instruction described in checkApplyAvoidsIntermediateArray
27+
def lub(b: Boolean): List[(String, String)] =
28+
if b then List(("foo", "bar")) else Nil
29+
30+
// from minimising CI failure in oslib
31+
// again, the lub of :: and Nil is Product, which breaks ++ (which requires IterableOnce)
32+
def lub2(b: Boolean): Unit =
33+
Seq(1) ++ (if (b) Seq(2) else Nil)
34+
35+
// Examples of arity and nesting arity
36+
// to find the thresholds and reproduce the behaviour of nsc
37+
def examples(): Unit =
38+
val max1 = List[Object]("1", "2", "3", "4", "5", "6", "7") // 7 cons w/ 7 string heads + nil
39+
val max2 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]()) // 7 cons w/ 6 string heads + 1 nil head + nil
40+
val max3 = List[Object]("1", "2", "3", "4", "5", List[Object]("6"))
41+
val max4 = List[Object]("1", "2", "3", "4", List[Object]("5", "6"))
42+
43+
val over1 = List[Object]("1", "2", "3", "4", "5", "6", "7", "8") // wrap 8-sized array
44+
val over2 = List[Object]("1", "2", "3", "4", "5", "6", "7", List[Object]()) // wrap 8-sized array
45+
val over3 = List[Object]("1", "2", "3", "4", "5", "6", List[Object]("7")) // wrap 1-sized array with 7
46+
val over4 = List[Object]("1", "2", "3", "4", "5", List[Object]("6", "7")) // wrap 2
47+
48+
val max5 =
49+
List[Object](
50+
List[Object](
51+
List[Object](
52+
List[Object](
53+
List[Object](
54+
List[Object](
55+
List[Object](
56+
List[Object](
57+
)))))))) // 7 cons + 1 nil
58+
59+
val over5 =
60+
List[Object](
61+
List[Object](
62+
List[Object](
63+
List[Object](
64+
List[Object](
65+
List[Object](
66+
List[Object](
67+
List[Object]( List[Object]()
68+
)))))))) // 7 cons + 1-sized array wrapping nil
69+
70+
val max6 =
71+
List[Object]( // ::(
72+
"1", "2", List[Object]( // 1, ::(2, ::(::(
73+
"3", "4", List[Object]( // 3, ::(4, ::(::(
74+
List[Object]() // Nil, Nil
75+
) // ), Nil))
76+
) // ), Nil))
77+
) // )
78+
// 7 cons + 4 string heads + 4 nils for nested lists
79+
80+
val max7 =
81+
List[Object]( // ::(
82+
"1", "2", List[Object]( // 1, ::(2, ::(::(
83+
"3", "4", List[Object]( // 3, ::(4, ::(::(
84+
"5" // 5, Nil
85+
) // ), Nil))
86+
) // ), Nil))
87+
) // )
88+
// 7 cons + 5 string heads + 3 nils for nested lists

0 commit comments

Comments
 (0)