Skip to content

Fix nondeterministic type checking by making join with explicit Protocol and type promotion commute #18402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mypy.typeops
from mypy.expandtype import expand_type
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY, TypeInfo
from mypy.state import state
from mypy.subtypes import (
SubtypeContext,
Expand Down Expand Up @@ -168,9 +168,20 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
# Compute the "best" supertype of t when joined with s.
# The definition of "best" may evolve; for now it is the one with
# the longest MRO. Ties are broken by using the earlier base.
best: ProperType | None = None

# Go over both sets of bases in case there's an explicit Protocol base. This is important
# to ensure commutativity of join (although in cases where both classes have relevant
# Protocol bases this maybe might still not be commutative)
base_types: dict[TypeInfo, None] = {} # dict to deduplicate but preserve order
for base in t.type.bases:
mapped = map_instance_to_supertype(t, base.type)
base_types[base.type] = None
for base in s.type.bases:
if base.type.is_protocol and is_subtype(t, base):
base_types[base.type] = None

best: ProperType | None = None
for base_type in base_types:
mapped = map_instance_to_supertype(t, base_type)
res = self.join_instances(mapped, s)
if best is None or is_better(res, best):
best = res
Expand Down Expand Up @@ -662,6 +673,10 @@ def is_better(t: Type, s: Type) -> bool:
if isinstance(t, Instance):
if not isinstance(s, Instance):
return True
if t.type.is_protocol != s.type.is_protocol:
if t.type.fullname != "builtins.object" and s.type.fullname != "builtins.object":
# mro of protocol is not really relevant
return not t.type.is_protocol
# Use len(mro) as a proxy for the better choice.
if len(t.type.mro) > len(s.type.mro):
return True
Expand Down
47 changes: 47 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -3888,6 +3888,53 @@ def a4(x: List[str], y: List[Never]) -> None:
z1[1].append("asdf") # E: "object" has no attribute "append"
[builtins fixtures/dict.pyi]


[case testDeterminismCommutativityWithJoinInvolvingProtocolBaseAndPromotableType]
# flags: --python-version 3.11
# Regression test for https://github.com/python/mypy/issues/16979#issuecomment-1982246306
from __future__ import annotations

from typing import Any, Generic, Protocol, TypeVar, overload, cast
from typing_extensions import Never

T = TypeVar("T")
U = TypeVar("U")

class _SupportsCompare(Protocol):
def __lt__(self, other: Any, /) -> bool:
return True

class Comparable(_SupportsCompare):
pass

comparable: Comparable = Comparable()

from typing import _promote

class floatlike:
def __lt__(self, other: floatlike, /) -> bool: ...

@_promote(floatlike)
class intlike:
def __lt__(self, other: intlike, /) -> bool: ...


class A(Generic[T, U]):
@overload
def __init__(self: A[T, T], a: T, b: T, /) -> None: ... # type: ignore[overload-overlap]
@overload
def __init__(self: A[T, U], a: T, b: U, /) -> Never: ...
def __init__(self, *a) -> None: ...

def join(a: T, b: T) -> T: ...

reveal_type(join(intlike(), comparable)) # N: Revealed type is "__main__._SupportsCompare"
reveal_type(join(comparable, intlike())) # N: Revealed type is "__main__._SupportsCompare"
reveal_type(A(intlike(), comparable)) # N: Revealed type is "__main__.A[__main__._SupportsCompare, __main__._SupportsCompare]"
reveal_type(A(comparable, intlike())) # N: Revealed type is "__main__.A[__main__._SupportsCompare, __main__._SupportsCompare]"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-medium.pyi]

[case testTupleJoinFallbackInference]
foo = [
(1, ("a", "b")),
Expand Down
24 changes: 24 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -4461,6 +4461,30 @@ f2(a4) # E: Argument 1 to "f2" has incompatible type "A4"; expected "P2" \
# N: foo: expected setter type "C1", got "str"
[builtins fixtures/property.pyi]


[case testExplicitProtocolJoinPreference]
from typing import Protocol, TypeVar

T = TypeVar("T")

class Proto1(Protocol):
def foo(self) -> int: ...
class Proto2(Proto1):
def bar(self) -> str: ...
class Proto3(Proto2):
def baz(self) -> str: ...

class Base: ...

class A(Base, Proto3): ...
class B(Base, Proto3): ...

def join(a: T, b: T) -> T: ...

def main(a: A, b: B) -> None:
reveal_type(join(a, b)) # N: Revealed type is "__main__.Proto3"
reveal_type(join(b, a)) # N: Revealed type is "__main__.Proto3"

[case testProtocolImplementationWithDescriptors]
from typing import Any, Protocol

Expand Down