Skip to content

Draft: Overloading result-based pruning shouldn't prefer inapplicable alt #21730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2129,11 +2129,13 @@ trait Applications extends Compatibility {
return resolveMapped(alts, alt => stripImplicit(alt.widen), pt)
case _ =>

var found = withoutMode(Mode.ImplicitsEnabled)(resolveOverloaded1(alts, pt))
var candidatesAndFound = withoutMode(Mode.ImplicitsEnabled)(resolveOverloaded1(alts, pt))
def candidates = candidatesAndFound._1
def found = candidatesAndFound._2
if found.isEmpty && ctx.mode.is(Mode.ImplicitsEnabled) then
found = resolveOverloaded1(alts, pt)
candidatesAndFound = resolveOverloaded1(alts, pt)
found match
case alt :: Nil => adaptByResult(alt, alts) :: Nil
case alt :: Nil => adaptByResult(alt, candidates) :: Nil
case _ => found
end resolve

Expand Down Expand Up @@ -2178,7 +2180,7 @@ trait Applications extends Compatibility {
* It might be called twice from the public `resolveOverloaded` method, once with
* implicits and SAM conversions enabled, and once without.
*/
private def resolveOverloaded1(alts: List[TermRef], pt: Type)(using Context): List[TermRef] =
private def resolveOverloaded1(alts: List[TermRef], pt: Type)(using Context): (List[TermRef], List[TermRef]) =
trace(i"resolve over $alts%, %, pt = $pt", typr, show = true) {
record(s"resolveOverloaded1", alts.length)

Expand Down Expand Up @@ -2306,7 +2308,9 @@ trait Applications extends Compatibility {
else
record("resolveOverloaded.narrowedByShape", alts2.length)
pretypeArgs(alts2, pt)
narrowByTrees(alts2, pt.typedArgs(normArg(alts2, _, _)), resultType)
val alts3 = narrowByTrees(alts2, pt.typedArgs(normArg(alts2, _, _)), resultType)
overload.println(i"narrowed by trees: ${alts3.map(_.symbol.showDcl)}%, %")
alts3

case pt @ PolyProto(targs1, pt1) =>
val alts1 = alts.filterConserve(pt.canInstantiate)
Expand Down Expand Up @@ -2372,42 +2376,42 @@ trait Applications extends Compatibility {
if pt.unusableForInference then
// `pt` might have become erroneous by typing arguments of FunProtos.
// If `pt` is erroneous, don't try to go further; report the error in `pt` instead.
candidates
(candidates, candidates)
else
val found = narrowMostSpecific(candidates)
if found.length <= 1 then found
if found.length <= 1 then (candidates, found)
else
val deepPt = pt.deepenProto
deepPt match
case pt @ FunProto(_, PolyProto(targs, resType)) =>
// try to narrow further with snd argument list and following type params
resolveMapped(candidates,
skipParamClause(pt.typedArgs().tpes, targs.tpes), resType)
(candidates, resolveMapped(candidates,
skipParamClause(pt.typedArgs().tpes, targs.tpes), resType))
case pt @ FunProto(_, resType: FunOrPolyProto) =>
// try to narrow further with snd argument list
resolveMapped(candidates,
skipParamClause(pt.typedArgs().tpes, Nil), resType)
(candidates, resolveMapped(candidates,
skipParamClause(pt.typedArgs().tpes, Nil), resType))
case _ =>
// prefer alternatives that need no eta expansion
val noCurried = alts.filterConserve(!resultIsMethod(_))
val noCurriedCount = noCurried.length
if noCurriedCount == 1 then
noCurried
(candidates, noCurried)
else if noCurriedCount > 1 && noCurriedCount < alts.length then
resolveOverloaded1(noCurried, pt)
else
// prefer alternatves that match without default parameters
val noDefaults = alts.filterConserve(!_.symbol.hasDefaultParams)
val noDefaultsCount = noDefaults.length
if noDefaultsCount == 1 then
noDefaults
(candidates, noDefaults)
else if noDefaultsCount > 1 && noDefaultsCount < alts.length then
resolveOverloaded1(noDefaults, pt)
else if deepPt ne pt then
// try again with a deeper known expected type
resolveOverloaded1(alts, deepPt)
else
candidates
(candidates, candidates)
}
end resolveOverloaded1

Expand Down
12 changes: 12 additions & 0 deletions tests/pos/i21410.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class A
object Test:
type F[X] <: Any = X match
case A => Int

def foo[T](x: String): T = ???
def foo[U](x: U): F[U] = ???

val x1 = foo(A())
val y: Int = x1

val x2: Int = foo(A()) // error
Loading