Skip to content

Commit 6c556aa

Browse files
committed
Improve type inference for functions like fold
When calling a fold with an accumulator like `Nil` or `List()` one used to have add an explicit type ascription. This is now no longer necessary. When instantiating type variables that occur invariantly in the expected type of a lambda, we now replace covariant occurrences of `Nothing` in the (possibly widened) type of the accumulator with fresh type variables. The idea is that a fresh type variable in such places is always better than Nothing. For module values such as `Nil` we widen to `List[<fresh var>]`. This does possibly cause a new type error if the fold really wanted a `Nil` instance. But that case seems very rare, so it looks like a good bet in general to do the widening.
1 parent 8cb4945 commit 6c556aa

File tree

6 files changed

+155
-48
lines changed

6 files changed

+155
-48
lines changed

Diff for: compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+3-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Flags.*
1010
import config.Config
1111
import config.Printers.typr
1212
import typer.ProtoTypes.{newTypeVar, representedParamRef}
13+
import transform.TypeUtils.isTransparent
1314
import UnificationDirection.*
1415
import NameKinds.AvoidNameKind
1516
import util.SimpleIdentitySet
@@ -566,13 +567,6 @@ trait ConstraintHandling {
566567
inst
567568
end approximation
568569

569-
private def isTransparent(tp: Type, traitOnly: Boolean)(using Context): Boolean = tp match
570-
case AndType(tp1, tp2) =>
571-
isTransparent(tp1, traitOnly) && isTransparent(tp2, traitOnly)
572-
case _ =>
573-
val cls = tp.underlyingClassRef(refinementOK = false).typeSymbol
574-
cls.isTransparentClass && (!traitOnly || cls.is(Trait))
575-
576570
/** If `tp` is an intersection such that some operands are transparent trait instances
577571
* and others are not, replace as many transparent trait instances as possible with Any
578572
* as long as the result is still a subtype of `bound`. But fall back to the
@@ -585,7 +579,7 @@ trait ConstraintHandling {
585579
var dropped: List[Type] = List() // the types dropped so far, last one on top
586580

587581
def dropOneTransparentTrait(tp: Type): Type =
588-
if isTransparent(tp, traitOnly = true) && !kept.contains(tp) then
582+
if tp.isTransparent(traitOnly = true) && !kept.contains(tp) then
589583
dropped = tp :: dropped
590584
defn.AnyType
591585
else tp match
@@ -658,7 +652,7 @@ trait ConstraintHandling {
658652
def widenOr(tp: Type) =
659653
if widenUnions then
660654
val tpw = tp.widenUnion
661-
if (tpw ne tp) && !isTransparent(tpw, traitOnly = false) && (tpw <:< bound) then tpw else tp
655+
if (tpw ne tp) && !tpw.isTransparent() && (tpw <:< bound) then tpw else tp
662656
else tp.hardenUnions
663657

664658
def widenSingle(tp: Type) =

Diff for: compiler/src/dotty/tools/dotc/core/Types.scala

+12-13
Original file line numberDiff line numberDiff line change
@@ -4895,19 +4895,22 @@ object Types {
48954895
/** Instantiate variable with given type */
48964896
def instantiateWith(tp: Type)(using Context): Type = {
48974897
assert(tp ne this, i"self instantiation of $origin, constraint = ${ctx.typerState.constraint}")
4898-
assert(!myInst.exists, i"$origin is already instantiated to $myInst but we attempted to instantiate it to $tp")
4899-
typr.println(i"instantiating $this with $tp")
4898+
if !myInst.exists then
4899+
typr.println(i"instantiating $this with $tp")
49004900

4901-
if Config.checkConstraintsSatisfiable then
4902-
assert(currentEntry.bounds.contains(tp),
4903-
i"$origin is constrained to be $currentEntry but attempted to instantiate it to $tp")
4901+
if Config.checkConstraintsSatisfiable then
4902+
assert(currentEntry.bounds.contains(tp),
4903+
i"$origin is constrained to be $currentEntry but attempted to instantiate it to $tp")
49044904

4905-
if ((ctx.typerState eq owningState.nn.get.uncheckedNN) && !TypeComparer.subtypeCheckInProgress)
4906-
setInst(tp)
4907-
ctx.typerState.constraint = ctx.typerState.constraint.replace(origin, tp)
4905+
if ((ctx.typerState eq owningState.nn.get.uncheckedNN) && !TypeComparer.subtypeCheckInProgress)
4906+
setInst(tp)
4907+
ctx.typerState.constraint = ctx.typerState.constraint.replace(origin, tp)
49084908
tp
49094909
}
49104910

4911+
def typeToInstantiateWith(fromBelow: Boolean)(using Context): Type =
4912+
TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
4913+
49114914
/** Instantiate variable from the constraints over its `origin`.
49124915
* If `fromBelow` is true, the variable is instantiated to the lub
49134916
* of its lower bounds in the current constraint; otherwise it is
@@ -4916,11 +4919,7 @@ object Types {
49164919
* is also a singleton type.
49174920
*/
49184921
def instantiate(fromBelow: Boolean)(using Context): Type =
4919-
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions, nestingLevel)
4920-
if myInst.exists then // The line above might have triggered instantiation of the current type variable
4921-
myInst
4922-
else
4923-
instantiateWith(tp)
4922+
instantiateWith(typeToInstantiateWith(fromBelow))
49244923

49254924
/** Widen unions when instantiating this variable in the current context? */
49264925
def widenUnions(using Context): Boolean = !ctx.typerState.constraint.isHard(this)

Diff for: compiler/src/dotty/tools/dotc/transform/TypeUtils.scala

+11-5
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@ package transform
44

55
import core.*
66
import TypeErasure.ErasedValueType
7-
import Types.*
8-
import Contexts.*
9-
import Symbols.*
7+
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
108
import Names.Name
119

12-
import dotty.tools.dotc.core.Decorators.*
13-
1410
object TypeUtils {
1511
/** A decorator that provides methods on types
1612
* that are needed in the transformer pipeline.
@@ -98,5 +94,15 @@ object TypeUtils {
9894
def takesImplicitParams(using Context): Boolean = self.stripPoly match
9995
case mt: MethodType => mt.isImplicitMethod || mt.resType.takesImplicitParams
10096
case _ => false
97+
98+
/** Is this a type deriving only from transparent classes?
99+
* @param traitOnly if true, all class symbols must be transparent traits
100+
*/
101+
def isTransparent(traitOnly: Boolean = false)(using Context): Boolean = self match
102+
case AndType(tp1, tp2) =>
103+
tp1.isTransparent(traitOnly) && tp2.isTransparent(traitOnly)
104+
case _ =>
105+
val cls = self.underlyingClassRef(refinementOK = false).typeSymbol
106+
cls.isTransparentClass && (!traitOnly || cls.is(Trait))
101107
}
102108
}

Diff for: compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+83-20
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import ProtoTypes.*
99
import NameKinds.UniqueName
1010
import util.Spans.*
1111
import util.{Stats, SimpleIdentityMap, SimpleIdentitySet, SrcPos}
12-
import Decorators.*
12+
import transform.TypeUtils.isTransparent
13+
import Decorators._
1314
import config.Printers.{gadts, typr}
1415
import annotation.tailrec
1516
import reporting.*
@@ -60,7 +61,9 @@ object Inferencing {
6061
def instantiateSelected(tp: Type, tvars: List[Type])(using Context): Unit =
6162
if (tvars.nonEmpty)
6263
IsFullyDefinedAccumulator(
63-
ForceDegree.Value(tvars.contains, IfBottom.flip), minimizeSelected = true
64+
new ForceDegree.Value(IfBottom.flip):
65+
override def appliesTo(tvar: TypeVar) = tvars.contains(tvar),
66+
minimizeSelected = true
6467
).process(tp)
6568

6669
/** Instantiate any type variables in `tp` whose bounds contain a reference to
@@ -154,15 +157,58 @@ object Inferencing {
154157
* their lower bound. Record whether successful.
155158
* 2nd Phase: If first phase was successful, instantiate all remaining type variables
156159
* to their upper bound.
160+
*
161+
* Instance types can be improved by replacing covariant occurrences of Nothing
162+
* with fresh type variables, if `force` allows this in its `canImprove` implementation.
157163
*/
158164
private class IsFullyDefinedAccumulator(force: ForceDegree.Value, minimizeSelected: Boolean = false)
159165
(using Context) extends TypeAccumulator[Boolean] {
160166

161-
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Type = {
167+
/** Replace toplevel-covariant occurrences (i.e. covariant without double flips)
168+
* of Nothing by fresh type variables.
169+
* For singleton types and references to module classes: try to
170+
* improve the widened type. For module classes, the widened type
171+
* is the intersection of all its non-transparent parent types.
172+
*/
173+
private def improve(tvar: TypeVar) = new TypeMap:
174+
def apply(t: Type) = trace(i"improve $t", show = true):
175+
def tryWidened(widened: Type): Type =
176+
val improved = apply(widened)
177+
if improved ne widened then improved else mapOver(t)
178+
if variance > 0 then
179+
t match
180+
case t: TypeRef =>
181+
if t.symbol == defn.NothingClass then
182+
newTypeVar(TypeBounds.empty, nestingLevel = tvar.nestingLevel)
183+
else if t.symbol.is(ModuleClass) then
184+
tryWidened(t.parents.filter(!_.isTransparent())
185+
.foldLeft(defn.AnyType: Type)(TypeComparer.andType(_, _)))
186+
else
187+
mapOver(t)
188+
case t: TermRef =>
189+
tryWidened(t.widen)
190+
case _ =>
191+
mapOver(t)
192+
else t
193+
194+
/** Instantiate type variable with possibly improved computed instance type.
195+
* @return true if variable was instantiated with improved type, which
196+
* in this case should not be instantiated further, false otherwise.
197+
*/
198+
private def instantiate(tvar: TypeVar, fromBelow: Boolean): Boolean =
199+
if fromBelow && force.canImprove(tvar) then
200+
val inst = tvar.typeToInstantiateWith(fromBelow = true)
201+
if apply(true, inst) then
202+
// need to recursively check before improving, since improving adds type vars
203+
// which should not be instantiated at this point
204+
val better = improve(tvar)(inst)
205+
if better <:< TypeComparer.fullUpperBound(tvar.origin) then
206+
typr.println(i"forced instantiation of invariant ${tvar.origin} = $inst, improved to $better")
207+
tvar.instantiateWith(better)
208+
return true
162209
val inst = tvar.instantiate(fromBelow)
163210
typr.println(i"forced instantiation of ${tvar.origin} = $inst")
164-
inst
165-
}
211+
false
166212

167213
private var toMaximize: List[TypeVar] = Nil
168214

@@ -178,26 +224,27 @@ object Inferencing {
178224
&& ctx.typerState.constraint.contains(tvar)
179225
&& {
180226
var fail = false
227+
var skip = false
181228
val direction = instDirection(tvar.origin)
182229
if minimizeSelected then
183230
if direction <= 0 && tvar.hasLowerBound then
184-
instantiate(tvar, fromBelow = true)
231+
skip = instantiate(tvar, fromBelow = true)
185232
else if direction >= 0 && tvar.hasUpperBound then
186-
instantiate(tvar, fromBelow = false)
233+
skip = instantiate(tvar, fromBelow = false)
187234
// else hold off instantiating unbounded unconstrained variable
188235
else if direction != 0 then
189-
instantiate(tvar, fromBelow = direction < 0)
236+
skip = instantiate(tvar, fromBelow = direction < 0)
190237
else if variance >= 0 && tvar.hasLowerBound then
191-
instantiate(tvar, fromBelow = true)
238+
skip = instantiate(tvar, fromBelow = true)
192239
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
193240
&& force.ifBottom == IfBottom.ok
194241
then // if variance == 0, prefer upper bound if one is given
195-
instantiate(tvar, fromBelow = true)
242+
skip = instantiate(tvar, fromBelow = true)
196243
else if variance >= 0 && force.ifBottom == IfBottom.fail then
197244
fail = true
198245
else
199246
toMaximize = tvar :: toMaximize
200-
!fail && foldOver(x, tvar)
247+
!fail && (skip || foldOver(x, tvar))
201248
}
202249
case tp => foldOver(x, tp)
203250
}
@@ -467,7 +514,7 @@ object Inferencing {
467514
*
468515
* we want to instantiate U to x.type right away. No need to wait further.
469516
*/
470-
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
517+
def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
471518
Stats.record("variances")
472519
val constraint = ctx.typerState.constraint
473520

@@ -769,14 +816,30 @@ trait Inferencing { this: Typer =>
769816
}
770817

771818
/** An enumeration controlling the degree of forcing in "is-fully-defined" checks. */
772-
@sharable object ForceDegree {
773-
class Value(val appliesTo: TypeVar => Boolean, val ifBottom: IfBottom):
774-
override def toString = s"ForceDegree.Value(.., $ifBottom)"
775-
val none: Value = new Value(_ => false, IfBottom.ok) { override def toString = "ForceDegree.none" }
776-
val all: Value = new Value(_ => true, IfBottom.ok) { override def toString = "ForceDegree.all" }
777-
val failBottom: Value = new Value(_ => true, IfBottom.fail) { override def toString = "ForceDegree.failBottom" }
778-
val flipBottom: Value = new Value(_ => true, IfBottom.flip) { override def toString = "ForceDegree.flipBottom" }
779-
}
819+
@sharable object ForceDegree:
820+
class Value(val ifBottom: IfBottom):
821+
822+
/** Does `tv` need to be instantiated? */
823+
def appliesTo(tv: TypeVar): Boolean = true
824+
825+
/** Should we try to improve the computed instance type by replacing bottom types
826+
* with fresh type variables?
827+
*/
828+
def canImprove(tv: TypeVar): Boolean = false
829+
830+
override def toString = s"ForceDegree.Value($ifBottom)"
831+
end Value
832+
833+
val none: Value = new Value(IfBottom.ok):
834+
override def appliesTo(tv: TypeVar) = false
835+
override def toString = "ForceDegree.none"
836+
val all: Value = new Value(IfBottom.ok):
837+
override def toString = "ForceDegree.all"
838+
val failBottom: Value = new Value(IfBottom.fail):
839+
override def toString = "ForceDegree.failBottom"
840+
val flipBottom: Value = new Value(IfBottom.flip):
841+
override def toString = "ForceDegree.flipBottom"
842+
end ForceDegree
780843

781844
enum IfBottom:
782845
case ok, fail, flip

Diff for: compiler/src/dotty/tools/dotc/typer/Typer.scala

+12-1
Original file line numberDiff line numberDiff line change
@@ -1622,14 +1622,25 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16221622
case _ =>
16231623

16241624
if desugared.isEmpty then
1625+
val forceDegree =
1626+
if pt.isValueType then
1627+
// Allow variables that appear invariantly in `pt` to be improved by mapping
1628+
// bottom types in their instance types to fresh type variables
1629+
new ForceDegree.Value(IfBottom.fail):
1630+
val tvmap = variances(pt)
1631+
override def canImprove(tvar: TypeVar) =
1632+
tvmap.computedVariance(tvar) == (0: Integer)
1633+
else
1634+
ForceDegree.failBottom
1635+
16251636
val inferredParams: List[untpd.ValDef] =
16261637
for ((param, i) <- params.zipWithIndex) yield
16271638
if (!param.tpt.isEmpty) param
16281639
else
16291640
val (formalBounds, isErased) = protoFormal(i)
16301641
val formal = formalBounds.loBound
16311642
val isBottomFromWildcard = (formalBounds ne formal) && formal.isExactlyNothing
1632-
val knownFormal = isFullyDefined(formal, ForceDegree.failBottom)
1643+
val knownFormal = isFullyDefined(formal, forceDegree)
16331644
// If the expected formal is a TypeBounds wildcard argument with Nothing as lower bound,
16341645
// try to prioritize inferring from target. See issue 16405 (tests/run/16405.scala)
16351646
val paramType =

Diff for: tests/pos/folds.scala

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
object Test:
3+
extension [A](xs: List[A])
4+
def foldl[B](acc: B)(f: (A, B) => B): B = ???
5+
6+
val xs = List(1, 2, 3)
7+
8+
val _ = xs.foldl(List())((y, ys) => y :: ys)
9+
10+
val _ = xs.foldl(Nil)((y, ys) => y :: ys)
11+
12+
def partition[a](xs: List[a], pred: a => Boolean): Tuple2[List[a], List[a]] = {
13+
xs.foldRight/*[Tuple2[List[a], List[a]]]*/((List(), List())) {
14+
(x, p) => if (pred (x)) (x :: p._1, p._2) else (p._1, x :: p._2)
15+
}
16+
}
17+
18+
def snoc[A](xs: List[A], x: A) = x :: xs
19+
20+
def reverse[A](xs: List[A]) =
21+
xs.foldLeft(Nil)(snoc)
22+
23+
def reverse2[A](xs: List[A]) =
24+
xs.foldLeft(List())(snoc)
25+
26+
val ys: Seq[Int] = xs
27+
ys.foldLeft(Seq())((ys, y) => y +: ys)
28+
ys.foldLeft(Nil)((ys, y) => y +: ys)
29+
30+
def dup[A](xs: List[A]) =
31+
xs.foldRight(Nil)((x, xs) => x :: x :: xs)
32+
33+
def toSet[A](xs: Seq[A]) =
34+
xs.foldLeft(Set.empty)(_ + _)

0 commit comments

Comments
 (0)