Skip to content

Commit 48835a3

Browse files
authored
Fix stubtest mypy enum.Flag edge case (#15933)
Fix edge-case stubtest crashes when an instance of an enum.Flag that is not a member of that enum.Flag is used as a parameter default Fixes #15923. Note: the test cases I've added reproduce the crash, but only if you're using a compiled version of mypy. (Some of them only repro the crash on <=py310, but some repro it on py311+ as well.) We run stubtest tests in CI with compiled mypy, so they do repro the crash in the context of our CI.
1 parent 7141d6b commit 48835a3

File tree

2 files changed

+101
-4
lines changed

2 files changed

+101
-4
lines changed

mypy/stubtest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1553,7 +1553,7 @@ def anytype() -> mypy.types.AnyType:
15531553
value: bool | int | str
15541554
if isinstance(runtime, bytes):
15551555
value = bytes_to_human_readable_repr(runtime)
1556-
elif isinstance(runtime, enum.Enum):
1556+
elif isinstance(runtime, enum.Enum) and isinstance(runtime.name, str):
15571557
value = runtime.name
15581558
elif isinstance(runtime, (bool, int, str)):
15591559
value = runtime

mypy/test/teststubtest.py

+100-3
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(self, name: str) -> None: ...
6464
6565
class Coroutine(Generic[_T_co, _S, _R]): ...
6666
class Iterable(Generic[_T_co]): ...
67+
class Iterator(Iterable[_T_co]): ...
6768
class Mapping(Generic[_K, _V]): ...
6869
class Match(Generic[AnyStr]): ...
6970
class Sequence(Iterable[_T_co]): ...
@@ -86,7 +87,9 @@ def __init__(self) -> None: pass
8687
def __repr__(self) -> str: pass
8788
class type: ...
8889
89-
class tuple(Sequence[T_co], Generic[T_co]): ...
90+
class tuple(Sequence[T_co], Generic[T_co]):
91+
def __ge__(self, __other: tuple[T_co, ...]) -> bool: pass
92+
9093
class dict(Mapping[KT, VT]): ...
9194
9295
class function: pass
@@ -105,6 +108,39 @@ def classmethod(f: T) -> T: ...
105108
def staticmethod(f: T) -> T: ...
106109
"""
107110

111+
stubtest_enum_stub = """
112+
import sys
113+
from typing import Any, TypeVar, Iterator
114+
115+
_T = TypeVar('_T')
116+
117+
class EnumMeta(type):
118+
def __len__(self) -> int: pass
119+
def __iter__(self: type[_T]) -> Iterator[_T]: pass
120+
def __reversed__(self: type[_T]) -> Iterator[_T]: pass
121+
def __getitem__(self: type[_T], name: str) -> _T: pass
122+
123+
class Enum(metaclass=EnumMeta):
124+
def __new__(cls: type[_T], value: object) -> _T: pass
125+
def __repr__(self) -> str: pass
126+
def __str__(self) -> str: pass
127+
def __format__(self, format_spec: str) -> str: pass
128+
def __hash__(self) -> Any: pass
129+
def __reduce_ex__(self, proto: Any) -> Any: pass
130+
name: str
131+
value: Any
132+
133+
class Flag(Enum):
134+
def __or__(self: _T, other: _T) -> _T: pass
135+
def __and__(self: _T, other: _T) -> _T: pass
136+
def __xor__(self: _T, other: _T) -> _T: pass
137+
def __invert__(self: _T) -> _T: pass
138+
if sys.version_info >= (3, 11):
139+
__ror__ = __or__
140+
__rand__ = __and__
141+
__rxor__ = __xor__
142+
"""
143+
108144

109145
def run_stubtest(
110146
stub: str, runtime: str, options: list[str], config_file: str | None = None
@@ -114,6 +150,8 @@ def run_stubtest(
114150
f.write(stubtest_builtins_stub)
115151
with open("typing.pyi", "w") as f:
116152
f.write(stubtest_typing_stub)
153+
with open("enum.pyi", "w") as f:
154+
f.write(stubtest_enum_stub)
117155
with open(f"{TEST_MODULE_NAME}.pyi", "w") as f:
118156
f.write(stub)
119157
with open(f"{TEST_MODULE_NAME}.py", "w") as f:
@@ -954,23 +992,82 @@ def fizz(self): pass
954992

955993
@collect_cases
956994
def test_enum(self) -> Iterator[Case]:
995+
yield Case(stub="import enum", runtime="import enum", error=None)
957996
yield Case(
958997
stub="""
959-
import enum
960998
class X(enum.Enum):
961999
a: int
9621000
b: str
9631001
c: str
9641002
""",
9651003
runtime="""
966-
import enum
9671004
class X(enum.Enum):
9681005
a = 1
9691006
b = "asdf"
9701007
c = 2
9711008
""",
9721009
error="X.c",
9731010
)
1011+
yield Case(
1012+
stub="""
1013+
class Flags1(enum.Flag):
1014+
a: int
1015+
b: int
1016+
def foo(x: Flags1 = ...) -> None: ...
1017+
""",
1018+
runtime="""
1019+
class Flags1(enum.Flag):
1020+
a = 1
1021+
b = 2
1022+
def foo(x=Flags1.a|Flags1.b): pass
1023+
""",
1024+
error=None,
1025+
)
1026+
yield Case(
1027+
stub="""
1028+
class Flags2(enum.Flag):
1029+
a: int
1030+
b: int
1031+
def bar(x: Flags2 | None = None) -> None: ...
1032+
""",
1033+
runtime="""
1034+
class Flags2(enum.Flag):
1035+
a = 1
1036+
b = 2
1037+
def bar(x=Flags2.a|Flags2.b): pass
1038+
""",
1039+
error="bar",
1040+
)
1041+
yield Case(
1042+
stub="""
1043+
class Flags3(enum.Flag):
1044+
a: int
1045+
b: int
1046+
def baz(x: Flags3 | None = ...) -> None: ...
1047+
""",
1048+
runtime="""
1049+
class Flags3(enum.Flag):
1050+
a = 1
1051+
b = 2
1052+
def baz(x=Flags3(0)): pass
1053+
""",
1054+
error=None,
1055+
)
1056+
yield Case(
1057+
stub="""
1058+
class Flags4(enum.Flag):
1059+
a: int
1060+
b: int
1061+
def spam(x: Flags4 | None = None) -> None: ...
1062+
""",
1063+
runtime="""
1064+
class Flags4(enum.Flag):
1065+
a = 1
1066+
b = 2
1067+
def spam(x=Flags4(0)): pass
1068+
""",
1069+
error="spam",
1070+
)
9741071

9751072
@collect_cases
9761073
def test_decorator(self) -> Iterator[Case]:

0 commit comments

Comments
 (0)