Skip to content

Commit f95b57c

Browse files
authored
New footprint calculation scheme (#19639)
Since match type reduction is expensive, it is cached. If a type is reduced (or not reduced) in `tryNormalize` we remember that decision and return the previous result - unless something in the context changed since the last attempt which could lead to a different outcome. Relevant here are: - constraints (regular and GADT) over type parameters - instantations of type variables We keep track of these things in a so-called footprint calculation. The old calculation clearly did not work. It either never worked or was broken by the changes to matchtype reduction. I now changed it to a more straightforward scheme that computes the footprint directly instead of relying on TypeComparer to produce the right trace.
2 parents b160bbb + d6ba9b2 commit f95b57c

File tree

3 files changed

+165
-57
lines changed

3 files changed

+165
-57
lines changed

Diff for: compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+13-35
Original file line numberDiff line numberDiff line change
@@ -3054,7 +3054,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
30543054
end provablyDisjointTypeArgs
30553055

30563056
protected def explainingTypeComparer(short: Boolean) = ExplainingTypeComparer(comparerContext, short)
3057-
protected def trackingTypeComparer = TrackingTypeComparer(comparerContext)
3057+
protected def matchReducer = MatchReducer(comparerContext)
30583058

30593059
private def inSubComparer[T, Cmp <: TypeComparer](comparer: Cmp)(op: Cmp => T): T =
30603060
val saved = myInstance
@@ -3068,8 +3068,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
30683068
inSubComparer(cmp)(op)
30693069
cmp.lastTrace(header)
30703070

3071-
def tracked[T](op: TrackingTypeComparer => T)(using Context): T =
3072-
inSubComparer(trackingTypeComparer)(op)
3071+
def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
3072+
inSubComparer(matchReducer)(op)
30733073
}
30743074

30753075
object TypeComparer {
@@ -3236,14 +3236,14 @@ object TypeComparer {
32363236
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean = false)(using Context): String =
32373237
comparing(_.explained(op, header, short))
32383238

3239-
def tracked[T](op: TrackingTypeComparer => T)(using Context): T =
3240-
comparing(_.tracked(op))
3239+
def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
3240+
comparing(_.reduceMatchWith(op))
32413241

32423242
def subCaptures(refs1: CaptureSet, refs2: CaptureSet, frozen: Boolean)(using Context): CaptureSet.CompareResult =
32433243
comparing(_.subCaptures(refs1, refs2, frozen))
32443244
}
32453245

3246-
object TrackingTypeComparer:
3246+
object MatchReducer:
32473247
import printing.*, Texts.*
32483248
enum MatchResult extends Showable:
32493249
case Reduced(tp: Type)
@@ -3259,38 +3259,16 @@ object TrackingTypeComparer:
32593259
case Stuck => "Stuck"
32603260
case NoInstance(fails) => "NoInstance(" ~ Text(fails.map(p.toText(_) ~ p.toText(_)), ", ") ~ ")"
32613261

3262-
class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
3263-
import TrackingTypeComparer.*
3262+
/** A type comparer for reducing match types.
3263+
* TODO: Not sure this needs to be a type comparer. Can we make it a
3264+
* separate class?
3265+
*/
3266+
class MatchReducer(initctx: Context) extends TypeComparer(initctx) {
3267+
import MatchReducer.*
32643268

32653269
init(initctx)
32663270

3267-
override def trackingTypeComparer = this
3268-
3269-
val footprint: mutable.Set[Type] = mutable.Set[Type]()
3270-
3271-
override def bounds(param: TypeParamRef)(using Context): TypeBounds = {
3272-
if (param.binder `ne` caseLambda) footprint += param
3273-
super.bounds(param)
3274-
}
3275-
3276-
override def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Boolean = {
3277-
if (param.binder `ne` caseLambda) footprint += param
3278-
super.addOneBound(param, bound, isUpper)
3279-
}
3280-
3281-
override def gadtBounds(sym: Symbol)(using Context): TypeBounds | Null = {
3282-
if (sym.exists) footprint += sym.typeRef
3283-
super.gadtBounds(sym)
3284-
}
3285-
3286-
override def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean =
3287-
if (sym.exists) footprint += sym.typeRef
3288-
super.gadtAddBound(sym, b, isUpper)
3289-
3290-
override def typeVarInstance(tvar: TypeVar)(using Context): Type = {
3291-
footprint += tvar
3292-
super.typeVarInstance(tvar)
3293-
}
3271+
override def matchReducer = this
32943272

32953273
def matchCases(scrut: Type, cases: List[MatchTypeCaseSpec])(using Context): Type = {
32963274
// a reference for the type parameters poisoned during matching

Diff for: compiler/src/dotty/tools/dotc/core/Types.scala

+53-22
Original file line numberDiff line numberDiff line change
@@ -5009,6 +5009,8 @@ object Types extends TypeUtils {
50095009
case ex: Throwable =>
50105010
handleRecursive("normalizing", s"${scrutinee.show} match ..." , ex)
50115011

5012+
private def thisMatchType = this
5013+
50125014
def reduced(using Context): Type = {
50135015

50145016
def contextInfo(tp: Type): Type = tp match {
@@ -5023,16 +5025,43 @@ object Types extends TypeUtils {
50235025
tp.underlying
50245026
}
50255027

5026-
def updateReductionContext(footprint: collection.Set[Type]): Unit =
5027-
reductionContext = util.HashMap()
5028-
for (tp <- footprint)
5029-
reductionContext(tp) = contextInfo(tp)
5030-
typr.println(i"footprint for $this $hashCode: ${footprint.toList.map(x => (x, contextInfo(x)))}%, %")
5031-
50325028
def isUpToDate: Boolean =
5033-
reductionContext.keysIterator.forall { tp =>
5029+
reductionContext.keysIterator.forall: tp =>
50345030
reductionContext(tp) `eq` contextInfo(tp)
5035-
}
5031+
5032+
def setReductionContext(): Unit =
5033+
new TypeTraverser:
5034+
var footprint: Set[Type] = Set()
5035+
var deep: Boolean = true
5036+
val seen = util.HashSet[Type]()
5037+
def traverse(tp: Type) =
5038+
if !seen.contains(tp) then
5039+
seen += tp
5040+
tp match
5041+
case tp: NamedType =>
5042+
if tp.symbol.is(TypeParam) then footprint += tp
5043+
traverseChildren(tp)
5044+
case _: AppliedType | _: RefinedType =>
5045+
if deep then traverseChildren(tp)
5046+
case TypeBounds(lo, hi) =>
5047+
traverse(hi)
5048+
case tp: TypeVar =>
5049+
footprint += tp
5050+
traverse(tp.underlying)
5051+
case tp: TypeParamRef =>
5052+
footprint += tp
5053+
case _ =>
5054+
traverseChildren(tp)
5055+
end traverse
5056+
5057+
traverse(scrutinee)
5058+
deep = false
5059+
cases.foreach(traverse)
5060+
reductionContext = util.HashMap()
5061+
for tp <- footprint do
5062+
reductionContext(tp) = contextInfo(tp)
5063+
matchTypes.println(i"footprint for $thisMatchType $hashCode: ${footprint.toList.map(x => (x, contextInfo(x)))}%, %")
5064+
end setReductionContext
50365065

50375066
record("MatchType.reduce called")
50385067
if !Config.cacheMatchReduced
@@ -5043,20 +5072,22 @@ object Types extends TypeUtils {
50435072
record("MatchType.reduce computed")
50445073
if (myReduced != null) record("MatchType.reduce cache miss")
50455074
myReduced =
5046-
trace(i"reduce match type $this $hashCode", matchTypes, show = true)(withMode(Mode.Type) {
5047-
def matchCases(cmp: TrackingTypeComparer): Type =
5048-
val saved = ctx.typerState.snapshot()
5049-
try cmp.matchCases(scrutinee.normalized, cases.map(MatchTypeCaseSpec.analyze(_)))
5050-
catch case ex: Throwable =>
5051-
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
5052-
finally
5053-
updateReductionContext(cmp.footprint)
5054-
ctx.typerState.resetTo(saved)
5055-
// this drops caseLambdas in constraint and undoes any typevar
5056-
// instantiations during matchtype reduction
5057-
5058-
TypeComparer.tracked(matchCases)
5059-
})
5075+
trace(i"reduce match type $this $hashCode", matchTypes, show = true):
5076+
withMode(Mode.Type):
5077+
setReductionContext()
5078+
def matchCases(cmp: MatchReducer): Type =
5079+
val saved = ctx.typerState.snapshot()
5080+
try
5081+
cmp.matchCases(scrutinee.normalized, cases.map(MatchTypeCaseSpec.analyze(_)))
5082+
catch case ex: Throwable =>
5083+
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
5084+
finally
5085+
ctx.typerState.resetTo(saved)
5086+
// this drops caseLambdas in constraint and undoes any typevar
5087+
// instantiations during matchtype reduction
5088+
TypeComparer.reduceMatchWith(matchCases)
5089+
5090+
//else println(i"no change for $this $hashCode / $myReduced")
50605091
myReduced.nn
50615092
}
50625093

Diff for: tests/pos/bad-footprint.scala

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
2+
object NamedTuple:
3+
4+
opaque type AnyNamedTuple = Any
5+
opaque type NamedTuple[N <: Tuple, +V <: Tuple] >: V <: AnyNamedTuple = V
6+
7+
export NamedTupleDecomposition.{Names, DropNames}
8+
9+
/** The type of the named tuple `X` mapped with the type-level function `F`.
10+
* If `X = (n1 : T1, ..., ni : Ti)` then `Map[X, F] = `(n1 : F[T1], ..., ni : F[Ti])`.
11+
*/
12+
type Map[X <: AnyNamedTuple, F[_ <: Tuple.Union[DropNames[X]]]] =
13+
NamedTuple[Names[X], Tuple.Map[DropNames[X], F]]
14+
15+
end NamedTuple
16+
17+
object NamedTupleDecomposition:
18+
import NamedTuple.*
19+
20+
/** The names of a named tuple, represented as a tuple of literal string values. */
21+
type Names[X <: AnyNamedTuple] <: Tuple = X match
22+
case NamedTuple[n, _] => n
23+
24+
/** The value types of a named tuple represented as a regular tuple. */
25+
type DropNames[NT <: AnyNamedTuple] <: Tuple = NT match
26+
case NamedTuple[_, x] => x
27+
end NamedTupleDecomposition
28+
29+
class Expr[Result]
30+
31+
object Expr:
32+
import NamedTuple.{NamedTuple, AnyNamedTuple}
33+
34+
type Of[A] = Expr[A]
35+
36+
type StripExpr[E] = E match
37+
case Expr.Of[b] => b
38+
39+
case class Ref[A]($name: String = "") extends Expr.Of[A]
40+
41+
case class Join[A <: AnyNamedTuple](a: A)
42+
extends Expr.Of[NamedTuple.Map[A, StripExpr]]
43+
end Expr
44+
45+
trait Query[A]
46+
47+
object Query:
48+
// Extension methods to support for-expression syntax for queries
49+
extension [R](x: Query[R])
50+
def map[B](f: Expr.Ref[R] => Expr.Of[B]): Query[B] = ???
51+
52+
case class City(zipCode: Int, name: String, population: Int)
53+
54+
object Test:
55+
import Expr.StripExpr
56+
import NamedTuple.{NamedTuple, AnyNamedTuple}
57+
58+
val cities: Query[City] = ???
59+
val q6 =
60+
cities.map: city =>
61+
val x: NamedTuple[
62+
("name", "zipCode"),
63+
(Expr.Of[String], Expr.Of[Int])] = ???
64+
Expr.Join(x)
65+
66+
/* Was error:
67+
68+
-- [E007] Type Mismatch Error: bad-footprint.scala:60:16 -----------------------
69+
60 | cities.map: city =>
70+
| ^
71+
|Found: Expr.Ref[City] =>
72+
| Expr[
73+
| NamedTuple.NamedTuple[(("name" : String), ("zipCode" : String)), (String,
74+
| Int)]
75+
| ]
76+
|Required: Expr.Ref[City] =>
77+
| Expr[
78+
| NamedTuple.NamedTuple[
79+
| NamedTupleDecomposition.Names[
80+
| NamedTuple.NamedTuple[(("name" : String), ("zipCode" : String)), (
81+
| Expr[String], Expr[Int])]
82+
| ],
83+
| Tuple.Map[
84+
| NamedTupleDecomposition.DropNames[
85+
| NamedTuple.NamedTuple[(("name" : String), ("zipCode" : String)), (
86+
| Expr[String], Expr[Int])]
87+
| ],
88+
| Expr.StripExpr]
89+
| ]
90+
| ]
91+
61 | val x: NamedTuple[
92+
62 | ("name", "zipCode"),
93+
63 | (Expr.Of[String], Expr.Of[Int])] = ???
94+
64 | Expr.Join(x)
95+
|
96+
| longer explanation available when compiling with `-explain`
97+
1 error found
98+
99+
*/

0 commit comments

Comments
 (0)