Skip to content

Add type annotations to pyreverse #4551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Release date: TBA
..
Put new features and bugfixes here and also in 'doc/whatsnew/2.9.rst'

* Add type annotations to pyreverse dot files

Closes #1548

* astroid has been upgraded to 2.6.0

* ``setuptools_scm`` has been removed and replaced by ``tbump`` in order to not
Expand Down
2 changes: 2 additions & 0 deletions doc/whatsnew/2.9.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ New checkers
Other Changes
=============

* Add type annotations to pyreverse dot files

* Pylint's tags are now the standard form ``vX.Y.Z`` and not ``pylint-X.Y.Z`` anymore.

* Fix false-positive ``too-many-ancestors`` when inheriting from builtin classes,
Expand Down
2 changes: 1 addition & 1 deletion pylint/pyreverse/diagrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def class_names(self, nodes):
if isinstance(node, astroid.Instance):
node = node._proxied
if (
isinstance(node, astroid.ClassDef)
isinstance(node, (astroid.ClassDef, astroid.Name, astroid.Subscript))
and hasattr(node, "name")
and not self.has_node(node)
):
Expand Down
6 changes: 4 additions & 2 deletions pylint/pyreverse/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def visit_assignname(self, node):
self.visit_module(frame)

current = frame.locals_type[node.name]
values = set(node.infer())
ann = utils.get_annotation(node)
values = {ann} if ann else set(node.infer())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here there can be an inference fail, we should catch it (like line 238). Btw it look like the code is similar, maybe we can create a function for it ?

frame.locals_type[node.name] = list(set(current) | values)
except astroid.InferenceError:
pass
Expand All @@ -229,8 +230,9 @@ def handle_assignattr_type(node, parent):

handle instance_attrs_type
"""
ann = utils.get_annotation(node)
try:
values = set(node.infer())
values = {ann} if ann else set(node.infer())
current = set(parent.instance_attrs_type[node.attrname])
parent.instance_attrs_type[node.attrname] = list(current | values)
except astroid.InferenceError:
Expand Down
43 changes: 43 additions & 0 deletions pylint/pyreverse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import os
import re
import sys
from typing import Optional, Union

import astroid

RCFILE = ".pyreverserc"

Expand Down Expand Up @@ -213,3 +216,43 @@ def visit(self, node):
if methods[1] is not None:
return methods[1](node)
return None


def get_annotation_label(ann: Union[astroid.Name, astroid.Subscript]) -> str:
label = ""
if isinstance(ann, astroid.Subscript):
label = ann.as_string()
elif isinstance(ann, astroid.Name):
label = ann.name
return label


def get_annotation(
node: Union[astroid.AnnAssign, astroid.AssignAttr]
) -> Optional[Union[astroid.Name, astroid.Subscript]]:
"""return the annotation for `node`"""
ann = None
if isinstance(node.parent, astroid.AnnAssign):
ann = node.parent.annotation
elif isinstance(node, astroid.AssignAttr):
init_method = node.parent.parent
try:
annotations = dict(zip(init_method.locals, init_method.args.annotations))
ann = annotations.get(node.parent.value.name)
except AttributeError:
pass
else:
return ann

default, *_ = node.infer()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There can be an InferenceError raised here.

label = get_annotation_label(ann)
if ann:
label = (
rf"Optional[{label}]"
if getattr(default, "value", "value") is None
and not label.startswith("Optional")
else label
)
if label:
ann.name = label
return ann
24 changes: 21 additions & 3 deletions pylint/pyreverse/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os

from pylint.graph import DotBackend
from pylint.pyreverse.utils import is_exception
from pylint.pyreverse.utils import get_annotation_label, is_exception
from pylint.pyreverse.vcgutils import VCGPrinter


Expand Down Expand Up @@ -134,11 +134,29 @@ def get_values(self, obj):
if not self.config.only_classnames:
label = r"{}|{}\l|".format(label, r"\l".join(obj.attrs))
for func in obj.methods:
return_type = (
f": {get_annotation_label(func.returns)}" if func.returns else ""
)

if func.args.args:
args = [arg.name for arg in func.args.args if arg.name != "self"]
args = [arg for arg in func.args.args if arg.name != "self"]
else:
args = []
label = r"{}{}({})\l".format(label, func.name, ", ".join(args))

annotations = dict(zip(args, func.args.annotations[1:]))
for arg in args:
annotation_label = ""
ann = annotations.get(arg)
if ann:
annotation_label = get_annotation_label(ann)
annotations[arg] = annotation_label

args = ", ".join(
f"{arg.name}: {ann}" if ann else f"{arg.name}"
for arg, ann in annotations.items()
)

label = fr"{label}{func.name}({args}){return_type}\l"
label = "{%s}" % label
if is_exception(obj.node):
return dict(fontcolor="red", label=label, shape="record")
Expand Down
46 changes: 45 additions & 1 deletion tests/unittest_pyreverse_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
import os
from difflib import unified_diff

import astroid
import pytest

from pylint.pyreverse.diadefslib import DefaultDiadefGenerator, DiadefsHandler
from pylint.pyreverse.inspector import Linker, project_from_files
from pylint.pyreverse.utils import get_visibility
from pylint.pyreverse.utils import get_annotation, get_visibility
from pylint.pyreverse.writer import DotWriter

_DEFAULTS = {
Expand Down Expand Up @@ -132,3 +133,46 @@ def test_get_visibility(names, expected):
for name in names:
got = get_visibility(name)
assert got == expected, f"got {got} instead of {expected} for value {name}"


@pytest.mark.parametrize(
"assign, label",
[
("a: str = None", "Optional[str]"),
("a: str = 'mystr'", "str"),
("a: Optional[str] = 'str'", "Optional[str]"),
("a: Optional[str] = None", "Optional[str]"),
],
)
def test_get_annotation_annassign(assign, label):
"""AnnAssign"""
node = astroid.extract_node(assign)
got = get_annotation(node.value).name
assert isinstance(node, astroid.AnnAssign)
assert got == label, f"got {got} instead of {label} for value {node}"


@pytest.mark.parametrize(
"init_method, label",
[
("def __init__(self, x: str): self.x = x", "str"),
("def __init__(self, x: str = 'str'): self.x = x", "str"),
("def __init__(self, x: str = None): self.x = x", "Optional[str]"),
("def __init__(self, x: Optional[str]): self.x = x", "Optional[str]"),
("def __init__(self, x: Optional[str] = None): self.x = x", "Optional[str]"),
("def __init__(self, x: Optional[str] = 'str'): self.x = x", "Optional[str]"),
],
)
def test_get_annotation_assignattr(init_method, label):
"""AssignAttr"""
assign = rf"""
class A:
{init_method}
"""
node = astroid.extract_node(assign)
instance_attrs = node.instance_attrs
for _, assign_attrs in instance_attrs.items():
for assign_attr in assign_attrs:
got = get_annotation(assign_attr).name
assert isinstance(assign_attr, astroid.AssignAttr)
assert got == label, f"got {got} instead of {label} for value {node}"