@@ -64,6 +64,7 @@ def __init__(self, name: str) -> None: ...
64
64
65
65
class Coroutine(Generic[_T_co, _S, _R]): ...
66
66
class Iterable(Generic[_T_co]): ...
67
+ class Iterator(Iterable[_T_co]): ...
67
68
class Mapping(Generic[_K, _V]): ...
68
69
class Match(Generic[AnyStr]): ...
69
70
class Sequence(Iterable[_T_co]): ...
@@ -86,7 +87,9 @@ def __init__(self) -> None: pass
86
87
def __repr__(self) -> str: pass
87
88
class type: ...
88
89
89
- class tuple(Sequence[T_co], Generic[T_co]): ...
90
+ class tuple(Sequence[T_co], Generic[T_co]):
91
+ def __ge__(self, __other: tuple[T_co, ...]) -> bool: pass
92
+
90
93
class dict(Mapping[KT, VT]): ...
91
94
92
95
class function: pass
@@ -105,6 +108,39 @@ def classmethod(f: T) -> T: ...
105
108
def staticmethod(f: T) -> T: ...
106
109
"""
107
110
111
+ stubtest_enum_stub = """
112
+ import sys
113
+ from typing import Any, TypeVar, Iterator
114
+
115
+ _T = TypeVar('_T')
116
+
117
+ class EnumMeta(type):
118
+ def __len__(self) -> int: pass
119
+ def __iter__(self: type[_T]) -> Iterator[_T]: pass
120
+ def __reversed__(self: type[_T]) -> Iterator[_T]: pass
121
+ def __getitem__(self: type[_T], name: str) -> _T: pass
122
+
123
+ class Enum(metaclass=EnumMeta):
124
+ def __new__(cls: type[_T], value: object) -> _T: pass
125
+ def __repr__(self) -> str: pass
126
+ def __str__(self) -> str: pass
127
+ def __format__(self, format_spec: str) -> str: pass
128
+ def __hash__(self) -> Any: pass
129
+ def __reduce_ex__(self, proto: Any) -> Any: pass
130
+ name: str
131
+ value: Any
132
+
133
+ class Flag(Enum):
134
+ def __or__(self: _T, other: _T) -> _T: pass
135
+ def __and__(self: _T, other: _T) -> _T: pass
136
+ def __xor__(self: _T, other: _T) -> _T: pass
137
+ def __invert__(self: _T) -> _T: pass
138
+ if sys.version_info >= (3, 11):
139
+ __ror__ = __or__
140
+ __rand__ = __and__
141
+ __rxor__ = __xor__
142
+ """
143
+
108
144
109
145
def run_stubtest (
110
146
stub : str , runtime : str , options : list [str ], config_file : str | None = None
@@ -114,6 +150,8 @@ def run_stubtest(
114
150
f .write (stubtest_builtins_stub )
115
151
with open ("typing.pyi" , "w" ) as f :
116
152
f .write (stubtest_typing_stub )
153
+ with open ("enum.pyi" , "w" ) as f :
154
+ f .write (stubtest_enum_stub )
117
155
with open (f"{ TEST_MODULE_NAME } .pyi" , "w" ) as f :
118
156
f .write (stub )
119
157
with open (f"{ TEST_MODULE_NAME } .py" , "w" ) as f :
@@ -954,23 +992,82 @@ def fizz(self): pass
954
992
955
993
@collect_cases
956
994
def test_enum (self ) -> Iterator [Case ]:
995
+ yield Case (stub = "import enum" , runtime = "import enum" , error = None )
957
996
yield Case (
958
997
stub = """
959
- import enum
960
998
class X(enum.Enum):
961
999
a: int
962
1000
b: str
963
1001
c: str
964
1002
""" ,
965
1003
runtime = """
966
- import enum
967
1004
class X(enum.Enum):
968
1005
a = 1
969
1006
b = "asdf"
970
1007
c = 2
971
1008
""" ,
972
1009
error = "X.c" ,
973
1010
)
1011
+ yield Case (
1012
+ stub = """
1013
+ class Flags1(enum.Flag):
1014
+ a: int
1015
+ b: int
1016
+ def foo(x: Flags1 = ...) -> None: ...
1017
+ """ ,
1018
+ runtime = """
1019
+ class Flags1(enum.Flag):
1020
+ a = 1
1021
+ b = 2
1022
+ def foo(x=Flags1.a|Flags1.b): pass
1023
+ """ ,
1024
+ error = None ,
1025
+ )
1026
+ yield Case (
1027
+ stub = """
1028
+ class Flags2(enum.Flag):
1029
+ a: int
1030
+ b: int
1031
+ def bar(x: Flags2 | None = None) -> None: ...
1032
+ """ ,
1033
+ runtime = """
1034
+ class Flags2(enum.Flag):
1035
+ a = 1
1036
+ b = 2
1037
+ def bar(x=Flags2.a|Flags2.b): pass
1038
+ """ ,
1039
+ error = "bar" ,
1040
+ )
1041
+ yield Case (
1042
+ stub = """
1043
+ class Flags3(enum.Flag):
1044
+ a: int
1045
+ b: int
1046
+ def baz(x: Flags3 | None = ...) -> None: ...
1047
+ """ ,
1048
+ runtime = """
1049
+ class Flags3(enum.Flag):
1050
+ a = 1
1051
+ b = 2
1052
+ def baz(x=Flags3(0)): pass
1053
+ """ ,
1054
+ error = None ,
1055
+ )
1056
+ yield Case (
1057
+ stub = """
1058
+ class Flags4(enum.Flag):
1059
+ a: int
1060
+ b: int
1061
+ def spam(x: Flags4 | None = None) -> None: ...
1062
+ """ ,
1063
+ runtime = """
1064
+ class Flags4(enum.Flag):
1065
+ a = 1
1066
+ b = 2
1067
+ def spam(x=Flags4(0)): pass
1068
+ """ ,
1069
+ error = "spam" ,
1070
+ )
974
1071
975
1072
@collect_cases
976
1073
def test_decorator (self ) -> Iterator [Case ]:
0 commit comments