Skip to content

Commit d75e0b2

Browse files
committed
Fixes for Metals infer expected type feature
Fixes to makeVarArg and assignType(SeqLiteral) fixes the propagation of vararg type errors, which would trip up InferExpectedType's looking at whether the tree's type are errors. Force LazyRef in AvoidWildcardsMap to avoid unneeded new LazyRef's. LazyRefs, which are created for recursive types, aren't cacheable, so if you TypeMap an AppliedType with one, it will create a brand new instance. OrderingConstraint#init runs AvoidWildcardsMap on param bounds. So when instDirection compares the constraint bounds and the original param bounds (to calculate the instantiate direction), because they are new instances they won't shortcircuit, leading to a recursion overflow. By forcing, it will eq check and return true. Finally, change interpolateTypeVars' instantiation decision, following the logic that isFullyDefined had. In particular, if the typevar has an upper bound constraint, maximise rather than minimise, which fixes the inference of map/flatMap's B type args. Finally, drop needless tree.tpe.isInstanceOf[MethodOrPoly] and tree.tpe.widen in interpolateTypeVars.
1 parent ca19d4a commit d75e0b2

File tree

8 files changed

+86
-87
lines changed

8 files changed

+86
-87
lines changed

Diff for: compiler/src/dotty/tools/dotc/core/Types.scala

+4
Original file line numberDiff line numberDiff line change
@@ -4691,6 +4691,7 @@ object Types extends TypeUtils {
46914691
type BT <: LambdaType
46924692
def paramNum: Int
46934693
def paramName: binder.ThisName = binder.paramNames(paramNum)
4694+
def paramInfo: binder.PInfo = binder.paramInfos(paramNum)
46944695

46954696
override def underlying(using Context): Type = {
46964697
// TODO: update paramInfos's type to nullable
@@ -6629,6 +6630,9 @@ object Types extends TypeUtils {
66296630
range(atVariance(-variance)(apply(bounds.lo)), apply(bounds.hi))
66306631
def apply(t: Type): Type = t match
66316632
case t: WildcardType => mapWild(t)
6633+
case tp: LazyRef => mapOver(tp) match
6634+
case tp1: LazyRef if tp.ref eq tp1.ref => tp
6635+
case tp1 => tp1
66326636
case _ => mapOver(t)
66336637

66346638
// ----- TypeAccumulators ----------------------------------------------------

Diff for: compiler/src/dotty/tools/dotc/typer/Applications.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ trait Applications extends Compatibility {
883883
def makeVarArg(n: Int, elemFormal: Type): Unit = {
884884
val args = typedArgBuf.takeRight(n).toList
885885
typedArgBuf.dropRightInPlace(n)
886-
val elemtpt = TypeTree(elemFormal)
886+
val elemtpt = TypeTree(if !args.exists(_.tpe.isError) then elemFormal else UnspecifiedErrorType)
887887
typedArgBuf += seqToRepeated(SeqLiteral(args, elemtpt))
888888
}
889889

Diff for: compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+31-42
Original file line numberDiff line numberDiff line change
@@ -240,25 +240,12 @@ object Inferencing {
240240
&& {
241241
var fail = false
242242
var skip = false
243-
val direction = instDirection(tvar.origin)
244-
if minimizeSelected then
245-
if direction <= 0 && tvar.hasLowerBound then
246-
skip = instantiate(tvar, fromBelow = true)
247-
else if direction >= 0 && tvar.hasUpperBound then
248-
skip = instantiate(tvar, fromBelow = false)
249-
// else hold off instantiating unbounded unconstrained variable
250-
else if direction != 0 then
251-
skip = instantiate(tvar, fromBelow = direction < 0)
252-
else if variance >= 0 && tvar.hasLowerBound then
253-
skip = instantiate(tvar, fromBelow = true)
254-
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
255-
&& force.ifBottom == IfBottom.ok
256-
then // if variance == 0, prefer upper bound if one is given
257-
skip = instantiate(tvar, fromBelow = true)
258-
else if variance >= 0 && force.ifBottom == IfBottom.fail then
259-
fail = true
260-
else
261-
toMaximize = tvar :: toMaximize
243+
instDecision(tvar.origin, variance, minimizeSelected, force.ifBottom) match
244+
case Decision.Min => skip = instantiate(tvar, fromBelow = true)
245+
case Decision.Max => skip = instantiate(tvar, fromBelow = false)
246+
case Decision.Skip => // hold off instantiating unbounded unconstrained variable
247+
case Decision.Fail => fail = true
248+
case Decision.ToMax => toMaximize ::= tvar
262249
!fail && (skip || foldOver(x, tvar))
263250
}
264251
case tp => foldOver(x, tp)
@@ -438,22 +425,20 @@ object Inferencing {
438425
occurring(tree, boundVars(tree, Nil), Nil)
439426
}
440427

441-
/** The instantiation direction for given poly param computed
442-
* from the constraint:
443-
* @return 1 (maximize) if constraint is uniformly from above,
444-
* -1 (minimize) if constraint is uniformly from below,
445-
* 0 if unconstrained, or constraint is from below and above.
446-
*/
447-
private def instDirection(param: TypeParamRef)(using Context): Int = {
448-
val constrained = TypeComparer.fullBounds(param)
449-
val original = param.binder.paramInfos(param.paramNum)
450-
val cmp = TypeComparer
451-
val approxBelow =
452-
if (!cmp.isSubTypeWhenFrozen(constrained.lo, original.lo)) 1 else 0
453-
val approxAbove =
454-
if (!cmp.isSubTypeWhenFrozen(original.hi, constrained.hi)) 1 else 0
455-
approxAbove - approxBelow
456-
}
428+
/** The instantiation decision for given poly param computed from the constraint. */
429+
enum Decision { case Min; case Max; case ToMax; case Skip; case Fail }
430+
private def instDecision(param: TypeParamRef, v: Int, min: Boolean, ifBottom: IfBottom)(using Context): Decision =
431+
import Decision.*
432+
val tb = param.paramInfo // type bounds
433+
val cb = TypeComparer.fullBounds(param) // constrained bounds
434+
val dir = (if cb.lo frozen_<:< tb.lo then 0 else -1) + (if tb.hi frozen_<:< cb.hi then 0 else 1)
435+
if dir < 0 || (min || v >= 0) && !cb.lo.isExactlyNothing then Min
436+
else if dir > 0 || (min || v == 0) && !cb.hi.isTopOfSomeKind then Max // prefer upper bound if one is given
437+
else if min then Skip
438+
else ifBottom match
439+
case IfBottom.ok => if v >= 0 then Min else ToMax
440+
case IfBottom.fail => if v >= 0 then Fail else ToMax
441+
case ifBottom_flip => ToMax
457442

458443
/** Following type aliases and stripping refinements and annotations, if one arrives at a
459444
* class type reference where the class has a companion module, a reference to
@@ -651,16 +636,17 @@ trait Inferencing { this: Typer =>
651636

652637
val ownedVars = state.ownedVars
653638
if (ownedVars ne locked) && !ownedVars.isEmpty then
654-
val qualifying = ownedVars -- locked
639+
val qualifying = (ownedVars -- locked).toList
655640
if (!qualifying.isEmpty) {
656-
typr.println(i"interpolate $tree: ${tree.tpe.widen} in $state, pt = $pt, owned vars = ${state.ownedVars.toList}%, %, qualifying = ${qualifying.toList}%, %, previous = ${locked.toList}%, % / ${state.constraint}")
641+
typr.println(i"interpolate $tree: ${tree.tpe.widen} in $state, pt = $pt, owned vars = ${ownedVars.toList}, qualifying = $qualifying.toList}, previous = ${locked.toList}%, % / ${state.constraint}")
657642
val resultAlreadyConstrained =
658-
tree.isInstanceOf[Apply] || tree.tpe.isInstanceOf[MethodOrPoly]
643+
tree.isInstanceOf[Apply]
659644
if (!resultAlreadyConstrained)
645+
trace(i"constrainResult($tree ${tree.symbol}, ${tree.tpe}, $pt)"):
660646
constrainResult(tree.symbol, tree.tpe, pt)
661647
// This is needed because it could establish singleton type upper bounds. See i2998.scala.
662648

663-
val tp = tree.tpe.widen
649+
val tp = tree.tpe
664650
val vs = variances(tp, pt)
665651

666652
// Avoid interpolating variables occurring in tree's type if typerstate has unreported errors.
@@ -687,6 +673,8 @@ trait Inferencing { this: Typer =>
687673

688674
def constraint = state.constraint
689675

676+
trace(i"interpolateTypeVars($tree: ${tree.tpe}, $pt, $qualifying)", typr, (_: Any) => i"$qualifying $constraint") {
677+
690678
/** Values of this type report type variables to instantiate with variance indication:
691679
* +1 variable appears covariantly, can be instantiated from lower bound
692680
* -1 variable appears contravariantly, can be instantiated from upper bound
@@ -782,12 +770,10 @@ trait Inferencing { this: Typer =>
782770
/** Try to instantiate `tvs`, return any suspended type variables */
783771
def tryInstantiate(tvs: ToInstantiate): ToInstantiate = tvs match
784772
case (hd @ (tvar, v)) :: tvs1 =>
785-
val fromBelow = v == 1 || (v == 0 && tvar.hasLowerBound)
786-
typr.println(
787-
i"interpolate${if v == 0 then " non-occurring" else ""} $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
788773
if tvar.isInstantiated then
789774
tryInstantiate(tvs1)
790775
else
776+
val fromBelow = instDecision(tvar.origin, v, false, IfBottom.flip) == Decision.Min
791777
val suspend = tvs1.exists{ (following, _) =>
792778
if fromBelow
793779
then constraint.isLess(following.origin, tvar.origin)
@@ -797,13 +783,16 @@ trait Inferencing { this: Typer =>
797783
typr.println(i"suspended: $hd")
798784
hd :: tryInstantiate(tvs1)
799785
else
786+
typr.println(
787+
i"interpolate${if v == 0 then " non-occurring" else ""} $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
800788
tvar.instantiate(fromBelow)
801789
tryInstantiate(tvs1)
802790
case Nil => Nil
803791
if tvs.nonEmpty then doInstantiate(tryInstantiate(tvs))
804792
end doInstantiate
805793

806794
doInstantiate(filterByDeps(toInstantiate))
795+
}
807796
}
808797
end if
809798
tree

Diff for: compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ trait TypeAssigner {
469469
else tree.withType(TypeComparer.lub(expr.tpe :: cases.tpes))
470470

471471
def assignType(tree: untpd.SeqLiteral, elems: List[Tree], elemtpt: Tree)(using Context): SeqLiteral =
472-
tree.withType(seqLitType(tree, elemtpt.tpe))
472+
tree.withType(if elemtpt.tpe.isError then elemtpt.tpe else seqLitType(tree, elemtpt.tpe))
473473

474474
def assignType(tree: untpd.SingletonTypeTree, ref: Tree)(using Context): SingletonTypeTree =
475475
tree.withType(ref.tpe)

Diff for: compiler/test/dotty/tools/dotc/typer/InstantiateModel.scala

+24-26
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,16 @@ package typer
44

55
// Modelling the decision in IsFullyDefined
66
object InstantiateModel:
7-
enum LB { case NN; case LL; case L1 }; import LB.*
8-
enum UB { case AA; case UU; case U1 }; import UB.*
9-
enum Var { case V; case NotV }; import Var.*
10-
enum MSe { case M; case NotM }; import MSe.*
11-
enum Bot { case Fail; case Ok; case Flip }; import Bot.*
12-
enum Act { case Min; case Max; case ToMax; case Skip; case False }; import Act.*
7+
enum LB { case NN; case LL; case L1 }; import LB.*
8+
enum UB { case AA; case UU; case U1 }; import UB.*
9+
import Inferencing.Decision.*
1310

1411
// NN/AA = Nothing/Any
1512
// LL/UU = the original bounds, on the type parameter
1613
// L1/U1 = the constrained bounds, on the type variable
17-
// V = variance >= 0 ("non-contravariant")
18-
// MSe = minimisedSelected
19-
// Bot = IfBottom
2014
// ToMax = delayed maximisation, via addition to toMaximize
2115
// Skip = minimisedSelected "hold off instantiating"
22-
// False = return false
16+
// Fail = IfBottom.fail's bail option
2317

2418
// there are 9 combinations:
2519
// # | LB | UB | d | // d = direction
@@ -34,24 +28,28 @@ object InstantiateModel:
3428
// 8 | NN | UU | 0 | T <: UU
3529
// 9 | NN | AA | 0 | T
3630

37-
def decide(lb: LB, ub: UB, v: Var, bot: Bot, m: MSe): Act = (lb, ub) match
31+
def instDecision(lb: LB, ub: UB, v: Int, ifBottom: IfBottom, min: Boolean) = (lb, ub) match
3832
case (L1, AA) => Min
3933
case (L1, UU) => Min
4034
case (LL, U1) => Max
4135
case (NN, U1) => Max
4236

43-
case (L1, U1) => if m==M || v==V then Min else ToMax
44-
case (LL, UU) => if m==M || v==V then Min else ToMax
45-
case (LL, AA) => if m==M || v==V then Min else ToMax
46-
47-
case (NN, UU) => bot match
48-
case _ if m==M => Max
49-
//case Ok if v==V => Min // removed, i14218 fix
50-
case Fail if v==V => False
51-
case _ => ToMax
52-
53-
case (NN, AA) => bot match
54-
case _ if m==M => Skip
55-
case Ok if v==V => Min
56-
case Fail if v==V => False
57-
case _ => ToMax
37+
case (L1, U1) => if min then Min else pickVar(v, Min, Min, ToMax)
38+
case (LL, UU) => if min then Min else pickVar(v, Min, Min, ToMax)
39+
case (LL, AA) => if min then Min else pickVar(v, Min, Min, ToMax)
40+
41+
case (NN, UU) => ifBottom match
42+
case IfBottom.ok => pickVar(v, Min, ToMax, ToMax)
43+
case IfBottom.fail => pickVar(v, Fail, Fail, ToMax)
44+
case IfBottom.flip => if min then Max else ToMax
45+
46+
case (NN, AA) => ifBottom match
47+
case IfBottom.ok => pickVar(v, Min, Min, ToMax)
48+
case IfBottom.fail => pickVar(v, Fail, Fail, ToMax)
49+
case IfBottom.flip => if min then Skip else ToMax
50+
51+
def interpolateTypeVars(lb: LB, ub: UB, v: Int) =
52+
instDecision(lb, ub, v, IfBottom.flip, min = false)
53+
54+
def pickVar[A](v: Int, cov: A, inv: A, con: A) =
55+
if v > 0 then cov else if v == 0 then inv else con

Diff for: presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala

+9-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import dotty.tools.dotc.core.Types.*
1111
import dotty.tools.dotc.core.Types.Type
1212
import dotty.tools.dotc.interactive.Interactive
1313
import dotty.tools.dotc.interactive.InteractiveDriver
14+
import dotty.tools.dotc.typer.Applications.UnapplyArgs
15+
import dotty.tools.dotc.util.NoSourcePosition
1416
import dotty.tools.dotc.util.SourceFile
1517
import dotty.tools.dotc.util.Spans.Span
1618
import dotty.tools.pc.IndexedContext
@@ -86,9 +88,15 @@ object InterCompletionType:
8688
// List(@@)
8789
case SeqLiteral(_, tpe) :: _ if !tpe.tpe.isErroneous =>
8890
Some(tpe.tpe)
91+
case SeqLiteral(_, _) :: _typed :: rest =>
92+
inferType(rest, span)
8993
// val _: T = @@
9094
// def _: T = @@
9195
case (defn: ValOrDefDef) :: rest if !defn.tpt.tpe.isErroneous => Some(defn.tpt.tpe)
96+
case UnApply(fun, _, pats) :: _ =>
97+
val ind = pats.indexWhere(_.span.contains(span))
98+
if ind < 0 then None
99+
else Some(UnapplyArgs(fun.tpe.finalResultType, fun, pats, NoSourcePosition).argTypes(ind))
92100
// f(@@)
93101
case (app: Apply) :: rest =>
94102
val param =
@@ -98,7 +106,7 @@ object InterCompletionType:
98106
}
99107
params <- app.symbol.paramSymss.find(!_.exists(_.isTypeParam))
100108
param <- params.get(ind)
101-
} yield param.info
109+
} yield param.info.repeatedToSingle
102110
param match
103111
// def f[T](a: T): T = ???
104112
// f[Int](@@)

Diff for: presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala

+15-15
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class InferExpectedTypeSuite extends BasePCSuite:
2626
EmptyCancelToken
2727
)
2828
presentationCompiler.asInstanceOf[ScalaPresentationCompiler].inferExpectedType(offsetParams).get().asScala match {
29-
case Some(value) => assertNoDiff(value, expectedType)
29+
case Some(value) => assertNoDiff(expectedType, value)
3030
case None => fail("Empty result.")
3131
}
3232

@@ -55,7 +55,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
5555
|""".stripMargin
5656
)
5757

58-
@Ignore("Not handled correctly.")
5958
@Test def list =
6059
check(
6160
"""|val i: List[Int] = List(@@)
@@ -193,7 +192,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
193192
|""".stripMargin
194193
)
195194

196-
@Ignore("Unapply is not handled correctly.")
197195
@Test def unapply =
198196
check(
199197
"""|val _ =
@@ -223,7 +221,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
223221
|""".stripMargin
224222
)
225223

226-
@Ignore("Generic functions are not handled correctly.")
227224
@Test def flatmap =
228225
check(
229226
"""|val _ : List[Int] = List().flatMap(_ => @@)
@@ -232,7 +229,14 @@ class InferExpectedTypeSuite extends BasePCSuite:
232229
|""".stripMargin
233230
)
234231

235-
@Ignore("Generic functions are not handled correctly.")
232+
@Test def map =
233+
check(
234+
"""|val _ : List[Int] = List().map(_ => @@)
235+
|""".stripMargin,
236+
"""|Int
237+
|""".stripMargin
238+
)
239+
236240
@Test def `for-comprehension` =
237241
check(
238242
"""|val _ : List[Int] =
@@ -245,40 +249,36 @@ class InferExpectedTypeSuite extends BasePCSuite:
245249
)
246250

247251
// bounds
248-
@Ignore("Bounds are not handled correctly.")
249252
@Test def any =
250253
check(
251254
"""|trait Foo
252255
|def foo[T](a: T): Boolean = ???
253256
|val _ = foo(@@)
254257
|""".stripMargin,
255-
"""|<: Any
258+
"""|Any
256259
|""".stripMargin
257260
)
258261

259-
@Ignore("Bounds are not handled correctly.")
260262
@Test def `bounds-1` =
261263
check(
262264
"""|trait Foo
263-
|def foo[T <: Foo](a: Foo): Boolean = ???
265+
|def foo[T <: Foo](a: T): Boolean = ???
264266
|val _ = foo(@@)
265267
|""".stripMargin,
266-
"""|<: Foo
268+
"""|Foo
267269
|""".stripMargin
268270
)
269271

270-
@Ignore("Bounds are not handled correctly.")
271272
@Test def `bounds-2` =
272273
check(
273274
"""|trait Foo
274-
|def foo[T :> Foo](a: Foo): Boolean = ???
275+
|def foo[T >: Foo](a: T): Boolean = ???
275276
|val _ = foo(@@)
276277
|""".stripMargin,
277-
"""|:> Foo
278+
"""|Foo
278279
|""".stripMargin
279280
)
280281

281-
@Ignore("Bounds are not handled correctly.")
282282
@Test def `bounds-3` =
283283
check(
284284
"""|trait A
@@ -287,6 +287,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
287287
|def roo[F >: C <: A](f: F) = ???
288288
|val kjk = roo(@@)
289289
|""".stripMargin,
290-
"""|>: C <: A
290+
"""|C
291291
|""".stripMargin
292292
)

Diff for: tests/neg/recursive-lower-constraint.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ class Bar extends Foo[Bar]
33

44
class A {
55
def foo[T <: Foo[T], U >: Foo[T] <: T](x: T): T = x
6-
foo(new Bar) // error // error
6+
foo(new Bar) // error
77
}

0 commit comments

Comments
 (0)