Skip to content

Commit 4b5b316

Browse files
ilevkivskyiJukkaL
authored andcommitted
Special-case unions in polymorphic inference (#16461)
Fixes #16451 This special-casing is unfortunate, but this is the best I came up so far.
1 parent f862d3e commit 4b5b316

File tree

3 files changed

+87
-9
lines changed

3 files changed

+87
-9
lines changed

mypy/solve.py

+44-9
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Iterable, Sequence
77
from typing_extensions import TypeAlias as _TypeAlias
88

9-
from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints
9+
from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op
1010
from mypy.expandtype import expand_type
1111
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
1212
from mypy.join import join_types
@@ -69,6 +69,10 @@ def solve_constraints(
6969
extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars])
7070
originals.update({v.id: v for v in c.extra_tvars if v.id not in originals})
7171

72+
if allow_polymorphic:
73+
# Constraints inferred from unions require special handling in polymorphic inference.
74+
constraints = skip_reverse_union_constraints(constraints)
75+
7276
# Collect a list of constraints for each type variable.
7377
cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars}
7478
for con in constraints:
@@ -431,19 +435,15 @@ def transitive_closure(
431435
uppers[l] |= uppers[upper]
432436
for lt in lowers[lower]:
433437
for ut in uppers[upper]:
434-
# TODO: what if secondary constraints result in inference
435-
# against polymorphic actual (also in below branches)?
436-
remaining |= set(infer_constraints(lt, ut, SUBTYPE_OF))
437-
remaining |= set(infer_constraints(ut, lt, SUPERTYPE_OF))
438+
add_secondary_constraints(remaining, lt, ut)
438439
elif c.op == SUBTYPE_OF:
439440
if c.target in uppers[c.type_var]:
440441
continue
441442
for l in tvars:
442443
if (l, c.type_var) in graph:
443444
uppers[l].add(c.target)
444445
for lt in lowers[c.type_var]:
445-
remaining |= set(infer_constraints(lt, c.target, SUBTYPE_OF))
446-
remaining |= set(infer_constraints(c.target, lt, SUPERTYPE_OF))
446+
add_secondary_constraints(remaining, lt, c.target)
447447
else:
448448
assert c.op == SUPERTYPE_OF
449449
if c.target in lowers[c.type_var]:
@@ -452,11 +452,24 @@ def transitive_closure(
452452
if (c.type_var, u) in graph:
453453
lowers[u].add(c.target)
454454
for ut in uppers[c.type_var]:
455-
remaining |= set(infer_constraints(ut, c.target, SUPERTYPE_OF))
456-
remaining |= set(infer_constraints(c.target, ut, SUBTYPE_OF))
455+
add_secondary_constraints(remaining, c.target, ut)
457456
return graph, lowers, uppers
458457

459458

459+
def add_secondary_constraints(cs: set[Constraint], lower: Type, upper: Type) -> None:
460+
"""Add secondary constraints inferred between lower and upper (in place)."""
461+
if isinstance(get_proper_type(upper), UnionType) and isinstance(
462+
get_proper_type(lower), UnionType
463+
):
464+
# When both types are unions, this can lead to inferring spurious constraints,
465+
# for example Union[T, int] <: S <: Union[T, int] may infer T <: int.
466+
# To avoid this, just skip them for now.
467+
return
468+
# TODO: what if secondary constraints result in inference against polymorphic actual?
469+
cs.update(set(infer_constraints(lower, upper, SUBTYPE_OF)))
470+
cs.update(set(infer_constraints(upper, lower, SUPERTYPE_OF)))
471+
472+
460473
def compute_dependencies(
461474
tvars: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds
462475
) -> dict[TypeVarId, list[TypeVarId]]:
@@ -494,6 +507,28 @@ def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool:
494507
return True
495508

496509

510+
def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
511+
"""Avoid ambiguities for constraints inferred from unions during polymorphic inference.
512+
513+
Polymorphic inference implicitly relies on assumption that a reverse of a linear constraint
514+
is a linear constraint. This is however not true in presence of union types, for example
515+
T :> Union[S, int] vs S <: T. Trying to solve such constraints would be detected ambiguous
516+
as (T, S) form a non-linear SCC. However, simply removing the linear part results in a valid
517+
solution T = Union[S, int], S = <free>.
518+
519+
TODO: a cleaner solution may be to avoid inferring such constraints in first place, but
520+
this would require passing around a flag through all infer_constraints() calls.
521+
"""
522+
reverse_union_cs = set()
523+
for c in cs:
524+
p_target = get_proper_type(c.target)
525+
if isinstance(p_target, UnionType):
526+
for item in p_target.items:
527+
if isinstance(item, TypeVarType):
528+
reverse_union_cs.add(Constraint(item, neg_op(c.op), c.origin_type_var))
529+
return [c for c in cs if c not in reverse_union_cs]
530+
531+
497532
def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]:
498533
"""Find type variables for which we are solving in a target type."""
499534
return {tv.id for tv in get_all_type_vars(target)} & set(vars)

test-data/unit/check-inference.test

+21
Original file line numberDiff line numberDiff line change
@@ -3767,3 +3767,24 @@ def f(values: List[T]) -> T: ...
37673767
x = foo(f([C()]))
37683768
reveal_type(x) # N: Revealed type is "__main__.C"
37693769
[builtins fixtures/list.pyi]
3770+
3771+
[case testInferenceAgainstGenericCallableUnion]
3772+
from typing import Callable, TypeVar, List, Union
3773+
3774+
T = TypeVar("T")
3775+
S = TypeVar("S")
3776+
3777+
def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]: ...
3778+
@dec
3779+
def func(arg: T) -> Union[T, str]:
3780+
...
3781+
reveal_type(func) # N: Revealed type is "def [S] (S`1) -> builtins.list[Union[S`1, builtins.str]]"
3782+
reveal_type(func(42)) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.str]]"
3783+
3784+
def dec2(f: Callable[[S], List[T]]) -> Callable[[S], T]: ...
3785+
@dec2
3786+
def func2(arg: T) -> List[Union[T, str]]:
3787+
...
3788+
reveal_type(func2) # N: Revealed type is "def [S] (S`4) -> Union[S`4, builtins.str]"
3789+
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
3790+
[builtins fixtures/list.pyi]

test-data/unit/check-parameter-specification.test

+22
Original file line numberDiff line numberDiff line change
@@ -2086,3 +2086,25 @@ reveal_type(d(b, f1)) # E: Cannot infer type argument 1 of "d" \
20862086
# N: Revealed type is "def (*Any, **Any)"
20872087
reveal_type(d(b, f2)) # N: Revealed type is "def (builtins.int)"
20882088
[builtins fixtures/paramspec.pyi]
2089+
2090+
[case testInferenceAgainstGenericCallableUnionParamSpec]
2091+
from typing import Callable, TypeVar, List, Union
2092+
from typing_extensions import ParamSpec
2093+
2094+
T = TypeVar("T")
2095+
P = ParamSpec("P")
2096+
2097+
def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ...
2098+
@dec
2099+
def func(arg: T) -> Union[T, str]:
2100+
...
2101+
reveal_type(func) # N: Revealed type is "def [T] (arg: T`-1) -> builtins.list[Union[T`-1, builtins.str]]"
2102+
reveal_type(func(42)) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.str]]"
2103+
2104+
def dec2(f: Callable[P, List[T]]) -> Callable[P, T]: ...
2105+
@dec2
2106+
def func2(arg: T) -> List[Union[T, str]]:
2107+
...
2108+
reveal_type(func2) # N: Revealed type is "def [T] (arg: T`-1) -> Union[T`-1, builtins.str]"
2109+
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
2110+
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)