Skip to content

Commit 98efdab

Browse files
authored
Add GADT symbols when typing typing-ahead lambda bodies (#19644)
2 parents 119bc33 + 2c81588 commit 98efdab

File tree

4 files changed

+75
-9
lines changed

4 files changed

+75
-9
lines changed

Diff for: compiler/src/dotty/tools/dotc/typer/Namer.scala

+14-9
Original file line numberDiff line numberDiff line change
@@ -1738,8 +1738,9 @@ class Namer { typer: Typer =>
17381738
val tpe = (paramss: @unchecked) match
17391739
case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams)
17401740
case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams)
1741+
val rhsCtx = prepareRhsCtx(ctx.fresh, paramss)
17411742
if (isFullyDefined(tpe, ForceDegree.none)) tpe
1742-
else typedAheadExpr(mdef.rhs, tpe).tpe
1743+
else typedAheadExpr(mdef.rhs, tpe)(using rhsCtx).tpe
17431744

17441745
case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
17451746
mdef match {
@@ -1937,14 +1938,7 @@ class Namer { typer: Typer =>
19371938
var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType)
19381939
if sym.isInlineMethod then rhsCtx = rhsCtx.addMode(Mode.InlineableBody)
19391940
if sym.is(ExtensionMethod) then rhsCtx = rhsCtx.addMode(Mode.InExtensionMethod)
1940-
val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten
1941-
if (typeParams.nonEmpty) {
1942-
// we'll be typing an expression from a polymorphic definition's body,
1943-
// so we must allow constraining its type parameters
1944-
// compare with typedDefDef, see tests/pos/gadt-inference.scala
1945-
rhsCtx.setFreshGADTBounds
1946-
rhsCtx.gadtState.addToConstraint(typeParams)
1947-
}
1941+
rhsCtx = prepareRhsCtx(rhsCtx, paramss)
19481942

19491943
def typedAheadRhs(pt: Type) =
19501944
PrepareInlineable.dropInlineIfError(sym,
@@ -1989,4 +1983,15 @@ class Namer { typer: Typer =>
19891983
lhsType orElse WildcardType
19901984
}
19911985
end inferredResultType
1986+
1987+
/** Prepare a GADT-aware context used to type the RHS of a ValOrDefDef. */
1988+
def prepareRhsCtx(rhsCtx: FreshContext, paramss: List[List[Symbol]])(using Context): FreshContext =
1989+
val typeParams = paramss.collect { case TypeSymbols(tparams) => tparams }.flatten
1990+
if typeParams.nonEmpty then
1991+
// we'll be typing an expression from a polymorphic definition's body,
1992+
// so we must allow constraining its type parameters
1993+
// compare with typedDefDef, see tests/pos/gadt-inference.scala
1994+
rhsCtx.setFreshGADTBounds
1995+
rhsCtx.gadtState.addToConstraint(typeParams)
1996+
rhsCtx
19921997
}

Diff for: tests/pos/i19570.min1.scala

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
enum Op[A]:
2+
case Dup[T]() extends Op[(T, T)]
3+
4+
def foo[R](f: [A] => Op[A] => R): R = ???
5+
6+
def test =
7+
foo([A] => (o: Op[A]) => o match
8+
case o: Op.Dup[u] =>
9+
summon[A =:= (u, u)] // Error: Cannot prove that A =:= (u, u)
10+
()
11+
)
12+
foo[Unit]([A] => (o: Op[A]) => o match
13+
case o: Op.Dup[u] =>
14+
summon[A =:= (u, u)] // Ok
15+
()
16+
)
17+
foo({
18+
val f1 = [B] => (o: Op[B]) => o match
19+
case o: Op.Dup[u] =>
20+
summon[B =:= (u, u)] // Also ok
21+
()
22+
f1
23+
})

Diff for: tests/pos/i19570.min2.scala

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
sealed trait Op[A, B] { def giveA: A; def giveB: B }
2+
final case class Dup[T](x: T) extends Op[T, (T, T)] { def giveA: T = x; def giveB: (T, T) = (x, x) }
3+
4+
class Test:
5+
def foo[R](f: [A, B] => (o: Op[A, B]) => R): R = ???
6+
7+
def m1: Unit =
8+
foo([A, B] => (o: Op[A, B]) => o match
9+
case o: Dup[t] =>
10+
var a1: t = o.giveA
11+
var a2: A = o.giveA
12+
a1 = a2
13+
a2 = a1
14+
15+
var b1: (t, t) = o.giveB
16+
var b2: B = o.giveB
17+
b1 = b2
18+
b2 = b1
19+
20+
summon[A =:= t] // ERROR: Cannot prove that A =:= t.
21+
summon[B =:= (t, t)] // ERROR: Cannot prove that B =:= (t, t).
22+
23+
()
24+
)

Diff for: tests/pos/i19570.orig.scala

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
enum Op[A, B]:
2+
case Dup[T]() extends Op[T, (T, T)]
3+
4+
def foo[R](f: [A, B] => (o: Op[A, B]) => R): R =
5+
f(Op.Dup())
6+
7+
def test =
8+
foo([A, B] => (o: Op[A, B]) => {
9+
o match
10+
case o: Op.Dup[t] =>
11+
summon[A =:= t] // ERROR: Cannot prove that A =:= t.
12+
summon[B =:= (t, t)] // ERROR: Cannot prove that B =:= (t, t).
13+
42
14+
})

0 commit comments

Comments
 (0)