Skip to content

Commit b71be8a

Browse files
Handle has-a relationships for type-hinted arguments in class diagrams (#4745)
* Pyreverse - Show class has-a relationships inferred from type-hints Closes #4744 Co-authored-by: Pierre Sassoulas <[email protected]>
1 parent 5e5f48d commit b71be8a

9 files changed

+45
-12
lines changed

ChangeLog

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ Release date: TBA
1212
..
1313
Put bug fixes that should not wait for a new minor version here
1414

15+
* pyreverse: Show class has-a relationships inferred from the type-hint
16+
17+
Closes #4744
18+
1519
* Added ``ignored-parents`` option to the design checker to ignore specific
1620
classes from the ``too-many-ancestors`` check (R0901).
1721

doc/whatsnew/2.10.rst

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ Extensions
4040
Other Changes
4141
=============
4242

43+
* Pyreverse - Show class has-a relationships inferred from type-hints
44+
4345
* Performance of the Similarity checker has been improved.
4446

4547
* Added ``time.clock`` to deprecated functions/methods for python 3.3

pylint/pyreverse/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,9 @@ def infer_node(node: Union[astroid.AssignAttr, astroid.AssignName]) -> set:
269269
otherwise return a set of the inferred types using the NodeNG.infer method"""
270270

271271
ann = get_annotation(node)
272-
if ann:
273-
return {ann}
274272
try:
273+
if ann:
274+
return set(ann.infer())
275275
return set(node.infer())
276276
except astroid.InferenceError:
277-
return set()
277+
return {ann} if ann else set()

tests/data/classes_No_Name.dot

+7-5
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ charset="utf-8"
33
rankdir=BT
44
"0" [label="{Ancestor|attr : str\lcls_member\l|get_value()\lset_value(value)\l}", shape="record"];
55
"1" [label="{DoNothing|\l|}", shape="record"];
6-
"2" [label="{Interface|\l|get_value()\lset_value(value)\l}", shape="record"];
7-
"3" [label="{Specialization|TYPE : str\lrelation\ltop : str\l|}", shape="record"];
8-
"3" -> "0" [arrowhead="empty", arrowtail="none"];
9-
"0" -> "2" [arrowhead="empty", arrowtail="node", style="dashed"];
6+
"2" [label="{DoNothing2|\l|}", shape="record"];
7+
"3" [label="{Interface|\l|get_value()\lset_value(value)\l}", shape="record"];
8+
"4" [label="{Specialization|TYPE : str\lrelation\lrelation2\ltop : str\l|}", shape="record"];
9+
"4" -> "0" [arrowhead="empty", arrowtail="none"];
10+
"0" -> "3" [arrowhead="empty", arrowtail="node", style="dashed"];
1011
"1" -> "0" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="cls_member", style="solid"];
11-
"1" -> "3" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="relation", style="solid"];
12+
"1" -> "4" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="relation", style="solid"];
13+
"2" -> "4" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="relation2", style="solid"];
1214
}

tests/data/clientmodule_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
""" docstring for file clientmodule.py """
2-
from data.suppliermodule_test import Interface, DoNothing
2+
from data.suppliermodule_test import Interface, DoNothing, DoNothing2
33

44
class Ancestor:
55
""" Ancestor method """
@@ -23,7 +23,8 @@ class Specialization(Ancestor):
2323
TYPE = 'final class'
2424
top = 'class'
2525

26-
def __init__(self, value, _id):
26+
def __init__(self, value, _id, relation2: DoNothing2):
2727
Ancestor.__init__(self, value)
2828
self._id = _id
2929
self.relation = DoNothing()
30+
self.relation2 = relation2

tests/data/suppliermodule_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ def set_value(self, value):
88
raise NotImplementedError
99

1010
class DoNothing: pass
11+
12+
class DoNothing2: pass

tests/unittest_pyreverse_diadefs.py

+4
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class TestDefaultDiadefGenerator:
9999
_should_rels = [
100100
("association", "DoNothing", "Ancestor"),
101101
("association", "DoNothing", "Specialization"),
102+
("association", "DoNothing2", "Specialization"),
102103
("implements", "Ancestor", "Interface"),
103104
("specialization", "Specialization", "Ancestor"),
104105
]
@@ -142,6 +143,7 @@ def test_known_values1(HANDLER, PROJECT):
142143
assert classes == [
143144
(True, "Ancestor"),
144145
(True, "DoNothing"),
146+
(True, "DoNothing2"),
145147
(True, "Interface"),
146148
(True, "Specialization"),
147149
]
@@ -170,6 +172,7 @@ def test_known_values3(HANDLER, PROJECT):
170172
(True, "data.clientmodule_test.Ancestor"),
171173
(True, special),
172174
(True, "data.suppliermodule_test.DoNothing"),
175+
(True, "data.suppliermodule_test.DoNothing2"),
173176
]
174177

175178

@@ -184,6 +187,7 @@ def test_known_values4(HANDLER, PROJECT):
184187
assert classes == [
185188
(True, "Ancestor"),
186189
(True, "DoNothing"),
190+
(True, "DoNothing2"),
187191
(True, "Specialization"),
188192
]
189193

tests/unittest_pyreverse_inspector.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def test_instance_attrs_resolution(project):
6262
klass = project.get_module("data.clientmodule_test")["Specialization"]
6363
assert hasattr(klass, "instance_attrs_type")
6464
type_dict = klass.instance_attrs_type
65-
assert len(type_dict) == 2
65+
assert len(type_dict) == 3
6666
keys = sorted(type_dict.keys())
67-
assert keys == ["_id", "relation"]
67+
assert keys == ["_id", "relation", "relation2"]
6868
assert isinstance(type_dict["relation"][0], astroid.bases.Instance), type_dict[
6969
"relation"
7070
]

tests/unittest_pyreverse_writer.py

+18
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,21 @@ def test_infer_node_2(mock_infer, mock_get_annotation):
204204
mock_infer.return_value = "x"
205205
assert infer_node(node) == set("x")
206206
assert mock_infer.called
207+
208+
209+
def test_infer_node_3():
210+
"""Return a set containing an astroid.ClassDef object when the attribute
211+
has a type annotation"""
212+
node = astroid.extract_node(
213+
"""
214+
class Component:
215+
pass
216+
217+
class Composite:
218+
def __init__(self, component: Component):
219+
self.component = component
220+
"""
221+
)
222+
instance_attr = node.instance_attrs.get("component")[0]
223+
assert isinstance(infer_node(instance_attr), set)
224+
assert isinstance(infer_node(instance_attr).pop(), astroid.ClassDef)

0 commit comments

Comments
 (0)