Skip to content

Commit 31a731a

Browse files
authored
Better handling of generic aliases (#923)
* Better handling of generic aliases * Use qname for tests mro * Add inference for re.Pattern and re.Match * Add comments
1 parent 76172d4 commit 31a731a

File tree

4 files changed

+325
-93
lines changed

4 files changed

+325
-93
lines changed

astroid/brain/brain_builtin_inference.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
"""Astroid hooks for various builtins."""
2020

2121
from functools import partial
22-
from textwrap import dedent
2322

2423
from astroid import (
2524
MANAGER,
@@ -153,6 +152,23 @@ def _extend_builtins(class_transforms):
153152

154153

155154
def _builtin_filter_predicate(node, builtin_name):
155+
if (
156+
builtin_name == "type"
157+
and node.root().name == "re"
158+
and isinstance(node.func, nodes.Name)
159+
and node.func.name == "type"
160+
and isinstance(node.parent, nodes.Assign)
161+
and len(node.parent.targets) == 1
162+
and isinstance(node.parent.targets[0], nodes.AssignName)
163+
and node.parent.targets[0].name in ("Pattern", "Match")
164+
):
165+
# Handle re.Pattern and re.Match in brain_re
166+
# Match these patterns from stdlib/re.py
167+
# ```py
168+
# Pattern = type(...)
169+
# Match = type(...)
170+
# ```
171+
return False
156172
if isinstance(node.func, nodes.Name) and node.func.name == builtin_name:
157173
return True
158174
if isinstance(node.func, nodes.Attribute):

astroid/brain/brain_re.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
# For details: https://github.com/PyCQA/astroid/blob/master/COPYING.LESSER
33
import sys
44
import astroid
5+
from astroid import MANAGER, inference_tip, nodes, context
56

67
PY36 = sys.version_info >= (3, 6)
8+
PY37 = sys.version_info[:2] >= (3, 7)
9+
PY39 = sys.version_info[:2] >= (3, 9)
710

811
if PY36:
912
# Since Python 3.6 there is the RegexFlag enum
@@ -34,3 +37,50 @@ def _re_transform():
3437
)
3538

3639
astroid.register_module_extender(astroid.MANAGER, "re", _re_transform)
40+
41+
42+
CLASS_GETITEM_TEMPLATE = """
43+
@classmethod
44+
def __class_getitem__(cls, item):
45+
return cls
46+
"""
47+
48+
49+
def _looks_like_pattern_or_match(node: nodes.Call) -> bool:
50+
"""Check for re.Pattern or re.Match call in stdlib.
51+
52+
Match these patterns from stdlib/re.py
53+
```py
54+
Pattern = type(...)
55+
Match = type(...)
56+
```
57+
"""
58+
return (
59+
node.root().name == "re"
60+
and isinstance(node.func, nodes.Name)
61+
and node.func.name == "type"
62+
and isinstance(node.parent, nodes.Assign)
63+
and len(node.parent.targets) == 1
64+
and isinstance(node.parent.targets[0], nodes.AssignName)
65+
and node.parent.targets[0].name in ("Pattern", "Match")
66+
)
67+
68+
69+
def infer_pattern_match(node: nodes.Call, ctx: context.InferenceContext = None):
70+
"""Infer re.Pattern and re.Match as classes. For PY39+ add `__class_getitem__`."""
71+
class_def = nodes.ClassDef(
72+
name=node.parent.targets[0].name,
73+
lineno=node.lineno,
74+
col_offset=node.col_offset,
75+
parent=node.parent,
76+
)
77+
if PY39:
78+
func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE)
79+
class_def.locals["__class_getitem__"] = [func_to_add]
80+
return iter([class_def])
81+
82+
83+
if PY37:
84+
MANAGER.register_transform(
85+
nodes.Call, inference_tip(infer_pattern_match), _looks_like_pattern_or_match
86+
)

astroid/brain/brain_typing.py

Lines changed: 91 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,51 @@ class {0}(metaclass=Meta):
4343
"""
4444
TYPING_MEMBERS = set(typing.__all__)
4545

46+
TYPING_ALIAS = frozenset(
47+
(
48+
"typing.Hashable",
49+
"typing.Awaitable",
50+
"typing.Coroutine",
51+
"typing.AsyncIterable",
52+
"typing.AsyncIterator",
53+
"typing.Iterable",
54+
"typing.Iterator",
55+
"typing.Reversible",
56+
"typing.Sized",
57+
"typing.Container",
58+
"typing.Collection",
59+
"typing.Callable",
60+
"typing.AbstractSet",
61+
"typing.MutableSet",
62+
"typing.Mapping",
63+
"typing.MutableMapping",
64+
"typing.Sequence",
65+
"typing.MutableSequence",
66+
"typing.ByteString",
67+
"typing.Tuple",
68+
"typing.List",
69+
"typing.Deque",
70+
"typing.Set",
71+
"typing.FrozenSet",
72+
"typing.MappingView",
73+
"typing.KeysView",
74+
"typing.ItemsView",
75+
"typing.ValuesView",
76+
"typing.ContextManager",
77+
"typing.AsyncContextManager",
78+
"typing.Dict",
79+
"typing.DefaultDict",
80+
"typing.OrderedDict",
81+
"typing.Counter",
82+
"typing.ChainMap",
83+
"typing.Generator",
84+
"typing.AsyncGenerator",
85+
"typing.Type",
86+
"typing.Pattern",
87+
"typing.Match",
88+
)
89+
)
90+
4691

4792
def looks_like_typing_typevar_or_newtype(node):
4893
func = node.func
@@ -88,7 +133,13 @@ def infer_typing_attr(node, context=None):
88133
except InferenceError as exc:
89134
raise UseInferenceDefault from exc
90135

91-
if not value.qname().startswith("typing."):
136+
if (
137+
not value.qname().startswith("typing.")
138+
or PY37
139+
and value.qname() in TYPING_ALIAS
140+
):
141+
# If typing subscript belongs to an alias
142+
# (PY37+) handle it separately.
92143
raise UseInferenceDefault
93144

94145
node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
@@ -159,8 +210,6 @@ def full_raiser(origin_func, attr, *args, **kwargs):
159210
else:
160211
return origin_func(attr, *args, **kwargs)
161212

162-
if not isinstance(node, nodes.ClassDef):
163-
raise TypeError("The parameter type should be ClassDef")
164213
try:
165214
node.getattr("__class_getitem__")
166215
# If we are here, then we are sure to modify object that do have __class_getitem__ method (which origin is one the
@@ -174,55 +223,54 @@ def full_raiser(origin_func, attr, *args, **kwargs):
174223

175224
def infer_typing_alias(
176225
node: nodes.Call, ctx: context.InferenceContext = None
177-
) -> typing.Optional[node_classes.NodeNG]:
226+
) -> typing.Iterator[nodes.ClassDef]:
178227
"""
179228
Infers the call to _alias function
229+
Insert ClassDef, with same name as aliased class,
230+
in mro to simulate _GenericAlias.
180231
181232
:param node: call node
182233
:param context: inference context
183234
"""
235+
if (
236+
not isinstance(node.parent, nodes.Assign)
237+
or not len(node.parent.targets) == 1
238+
or not isinstance(node.parent.targets[0], nodes.AssignName)
239+
):
240+
return None
184241
res = next(node.args[0].infer(context=ctx))
242+
assign_name = node.parent.targets[0]
185243

244+
class_def = nodes.ClassDef(
245+
name=assign_name.name,
246+
lineno=assign_name.lineno,
247+
col_offset=assign_name.col_offset,
248+
parent=node.parent,
249+
)
186250
if res != astroid.Uninferable and isinstance(res, nodes.ClassDef):
187-
if not PY39:
188-
# Here the node is a typing object which is an alias toward
189-
# the corresponding object of collection.abc module.
190-
# Before python3.9 there is no subscript allowed for any of the collections.abc objects.
191-
# The subscript ability is given through the typing._GenericAlias class
192-
# which is the metaclass of the typing object but not the metaclass of the inferred
193-
# collections.abc object.
194-
# Thus we fake subscript ability of the collections.abc object
195-
# by mocking the existence of a __class_getitem__ method.
196-
# We can not add `__getitem__` method in the metaclass of the object because
197-
# the metaclass is shared by subscriptable and not subscriptable object
198-
maybe_type_var = node.args[1]
199-
if not (
200-
isinstance(maybe_type_var, node_classes.Tuple)
201-
and not maybe_type_var.elts
202-
):
203-
# The typing object is subscriptable if the second argument of the _alias function
204-
# is a TypeVar or a tuple of TypeVar. We could check the type of the second argument but
205-
# it appears that in the typing module the second argument is only TypeVar or a tuple of TypeVar or empty tuple.
206-
# This last value means the type is not Generic and thus cannot be subscriptable
207-
func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE)
208-
res.locals["__class_getitem__"] = [func_to_add]
209-
else:
210-
# If we are here, then we are sure to modify object that do have __class_getitem__ method (which origin is one the
211-
# protocol defined in collections module) whereas the typing module consider it should not
212-
# We do not want __class_getitem__ to be found in the classdef
213-
_forbid_class_getitem_access(res)
214-
else:
215-
# Within python3.9 discrepencies exist between some collections.abc containers that are subscriptable whereas
216-
# corresponding containers in the typing module are not! This is the case at least for ByteString.
217-
# It is far more to complex and dangerous to try to remove __class_getitem__ method from all the ancestors of the
218-
# current class. Instead we raise an AttributeInferenceError if we try to access it.
219-
maybe_type_var = node.args[1]
220-
if isinstance(maybe_type_var, nodes.Const) and maybe_type_var.value == 0:
221-
# Starting with Python39 the _alias function is in fact instantiation of _SpecialGenericAlias class.
222-
# Thus the type is not Generic if the second argument of the call is equal to zero
223-
_forbid_class_getitem_access(res)
224-
return iter([res])
225-
return iter([astroid.Uninferable])
251+
# Only add `res` as base if it's a `ClassDef`
252+
# This isn't the case for `typing.Pattern` and `typing.Match`
253+
class_def.postinit(bases=[res], body=[], decorators=None)
254+
255+
maybe_type_var = node.args[1]
256+
if (
257+
not PY39
258+
and not (
259+
isinstance(maybe_type_var, node_classes.Tuple) and not maybe_type_var.elts
260+
)
261+
or PY39
262+
and isinstance(maybe_type_var, nodes.Const)
263+
and maybe_type_var.value > 0
264+
):
265+
# If typing alias is subscriptable, add `__class_getitem__` to ClassDef
266+
func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE)
267+
class_def.locals["__class_getitem__"] = [func_to_add]
268+
else:
269+
# If not, make sure that `__class_getitem__` access is forbidden.
270+
# This is an issue in cases where the aliased class implements it,
271+
# but the typing alias isn't subscriptable. E.g., `typing.ByteString` for PY39+
272+
_forbid_class_getitem_access(class_def)
273+
return iter([class_def])
226274

227275

228276
MANAGER.register_transform(

0 commit comments

Comments
 (0)