Skip to content

Commit 2ac7c1c

Browse files
committed
Fixup and finish List optimisation
1 parent 90aea07 commit 2ac7c1c

File tree

4 files changed

+91
-12
lines changed

4 files changed

+91
-12
lines changed

Diff for: compiler/src/dotty/tools/dotc/core/Definitions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ class Definitions {
521521
def ListType: TypeRef = ListClass.typeRef
522522
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
523523
@tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply)
524+
def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List)
524525
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
525526
def NilType: TermRef = NilModule.termRef
526527
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")

Diff for: compiler/src/dotty/tools/dotc/transform/ArrayApply.scala

+19-4
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ class ArrayApply extends MiniPhase {
3434
case _ =>
3535
tree
3636

37-
else if isListOrSeqModuleApply(tree.symbol) then
37+
else if isSeqApply(tree) then
3838
tree.args match
3939
// <List or Seq>(a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types
4040
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil
4141
if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) &&
4242
rest.elems.lengthIs < transformListApplyLimit =>
43-
rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) =>
43+
val consed = rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) =>
4444
New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc))
45+
consed.cast(tree.tpe)
4546

4647
case _ =>
4748
tree
@@ -52,8 +53,22 @@ class ArrayApply extends MiniPhase {
5253
sym.name == nme.apply
5354
&& (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension)))
5455

55-
private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean =
56-
sym == defn.ListModule_apply || sym == defn.SeqModule_apply
56+
private def isListApply(tree: Tree)(using Context): Boolean =
57+
(tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match
58+
case Select(qual, _) =>
59+
val sym = qual.symbol
60+
sym == defn.ListModule
61+
|| sym == defn.ListModuleAlias
62+
case _ => false
63+
64+
private def isSeqApply(tree: Tree)(using Context): Boolean =
65+
isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match
66+
case Select(qual, _) =>
67+
val sym = qual.symbol
68+
sym == defn.SeqModule
69+
|| sym == defn.SeqModuleAlias
70+
|| sym == defn.CollectionSeqType.symbol.companionModule
71+
case _ => false
5772

5873
/** Only optimize when classtag if it is one of
5974
* - `ClassTag.apply(classOf[XYZ])`

Diff for: compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala

+58-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
package dotty.tools.backend.jvm
1+
package dotty.tools
2+
package backend.jvm
23

34
import org.junit.Test
45
import org.junit.Assert._
@@ -161,26 +162,76 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
161162
}
162163

163164
@Test def testListApplyAvoidsIntermediateArray = {
164-
val source =
165-
"""
165+
checkApplyAvoidsIntermediateArray("List"):
166+
"""import scala.collection.immutable.{ ::, Nil }
166167
|class Foo {
167168
| def meth1: List[String] = List("1", "2", "3")
168-
| def meth2: List[String] =
169-
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
169+
| def meth2: List[String] = new ::("1", new ::("2", new ::("3", Nil)))
170+
|}
171+
""".stripMargin
172+
}
173+
174+
@Test def testSeqApplyAvoidsIntermediateArray = {
175+
checkApplyAvoidsIntermediateArray("Seq"):
176+
"""import scala.collection.immutable.{ ::, Nil }
177+
|class Foo {
178+
| def meth1: Seq[String] = Seq("1", "2", "3")
179+
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
170180
|}
171181
""".stripMargin
182+
}
183+
184+
@Test def testSeqApplyAvoidsIntermediateArray2 = {
185+
checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"):
186+
"""import scala.collection.immutable.{ ::, Seq, Nil }
187+
|class Foo {
188+
| def meth1: Seq[String] = Seq("1", "2", "3")
189+
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
190+
|}
191+
""".stripMargin
192+
}
193+
194+
@Test def testSeqApplyAvoidsIntermediateArray3 = {
195+
checkApplyAvoidsIntermediateArray("scala.collection.Seq"):
196+
"""import scala.collection.immutable.{ ::, Nil }, scala.collection.Seq
197+
|class Foo {
198+
| def meth1: Seq[String] = Seq("1", "2", "3")
199+
| def meth2: Seq[String] = new ::("1", new ::("2", new ::("3", Nil)))
200+
|}
201+
""".stripMargin
202+
}
172203

204+
def checkApplyAvoidsIntermediateArray(name: String)(source: String) = {
173205
checkBCode(source) { dir =>
174206
val clsIn = dir.lookupName("Foo.class", directory = false).input
175207
val clsNode = loadClassNode(clsIn)
176208
val meth1 = getMethod(clsNode, "meth1")
177209
val meth2 = getMethod(clsNode, "meth2")
178210

179-
val instructions1 = instructionsFromMethod(meth1)
211+
val instructions1 = instructionsFromMethod(meth1) match
212+
case instr :+ TypeOp(CHECKCAST, _) :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) =>
213+
instr :+ ret
214+
case instr :+ TypeOp(CHECKCAST, _) :+ (ret @ Op(ARETURN)) =>
215+
// List.apply[?A] doesn't, strictly, return List[?A],
216+
// because it cascades to its definition on IterableFactory
217+
// where it returns CC[A]. The erasure of that is Object,
218+
// which is why Erasure's Typer adds a cast to compensate.
219+
// If we drop that cast while optimising (because using
220+
// the constructor for :: doesn't require the cast like
221+
// List.apply did) then then cons construction chain will
222+
// be typed as ::.
223+
// Unfortunately the LUB of :: and Nil.type is Product
224+
// instead of List, so a cast remains necessary,
225+
// across whatever causes the lub, like `if` or `try` branches.
226+
// Therefore if we dropping the cast may cause a needed cast
227+
// to be necessary, we shouldn't drop the cast,
228+
// which was only motivated by the assert here.
229+
instr :+ ret
230+
case instr => instr
180231
val instructions2 = instructionsFromMethod(meth2)
181232

182233
assert(instructions1 == instructions2,
183-
"the List.apply method " +
234+
s"the $name.apply method\n" +
184235
diffInstructions(instructions1, instructions2))
185236
}
186237
}

Diff for: tests/run/list-apply-eval.scala

+13-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ object Test:
66
counter += 1
77
counter.toString
88

9-
def main(args: Array[String]): Unit =
9+
def main(args: Array[String]): Unit =
1010
//List.apply is subject to an optimisation in cleanup
1111
//ensure that the arguments are evaluated in the currect order
1212
// Rewritten to:
@@ -19,3 +19,15 @@ object Test:
1919

2020
val emptyList = List[Int]()
2121
assert(emptyList == Nil)
22+
23+
// just assert it doesn't throw CCE to List
24+
val queue = scala.collection.mutable.Queue[String]()
25+
26+
// test for the cast instruction described in checkApplyAvoidsIntermediateArray
27+
def lub(b: Boolean): List[(String, String)] =
28+
if b then List(("foo", "bar")) else Nil
29+
30+
// from minimising CI failure in oslib
31+
// again, the lub of :: and Nil is Product, which breaks ++ (which requires IterableOnce)
32+
def lub2(b: Boolean): Unit =
33+
Seq(1) ++ (if (b) Seq(2) else Nil)

0 commit comments

Comments
 (0)