Skip to content

Commit aa7733a

Browse files
authored
Don't use equality to narrow when value is IntEnum/StrEnum (#17866)
IntEnum/StrEnum values compare equal to the corresponding int/str values, which breaks the logic we use for narrowing based on equality to a literal value. Special case IntEnum/StrEnum to avoid the incorrect behavior. Fix #17860.
1 parent 329e38e commit aa7733a

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

mypy/typeops.py

+8
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,14 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool
10221022
"""
10231023
typ = get_proper_type(typ)
10241024
if isinstance(typ, Instance):
1025+
if (
1026+
typ.type.is_enum
1027+
and name in ("__eq__", "__ne__")
1028+
and any(base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in typ.type.mro)
1029+
):
1030+
# IntEnum and StrEnum values have non-straightfoward equality, so treat them
1031+
# as if they had custom __eq__ and __ne__
1032+
return True
10251033
method = typ.type.get(name)
10261034
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
10271035
if method.node.info:

test-data/unit/check-narrowing.test

+76
Original file line numberDiff line numberDiff line change
@@ -2130,3 +2130,79 @@ else:
21302130

21312131
[typing fixtures/typing-medium.pyi]
21322132
[builtins fixtures/ops.pyi]
2133+
2134+
[case testNarrowingWithIntEnum]
2135+
# mypy: strict-equality
2136+
from __future__ import annotations
2137+
from typing import Any
2138+
from enum import IntEnum, StrEnum
2139+
2140+
class IE(IntEnum):
2141+
X = 1
2142+
Y = 2
2143+
2144+
def f1(x: int) -> None:
2145+
if x == IE.X:
2146+
reveal_type(x) # N: Revealed type is "builtins.int"
2147+
else:
2148+
reveal_type(x) # N: Revealed type is "builtins.int"
2149+
if x != IE.X:
2150+
reveal_type(x) # N: Revealed type is "builtins.int"
2151+
else:
2152+
reveal_type(x) # N: Revealed type is "builtins.int"
2153+
2154+
def f2(x: IE) -> None:
2155+
if x == 1:
2156+
reveal_type(x) # N: Revealed type is "__main__.IE"
2157+
else:
2158+
reveal_type(x) # N: Revealed type is "__main__.IE"
2159+
2160+
def f3(x: object) -> None:
2161+
if x == IE.X:
2162+
reveal_type(x) # N: Revealed type is "builtins.object"
2163+
else:
2164+
reveal_type(x) # N: Revealed type is "builtins.object"
2165+
2166+
def f4(x: int | Any) -> None:
2167+
if x == IE.X:
2168+
reveal_type(x) # N: Revealed type is "Union[builtins.int, Any]"
2169+
else:
2170+
reveal_type(x) # N: Revealed type is "Union[builtins.int, Any]"
2171+
2172+
def f5(x: int) -> None:
2173+
if x is IE.X:
2174+
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"
2175+
else:
2176+
reveal_type(x) # N: Revealed type is "builtins.int"
2177+
if x is not IE.X:
2178+
reveal_type(x) # N: Revealed type is "builtins.int"
2179+
else:
2180+
reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"
2181+
[builtins fixtures/primitives.pyi]
2182+
2183+
[case testNarrowingWithStrEnum]
2184+
# mypy: strict-equality
2185+
from enum import StrEnum
2186+
2187+
class SE(StrEnum):
2188+
A = 'a'
2189+
B = 'b'
2190+
2191+
def f1(x: str) -> None:
2192+
if x == SE.A:
2193+
reveal_type(x) # N: Revealed type is "builtins.str"
2194+
else:
2195+
reveal_type(x) # N: Revealed type is "builtins.str"
2196+
2197+
def f2(x: SE) -> None:
2198+
if x == 'a':
2199+
reveal_type(x) # N: Revealed type is "__main__.SE"
2200+
else:
2201+
reveal_type(x) # N: Revealed type is "__main__.SE"
2202+
2203+
def f3(x: object) -> None:
2204+
if x == SE.A:
2205+
reveal_type(x) # N: Revealed type is "builtins.object"
2206+
else:
2207+
reveal_type(x) # N: Revealed type is "builtins.object"
2208+
[builtins fixtures/primitives.pyi]

0 commit comments

Comments
 (0)