Skip to content

Commit 0aa2caf

Browse files
authored
Fix recently added enum value type prediction (#10057)
In #9443, some code was added to predict the type of enum values where it is not explicitly when all enum members have the same type. However, it didn't consider that subclasses of Enum that have a custom __new__ implementation may use any type for the enum value (typically it would use only one of their parameters instead of a whole tuple that is specified in the definition of the member). Fix this by avoiding to guess the enum value type in classes that implement __new__. In addition, the added code was buggy in that it didn't only consider class attributes as enum members, but also instance attributes assigned to self.* in __init__. Fix this by ignoring implicit nodes when checking the enum members. Fixes #10000. Co-authored-by: Kevin Wolf <[email protected]>
1 parent 98c2c12 commit 0aa2caf

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

mypy/plugins/enums.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import mypy.plugin # To avoid circular imports.
1717
from mypy.types import Type, Instance, LiteralType, CallableType, ProperType, get_proper_type
18+
from mypy.nodes import TypeInfo
1819

1920
# Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use
2021
# enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes.
@@ -103,6 +104,17 @@ def _infer_value_type_with_auto_fallback(
103104
return ctx.default_attr_type
104105

105106

107+
def _implements_new(info: TypeInfo) -> bool:
108+
"""Check whether __new__ comes from enum.Enum or was implemented in a
109+
subclass. In the latter case, we must infer Any as long as mypy can't infer
110+
the type of _value_ from assignments in __new__.
111+
"""
112+
type_with_new = _first(ti for ti in info.mro if ti.names.get('__new__'))
113+
if type_with_new is None:
114+
return False
115+
return type_with_new.fullname != 'enum.Enum'
116+
117+
106118
def enum_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
107119
"""This plugin refines the 'value' attribute in enums to refer to
108120
the original underlying value. For example, suppose we have the
@@ -135,12 +147,22 @@ class SomeEnum:
135147
# The value-type is still known.
136148
if isinstance(ctx.type, Instance):
137149
info = ctx.type.type
150+
151+
# As long as mypy doesn't understand attribute creation in __new__,
152+
# there is no way to predict the value type if the enum class has a
153+
# custom implementation
154+
if _implements_new(info):
155+
return ctx.default_attr_type
156+
138157
stnodes = (info.get(name) for name in info.names)
139-
# Enums _can_ have methods.
140-
# Omit methods for our value inference.
158+
159+
# Enums _can_ have methods and instance attributes.
160+
# Omit methods and attributes created by assigning to self.*
161+
# for our value inference.
141162
node_types = (
142163
get_proper_type(n.type) if n else None
143-
for n in stnodes)
164+
for n in stnodes
165+
if n is None or not n.implicit)
144166
proper_types = (
145167
_infer_value_type_with_auto_fallback(ctx, t)
146168
for t in node_types
@@ -158,6 +180,13 @@ class SomeEnum:
158180

159181
assert isinstance(ctx.type, Instance)
160182
info = ctx.type.type
183+
184+
# As long as mypy doesn't understand attribute creation in __new__,
185+
# there is no way to predict the value type if the enum class has a
186+
# custom implementation
187+
if _implements_new(info):
188+
return ctx.default_attr_type
189+
161190
stnode = info.get(enum_field_name)
162191
if stnode is None:
163192
return ctx.default_attr_type

test-data/unit/check-enum.test

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,3 +1243,59 @@ class Comparator(enum.Enum):
12431243

12441244
reveal_type(Comparator.__foo__) # N: Revealed type is 'builtins.dict[builtins.str, builtins.int]'
12451245
[builtins fixtures/dict.pyi]
1246+
1247+
[case testEnumWithInstanceAttributes]
1248+
from enum import Enum
1249+
class Foo(Enum):
1250+
def __init__(self, value: int) -> None:
1251+
self.foo = "bar"
1252+
A = 1
1253+
B = 2
1254+
1255+
a = Foo.A
1256+
reveal_type(a.value) # N: Revealed type is 'builtins.int'
1257+
reveal_type(a._value_) # N: Revealed type is 'builtins.int'
1258+
1259+
[case testNewSetsUnexpectedValueType]
1260+
from enum import Enum
1261+
1262+
class bytes:
1263+
def __new__(cls): pass
1264+
1265+
class Foo(bytes, Enum):
1266+
def __new__(cls, value: int) -> 'Foo':
1267+
obj = bytes.__new__(cls)
1268+
obj._value_ = "Number %d" % value
1269+
return obj
1270+
A = 1
1271+
B = 2
1272+
1273+
a = Foo.A
1274+
reveal_type(a.value) # N: Revealed type is 'Any'
1275+
reveal_type(a._value_) # N: Revealed type is 'Any'
1276+
[builtins fixtures/__new__.pyi]
1277+
[builtins fixtures/primitives.pyi]
1278+
[typing fixtures/typing-medium.pyi]
1279+
1280+
[case testValueTypeWithNewInParentClass]
1281+
from enum import Enum
1282+
1283+
class bytes:
1284+
def __new__(cls): pass
1285+
1286+
class Foo(bytes, Enum):
1287+
def __new__(cls, value: int) -> 'Foo':
1288+
obj = bytes.__new__(cls)
1289+
obj._value_ = "Number %d" % value
1290+
return obj
1291+
1292+
class Bar(Foo):
1293+
A = 1
1294+
B = 2
1295+
1296+
a = Bar.A
1297+
reveal_type(a.value) # N: Revealed type is 'Any'
1298+
reveal_type(a._value_) # N: Revealed type is 'Any'
1299+
[builtins fixtures/__new__.pyi]
1300+
[builtins fixtures/primitives.pyi]
1301+
[typing fixtures/typing-medium.pyi]

0 commit comments

Comments
 (0)