Skip to content

Commit 661adb7

Browse files
ilevkivskyiAlexWaygood
authored andcommitted
Fix crash on strict-equality with recursive types (#16483)
Fixes #16473 Potentially we can turn this helper function into a proper visitor, but I don't think it is worth it as of right now. --------- Co-authored-by: Alex Waygood <[email protected]>
1 parent 6c8e0cc commit 661adb7

File tree

4 files changed

+60
-6
lines changed

4 files changed

+60
-6
lines changed

mypy/checkexpr.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -3617,8 +3617,9 @@ def dangerous_comparison(
36173617
self,
36183618
left: Type,
36193619
right: Type,
3620-
original_container: Type | None = None,
36213620
*,
3621+
original_container: Type | None = None,
3622+
seen_types: set[tuple[Type, Type]] | None = None,
36223623
prefer_literal: bool = True,
36233624
) -> bool:
36243625
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.
@@ -3639,6 +3640,12 @@ def dangerous_comparison(
36393640
if not self.chk.options.strict_equality:
36403641
return False
36413642

3643+
if seen_types is None:
3644+
seen_types = set()
3645+
if (left, right) in seen_types:
3646+
return False
3647+
seen_types.add((left, right))
3648+
36423649
left, right = get_proper_types((left, right))
36433650

36443651
# We suppress the error if there is a custom __eq__() method on either
@@ -3694,17 +3701,21 @@ def dangerous_comparison(
36943701
abstract_set = self.chk.lookup_typeinfo("typing.AbstractSet")
36953702
left = map_instance_to_supertype(left, abstract_set)
36963703
right = map_instance_to_supertype(right, abstract_set)
3697-
return self.dangerous_comparison(left.args[0], right.args[0])
3704+
return self.dangerous_comparison(
3705+
left.args[0], right.args[0], seen_types=seen_types
3706+
)
36983707
elif left.type.has_base("typing.Mapping") and right.type.has_base("typing.Mapping"):
36993708
# Similar to above: Mapping ignores the classes, it just compares items.
37003709
abstract_map = self.chk.lookup_typeinfo("typing.Mapping")
37013710
left = map_instance_to_supertype(left, abstract_map)
37023711
right = map_instance_to_supertype(right, abstract_map)
37033712
return self.dangerous_comparison(
3704-
left.args[0], right.args[0]
3705-
) or self.dangerous_comparison(left.args[1], right.args[1])
3713+
left.args[0], right.args[0], seen_types=seen_types
3714+
) or self.dangerous_comparison(left.args[1], right.args[1], seen_types=seen_types)
37063715
elif left_name in ("builtins.list", "builtins.tuple") and right_name == left_name:
3707-
return self.dangerous_comparison(left.args[0], right.args[0])
3716+
return self.dangerous_comparison(
3717+
left.args[0], right.args[0], seen_types=seen_types
3718+
)
37083719
elif left_name in OVERLAPPING_BYTES_ALLOWLIST and right_name in (
37093720
OVERLAPPING_BYTES_ALLOWLIST
37103721
):

mypy/meet.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def is_overlapping_types(
262262
ignore_promotions: bool = False,
263263
prohibit_none_typevar_overlap: bool = False,
264264
ignore_uninhabited: bool = False,
265+
seen_types: set[tuple[Type, Type]] | None = None,
265266
) -> bool:
266267
"""Can a value of type 'left' also be of type 'right' or vice-versa?
267268
@@ -275,18 +276,27 @@ def is_overlapping_types(
275276
# A type guard forces the new type even if it doesn't overlap the old.
276277
return True
277278

279+
if seen_types is None:
280+
seen_types = set()
281+
if (left, right) in seen_types:
282+
return True
283+
if isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType):
284+
seen_types.add((left, right))
285+
278286
left, right = get_proper_types((left, right))
279287

280288
def _is_overlapping_types(left: Type, right: Type) -> bool:
281289
"""Encode the kind of overlapping check to perform.
282290
283-
This function mostly exists so we don't have to repeat keyword arguments everywhere."""
291+
This function mostly exists, so we don't have to repeat keyword arguments everywhere.
292+
"""
284293
return is_overlapping_types(
285294
left,
286295
right,
287296
ignore_promotions=ignore_promotions,
288297
prohibit_none_typevar_overlap=prohibit_none_typevar_overlap,
289298
ignore_uninhabited=ignore_uninhabited,
299+
seen_types=seen_types.copy(),
290300
)
291301

292302
# We should never encounter this type.

test-data/unit/check-expressions.test

+32
Original file line numberDiff line numberDiff line change
@@ -2378,6 +2378,38 @@ assert a == b
23782378
[builtins fixtures/dict.pyi]
23792379
[typing fixtures/typing-full.pyi]
23802380

2381+
[case testStrictEqualityWithRecursiveMapTypes]
2382+
# flags: --strict-equality
2383+
from typing import Dict
2384+
2385+
R = Dict[str, R]
2386+
2387+
a: R
2388+
b: R
2389+
assert a == b
2390+
2391+
R2 = Dict[int, R2]
2392+
c: R2
2393+
assert a == c # E: Non-overlapping equality check (left operand type: "Dict[str, R]", right operand type: "Dict[int, R2]")
2394+
[builtins fixtures/dict.pyi]
2395+
[typing fixtures/typing-full.pyi]
2396+
2397+
[case testStrictEqualityWithRecursiveListTypes]
2398+
# flags: --strict-equality
2399+
from typing import List, Union
2400+
2401+
R = List[Union[str, R]]
2402+
2403+
a: R
2404+
b: R
2405+
assert a == b
2406+
2407+
R2 = List[Union[int, R2]]
2408+
c: R2
2409+
assert a == c
2410+
[builtins fixtures/list.pyi]
2411+
[typing fixtures/typing-full.pyi]
2412+
23812413
[case testUnimportedHintAny]
23822414
def f(x: Any) -> None: # E: Name "Any" is not defined \
23832415
# N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")

test-data/unit/fixtures/list.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ T = TypeVar('T')
66

77
class object:
88
def __init__(self) -> None: pass
9+
def __eq__(self, other: object) -> bool: pass
910

1011
class type: pass
1112
class ellipsis: pass

0 commit comments

Comments
 (0)