Skip to content

Commit 591db4d

Browse files
authored
Merge pull request #15060 from dwijnand/tuple-specialisation
Support tuple specialisation
2 parents d522766 + 4c0ab7b commit 591db4d

17 files changed

+284
-74
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class Compiler {
9191
new InterceptedMethods, // Special handling of `==`, `|=`, `getClass` methods
9292
new Getters, // Replace non-private vals and vars with getter defs (fields are added later)
9393
new SpecializeFunctions, // Specialized Function{0,1,2} by replacing super with specialized super
94+
new SpecializeTuples, // Specializes Tuples by replacing tuple construction and selection trees
9495
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
9596
new CollectNullableFields, // Collect fields that can be nulled out after use in lazy initialization
9697
new ElimOuterSelect, // Expand outer selections

compiler/src/dotty/tools/dotc/Run.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
282282

283283
private def printTree(last: PrintedTree)(using Context): PrintedTree = {
284284
val unit = ctx.compilationUnit
285-
val prevPhase = ctx.phase.prev // can be a mini-phase
286-
val fusedPhase = ctx.base.fusedContaining(prevPhase)
285+
val fusedPhase = ctx.phase.prevMega
287286
val echoHeader = f"[[syntax trees at end of $fusedPhase%25s]] // ${unit.source}"
288287
val tree = if ctx.isAfterTyper then unit.tpdTree else unit.untpdTree
289288
val treeString = tree.show(using ctx.withProperty(XprintMode, Some(())))

compiler/src/dotty/tools/dotc/core/Definitions.scala

+27
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,12 @@ class Definitions {
13261326

13271327
@tu lazy val TupleType: Array[TypeRef | Null] = mkArityArray("scala.Tuple", MaxTupleArity, 1)
13281328

1329+
def isSpecializedTuple(cls: Symbol)(using Context): Boolean =
1330+
cls.isClass && TupleSpecializedClasses.exists(tupleCls => cls.name.isSpecializedNameOf(tupleCls.name))
1331+
1332+
def SpecializedTuple(base: Symbol, args: List[Type])(using Context): Symbol =
1333+
base.owner.requiredClass(base.name.specializedName(args))
1334+
13291335
private class FunType(prefix: String):
13301336
private var classRefs: Array[TypeRef | Null] = new Array(22)
13311337
def apply(n: Int): TypeRef =
@@ -1584,6 +1590,20 @@ class Definitions {
15841590
def isFunctionType(tp: Type)(using Context): Boolean =
15851591
isNonRefinedFunction(tp.dropDependentRefinement)
15861592

1593+
private def withSpecMethods(cls: ClassSymbol, bases: List[Name], paramTypes: Set[TypeRef]) =
1594+
for base <- bases; tp <- paramTypes do
1595+
cls.enter(newSymbol(cls, base.specializedName(List(tp)), Method, ExprType(tp)))
1596+
cls
1597+
1598+
@tu lazy val Tuple1: ClassSymbol = withSpecMethods(requiredClass("scala.Tuple1"), List(nme._1), Tuple1SpecializedParamTypes)
1599+
@tu lazy val Tuple2: ClassSymbol = withSpecMethods(requiredClass("scala.Tuple2"), List(nme._1, nme._2), Tuple2SpecializedParamTypes)
1600+
1601+
@tu lazy val TupleSpecializedClasses: Set[Symbol] = Set(Tuple1, Tuple2)
1602+
@tu lazy val Tuple1SpecializedParamTypes: Set[TypeRef] = Set(IntType, LongType, DoubleType)
1603+
@tu lazy val Tuple2SpecializedParamTypes: Set[TypeRef] = Set(IntType, LongType, DoubleType, CharType, BooleanType)
1604+
@tu lazy val Tuple1SpecializedParamClasses: PerRun[Set[Symbol]] = new PerRun(Tuple1SpecializedParamTypes.map(_.symbol))
1605+
@tu lazy val Tuple2SpecializedParamClasses: PerRun[Set[Symbol]] = new PerRun(Tuple2SpecializedParamTypes.map(_.symbol))
1606+
15871607
// Specialized type parameters defined for scala.Function{0,1,2}.
15881608
@tu lazy val Function1SpecializedParamTypes: collection.Set[TypeRef] =
15891609
Set(IntType, LongType, FloatType, DoubleType)
@@ -1607,6 +1627,13 @@ class Definitions {
16071627
@tu lazy val Function2SpecializedReturnClasses: PerRun[collection.Set[Symbol]] =
16081628
new PerRun(Function2SpecializedReturnTypes.map(_.symbol))
16091629

1630+
def isSpecializableTuple(base: Symbol, args: List[Type])(using Context): Boolean =
1631+
args.length <= 2 && base.isClass && TupleSpecializedClasses.exists(base.asClass.derivesFrom) && args.match
1632+
case List(x) => Tuple1SpecializedParamClasses().contains(x.classSymbol)
1633+
case List(x, y) => Tuple2SpecializedParamClasses().contains(x.classSymbol) && Tuple2SpecializedParamClasses().contains(y.classSymbol)
1634+
case _ => false
1635+
&& base.owner.denot.info.member(base.name.specializedName(args)).exists // when dotc compiles the stdlib there are no specialised classes
1636+
16101637
def isSpecializableFunction(cls: ClassSymbol, paramTypes: List[Type], retType: Type)(using Context): Boolean =
16111638
paramTypes.length <= 2
16121639
&& (cls.derivesFrom(FunctionClass(paramTypes.length)) || isByNameFunctionClass(cls))

compiler/src/dotty/tools/dotc/core/NameOps.scala

+24
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import scala.io.Codec
88
import Int.MaxValue
99
import Names._, StdNames._, Contexts._, Symbols._, Flags._, NameKinds._, Types._
1010
import util.Chars.{isOperatorPart, digit2int}
11+
import Decorators.*
1112
import Definitions._
1213
import nme._
1314

@@ -278,6 +279,29 @@ object NameOps {
278279
classTags.fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix)
279280
}
280281

282+
/** Determines if the current name is the specialized name of the given base name.
283+
* For example `typeName("Tuple2$mcII$sp").isSpecializedNameOf(tpnme.Tuple2) == true`
284+
*/
285+
def isSpecializedNameOf(base: N)(using Context): Boolean =
286+
var i = 0
287+
inline def nextString(str: String) = name.startsWith(str, i) && { i += str.length; true }
288+
nextString(base.toString)
289+
&& nextString(nme.specializedTypeNames.prefix.toString)
290+
&& nextString(nme.specializedTypeNames.separator.toString)
291+
&& name.endsWith(nme.specializedTypeNames.suffix.toString)
292+
293+
/** Returns the name of the class specialised to the provided types,
294+
* in the given order. Used for the specialized tuple classes.
295+
*/
296+
def specializedName(args: List[Type])(using Context): N =
297+
val sb = new StringBuilder
298+
sb.append(name.toString)
299+
sb.append(nme.specializedTypeNames.prefix.toString)
300+
sb.append(nme.specializedTypeNames.separator)
301+
args.foreach { arg => sb.append(defn.typeTag(arg)) }
302+
sb.append(nme.specializedTypeNames.suffix)
303+
likeSpacedN(termName(sb.toString))
304+
281305
/** Use for specializing function names ONLY and use it if you are **not**
282306
* creating specialized name from type parameters. The order of names will
283307
* be:

compiler/src/dotty/tools/dotc/core/Phases.scala

+3
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,9 @@ object Phases {
402402
final def prev: Phase =
403403
if (id > FirstPhaseId) myBase.phases(start - 1) else NoPhase
404404

405+
final def prevMega(using Context): Phase =
406+
ctx.base.fusedContaining(ctx.phase.prev)
407+
405408
final def next: Phase =
406409
if (hasNext) myBase.phases(end + 1) else NoPhase
407410

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

+4
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ object SymDenotations {
9292
if (myFlags.is(Trait)) NoInitsInterface & bodyFlags // no parents are initialized from a trait
9393
else NoInits & bodyFlags & parentFlags)
9494

95+
final def setStableConstructor()(using Context): Unit =
96+
val ctorStable = if myFlags.is(Trait) then myFlags.is(NoInits) else isNoInitsRealClass
97+
if ctorStable then primaryConstructor.setFlag(StableRealizable)
98+
9599
def isCurrent(fs: FlagSet)(using Context): Boolean =
96100
def knownFlags(info: Type): FlagSet = info match
97101
case _: SymbolLoader | _: ModuleCompleter => FromStartFlags

compiler/src/dotty/tools/dotc/core/classfile/ClassfileParser.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,10 @@ class ClassfileParser(
988988
else return unpickleTASTY(bytes)
989989
}
990990

991-
if (scan(tpnme.ScalaATTR) && !scalaUnpickleWhitelist.contains(classRoot.name))
991+
if scan(tpnme.ScalaATTR) && !scalaUnpickleWhitelist.contains(classRoot.name)
992+
&& !(classRoot.name.startsWith("Tuple") && classRoot.name.endsWith("$sp"))
993+
&& !(classRoot.name.startsWith("Product") && classRoot.name.endsWith("$sp"))
994+
then
992995
// To understand the situation, it's helpful to know that:
993996
// - Scalac emits the `ScalaSig` attribute for classfiles with pickled information
994997
// and the `Scala` attribute for everything else.

compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,9 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
617617
// we need the checkNonCyclic call to insert LazyRefs for F-bounded cycles
618618
else if (!denot.is(Param)) tp1.translateFromRepeated(toArray = false)
619619
else tp1
620-
if (denot.isConstructor) addConstructorTypeParams(denot)
620+
if (denot.isConstructor)
621+
denot.owner.setStableConstructor()
622+
addConstructorTypeParams(denot)
621623
if (atEnd)
622624
assert(!denot.symbol.isSuperAccessor, denot)
623625
else {

compiler/src/dotty/tools/dotc/printing/Formatting.scala

+10-1
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,16 @@ object Formatting {
3333
object ShowAny extends Show[Any]:
3434
def show(x: Any): Shown = x
3535

36-
class ShowImplicits2:
36+
class ShowImplicits3:
3737
given Show[Product] = ShowAny
3838

39+
class ShowImplicits2 extends ShowImplicits3:
40+
given Show[ParamInfo] with
41+
def show(x: ParamInfo) = x match
42+
case x: Symbol => Show[x.type].show(x)
43+
case x: LambdaParam => Show[x.type].show(x)
44+
case _ => ShowAny
45+
3946
class ShowImplicits1 extends ShowImplicits2:
4047
given Show[ImplicitRef] = ShowAny
4148
given Show[Names.Designator] = ShowAny
@@ -99,6 +106,8 @@ object Formatting {
99106
val sep = StringContext.processEscapes(rawsep)
100107
if (rest.nonEmpty) (arg.map(showArg).mkString(sep), rest.tail)
101108
else (arg, suffix)
109+
case arg: Seq[?] =>
110+
(arg.map(showArg).mkString("[", ", ", "]"), suffix)
102111
case _ =>
103112
(showArg(arg), suffix)
104113
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package dotty.tools
2+
package dotc
3+
package transform
4+
5+
import ast.Trees.*, ast.tpd, core.*
6+
import Contexts.*, Types.*, Decorators.*, Symbols.*, DenotTransformers.*
7+
import SymDenotations.*, Scopes.*, StdNames.*, NameOps.*, Names.*
8+
import MegaPhase.MiniPhase
9+
import typer.Inliner.isElideableExpr
10+
11+
/** Specializes Tuples by replacing tuple construction and selection trees.
12+
*
13+
* Specifically:
14+
* 1. Replaces `(1, 1)` (which is `Tuple2.apply[Int, Int](1, 1)`) and
15+
* `new Tuple2[Int, Int](1, 1)` with `new Tuple2$mcII$sp(1, 1)`.
16+
* 2. Replaces `(_: Tuple2[Int, Int])._1` with `(_: Tuple2[Int, Int])._1$mcI$sp`
17+
*/
18+
class SpecializeTuples extends MiniPhase:
19+
import tpd.*
20+
21+
override def phaseName: String = SpecializeTuples.name
22+
override def description: String = SpecializeTuples.description
23+
override def isEnabled(using Context): Boolean = !ctx.settings.scalajs.value
24+
25+
override def transformApply(tree: Apply)(using Context): Tree = tree match
26+
case Apply(TypeApply(fun: NameTree, targs), args)
27+
if fun.symbol.name == nme.apply && fun.symbol.exists && defn.isSpecializableTuple(fun.symbol.owner.companionClass, targs.map(_.tpe))
28+
&& isElideableExpr(tree)
29+
=>
30+
cpy.Apply(tree)(Select(New(defn.SpecializedTuple(fun.symbol.owner.companionClass, targs.map(_.tpe)).typeRef), nme.CONSTRUCTOR), args).withType(tree.tpe)
31+
case Apply(TypeApply(fun: NameTree, targs), args)
32+
if fun.symbol.name == nme.CONSTRUCTOR && fun.symbol.exists && defn.isSpecializableTuple(fun.symbol.owner, targs.map(_.tpe))
33+
&& isElideableExpr(tree)
34+
=>
35+
cpy.Apply(tree)(Select(New(defn.SpecializedTuple(fun.symbol.owner, targs.map(_.tpe)).typeRef), nme.CONSTRUCTOR), args).withType(tree.tpe)
36+
case _ => tree
37+
end transformApply
38+
39+
override def transformSelect(tree: Select)(using Context): Tree = tree match
40+
case Select(qual, nme._1) if isAppliedSpecializableTuple(qual.tpe.widen) =>
41+
Select(qual, nme._1.specializedName(qual.tpe.widen.argInfos.slice(0, 1)))
42+
case Select(qual, nme._2) if isAppliedSpecializableTuple(qual.tpe.widen) =>
43+
Select(qual, nme._2.specializedName(qual.tpe.widen.argInfos.slice(1, 2)))
44+
case _ => tree
45+
46+
private def isAppliedSpecializableTuple(tp: Type)(using Context) = tp match
47+
case AppliedType(tycon, args) => defn.isSpecializableTuple(tycon.classSymbol, args)
48+
case _ => false
49+
end SpecializeTuples
50+
51+
object SpecializeTuples:
52+
val name: String = "specializeTuples"
53+
val description: String = "replaces tuple construction and selection trees"

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

+10-4
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class TreeChecker extends Phase with SymTransformer {
9191
if (ctx.phaseId <= erasurePhase.id) {
9292
val initial = symd.initial
9393
assert(symd == initial || symd.signature == initial.signature,
94-
i"""Signature of ${sym.showLocated} changed at phase ${ctx.base.fusedContaining(ctx.phase.prev)}
94+
i"""Signature of ${sym.showLocated} changed at phase ${ctx.phase.prevMega}
9595
|Initial info: ${initial.info}
9696
|Initial sig : ${initial.signature}
9797
|Current info: ${symd.info}
@@ -122,8 +122,7 @@ class TreeChecker extends Phase with SymTransformer {
122122
}
123123

124124
def check(phasesToRun: Seq[Phase], ctx: Context): Tree = {
125-
val prevPhase = ctx.phase.prev // can be a mini-phase
126-
val fusedPhase = ctx.base.fusedContaining(prevPhase)
125+
val fusedPhase = ctx.phase.prevMega(using ctx)
127126
report.echo(s"checking ${ctx.compilationUnit} after phase ${fusedPhase}")(using ctx)
128127

129128
inContext(ctx) {
@@ -145,7 +144,7 @@ class TreeChecker extends Phase with SymTransformer {
145144
catch {
146145
case NonFatal(ex) => //TODO CHECK. Check that we are bootstrapped
147146
inContext(checkingCtx) {
148-
println(i"*** error while checking ${ctx.compilationUnit} after phase ${ctx.phase.prev} ***")
147+
println(i"*** error while checking ${ctx.compilationUnit} after phase ${ctx.phase.prevMega(using ctx)} ***")
149148
}
150149
throw ex
151150
}
@@ -422,6 +421,13 @@ class TreeChecker extends Phase with SymTransformer {
422421
assert(tree.qual.typeOpt.isInstanceOf[ThisType], i"expect prefix of Super to be This, actual = ${tree.qual}")
423422
super.typedSuper(tree, pt)
424423

424+
override def typedApply(tree: untpd.Apply, pt: Type)(using Context): Tree = tree match
425+
case Apply(Select(qual, nme.CONSTRUCTOR), _)
426+
if !ctx.phase.erasedTypes
427+
&& defn.isSpecializedTuple(qual.typeOpt.typeSymbol) =>
428+
promote(tree) // e.g. `new Tuple2$mcII$sp(7, 8)` should keep its `(7, 8)` type instead of `Tuple2$mcII$sp`
429+
case _ => super.typedApply(tree, pt)
430+
425431
override def typedTyped(tree: untpd.Typed, pt: Type)(using Context): Tree =
426432
val tpt1 = checkSimpleKinded(typedType(tree.tpt))
427433
val expr1 = tree.expr match

0 commit comments

Comments
 (0)