Skip to content

Commit cce2933

Browse files
Instantiate argument type vars before implicit search (#19096)
1 parent 1716bcd commit cce2933

File tree

4 files changed

+70
-10
lines changed

4 files changed

+70
-10
lines changed

Diff for: compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+14-9
Original file line numberDiff line numberDiff line change
@@ -383,17 +383,20 @@ object Inferencing {
383383
def isSkolemFree(tp: Type)(using Context): Boolean =
384384
!tp.existsPart(_.isInstanceOf[SkolemType])
385385

386-
/** The list of uninstantiated type variables bound by some prefix of type `T` which
387-
* occur in at least one formal parameter type of a prefix application.
386+
/** The list of uninstantiated type variables bound by some prefix of type `T` or
387+
* by arguments of an application prefix, which occur at least once as a formal type parameter
388+
* of an application either from a prefix or an argument of an application node.
388389
* Considered prefixes are:
389390
* - The function `f` of an application node `f(e1, .., en)`
390391
* - The function `f` of a type application node `f[T1, ..., Tn]`
391392
* - The prefix `p` of a selection `p.f`.
392393
* - The result expression `e` of a block `{s1; .. sn; e}`.
393394
*/
394395
def tvarsInParams(tree: Tree, locked: TypeVars)(using Context): List[TypeVar] = {
395-
@tailrec def boundVars(tree: Tree, acc: List[TypeVar]): List[TypeVar] = tree match {
396-
case Apply(fn, _) => boundVars(fn, acc)
396+
def boundVars(tree: Tree, acc: List[TypeVar]): List[TypeVar] = tree match {
397+
case Apply(fn, args) =>
398+
val argTpVars = args.flatMap(boundVars(_, Nil))
399+
boundVars(fn, acc ++ argTpVars)
397400
case TypeApply(fn, targs) =>
398401
val tvars = targs.filter(_.isInstanceOf[InferredTypeTree]).tpes.collect {
399402
case tvar: TypeVar
@@ -406,16 +409,18 @@ object Inferencing {
406409
case Block(_, expr) => boundVars(expr, acc)
407410
case _ => acc
408411
}
409-
@tailrec def occurring(tree: Tree, toTest: List[TypeVar], acc: List[TypeVar]): List[TypeVar] =
412+
def occurring(tree: Tree, toTest: List[TypeVar], acc: List[TypeVar]): List[TypeVar] =
410413
if (toTest.isEmpty) acc
411414
else tree match {
412-
case Apply(fn, _) =>
415+
case Apply(fn, args) =>
416+
val argsOcc = args.flatMap(occurring(_, toTest, Nil))
417+
val argsNocc = toTest.filterNot(argsOcc.contains)
413418
fn.tpe.widen match {
414419
case mtp: MethodType =>
415-
val (occ, nocc) = toTest.partition(tvar => mtp.paramInfos.exists(tvar.occursIn))
416-
occurring(fn, nocc, occ ::: acc)
420+
val (occ, nocc) = argsNocc.partition(tvar => mtp.paramInfos.exists(tvar.occursIn))
421+
occurring(fn, nocc, occ ::: argsOcc ::: acc)
417422
case _ =>
418-
occurring(fn, toTest, acc)
423+
occurring(fn, argsNocc, argsOcc ::: acc)
419424
}
420425
case TypeApply(fn, targs) => occurring(fn, toTest, acc)
421426
case Select(pre, _) => occurring(pre, toTest, acc)

Diff for: tests/pos/i18578.scala

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
trait Animal
3+
class Dog extends Animal
4+
5+
trait Ev[T]
6+
7+
given Ev[Dog] = ???
8+
given Ev[Animal] = ???
9+
given[T: Ev]: Ev[Set[T]] = ???
10+
11+
def f[T: Ev](v: T): Unit = ???
12+
13+
def main =
14+
val s = Set(new Dog)
15+
f(s) // Ok
16+
f(Set(new Dog)) // Error before changes: Ambiguous given instances: both given instance given_Ev_Dog and given instance given_Ev_Animal match type Ev[T]
17+

Diff for: tests/pos/i7586.scala

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
trait Nat
3+
case object Z extends Nat
4+
case class S[N <: Nat](pred: N) extends Nat
5+
6+
type Z = Z.type
7+
given zero: Z = Z
8+
given succ[N <: Nat](using n: N): S[N] = S(n)
9+
10+
case class Sum[N <: Nat, M <: Nat, R <: Nat](result: R)
11+
12+
given sumZ[N <: Nat](using n: N): Sum[Z, N, N] = Sum(n)
13+
given sumS[N <: Nat, M <: Nat, R <: Nat](
14+
using sum: Sum[N, M, R]
15+
): Sum[S[N], M, S[R]] = Sum(S(sum.result))
16+
17+
def add[N <: Nat, M <: Nat, R <: Nat](n: N, m: M)(
18+
using sum: Sum[N, M, R]
19+
): R = sum.result
20+
21+
case class Prod[N <: Nat, M <: Nat, R <: Nat](result: R)
22+
23+
24+
@main def Test: Unit =
25+
26+
val n1: S[Z] = add(Z, S(Z))
27+
summon[n1.type <:< S[Z]] // OK
28+
29+
val n3: S[S[S[Z]]] = add(S(S(Z)), S(Z))
30+
summon[n3.type <:< S[S[S[Z]]]] // Ok
31+
32+
val m3_2 = add(S(Z), S(S(Z)))
33+
summon[m3_2.type <:< S[S[S[Z]]]] // Error before changes: Cannot prove that (m3_2 : S[S[Nat]]) <:< S[S[S[Z]]]
34+
35+
val m4_2 = add(S(Z), S(S(S(Z))))
36+
summon[m4_2.type <:< S[S[S[S[Z]]]]]
37+
38+

0 commit comments

Comments
 (0)