Skip to content

Commit 457a463

Browse files
authored
Improve GADT reasoning for pattern alternatives (#23205)
fixes #22882. Previously, for a pattern alternative: ``` scrutinee match case P1(...) | P2(...) => ... ``` we ignore GADT constraints derived from it, which is too restrictive. Now we try to find a GADT constraint that is subsumed by all and use it, and only ignore GADT constraints when such a constraint cannot be found.
2 parents e8e3903 + ff3df3c commit 457a463

9 files changed

+160
-4
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import cc.*
2626
import Capabilities.Capability
2727
import NameKinds.WildcardParamName
2828
import MatchTypes.isConcrete
29+
import scala.util.boundary, boundary.break
2930

3031
/** Provides methods to compare types.
3132
*/
@@ -2090,6 +2091,45 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20902091
else op2
20912092
end necessaryEither
20922093

2094+
/** Finds the necessary (the weakest) GADT constraint among a list of them.
2095+
* It returns the one being subsumed by all others if exists, and `None` otherwise.
2096+
*
2097+
* This is used when typechecking pattern alternatives, for instance:
2098+
*
2099+
* enum Expr[+T]:
2100+
* case I1(x: Int) extends Expr[Int]
2101+
* case I2(x: Int) extends Expr[Int]
2102+
* case B(x: Boolean) extends Expr[Boolean]
2103+
* import Expr.*
2104+
*
2105+
* The following function should compile:
2106+
*
2107+
* def foo[T](e: Expr[T]): T = e match
2108+
* case I1(_) | I2(_) => 42
2109+
*
2110+
* since `T >: Int` is subsumed by both alternatives in the first match clause.
2111+
*
2112+
* However, the following should not:
2113+
*
2114+
* def foo[T](e: Expr[T]): T = e match
2115+
* case I1(_) | B(_) => 42
2116+
*
2117+
* since the `I1(_)` case gives the constraint `T >: Int` while `B(_)` gives `T >: Boolean`.
2118+
* Neither of the constraints is subsumed by the other.
2119+
*/
2120+
def necessaryGadtConstraint(constrs: List[GadtConstraint], preGadt: GadtConstraint)(using Context): Option[GadtConstraint] = boundary:
2121+
constrs match
2122+
case Nil => break(None)
2123+
case c0 :: constrs =>
2124+
var weakest = c0
2125+
for c <- constrs do
2126+
if subsumes(weakest.constraint, c.constraint, preGadt.constraint) then
2127+
weakest = c
2128+
else if !subsumes(c.constraint, weakest.constraint, preGadt.constraint) then
2129+
// this two constraints are disjoint
2130+
break(None)
2131+
break(Some(weakest))
2132+
20932133
inline def rollbackConstraintsUnless(inline op: Boolean): Boolean =
20942134
val saved = constraint
20952135
var result = false
@@ -3449,6 +3489,9 @@ object TypeComparer {
34493489
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean =
34503490
comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement))
34513491

3492+
def necessaryGadtConstraint(constrs: List[GadtConstraint], preGadt: GadtConstraint)(using Context): Option[GadtConstraint] =
3493+
comparing(_.necessaryGadtConstraint(constrs, preGadt))
3494+
34523495
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean = false)(using Context): String =
34533496
comparing(_.explained(op, header, short))
34543497

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,10 +2827,20 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28272827
else
28282828
assert(ctx.reporter.errorsReported)
28292829
tree.withType(defn.AnyType)
2830-
val savedGadt = nestedCtx.gadt
2831-
val trees1 = tree.trees.mapconserve(typed(_, pt)(using nestedCtx))
2830+
val preGadt = nestedCtx.gadt
2831+
var gadtConstrs: mutable.ArrayBuffer[GadtConstraint] = mutable.ArrayBuffer.empty
2832+
val trees1 = tree.trees.mapconserve: t =>
2833+
nestedCtx.gadtState.restore(preGadt)
2834+
val res = typed(t, pt)(using nestedCtx)
2835+
gadtConstrs += nestedCtx.gadt
2836+
res
28322837
.mapconserve(ensureValueTypeOrWildcard)
2833-
nestedCtx.gadtState.restore(savedGadt) // Disable GADT reasoning for pattern alternatives
2838+
// Look for the necessary constraint that is subsumed by all alternatives.
2839+
// Use that constraint as the outcome if possible, otherwise fallback to not using
2840+
// GADT reasoning for soundness.
2841+
TypeComparer.necessaryGadtConstraint(gadtConstrs.toList, preGadt) match
2842+
case Some(constr) => nestedCtx.gadtState.restore(constr)
2843+
case None => nestedCtx.gadtState.restore(preGadt)
28342844
assignType(cpy.Alternative(tree)(trees1), trees1)
28352845
}
28362846

tests/neg/gadt-alt-expr1.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
enum Expr[+T]:
2+
case I1() extends Expr[Int]
3+
case I2() extends Expr[Int]
4+
case B() extends Expr[Boolean]
5+
import Expr.*
6+
def foo[T](e: Expr[T]): T =
7+
e match
8+
case I1() | I2() => 42 // ok
9+
case B() => true
10+
def bar[T](e: Expr[T]): T =
11+
e match
12+
case I1() | B() => 42 // error
13+
case I2() => 0

tests/neg/gadt-alt-expr2.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
enum Expr[+T]:
2+
case I1() extends Expr[Int]
3+
case I2() extends Expr[Int]
4+
case I3() extends Expr[Int]
5+
case I4() extends Expr[Int]
6+
case I5() extends Expr[Int]
7+
case B() extends Expr[Boolean]
8+
import Expr.*
9+
def test1[T](e: Expr[T]): T =
10+
e match
11+
case I1() | I2() | I3() | I4() | I5() => 42 // ok
12+
case B() => true
13+
def test2[T](e: Expr[T]): T =
14+
e match
15+
case I1() | I2() | I3() | I4() | I5() | B() => 42 // error

tests/neg/gadt-alt-expr3.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
trait A
2+
trait B extends A
3+
trait C extends B
4+
trait D
5+
enum Expr[+T]:
6+
case IsA() extends Expr[A]
7+
case IsB() extends Expr[B]
8+
case IsC() extends Expr[C]
9+
case IsD() extends Expr[D]
10+
import Expr.*
11+
def test1[T](e: Expr[T]): T = e match
12+
case IsA() => new A {}
13+
case IsB() => new B {}
14+
case IsC() => new C {}
15+
def test2[T](e: Expr[T]): T = e match
16+
case IsA() | IsB() =>
17+
// IsA() implies T >: A
18+
// IsB() implies T >: B
19+
// So T >: B is chosen
20+
new B {}
21+
case IsC() => new C {}
22+
def test3[T](e: Expr[T]): T = e match
23+
case IsA() | IsB() | IsC() =>
24+
// T >: C is chosen
25+
new C {}
26+
def test4[T](e: Expr[T]): T = e match
27+
case IsA() | IsB() | IsC() =>
28+
new B {} // error
29+
def test5[T](e: Expr[T]): T = e match
30+
case IsA() | IsB() =>
31+
new A {} // error
32+
def test6[T](e: Expr[T]): T = e match
33+
case IsA() | IsC() | IsD() =>
34+
new C {} // error

tests/neg/gadt-alt-expr4.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
trait A
2+
trait B extends A
3+
trait C extends B
4+
enum Expr[T]:
5+
case IsA() extends Expr[A]
6+
case IsB() extends Expr[B]
7+
case IsC() extends Expr[C]
8+
import Expr.*
9+
def test1[T](e: Expr[T]): T = e match
10+
case IsA() => new A {}
11+
case IsB() => new B {}
12+
case IsC() => new C {}
13+
def test2[T](e: Expr[T]): T = e match
14+
case IsA() | IsB() =>
15+
// IsA() implies T =:= A
16+
// IsB() implies T =:= B
17+
// No necessary constraint can be found
18+
new B {} // error
19+
case IsC() => new C {}

tests/neg/gadt-alt-expr5.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
trait A
2+
trait B extends A
3+
trait C extends B
4+
enum Expr[-T]:
5+
case IsA() extends Expr[A]
6+
case IsB() extends Expr[B]
7+
case IsC() extends Expr[C]
8+
import Expr.*
9+
def test1[T](e: Expr[T]): Unit = e match
10+
case IsA() | IsB() =>
11+
val t1: T = ???
12+
val t2: A = t1
13+
val t3: B = t1 // error
14+
case IsC() =>

tests/neg/gadt-alternatives.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ import Expr.*
66
def eval[T](e: Expr[T]): T = e match
77
case StringVal(_) | IntVal(_) => "42" // error
88
def eval1[T](e: Expr[T]): T = e match
9-
case IntValAlt(_) | IntVal(_) => 42 // error // limitation
9+
case IntValAlt(_) | IntVal(_) => 42 // previously error, now ok

tests/pos/gadt-alt-doc1.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
trait Document[Doc <: Document[Doc]]
2+
sealed trait Conversion[Doc, V]
3+
4+
case class C[Doc <: Document[Doc]]() extends Conversion[Doc, Doc]
5+
6+
def Test[Doc <: Document[Doc], V](conversion: Conversion[Doc, V]) =
7+
conversion match
8+
case C() | C() => ??? // error

0 commit comments

Comments
 (0)