Skip to content

Commit 194e6d3

Browse files
committed
Add type annotations to pyreverse dot files
- Add a try/except to deal with a possible InferenceError when using NodeNG.infer method - Create a function in utils and so remove repeated logic in inspector.py - Add unittests to check the InferenceError logic - Adjust the types in function input
1 parent 1d3d13b commit 194e6d3

File tree

3 files changed

+64
-28
lines changed

3 files changed

+64
-28
lines changed

pylint/pyreverse/inspector.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -205,38 +205,30 @@ def visit_assignname(self, node):
205205
# the name has been defined as 'global' in the frame and belongs
206206
# there.
207207
frame = node.root()
208-
try:
209-
if not hasattr(frame, "locals_type"):
210-
# If the frame doesn't have a locals_type yet,
211-
# it means it wasn't yet visited. Visit it now
212-
# to add what's missing from it.
213-
if isinstance(frame, astroid.ClassDef):
214-
self.visit_classdef(frame)
215-
elif isinstance(frame, astroid.FunctionDef):
216-
self.visit_functiondef(frame)
217-
else:
218-
self.visit_module(frame)
219-
220-
current = frame.locals_type[node.name]
221-
ann = utils.get_annotation(node)
222-
values = {ann} if ann else set(node.infer())
223-
frame.locals_type[node.name] = list(set(current) | values)
224-
except astroid.InferenceError:
225-
pass
208+
if not hasattr(frame, "locals_type"):
209+
# If the frame doesn't have a locals_type yet,
210+
# it means it wasn't yet visited. Visit it now
211+
# to add what's missing from it.
212+
if isinstance(frame, astroid.ClassDef):
213+
self.visit_classdef(frame)
214+
elif isinstance(frame, astroid.FunctionDef):
215+
self.visit_functiondef(frame)
216+
else:
217+
self.visit_module(frame)
218+
219+
current = frame.locals_type[node.name]
220+
frame.locals_type[node.name] = list(set(current) | utils.infer_node(node))
226221

227222
@staticmethod
228223
def handle_assignattr_type(node, parent):
229224
"""handle an astroid.assignattr node
230225
231226
handle instance_attrs_type
232227
"""
233-
ann = utils.get_annotation(node)
234-
try:
235-
values = {ann} if ann else set(node.infer())
236-
current = set(parent.instance_attrs_type[node.attrname])
237-
parent.instance_attrs_type[node.attrname] = list(current | values)
238-
except astroid.InferenceError:
239-
pass
228+
current = set(parent.instance_attrs_type[node.attrname])
229+
parent.instance_attrs_type[node.attrname] = list(
230+
current | utils.infer_node(node)
231+
)
240232

241233
def visit_import(self, node):
242234
"""visit an astroid.Import node

pylint/pyreverse/utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def get_annotation_label(ann: Union[astroid.Name, astroid.Subscript]) -> str:
228228

229229

230230
def get_annotation(
231-
node: Union[astroid.AnnAssign, astroid.AssignAttr]
231+
node: Union[astroid.AssignAttr, astroid.AssignName]
232232
) -> Optional[Union[astroid.Name, astroid.Subscript]]:
233233
"""return the annotation for `node`"""
234234
ann = None
@@ -244,7 +244,11 @@ def get_annotation(
244244
else:
245245
return ann
246246

247-
default, *_ = node.infer()
247+
try:
248+
default, *_ = node.infer()
249+
except astroid.InferenceError:
250+
default = ""
251+
248252
label = get_annotation_label(ann)
249253
if ann:
250254
label = (
@@ -256,3 +260,16 @@ def get_annotation(
256260
if label:
257261
ann.name = label
258262
return ann
263+
264+
265+
def infer_node(node: Union[astroid.AssignAttr, astroid.AssignName]) -> set:
266+
"""Return a set containing the node annotation if it exists
267+
otherwise return a set of the inferred types using the NodeNG.infer method"""
268+
269+
ann = get_annotation(node)
270+
if ann:
271+
return {ann}
272+
try:
273+
return set(node.infer())
274+
except astroid.InferenceError:
275+
return set()

tests/unittest_pyreverse_writer.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
import codecs
2323
import os
2424
from difflib import unified_diff
25+
from unittest.mock import patch
2526

2627
import astroid
2728
import pytest
2829

2930
from pylint.pyreverse.diadefslib import DefaultDiadefGenerator, DiadefsHandler
3031
from pylint.pyreverse.inspector import Linker, project_from_files
31-
from pylint.pyreverse.utils import get_annotation, get_visibility
32+
from pylint.pyreverse.utils import get_annotation, get_visibility, infer_node
3233
from pylint.pyreverse.writer import DotWriter
3334

3435
_DEFAULTS = {
@@ -176,3 +177,29 @@ class A:
176177
got = get_annotation(assign_attr).name
177178
assert isinstance(assign_attr, astroid.AssignAttr)
178179
assert got == label, f"got {got} instead of {label} for value {node}"
180+
181+
182+
@patch("pylint.pyreverse.utils.get_annotation")
183+
@patch("astroid.node_classes.NodeNG.infer", side_effect=astroid.InferenceError)
184+
def test_infer_node_1(mock_infer, mock_get_annotation):
185+
"""Return set() when astroid.InferenceError is raised and an annotation has
186+
not been returned
187+
"""
188+
mock_get_annotation.return_value = None
189+
node = astroid.extract_node("a: str = 'mystr'")
190+
mock_infer.return_value = "x"
191+
assert infer_node(node) == set()
192+
assert mock_infer.called
193+
194+
195+
@patch("pylint.pyreverse.utils.get_annotation")
196+
@patch("astroid.node_classes.NodeNG.infer")
197+
def test_infer_node_2(mock_infer, mock_get_annotation):
198+
"""Return set(node.infer()) when InferenceError is not raised and an
199+
annotation has not been returned
200+
"""
201+
mock_get_annotation.return_value = None
202+
node = astroid.extract_node("a: str = 'mystr'")
203+
mock_infer.return_value = "x"
204+
assert infer_node(node) == set("x")
205+
assert mock_infer.called

0 commit comments

Comments
 (0)