diff --git a/src/dotty/tools/dotc/core/ConstraintHandling.scala b/src/dotty/tools/dotc/core/ConstraintHandling.scala index 544304e8a18d..6f0377a4d8de 100644 --- a/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -104,7 +104,7 @@ trait ConstraintHandling { up.forall(addOneBound(_, lo, isUpper = false)) } - protected final def isSubTypeWhenFrozen(tp1: Type, tp2: Type): Boolean = { + final def isSubTypeWhenFrozen(tp1: Type, tp2: Type): Boolean = { val saved = frozenConstraint frozenConstraint = true try isSubType(tp1, tp2) diff --git a/src/dotty/tools/dotc/typer/Inferencing.scala b/src/dotty/tools/dotc/typer/Inferencing.scala index 8df544dd62c6..0a76f45c5089 100644 --- a/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/src/dotty/tools/dotc/typer/Inferencing.scala @@ -17,6 +17,7 @@ import Decorators._ import Uniques._ import ErrorReporting.{errorType, DiagnosticString} import config.Printers._ +import annotation.tailrec import collection.mutable trait Inferencing { this: Checking => @@ -43,9 +44,26 @@ trait Inferencing { this: Checking => if (isFullyDefined(tp, ForceDegree.all)) tp else throw new Error(i"internal error: type of $what $tp is not fully defined, pos = $pos") // !!! DEBUG + + /** Instantiate selected type variables `tvars` in type `tp` */ + def instantiateSelected(tp: Type, tvars: List[Type])(implicit ctx: Context): Unit = + new IsFullyDefinedAccumulator(new ForceDegree.Value(tvars.contains)).process(tp) + /** The accumulator which forces type variables using the policy encoded in `force` - * and returns whether the type is fully defined. Two phases: - * 1st Phase: Try to instantiate covariant and non-variant type variables to + * and returns whether the type is fully defined. The direction in which + * a type variable is instantiated is determined as follows: + * 1. T is minimized if the constraint over T is only from below (i.e. + * constrained lower bound != given lower bound and + * constrained upper bound == given upper bound). + * 2. T is maximized if the constraint over T is only from above (i.e. + * constrained upper bound != given upper bound and + * constrained lower bound == given lower bound). + * If (1) and (2) do not apply: + * 3. T is maximized if it appears only contravariantly in the given type. + * 4. T is minimized in all other cases. + * + * The instantiation is done in two phases: + * 1st Phase: Try to instantiate minimizable type variables to * their lower bound. Record whether successful. * 2nd Phase: If first phase was successful, instantiate all remaining type variables * to their upper bound. @@ -61,14 +79,20 @@ trait Inferencing { this: Checking => case _: WildcardType | _: ProtoType => false case tvar: TypeVar if !tvar.isInstantiated => - if (force == ForceDegree.none) false - else { - val minimize = - variance >= 0 && !( - force == ForceDegree.noBottom && - isBottomType(ctx.typeComparer.approximation(tvar.origin, fromBelow = true))) - if (minimize) instantiate(tvar, fromBelow = true) - else toMaximize = true + force.appliesTo(tvar) && { + val direction = instDirection(tvar.origin) + if (direction != 0) { + if (direction > 0) println(s"inst $tvar dir = up") + instantiate(tvar, direction < 0) + } + else { + val minimize = + variance >= 0 && !( + force == ForceDegree.noBottom && + isBottomType(ctx.typeComparer.approximation(tvar.origin, fromBelow = true))) + if (minimize) instantiate(tvar, fromBelow = true) + else toMaximize = true + } foldOver(x, tvar) } case tp => @@ -93,6 +117,62 @@ trait Inferencing { this: Checking => } } + /** The list of uninstantiated type variables bound by some prefix of type `T` which + * occur in at least one formal parameter type of a prefix application. + * Considered prefixes are: + * - The function `f` of an application node `f(e1, .., en)` + * - The function `f` of a type application node `f[T1, ..., Tn]` + * - The prefix `p` of a selection `p.f`. + * - The result expression `e` of a block `{s1; .. sn; e}`. + */ + def tvarsInParams(tree: Tree)(implicit ctx: Context): List[TypeVar] = { + @tailrec def boundVars(tree: Tree, acc: List[TypeVar]): List[TypeVar] = tree match { + case Apply(fn, _) => boundVars(fn, acc) + case TypeApply(fn, targs) => + val tvars = targs.tpes.collect { + case tvar: TypeVar if !tvar.isInstantiated => tvar + } + boundVars(fn, acc ::: tvars) + case Select(pre, _) => boundVars(pre, acc) + case Block(_, expr) => boundVars(expr, acc) + case _ => acc + } + @tailrec def occurring(tree: Tree, toTest: List[TypeVar], acc: List[TypeVar]): List[TypeVar] = + if (toTest.isEmpty) acc + else tree match { + case Apply(fn, _) => + fn.tpe match { + case mtp: MethodType => + val (occ, nocc) = toTest.partition(tvar => mtp.paramTypes.exists(tvar.occursIn)) + occurring(fn, nocc, occ ::: acc) + case _ => + occurring(fn, toTest, acc) + } + case TypeApply(fn, targs) => occurring(fn, toTest, acc) + case Select(pre, _) => occurring(pre, toTest, acc) + case Block(_, expr) => occurring(expr, toTest, acc) + case _ => acc + } + occurring(tree, boundVars(tree, Nil), Nil) + } + + /** The instantiation direction for given poly param computed + * from the constraint: + * @return 1 (maximize) if constraint is uniformly from above, + * -1 (minimize) if constraint is uniformly from below, + * 0 if unconstrained, or constraint is from below and above. + */ + private def instDirection(param: PolyParam)(implicit ctx: Context): Int = { + val constrained = ctx.typerState.constraint.fullBounds(param) + val original = param.binder.paramBounds(param.paramNum) + val cmp = ctx.typeComparer + val approxBelow = + if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0 + val approxAbove = + if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0 + approxAbove - approxBelow + } + def isBottomType(tp: Type)(implicit ctx: Context) = tp == defn.NothingType || tp == defn.NullType @@ -257,9 +337,10 @@ trait Inferencing { this: Checking => } /** An enumeration controlling the degree of forcing in "is-dully-defined" checks. */ -@sharable object ForceDegree extends Enumeration { - val none, // don't force type variables - noBottom, // force type variables, fail if forced to Nothing or Null - all = Value // force type variables, don't fail +@sharable object ForceDegree { + class Value(val appliesTo: TypeVar => Boolean) + val none = new Value(_ => false) + val all = new Value(_ => true) + val noBottom = new Value(_ => true) } diff --git a/src/dotty/tools/dotc/typer/Typer.scala b/src/dotty/tools/dotc/typer/Typer.scala index a2c49cdd9751..4dfd69203323 100644 --- a/src/dotty/tools/dotc/typer/Typer.scala +++ b/src/dotty/tools/dotc/typer/Typer.scala @@ -1312,6 +1312,8 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit case wtp: ExprType => adaptInterpolated(tree.withType(wtp.resultType), pt, original) case wtp: ImplicitMethodType if constrainResult(wtp, followAlias(pt)) => + val tvarsToInstantiate = tvarsInParams(tree) + wtp.paramTypes.foreach(instantiateSelected(_, tvarsToInstantiate)) val constr = ctx.typerState.constraint def addImplicitArgs = { def implicitArgError(msg: => String): Tree = { diff --git a/tests/neg/i739.scala b/tests/neg/i739.scala new file mode 100644 index 000000000000..5385fa42cced --- /dev/null +++ b/tests/neg/i739.scala @@ -0,0 +1,7 @@ +class Foo[A, B] +class Test { + implicit val f: Foo[Int, String] = ??? + def t[A, B >: A](a: A)(implicit f: Foo[A, B]) = ??? + t(1) // error +} + diff --git a/tests/pos/i739.scala b/tests/pos/i739.scala new file mode 100644 index 000000000000..61fed4e5dee4 --- /dev/null +++ b/tests/pos/i739.scala @@ -0,0 +1,17 @@ +class Foo + +object Test { + def foo[T](x: T)(implicit ev: T): T = ??? + + class Fn[T] { + def invoke(implicit ev: T): T = ??? + } + + def bar[T](x: T): Fn[T] = ??? + + def test: Unit = { + implicit val evidence: Foo = new Foo + foo(new Foo) + bar(new Foo).invoke + } +} diff --git a/tests/run/liftedTry.scala b/tests/run/liftedTry.scala index 5ff4add6dee3..ff9af98eca32 100644 --- a/tests/run/liftedTry.scala +++ b/tests/run/liftedTry.scala @@ -12,7 +12,7 @@ object Test { foo(try 3 catch handle) - def main(args: Array[String]): Unit = { + def main(args: Array[String]) = { assert(x == 1) assert(foo(2) == 2) assert(foo(try raise(3) catch handle) == 3) @@ -20,7 +20,6 @@ object Test { } } - object Tr { def fun(a: Int => Unit) = a(2) def foo: Int = {