@@ -21,7 +21,7 @@ class Context:
21
21
lines : set [int ]
22
22
23
23
24
- class RegionFinder ( ast . NodeVisitor ) :
24
+ class RegionFinder :
25
25
"""An ast visitor that will find and track regions of code.
26
26
27
27
Functions and classes are tracked by name. Results are in the .regions
@@ -34,13 +34,29 @@ def __init__(self) -> None:
34
34
35
35
def parse_source (self , source : str ) -> None :
36
36
"""Parse `source` and walk the ast to populate the .regions attribute."""
37
- self .visit (ast .parse (source ))
37
+ self .handle_node (ast .parse (source ))
38
38
39
39
def fq_node_name (self ) -> str :
40
40
"""Get the current fully qualified name we're processing."""
41
41
return "." .join (c .name for c in self .context )
42
42
43
- def visit_FunctionDef (self , node : ast .FunctionDef ) -> None :
43
+ def handle_node (self , node : ast .AST ) -> None :
44
+ """Recursively handle any node."""
45
+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef )):
46
+ self .handle_FunctionDef (node )
47
+ elif isinstance (node , ast .ClassDef ):
48
+ self .handle_ClassDef (node )
49
+ elif isinstance (node , (ast .Module , ast .stmt )):
50
+ # Only modules and statements can contain function and class
51
+ # definitions. Handle them and ignore all others.
52
+ self .handle_node_body (node )
53
+
54
+ def handle_node_body (self , node : ast .AST ) -> None :
55
+ """Recursively handle the nodes in this node's body, if any."""
56
+ for body_node in getattr (node , "body" , ()):
57
+ self .handle_node (body_node )
58
+
59
+ def handle_FunctionDef (self , node : ast .FunctionDef | ast .AsyncFunctionDef ) -> None :
44
60
"""Called for `def` or `async def`."""
45
61
lines = set (range (node .body [0 ].lineno , cast (int , node .body [- 1 ].end_lineno ) + 1 ))
46
62
if self .context and self .context [- 1 ].kind == "class" :
@@ -60,12 +76,10 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
60
76
lines = lines ,
61
77
)
62
78
)
63
- self .generic_visit (node )
79
+ self .handle_node_body (node )
64
80
self .context .pop ()
65
81
66
- visit_AsyncFunctionDef = visit_FunctionDef # type: ignore[assignment]
67
-
68
- def visit_ClassDef (self , node : ast .ClassDef ) -> None :
82
+ def handle_ClassDef (self , node : ast .ClassDef ) -> None :
69
83
"""Called for `class`."""
70
84
# The lines for a class are the lines in the methods of the class.
71
85
# We start empty, and count on visit_FunctionDef to add the lines it
@@ -80,7 +94,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
80
94
lines = lines ,
81
95
)
82
96
)
83
- self .generic_visit (node )
97
+ self .handle_node_body (node )
84
98
self .context .pop ()
85
99
# Class bodies should be excluded from the enclosing classes.
86
100
for ancestor in reversed (self .context ):
0 commit comments