Skip to content

Commit a4fc868

Browse files
committed
Better handling of generic aliases
1 parent 8e28720 commit a4fc868

File tree

2 files changed

+118
-49
lines changed

2 files changed

+118
-49
lines changed

astroid/brain/brain_typing.py

Lines changed: 90 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,51 @@ class {0}(metaclass=Meta):
4343
"""
4444
TYPING_MEMBERS = set(typing.__all__)
4545

46+
TYPING_ALIAS = frozenset(
47+
(
48+
"typing.Hashable",
49+
"typing.Awaitable",
50+
"typing.Coroutine",
51+
"typing.AsyncIterable",
52+
"typing.AsyncIterator",
53+
"typing.Iterable",
54+
"typing.Iterator",
55+
"typing.Reversible",
56+
"typing.Sized",
57+
"typing.Container",
58+
"typing.Collection",
59+
"typing.Callable",
60+
"typing.AbstractSet",
61+
"typing.MutableSet",
62+
"typing.Mapping",
63+
"typing.MutableMapping",
64+
"typing.Sequence",
65+
"typing.MutableSequence",
66+
"typing.ByteString",
67+
"typing.Tuple",
68+
"typing.List",
69+
"typing.Deque",
70+
"typing.Set",
71+
"typing.FrozenSet",
72+
"typing.MappingView",
73+
"typing.KeysView",
74+
"typing.ItemsView",
75+
"typing.ValuesView",
76+
"typing.ContextManager",
77+
"typing.AsyncContextManager",
78+
"typing.Dict",
79+
"typing.DefaultDict",
80+
"typing.OrderedDict",
81+
"typing.Counter",
82+
"typing.ChainMap",
83+
"typing.Generator",
84+
"typing.AsyncGenerator",
85+
"typing.Type",
86+
"typing.Pattern",
87+
"typing.Match",
88+
)
89+
)
90+
4691

4792
def looks_like_typing_typevar_or_newtype(node):
4893
func = node.func
@@ -88,7 +133,13 @@ def infer_typing_attr(node, context=None):
88133
except InferenceError as exc:
89134
raise UseInferenceDefault from exc
90135

91-
if not value.qname().startswith("typing."):
136+
if (
137+
not value.qname().startswith("typing.")
138+
or PY37
139+
and value.qname() in TYPING_ALIAS
140+
):
141+
# If typing subscript belongs to an alias
142+
# (PY37+) handle it separately later.
92143
raise UseInferenceDefault
93144

94145
node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
@@ -161,8 +212,6 @@ def full_raiser(origin_func, attr, *args, **kwargs):
161212
else:
162213
return origin_func(attr, *args, **kwargs)
163214

164-
if not isinstance(node, nodes.ClassDef):
165-
raise TypeError("The parameter type should be ClassDef")
166215
try:
167216
node.getattr("__class_getitem__")
168217
# If we are here, then we are sure to modify object that do have __class_getitem__ method (which origin is one the
@@ -179,52 +228,51 @@ def infer_typing_alias(
179228
) -> typing.Optional[node_classes.NodeNG]:
180229
"""
181230
Infers the call to _alias function
231+
Insert ClassDef with same name as aliased class
232+
in mro to simulate _GenericAlias.
182233
183234
:param node: call node
184235
:param context: inference context
185236
"""
237+
if (
238+
not isinstance(node.parent, nodes.Assign)
239+
or not len(node.parent.targets) == 1
240+
or not isinstance(node.parent.targets[0], nodes.AssignName)
241+
):
242+
return None
186243
res = next(node.args[0].infer(context=ctx))
244+
assign_name = node.parent.targets[0]
187245

246+
class_def = nodes.ClassDef(
247+
name=assign_name.name,
248+
lineno=assign_name.lineno,
249+
col_offset=assign_name.col_offset,
250+
parent=node.parent,
251+
)
188252
if res != astroid.Uninferable and isinstance(res, nodes.ClassDef):
189-
if not PY39:
190-
# Here the node is a typing object which is an alias toward
191-
# the corresponding object of collection.abc module.
192-
# Before python3.9 there is no subscript allowed for any of the collections.abc objects.
193-
# The subscript ability is given through the typing._GenericAlias class
194-
# which is the metaclass of the typing object but not the metaclass of the inferred
195-
# collections.abc object.
196-
# Thus we fake subscript ability of the collections.abc object
197-
# by mocking the existence of a __class_getitem__ method.
198-
# We can not add `__getitem__` method in the metaclass of the object because
199-
# the metaclass is shared by subscriptable and not subscriptable object
200-
maybe_type_var = node.args[1]
201-
if not (
202-
isinstance(maybe_type_var, node_classes.Tuple)
203-
and not maybe_type_var.elts
204-
):
205-
# The typing object is subscriptable if the second argument of the _alias function
206-
# is a TypeVar or a tuple of TypeVar. We could check the type of the second argument but
207-
# it appears that in the typing module the second argument is only TypeVar or a tuple of TypeVar or empty tuple.
208-
# This last value means the type is not Generic and thus cannot be subscriptable
209-
func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE)
210-
res.locals["__class_getitem__"] = [func_to_add]
211-
else:
212-
# If we are here, then we are sure to modify object that do have __class_getitem__ method (which origin is one the
213-
# protocol defined in collections module) whereas the typing module consider it should not
214-
# We do not want __class_getitem__ to be found in the classdef
215-
_forbid_class_getitem_access(res)
216-
else:
217-
# Within python3.9 discrepencies exist between some collections.abc containers that are subscriptable whereas
218-
# corresponding containers in the typing module are not! This is the case at least for ByteString.
219-
# It is far more to complex and dangerous to try to remove __class_getitem__ method from all the ancestors of the
220-
# current class. Instead we raise an AttributeInferenceError if we try to access it.
221-
maybe_type_var = node.args[1]
222-
if isinstance(maybe_type_var, nodes.Const) and maybe_type_var.value == 0:
223-
# Starting with Python39 the _alias function is in fact instantiation of _SpecialGenericAlias class.
224-
# Thus the type is not Generic if the second argument of the call is equal to zero
225-
_forbid_class_getitem_access(res)
226-
return iter([res])
227-
return iter([astroid.Uninferable])
253+
# Only add `res` as base if it's a `ClassDef`
254+
# This isn't the case for `typing.Pattern` and `typing.Match`
255+
class_def.postinit(bases=[res], body=[], decorators=None)
256+
257+
maybe_type_var = node.args[1]
258+
if (
259+
not PY39
260+
and not (
261+
isinstance(maybe_type_var, node_classes.Tuple) and not maybe_type_var.elts
262+
)
263+
or PY39
264+
and isinstance(maybe_type_var, nodes.Const)
265+
and maybe_type_var.value > 0
266+
):
267+
# If typing alias is subscriptable, add `__class_getitem__` to ClassDef
268+
func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE)
269+
class_def.locals["__class_getitem__"] = [func_to_add]
270+
else:
271+
# If not, make sure that `__class_getitem__` access is forbidden.
272+
# This is an issue in cases where the aliased class implements it,
273+
# but the typing alias doesn't. E.g. `typing.ByteString` for PY39+
274+
_forbid_class_getitem_access(class_def)
275+
return iter([class_def])
228276

229277

230278
MANAGER.register_transform(

tests/unittest_brain.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,7 +1167,7 @@ class Derived(collections.abc.Iterator[int]):
11671167
],
11681168
)
11691169

1170-
@test_utils.require_version(maxver="3.8")
1170+
@test_utils.require_version(maxver="3.9")
11711171
def test_collections_object_not_yet_subscriptable_2(self):
11721172
"""Before python39 Iterator in the collection.abc module is not subscriptable"""
11731173
node = builder.extract_node(
@@ -1194,6 +1194,28 @@ def test_collections_object_subscriptable_3(self):
11941194
inferred.getattr("__class_getitem__")[0], nodes.FunctionDef
11951195
)
11961196

1197+
@test_utils.require_version(minver="3.9")
1198+
def test_collections_object_subscriptable_4(self):
1199+
"""Multiple inheritance with subscriptable collection class"""
1200+
node = builder.extract_node(
1201+
"""
1202+
import collections.abc
1203+
class Derived(collections.abc.Hashable, collections.abc.Iterator[int]):
1204+
pass
1205+
"""
1206+
)
1207+
inferred = next(node.infer())
1208+
assertEqualMro(
1209+
inferred,
1210+
[
1211+
"Derived",
1212+
"Hashable",
1213+
"Iterator",
1214+
"Iterable",
1215+
"object",
1216+
],
1217+
)
1218+
11971219

11981220
@test_utils.require_version("3.6")
11991221
class TypingBrain(unittest.TestCase):
@@ -1398,12 +1420,12 @@ class Derived1(MutableSet[T]):
13981420
"""
13991421
)
14001422
inferred = next(node.infer())
1401-
check_metaclass_is_abc(inferred)
14021423
assertEqualMro(
14031424
inferred,
14041425
[
14051426
"Derived1",
14061427
"MutableSet",
1428+
"MutableSet",
14071429
"Set",
14081430
"Collection",
14091431
"Sized",
@@ -1429,19 +1451,18 @@ class Derived2(typing.OrderedDict[int, str]):
14291451
"""
14301452
)
14311453
inferred = next(node.infer())
1432-
# OrderedDict has no metaclass because it
1433-
# inherits from dict which is C coded
1434-
self.assertIsNone(inferred.metaclass())
14351454
assertEqualMro(
14361455
inferred,
14371456
[
14381457
"Derived2",
14391458
"OrderedDict",
1459+
"OrderedDict",
14401460
"dict",
14411461
"object",
14421462
],
14431463
)
14441464

1465+
@test_utils.require_version(minver="3.7")
14451466
def test_typing_object_not_subscriptable(self):
14461467
"""Hashable is not subscriptable"""
14471468
wrong_node = builder.extract_node(
@@ -1459,10 +1480,10 @@ def test_typing_object_not_subscriptable(self):
14591480
"""
14601481
)
14611482
inferred = next(right_node.infer())
1462-
check_metaclass_is_abc(inferred)
14631483
assertEqualMro(
14641484
inferred,
14651485
[
1486+
"Hashable",
14661487
"Hashable",
14671488
"object",
14681489
],
@@ -1480,10 +1501,10 @@ def test_typing_object_subscriptable(self):
14801501
"""
14811502
)
14821503
inferred = next(right_node.infer())
1483-
check_metaclass_is_abc(inferred)
14841504
assertEqualMro(
14851505
inferred,
14861506
[
1507+
"MutableSet",
14871508
"MutableSet",
14881509
"Set",
14891510
"Collection",

0 commit comments

Comments
 (0)