Skip to content

Commit d2153ed

Browse files
committed
Reimplement support for type aliases in SAM types
This was dropped in #18201 which restricted SAM types to valid parent types, but it turns out that there is code in the wild that relies on refinements being allowed here. To support this properly, we had to enhance ExpandSAMs to move refinements into type members to pass Ycheck (previous Scala 3 releases would accept the code in tests/run/i18315.scala but fail Ycheck). Fixes #18315.
1 parent 082dc6f commit d2153ed

File tree

4 files changed

+84
-45
lines changed

4 files changed

+84
-45
lines changed

Diff for: compiler/src/dotty/tools/dotc/ast/tpd.scala

+13-10
Original file line numberDiff line numberDiff line change
@@ -349,24 +349,27 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
349349

350350
/** An anonymous class
351351
*
352-
* new parents { forwarders }
352+
* new parents { termForwarders; typeAliases }
353353
*
354-
* where `forwarders` contains forwarders for all functions in `fns`.
355-
* @param parents a non-empty list of class types
356-
* @param fns a non-empty of functions for which forwarders should be defined in the class.
357-
* The class has the same owner as the first function in `fns`.
358-
* Its position is the union of all functions in `fns`.
354+
* @param parents a non-empty list of class types
355+
* @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to.
356+
* @param typeMembers a possibly-empty list of type members specified by their name and their right hand side.
357+
*
358+
* The class has the same owner as the first function in `termForwarders`.
359+
* Its position is the union of all symbols in `termForwarders`.
359360
*/
360-
def AnonClass(parents: List[Type], fns: List[TermSymbol], methNames: List[TermName])(using Context): Block = {
361-
AnonClass(fns.head.owner, parents, fns.map(_.span).reduceLeft(_ union _)) { cls =>
362-
def forwarder(fn: TermSymbol, name: TermName) = {
361+
def AnonClass(parents: List[Type], termForwarders: List[(TermName, TermSymbol)],
362+
typeMembers: List[(TypeName, TypeBounds)] = Nil)(using Context): Block = {
363+
AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _)) { cls =>
364+
def forwarder(name: TermName, fn: TermSymbol) = {
363365
val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm
364366
for overridden <- fwdMeth.allOverriddenSymbols do
365367
if overridden.is(Extension) then fwdMeth.setFlag(Extension)
366368
if !overridden.is(Deferred) then fwdMeth.setFlag(Override)
367369
DefDef(fwdMeth, ref(fn).appliedToArgss(_))
368370
}
369-
fns.lazyZip(methNames).map(forwarder)
371+
termForwarders.map((name, sym) => forwarder(name, sym)) ++
372+
typeMembers.map((name, info) => TypeDef(newSymbol(cls, name, Synthetic, info).entered))
370373
}
371374
}
372375

Diff for: compiler/src/dotty/tools/dotc/core/Types.scala

+38-21
Original file line numberDiff line numberDiff line change
@@ -5536,13 +5536,16 @@ object Types {
55365536
* and PolyType not allowed!) according to `possibleSamMethods`.
55375537
* - can be instantiated without arguments or with just () as argument.
55385538
*
5539+
* Additionally, a SAM type may contain type aliases refinements if they refine
5540+
* an existing type member.
5541+
*
55395542
* The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
55405543
* type of the single abstract method and `samParent` is a subtype of the matched
55415544
* SAM type which has been stripped of wildcards to turn it into a valid parent
55425545
* type.
55435546
*/
55445547
object SAMType {
5545-
/** If possible, return a type which is both a subtype of `origTp` and a type
5548+
/** If possible, return a type which is both a subtype of `origTp` and a (possibly refined) type
55465549
* application of `samClass` where none of the type arguments are
55475550
* wildcards (thus making it a valid parent type), otherwise return
55485551
* NoType.
@@ -5572,27 +5575,41 @@ object Types {
55725575
* we arbitrarily pick the upper-bound.
55735576
*/
55745577
def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type =
5575-
val tp = origTp.baseType(samClass)
5578+
val tp0 = origTp.baseType(samClass)
5579+
5580+
/** Copy type aliases refinements to `toTp` from `fromTp` */
5581+
def withRefinements(toType: Type, fromTp: Type): Type = fromTp.dealias match
5582+
case RefinedType(fromParent, name, info: TypeAlias) if tp0.member(name).exists =>
5583+
val parent1 = withRefinements(toType, fromParent)
5584+
RefinedType(toType, name, info)
5585+
case _ => toType
5586+
val tp = withRefinements(tp0, origTp)
5587+
55765588
if !(tp <:< origTp) then NoType
5577-
else tp match
5578-
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
5579-
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
5580-
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
5581-
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
5582-
vmap.recordLocalVariance(tp.symbol, variance)
5583-
case _ =>
5584-
foldOver(vmap, t)
5585-
val vmap = accu(VarianceMap.empty, samMeth.info)
5586-
val tparams = tycon.typeParamSymbols
5587-
val args1 = args.zipWithConserve(tparams):
5588-
case (arg @ TypeBounds(lo, hi), tparam) =>
5589-
val v = vmap.computedVariance(tparam)
5590-
if v.uncheckedNN < 0 then lo
5591-
else hi
5592-
case (arg, _) => arg
5593-
tp.derivedAppliedType(tycon, args1)
5594-
case _ =>
5595-
tp
5589+
else
5590+
def approxWildcardArgs(tp: Type): Type = tp match
5591+
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
5592+
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
5593+
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
5594+
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
5595+
vmap.recordLocalVariance(tp.symbol, variance)
5596+
case _ =>
5597+
foldOver(vmap, t)
5598+
val vmap = accu(VarianceMap.empty, samMeth.info)
5599+
val tparams = tycon.typeParamSymbols
5600+
val args1 = args.zipWithConserve(tparams):
5601+
case (arg @ TypeBounds(lo, hi), tparam) =>
5602+
val v = vmap.computedVariance(tparam)
5603+
if v.uncheckedNN < 0 then lo
5604+
else hi
5605+
case (arg, _) => arg
5606+
tp.derivedAppliedType(tycon, args1)
5607+
case tp @ RefinedType(parent, name, info) =>
5608+
tp.derivedRefinedType(approxWildcardArgs(parent), name, info)
5609+
case _ =>
5610+
tp
5611+
approxWildcardArgs(tp)
5612+
end samParent
55965613

55975614
def samClass(tp: Type)(using Context): Symbol = tp match
55985615
case tp: ClassInfo =>

Diff for: compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

+18-14
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import core._
66
import Scopes.newScope
77
import Contexts._, Symbols._, Types._, Flags._, Decorators._, StdNames._, Constants._
88
import MegaPhase._
9+
import Names.TypeName
910
import SymUtils._
1011
import NullOpsDecorator._
1112
import ast.untpd
@@ -51,16 +52,28 @@ class ExpandSAMs extends MiniPhase:
5152
case tpe if defn.isContextFunctionType(tpe) =>
5253
tree
5354
case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) =>
54-
val tpe1 = checkRefinements(tpe, fn)
55-
toPartialFunction(tree, tpe1)
55+
toPartialFunction(tree, tpe)
5656
case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
57-
checkRefinements(tpe, fn)
5857
tree
5958
case tpe =>
60-
val tpe1 = checkRefinements(tpe.stripNull, fn)
59+
// A SAM type is allowed to have type aliases refinements (see
60+
// SAMType#samParent) which must be converted into type members if
61+
// the closure is desugared into a class.
62+
val refinements = collection.mutable.ListBuffer[(TypeName, TypeAlias)]()
63+
def collectAndStripRefinements(tp: Type): Type = tp match
64+
case RefinedType(parent, name, info: TypeAlias) =>
65+
val res = collectAndStripRefinements(parent)
66+
refinements += ((name.asTypeName, info))
67+
res
68+
case _ => tp
69+
val tpe1 = collectAndStripRefinements(tpe)
6170
val Seq(samDenot) = tpe1.possibleSamMethods
6271
cpy.Block(tree)(stats,
63-
AnonClass(tpe1 :: Nil, fn.symbol.asTerm :: Nil, samDenot.symbol.asTerm.name :: Nil))
72+
AnonClass(List(tpe1),
73+
List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm),
74+
refinements.toList
75+
)
76+
)
6477
}
6578
case _ =>
6679
tree
@@ -171,13 +184,4 @@ class ExpandSAMs extends MiniPhase:
171184
List(isDefinedAtDef, applyOrElseDef)
172185
}
173186
}
174-
175-
private def checkRefinements(tpe: Type, tree: Tree)(using Context): Type = tpe.dealias match {
176-
case RefinedType(parent, name, _) =>
177-
if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement
178-
report.error(em"Lambda does not define $name", tree.srcPos)
179-
checkRefinements(parent, tree)
180-
case tpe =>
181-
tpe
182-
}
183187
end ExpandSAMs

Diff for: tests/run/i18315.scala

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
trait Sam1:
2+
type T
3+
def apply(x: T): T
4+
5+
trait Sam2:
6+
var x: Int = 1 // To force anonymous class generation
7+
type T
8+
def apply(x: T): T
9+
10+
object Test:
11+
def main(args: Array[String]): Unit =
12+
val s1: Sam1 { type T = String } = x => x.trim
13+
s1.apply("foo")
14+
val s2: Sam2 { type T = Int } = x => x + 1
15+
s2.apply(1)

0 commit comments

Comments
 (0)