Skip to content

Commit 1e55ae6

Browse files
authored
Add type annotations to pyreverse (#4551)
* Add type annotations to pyreverse dot files Closes #1548 - Indicate the attribute is optional in the dot files by inspecting the default value. - Handle the Subscript type annotations - Refactor & move the logic to obtain annotations to utils.py as re-usable functions - Add unittests for the added functions - Add a try/except to deal with a possible InferenceError when using NodeNG.infer method - Create a function in utils and so remove repeated logic in inspector.py - Add unittests to check the InferenceError logic - Adjust the types in function input
1 parent 14731ad commit 1e55ae6

File tree

7 files changed

+177
-28
lines changed

7 files changed

+177
-28
lines changed

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ Release date: TBA
99
..
1010
Put new features and bugfixes here and also in 'doc/whatsnew/2.9.rst'
1111

12+
* Add type annotations to pyreverse dot files
13+
14+
Closes #1548
15+
1216
* astroid has been upgraded to 2.6.0
1317

1418
* Fix false positive ``useless-type-doc`` on ignored argument using ``pylint.extensions.docparams``

doc/whatsnew/2.9.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ New checkers
4949
Other Changes
5050
=============
5151

52+
* Add type annotations to pyreverse dot files
53+
5254
* Pylint's tags are now the standard form ``vX.Y.Z`` and not ``pylint-X.Y.Z`` anymore.
5355

5456
* Fix false-positive ``too-many-ancestors`` when inheriting from builtin classes,

pylint/pyreverse/diagrams.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def class_names(self, nodes):
122122
if isinstance(node, astroid.Instance):
123123
node = node._proxied
124124
if (
125-
isinstance(node, astroid.ClassDef)
125+
isinstance(node, (astroid.ClassDef, astroid.Name, astroid.Subscript))
126126
and hasattr(node, "name")
127127
and not self.has_node(node)
128128
):

pylint/pyreverse/inspector.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -205,36 +205,30 @@ def visit_assignname(self, node):
205205
# the name has been defined as 'global' in the frame and belongs
206206
# there.
207207
frame = node.root()
208-
try:
209-
if not hasattr(frame, "locals_type"):
210-
# If the frame doesn't have a locals_type yet,
211-
# it means it wasn't yet visited. Visit it now
212-
# to add what's missing from it.
213-
if isinstance(frame, astroid.ClassDef):
214-
self.visit_classdef(frame)
215-
elif isinstance(frame, astroid.FunctionDef):
216-
self.visit_functiondef(frame)
217-
else:
218-
self.visit_module(frame)
219-
220-
current = frame.locals_type[node.name]
221-
values = set(node.infer())
222-
frame.locals_type[node.name] = list(set(current) | values)
223-
except astroid.InferenceError:
224-
pass
208+
if not hasattr(frame, "locals_type"):
209+
# If the frame doesn't have a locals_type yet,
210+
# it means it wasn't yet visited. Visit it now
211+
# to add what's missing from it.
212+
if isinstance(frame, astroid.ClassDef):
213+
self.visit_classdef(frame)
214+
elif isinstance(frame, astroid.FunctionDef):
215+
self.visit_functiondef(frame)
216+
else:
217+
self.visit_module(frame)
218+
219+
current = frame.locals_type[node.name]
220+
frame.locals_type[node.name] = list(set(current) | utils.infer_node(node))
225221

226222
@staticmethod
227223
def handle_assignattr_type(node, parent):
228224
"""handle an astroid.assignattr node
229225
230226
handle instance_attrs_type
231227
"""
232-
try:
233-
values = set(node.infer())
234-
current = set(parent.instance_attrs_type[node.attrname])
235-
parent.instance_attrs_type[node.attrname] = list(current | values)
236-
except astroid.InferenceError:
237-
pass
228+
current = set(parent.instance_attrs_type[node.attrname])
229+
parent.instance_attrs_type[node.attrname] = list(
230+
current | utils.infer_node(node)
231+
)
238232

239233
def visit_import(self, node):
240234
"""visit an astroid.Import node

pylint/pyreverse/utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import os
2020
import re
2121
import sys
22+
from typing import Optional, Union
23+
24+
import astroid
2225

2326
RCFILE = ".pyreverserc"
2427

@@ -213,3 +216,60 @@ def visit(self, node):
213216
if methods[1] is not None:
214217
return methods[1](node)
215218
return None
219+
220+
221+
def get_annotation_label(ann: Union[astroid.Name, astroid.Subscript]) -> str:
222+
label = ""
223+
if isinstance(ann, astroid.Subscript):
224+
label = ann.as_string()
225+
elif isinstance(ann, astroid.Name):
226+
label = ann.name
227+
return label
228+
229+
230+
def get_annotation(
231+
node: Union[astroid.AssignAttr, astroid.AssignName]
232+
) -> Optional[Union[astroid.Name, astroid.Subscript]]:
233+
"""return the annotation for `node`"""
234+
ann = None
235+
if isinstance(node.parent, astroid.AnnAssign):
236+
ann = node.parent.annotation
237+
elif isinstance(node, astroid.AssignAttr):
238+
init_method = node.parent.parent
239+
try:
240+
annotations = dict(zip(init_method.locals, init_method.args.annotations))
241+
ann = annotations.get(node.parent.value.name)
242+
except AttributeError:
243+
pass
244+
else:
245+
return ann
246+
247+
try:
248+
default, *_ = node.infer()
249+
except astroid.InferenceError:
250+
default = ""
251+
252+
label = get_annotation_label(ann)
253+
if ann:
254+
label = (
255+
rf"Optional[{label}]"
256+
if getattr(default, "value", "value") is None
257+
and not label.startswith("Optional")
258+
else label
259+
)
260+
if label:
261+
ann.name = label
262+
return ann
263+
264+
265+
def infer_node(node: Union[astroid.AssignAttr, astroid.AssignName]) -> set:
266+
"""Return a set containing the node annotation if it exists
267+
otherwise return a set of the inferred types using the NodeNG.infer method"""
268+
269+
ann = get_annotation(node)
270+
if ann:
271+
return {ann}
272+
try:
273+
return set(node.infer())
274+
except astroid.InferenceError:
275+
return set()

pylint/pyreverse/writer.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020

2121
from pylint.graph import DotBackend
22-
from pylint.pyreverse.utils import is_exception
22+
from pylint.pyreverse.utils import get_annotation_label, is_exception
2323
from pylint.pyreverse.vcgutils import VCGPrinter
2424

2525

@@ -134,11 +134,29 @@ def get_values(self, obj):
134134
if not self.config.only_classnames:
135135
label = r"{}|{}\l|".format(label, r"\l".join(obj.attrs))
136136
for func in obj.methods:
137+
return_type = (
138+
f": {get_annotation_label(func.returns)}" if func.returns else ""
139+
)
140+
137141
if func.args.args:
138-
args = [arg.name for arg in func.args.args if arg.name != "self"]
142+
args = [arg for arg in func.args.args if arg.name != "self"]
139143
else:
140144
args = []
141-
label = r"{}{}({})\l".format(label, func.name, ", ".join(args))
145+
146+
annotations = dict(zip(args, func.args.annotations[1:]))
147+
for arg in args:
148+
annotation_label = ""
149+
ann = annotations.get(arg)
150+
if ann:
151+
annotation_label = get_annotation_label(ann)
152+
annotations[arg] = annotation_label
153+
154+
args = ", ".join(
155+
f"{arg.name}: {ann}" if ann else f"{arg.name}"
156+
for arg, ann in annotations.items()
157+
)
158+
159+
label = fr"{label}{func.name}({args}){return_type}\l"
142160
label = "{%s}" % label
143161
if is_exception(obj.node):
144162
return dict(fontcolor="red", label=label, shape="record")

tests/unittest_pyreverse_writer.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
import codecs
2323
import os
2424
from difflib import unified_diff
25+
from unittest.mock import patch
2526

27+
import astroid
2628
import pytest
2729

2830
from pylint.pyreverse.diadefslib import DefaultDiadefGenerator, DiadefsHandler
2931
from pylint.pyreverse.inspector import Linker, project_from_files
30-
from pylint.pyreverse.utils import get_visibility
32+
from pylint.pyreverse.utils import get_annotation, get_visibility, infer_node
3133
from pylint.pyreverse.writer import DotWriter
3234

3335
_DEFAULTS = {
@@ -132,3 +134,72 @@ def test_get_visibility(names, expected):
132134
for name in names:
133135
got = get_visibility(name)
134136
assert got == expected, f"got {got} instead of {expected} for value {name}"
137+
138+
139+
@pytest.mark.parametrize(
140+
"assign, label",
141+
[
142+
("a: str = None", "Optional[str]"),
143+
("a: str = 'mystr'", "str"),
144+
("a: Optional[str] = 'str'", "Optional[str]"),
145+
("a: Optional[str] = None", "Optional[str]"),
146+
],
147+
)
148+
def test_get_annotation_annassign(assign, label):
149+
"""AnnAssign"""
150+
node = astroid.extract_node(assign)
151+
got = get_annotation(node.value).name
152+
assert isinstance(node, astroid.AnnAssign)
153+
assert got == label, f"got {got} instead of {label} for value {node}"
154+
155+
156+
@pytest.mark.parametrize(
157+
"init_method, label",
158+
[
159+
("def __init__(self, x: str): self.x = x", "str"),
160+
("def __init__(self, x: str = 'str'): self.x = x", "str"),
161+
("def __init__(self, x: str = None): self.x = x", "Optional[str]"),
162+
("def __init__(self, x: Optional[str]): self.x = x", "Optional[str]"),
163+
("def __init__(self, x: Optional[str] = None): self.x = x", "Optional[str]"),
164+
("def __init__(self, x: Optional[str] = 'str'): self.x = x", "Optional[str]"),
165+
],
166+
)
167+
def test_get_annotation_assignattr(init_method, label):
168+
"""AssignAttr"""
169+
assign = rf"""
170+
class A:
171+
{init_method}
172+
"""
173+
node = astroid.extract_node(assign)
174+
instance_attrs = node.instance_attrs
175+
for _, assign_attrs in instance_attrs.items():
176+
for assign_attr in assign_attrs:
177+
got = get_annotation(assign_attr).name
178+
assert isinstance(assign_attr, astroid.AssignAttr)
179+
assert got == label, f"got {got} instead of {label} for value {node}"
180+
181+
182+
@patch("pylint.pyreverse.utils.get_annotation")
183+
@patch("astroid.node_classes.NodeNG.infer", side_effect=astroid.InferenceError)
184+
def test_infer_node_1(mock_infer, mock_get_annotation):
185+
"""Return set() when astroid.InferenceError is raised and an annotation has
186+
not been returned
187+
"""
188+
mock_get_annotation.return_value = None
189+
node = astroid.extract_node("a: str = 'mystr'")
190+
mock_infer.return_value = "x"
191+
assert infer_node(node) == set()
192+
assert mock_infer.called
193+
194+
195+
@patch("pylint.pyreverse.utils.get_annotation")
196+
@patch("astroid.node_classes.NodeNG.infer")
197+
def test_infer_node_2(mock_infer, mock_get_annotation):
198+
"""Return set(node.infer()) when InferenceError is not raised and an
199+
annotation has not been returned
200+
"""
201+
mock_get_annotation.return_value = None
202+
node = astroid.extract_node("a: str = 'mystr'")
203+
mock_infer.return_value = "x"
204+
assert infer_node(node) == set("x")
205+
assert mock_infer.called

0 commit comments

Comments
 (0)