Skip to content

Commit f328ad6

Browse files
authored
Fix nondeterministic type checking caused by nonassociativity of joins (#19147)
I thought about doing this in `join_type_list`, but most callers look like they do have some deterministic order. Fixes #19121 (torchvision case only, haven't looked at xarray) Fixes #16979 (OP case only, bzoracler case fixed by #18402)
1 parent 33d1eed commit f328ad6

File tree

3 files changed

+68
-15
lines changed

3 files changed

+68
-15
lines changed

mypy/solve.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op
1010
from mypy.expandtype import expand_type
1111
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
12-
from mypy.join import join_types
12+
from mypy.join import join_type_list
1313
from mypy.meet import meet_type_list, meet_types
1414
from mypy.subtypes import is_subtype
1515
from mypy.typeops import get_all_type_vars
@@ -247,10 +247,16 @@ def solve_iteratively(
247247
return solutions
248248

249249

250+
def _join_sorted_key(t: Type) -> int:
251+
t = get_proper_type(t)
252+
if isinstance(t, UnionType):
253+
return -1
254+
return 0
255+
256+
250257
def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None:
251258
"""Solve constraints by finding by using meets of upper bounds, and joins of lower bounds."""
252-
bottom: Type | None = None
253-
top: Type | None = None
259+
254260
candidate: Type | None = None
255261

256262
# Filter out previous results of failed inference, they will only spoil the current pass...
@@ -267,19 +273,26 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None:
267273
candidate.ambiguous = True
268274
return candidate
269275

276+
bottom: Type | None = None
277+
top: Type | None = None
278+
270279
# Process each bound separately, and calculate the lower and upper
271280
# bounds based on constraints. Note that we assume that the constraint
272281
# targets do not have constraint references.
273-
for target in lowers:
274-
if bottom is None:
275-
bottom = target
276-
else:
277-
if type_state.infer_unions:
278-
# This deviates from the general mypy semantics because
279-
# recursive types are union-heavy in 95% of cases.
280-
bottom = UnionType.make_union([bottom, target])
281-
else:
282-
bottom = join_types(bottom, target)
282+
if type_state.infer_unions:
283+
# This deviates from the general mypy semantics because
284+
# recursive types are union-heavy in 95% of cases.
285+
bottom = UnionType.make_union(list(lowers))
286+
else:
287+
# The order of lowers is non-deterministic.
288+
# We attempt to sort lowers because joins are non-associative. For instance:
289+
# join(join(int, str), int | str) == join(object, int | str) == object
290+
# join(int, join(str, int | str)) == join(int, int | str) == int | str
291+
# Note that joins in theory should be commutative, but in practice some bugs mean this is
292+
# also a source of non-deterministic type checking results.
293+
sorted_lowers = sorted(lowers, key=_join_sorted_key)
294+
if sorted_lowers:
295+
bottom = join_type_list(sorted_lowers)
283296

284297
for target in uppers:
285298
if top is None:

test-data/unit/check-generics.test

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3563,3 +3563,43 @@ def foo(x: T):
35633563
reveal_type(C) # N: Revealed type is "Overload(def [T, S] (x: builtins.int, y: S`-1) -> __main__.C[__main__.Int[S`-1]], def [T, S] (x: builtins.str, y: S`-1) -> __main__.C[__main__.Str[S`-1]])"
35643564
reveal_type(C(0, x)) # N: Revealed type is "__main__.C[__main__.Int[T`-1]]"
35653565
reveal_type(C("yes", x)) # N: Revealed type is "__main__.C[__main__.Str[T`-1]]"
3566+
3567+
[case testDeterminismFromJoinOrderingInSolver]
3568+
# Used to fail non-deterministically
3569+
# https://github.com/python/mypy/issues/19121
3570+
from __future__ import annotations
3571+
from typing import Generic, Iterable, Iterator, Self, TypeVar
3572+
3573+
_T1 = TypeVar("_T1")
3574+
_T2 = TypeVar("_T2")
3575+
_T3 = TypeVar("_T3")
3576+
_T_co = TypeVar("_T_co", covariant=True)
3577+
3578+
class Base(Iterable[_T1]):
3579+
def __iter__(self) -> Iterator[_T1]: ...
3580+
class A(Base[_T1]): ...
3581+
class B(Base[_T1]): ...
3582+
class C(Base[_T1]): ...
3583+
class D(Base[_T1]): ...
3584+
class E(Base[_T1]): ...
3585+
3586+
class zip2(Generic[_T_co]):
3587+
def __new__(
3588+
cls,
3589+
iter1: Iterable[_T1],
3590+
iter2: Iterable[_T2],
3591+
iter3: Iterable[_T3],
3592+
) -> zip2[tuple[_T1, _T2, _T3]]: ...
3593+
def __iter__(self) -> Self: ...
3594+
def __next__(self) -> _T_co: ...
3595+
3596+
def draw(
3597+
colors1: A[str] | B[str] | C[int] | D[int | str],
3598+
colors2: A[str] | B[str] | C[int] | D[int | str],
3599+
colors3: A[str] | B[str] | C[int] | D[int | str],
3600+
) -> None:
3601+
for c1, c2, c3 in zip2(colors1, colors2, colors3):
3602+
reveal_type(c1) # N: Revealed type is "Union[builtins.int, builtins.str]"
3603+
reveal_type(c2) # N: Revealed type is "Union[builtins.int, builtins.str]"
3604+
reveal_type(c3) # N: Revealed type is "Union[builtins.int, builtins.str]"
3605+
[builtins fixtures/tuple.pyi]

test-data/unit/check-recursive-types.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ reveal_type(flatten([1, [2, [3]]])) # N: Revealed type is "builtins.list[builti
5454

5555
class Bad: ...
5656
x: Nested[int] = [1, [2, [3]]]
57-
x = [1, [Bad()]] # E: List item 0 has incompatible type "Bad"; expected "Union[int, Nested[int]]"
57+
x = [1, [Bad()]] # E: List item 1 has incompatible type "List[Bad]"; expected "Union[int, Nested[int]]"
5858
[builtins fixtures/isinstancelist.pyi]
5959

6060
[case testRecursiveAliasGenericInferenceNested]
@@ -605,7 +605,7 @@ class NT(NamedTuple, Generic[T]):
605605
class A: ...
606606
class B(A): ...
607607

608-
nti: NT[int] = NT(key=0, value=NT(key=1, value=A())) # E: Argument "value" to "NT" has incompatible type "A"; expected "Union[int, NT[int]]"
608+
nti: NT[int] = NT(key=0, value=NT(key=1, value=A())) # E: Argument "value" to "NT" has incompatible type "NT[A]"; expected "Union[int, NT[int]]"
609609
reveal_type(nti) # N: Revealed type is "Tuple[builtins.int, Union[builtins.int, ...], fallback=__main__.NT[builtins.int]]"
610610

611611
nta: NT[A]

0 commit comments

Comments
 (0)