Skip to content

Commit de6a090

Browse files
authored
Approximate MatchTypes with lub of case bodies, if non-recursive (#19761)
2 parents 469c980 + d687dee commit de6a090

File tree

7 files changed

+65
-9
lines changed

7 files changed

+65
-9
lines changed

Diff for: compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+7
Original file line numberDiff line numberDiff line change
@@ -2857,6 +2857,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
28572857
tp
28582858
case tp: HKTypeLambda =>
28592859
tp
2860+
case tp: ParamRef =>
2861+
val st = tp.superTypeNormalized
2862+
if st.exists then
2863+
disjointnessBoundary(st)
2864+
else
2865+
// workaround for when ParamRef#underlying returns NoType
2866+
defn.AnyType
28602867
case tp: TypeProxy =>
28612868
disjointnessBoundary(tp.superTypeNormalized)
28622869
case tp: WildcardType =>

Diff for: compiler/src/dotty/tools/dotc/typer/Typer.scala

+9-1
Original file line numberDiff line numberDiff line change
@@ -2375,7 +2375,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
23752375
report.error(MatchTypeScrutineeCannotBeHigherKinded(sel1Tpe), sel1.srcPos)
23762376
val pt1 = if (bound1.isEmpty) pt else bound1.tpe
23772377
val cases1 = tree.cases.mapconserve(typedTypeCase(_, sel1Tpe, pt1))
2378-
assignType(cpy.MatchTypeTree(tree)(bound1, sel1, cases1), bound1, sel1, cases1)
2378+
val bound2 = if tree.bound.isEmpty then
2379+
val lub = cases1.foldLeft(defn.NothingType: Type): (acc, case1) =>
2380+
if !acc.exists then NoType
2381+
else if case1.body.tpe.isProvisional then NoType
2382+
else acc | case1.body.tpe
2383+
if lub.exists then TypeTree(lub, inferred = true)
2384+
else bound1
2385+
else bound1
2386+
assignType(cpy.MatchTypeTree(tree)(bound2, sel1, cases1), bound2, sel1, cases1)
23792387
}
23802388

23812389
def typedByNameTypeTree(tree: untpd.ByNameTypeTree)(using Context): ByNameTypeTree = tree.result match

Diff for: tests/pos/13633.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ object Sums extends App:
2121

2222
type Reverse[A] = ReverseLoop[A, EmptyTuple]
2323

24-
type PlusTri[A, B, C] = (A, B, C) match
24+
type PlusTri[A, B, C] <: Tuple = (A, B, C) match
2525
case (false, false, false) => (false, false)
2626
case (true, false, false) | (false, true, false) | (false, false, true) => (false, true)
2727
case (true, true, false) | (true, false, true) | (false, true, true) => (true, false)

Diff for: tests/pos/Tuple.Drop.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import compiletime.ops.int.*
2+
3+
type Drop[T <: Tuple, N <: Int] <: Tuple = N match
4+
case 0 => T
5+
case S[n1] => T match
6+
case EmptyTuple => EmptyTuple
7+
case x *: xs => Drop[xs, n1]

Diff for: tests/pos/Tuple.Elem.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import compiletime.ops.int.*
2+
3+
type Elem[T <: Tuple, I <: Int] = T match
4+
case h *: tail =>
5+
I match
6+
case 0 => h
7+
case S[j] => Elem[tail, j]

Diff for: tests/pos/i19710.scala

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import scala.util.NotGiven
2+
3+
type HasName1 = [n] =>> [x] =>> x match {
4+
case n => true
5+
case _ => false
6+
}
7+
@main def Test = {
8+
summon[HasName1["foo"]["foo"] =:= true]
9+
summon[NotGiven[HasName1["foo"]["bar"] =:= true]]
10+
summon[Tuple.Filter[(1, "foo", 2, "bar"), HasName1["foo"]] =:= Tuple1["foo"]] // error
11+
}

Diff for: tests/run-macros/type-show/Test_2.scala

+23-7
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,34 @@
11

22
object Test {
33
import TypeToolbox.*
4+
5+
def assertEql[A](obt: A, exp: A): Unit =
6+
assert(obt == exp, s"\nexpected: $exp\nobtained: $obt")
7+
48
def main(args: Array[String]): Unit = {
59
val x = 5
6-
assert(show[x.type] == "x.type")
7-
assert(show[Nil.type] == "scala.Nil.type")
8-
assert(show[Int] == "scala.Int")
9-
assert(show[Int => Int] == "scala.Function1[scala.Int, scala.Int]")
10-
assert(show[(Int, String)] == "scala.Tuple2[scala.Int, scala.Predef.String]")
11-
assert(show[[X] =>> X match { case Int => Int }] ==
10+
assertEql(show[x.type], "x.type")
11+
assertEql(show[Nil.type], "scala.Nil.type")
12+
assertEql(show[Int], "scala.Int")
13+
assertEql(show[Int => Int], "scala.Function1[scala.Int, scala.Int]")
14+
assertEql(show[(Int, String)], "scala.Tuple2[scala.Int, scala.Predef.String]")
15+
assertEql(show[[X] =>> X match { case Int => Int }],
1216
"""[X >: scala.Nothing <: scala.Any] =>> X match {
1317
| case scala.Int => scala.Int
1418
|}""".stripMargin)
15-
assert(showStructure[[X] =>> X match { case Int => Int }] == """TypeLambda(List(X), List(TypeBounds(TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Nothing"), TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Any"))), MatchType(TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Any"), ParamRef(binder, 0), List(MatchCase(TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int"), TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int")))))""")
19+
assertEql(showStructure[[X] =>> X match { case Int => Int }],
20+
"""TypeLambda("""+
21+
"""List(X), """+
22+
"""List(TypeBounds("""+
23+
"""TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Nothing"), """+
24+
"""TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Any"))), """+
25+
"""MatchType("""+
26+
"""TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int"), """+ // match type bound
27+
"""ParamRef(binder, 0), """+
28+
"""List("""+
29+
"""MatchCase("""+
30+
"""TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int"), """+
31+
"""TypeRef(TermRef(ThisType(TypeRef(NoPrefix(), "<root>")), "scala"), "Int")))))""")
1632

1733
// TODO: more complex types:
1834
// - implicit function types

0 commit comments

Comments
 (0)