22
22
YieldExpr , ExecStmt , Argument , BackquoteExpr , AwaitExpr ,
23
23
)
24
24
from mypy .types import Type , FunctionLike , Instance
25
+ from mypy .traverser import TraverserVisitor
25
26
from mypy .visitor import NodeVisitor
26
27
27
28
@@ -36,7 +37,7 @@ class TransformVisitor(NodeVisitor[Node]):
36
37
37
38
* Do not duplicate TypeInfo nodes. This would generally not be desirable.
38
39
* Only update some name binding cross-references, but only those that
39
- refer to Var nodes, not those targeting ClassDef, TypeInfo or FuncDef
40
+ refer to Var or FuncDef nodes, not those targeting ClassDef or TypeInfo
40
41
nodes.
41
42
* Types are not transformed, but you can override type() to also perform
42
43
type transformation.
@@ -48,6 +49,7 @@ def __init__(self) -> None:
48
49
# There may be multiple references to a Var node. Keep track of
49
50
# Var translations using a dictionary.
50
51
self .var_map = {} # type: Dict[Var, Var]
52
+ self .func_map = {} # type: Dict[FuncDef, FuncDef]
51
53
52
54
def visit_mypy_file (self , node : MypyFile ) -> Node :
53
55
# NOTE: The 'names' and 'imports' instance variables will be empty!
@@ -98,6 +100,18 @@ def copy_argument(self, argument: Argument) -> Argument:
98
100
99
101
def visit_func_def (self , node : FuncDef ) -> FuncDef :
100
102
# Note that a FuncDef must be transformed to a FuncDef.
103
+
104
+ # These contortions are needed to handle the case of recursive
105
+ # references inside the function being transformed.
106
+ # Set up empty nodes for references within this function
107
+ # to other functions defined inside it.
108
+ # Don't create an entry for this function itself though,
109
+ # since we want self-references to point to the original
110
+ # function if this is the top-level node we are transforming.
111
+ init = FuncMapInitializer (self )
112
+ for stmt in node .body .body :
113
+ stmt .accept (init )
114
+
101
115
new = FuncDef (node .name (),
102
116
[self .copy_argument (arg ) for arg in node .arguments ],
103
117
self .block (node .body ),
@@ -113,7 +127,13 @@ def visit_func_def(self, node: FuncDef) -> FuncDef:
113
127
new .is_class = node .is_class
114
128
new .is_property = node .is_property
115
129
new .original_def = node .original_def
116
- return new
130
+
131
+ if node in self .func_map :
132
+ result = self .func_map [node ]
133
+ result .__dict__ = new .__dict__
134
+ return result
135
+ else :
136
+ return new
117
137
118
138
def visit_func_expr (self , node : FuncExpr ) -> Node :
119
139
new = FuncExpr ([self .copy_argument (arg ) for arg in node .arguments ],
@@ -330,6 +350,9 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None:
330
350
target = original .node
331
351
if isinstance (target , Var ):
332
352
target = self .visit_var (target )
353
+ elif isinstance (target , FuncDef ):
354
+ if target in self .func_map :
355
+ target = self .func_map [target ]
333
356
new .node = target
334
357
new .is_def = original .is_def
335
358
@@ -527,3 +550,14 @@ def types(self, types: List[Type]) -> List[Type]:
527
550
528
551
def optional_types (self , types : List [Type ]) -> List [Type ]:
529
552
return [self .optional_type (type ) for type in types ]
553
+
554
+
555
+ class FuncMapInitializer (TraverserVisitor ):
556
+ def __init__ (self , transformer : TransformVisitor ) -> None :
557
+ self .transformer = transformer
558
+
559
+ def visit_func_def (self , node : FuncDef ) -> None :
560
+ if node not in self .transformer .func_map :
561
+ self .transformer .func_map [node ] = FuncDef (
562
+ node .name (), node .arguments , node .body , None )
563
+ super ().visit_func_def (node )
0 commit comments