diff --git a/mypy/join.py b/mypy/join.py index fcfc6cbaa0e7..65cc3bef66a4 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -8,7 +8,7 @@ import mypy.typeops from mypy.expandtype import expand_type from mypy.maptype import map_instance_to_supertype -from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY +from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY, TypeInfo from mypy.state import state from mypy.subtypes import ( SubtypeContext, @@ -168,9 +168,20 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType: # Compute the "best" supertype of t when joined with s. # The definition of "best" may evolve; for now it is the one with # the longest MRO. Ties are broken by using the earlier base. - best: ProperType | None = None + + # Go over both sets of bases in case there's an explicit Protocol base. This is important + # to ensure commutativity of join (although in cases where both classes have relevant + # Protocol bases this maybe might still not be commutative) + base_types: dict[TypeInfo, None] = {} # dict to deduplicate but preserve order for base in t.type.bases: - mapped = map_instance_to_supertype(t, base.type) + base_types[base.type] = None + for base in s.type.bases: + if base.type.is_protocol and is_subtype(t, base): + base_types[base.type] = None + + best: ProperType | None = None + for base_type in base_types: + mapped = map_instance_to_supertype(t, base_type) res = self.join_instances(mapped, s) if best is None or is_better(res, best): best = res @@ -662,6 +673,10 @@ def is_better(t: Type, s: Type) -> bool: if isinstance(t, Instance): if not isinstance(s, Instance): return True + if t.type.is_protocol != s.type.is_protocol: + if t.type.fullname != "builtins.object" and s.type.fullname != "builtins.object": + # mro of protocol is not really relevant + return not t.type.is_protocol # Use len(mro) as a proxy for the better choice. if len(t.type.mro) > len(s.type.mro): return True diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 381f73ed9862..4cf24ef9cb6c 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3888,6 +3888,53 @@ def a4(x: List[str], y: List[Never]) -> None: z1[1].append("asdf") # E: "object" has no attribute "append" [builtins fixtures/dict.pyi] + +[case testDeterminismCommutativityWithJoinInvolvingProtocolBaseAndPromotableType] +# flags: --python-version 3.11 +# Regression test for https://github.com/python/mypy/issues/16979#issuecomment-1982246306 +from __future__ import annotations + +from typing import Any, Generic, Protocol, TypeVar, overload, cast +from typing_extensions import Never + +T = TypeVar("T") +U = TypeVar("U") + +class _SupportsCompare(Protocol): + def __lt__(self, other: Any, /) -> bool: + return True + +class Comparable(_SupportsCompare): + pass + +comparable: Comparable = Comparable() + +from typing import _promote + +class floatlike: + def __lt__(self, other: floatlike, /) -> bool: ... + +@_promote(floatlike) +class intlike: + def __lt__(self, other: intlike, /) -> bool: ... + + +class A(Generic[T, U]): + @overload + def __init__(self: A[T, T], a: T, b: T, /) -> None: ... # type: ignore[overload-overlap] + @overload + def __init__(self: A[T, U], a: T, b: U, /) -> Never: ... + def __init__(self, *a) -> None: ... + +def join(a: T, b: T) -> T: ... + +reveal_type(join(intlike(), comparable)) # N: Revealed type is "__main__._SupportsCompare" +reveal_type(join(comparable, intlike())) # N: Revealed type is "__main__._SupportsCompare" +reveal_type(A(intlike(), comparable)) # N: Revealed type is "__main__.A[__main__._SupportsCompare, __main__._SupportsCompare]" +reveal_type(A(comparable, intlike())) # N: Revealed type is "__main__.A[__main__._SupportsCompare, __main__._SupportsCompare]" +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-medium.pyi] + [case testTupleJoinFallbackInference] foo = [ (1, ("a", "b")), diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 5e34d5223907..934f48a5e9c3 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -4461,6 +4461,30 @@ f2(a4) # E: Argument 1 to "f2" has incompatible type "A4"; expected "P2" \ # N: foo: expected setter type "C1", got "str" [builtins fixtures/property.pyi] + +[case testExplicitProtocolJoinPreference] +from typing import Protocol, TypeVar + +T = TypeVar("T") + +class Proto1(Protocol): + def foo(self) -> int: ... +class Proto2(Proto1): + def bar(self) -> str: ... +class Proto3(Proto2): + def baz(self) -> str: ... + +class Base: ... + +class A(Base, Proto3): ... +class B(Base, Proto3): ... + +def join(a: T, b: T) -> T: ... + +def main(a: A, b: B) -> None: + reveal_type(join(a, b)) # N: Revealed type is "__main__.Proto3" + reveal_type(join(b, a)) # N: Revealed type is "__main__.Proto3" + [case testProtocolImplementationWithDescriptors] from typing import Any, Protocol