Skip to content

Commit eb8b7b9

Browse files
rwbartonJukkaL
authored andcommitted
Rewrite references to inner functions in treetransform
Fixes #1323.
1 parent 7ab19c1 commit eb8b7b9

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

mypy/treetransform.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr,
2323
)
2424
from mypy.types import Type, FunctionLike, Instance
25+
from mypy.traverser import TraverserVisitor
2526
from mypy.visitor import NodeVisitor
2627

2728

@@ -36,7 +37,7 @@ class TransformVisitor(NodeVisitor[Node]):
3637
3738
* Do not duplicate TypeInfo nodes. This would generally not be desirable.
3839
* Only update some name binding cross-references, but only those that
39-
refer to Var nodes, not those targeting ClassDef, TypeInfo or FuncDef
40+
refer to Var or FuncDef nodes, not those targeting ClassDef or TypeInfo
4041
nodes.
4142
* Types are not transformed, but you can override type() to also perform
4243
type transformation.
@@ -48,6 +49,7 @@ def __init__(self) -> None:
4849
# There may be multiple references to a Var node. Keep track of
4950
# Var translations using a dictionary.
5051
self.var_map = {} # type: Dict[Var, Var]
52+
self.func_map = {} # type: Dict[FuncDef, FuncDef]
5153

5254
def visit_mypy_file(self, node: MypyFile) -> Node:
5355
# NOTE: The 'names' and 'imports' instance variables will be empty!
@@ -98,6 +100,18 @@ def copy_argument(self, argument: Argument) -> Argument:
98100

99101
def visit_func_def(self, node: FuncDef) -> FuncDef:
100102
# Note that a FuncDef must be transformed to a FuncDef.
103+
104+
# These contortions are needed to handle the case of recursive
105+
# references inside the function being transformed.
106+
# Set up empty nodes for references within this function
107+
# to other functions defined inside it.
108+
# Don't create an entry for this function itself though,
109+
# since we want self-references to point to the original
110+
# function if this is the top-level node we are transforming.
111+
init = FuncMapInitializer(self)
112+
for stmt in node.body.body:
113+
stmt.accept(init)
114+
101115
new = FuncDef(node.name(),
102116
[self.copy_argument(arg) for arg in node.arguments],
103117
self.block(node.body),
@@ -113,7 +127,13 @@ def visit_func_def(self, node: FuncDef) -> FuncDef:
113127
new.is_class = node.is_class
114128
new.is_property = node.is_property
115129
new.original_def = node.original_def
116-
return new
130+
131+
if node in self.func_map:
132+
result = self.func_map[node]
133+
result.__dict__ = new.__dict__
134+
return result
135+
else:
136+
return new
117137

118138
def visit_func_expr(self, node: FuncExpr) -> Node:
119139
new = FuncExpr([self.copy_argument(arg) for arg in node.arguments],
@@ -330,6 +350,9 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None:
330350
target = original.node
331351
if isinstance(target, Var):
332352
target = self.visit_var(target)
353+
elif isinstance(target, FuncDef):
354+
if target in self.func_map:
355+
target = self.func_map[target]
333356
new.node = target
334357
new.is_def = original.is_def
335358

@@ -527,3 +550,14 @@ def types(self, types: List[Type]) -> List[Type]:
527550

528551
def optional_types(self, types: List[Type]) -> List[Type]:
529552
return [self.optional_type(type) for type in types]
553+
554+
555+
class FuncMapInitializer(TraverserVisitor):
556+
def __init__(self, transformer: TransformVisitor) -> None:
557+
self.transformer = transformer
558+
559+
def visit_func_def(self, node: FuncDef) -> None:
560+
if node not in self.transformer.func_map:
561+
self.transformer.func_map[node] = FuncDef(
562+
node.name(), node.arguments, node.body, None)
563+
super().visit_func_def(node)

test-data/unit/check-typevar-values.test

+30
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,33 @@ a = g
479479
b = g
480480
b = g
481481
b = f # E: Incompatible types in assignment (expression has type Callable[[T], T], variable has type Callable[[U], U])
482+
483+
[case testInnerFunctionWithTypevarValues]
484+
from typing import TypeVar
485+
T = TypeVar('T', int, str)
486+
U = TypeVar('U', int, str)
487+
def outer(x: T) -> T:
488+
def inner(y: T) -> T:
489+
return x
490+
def inner2(y: U) -> U:
491+
return y
492+
inner(x)
493+
inner(3) # E: Argument 1 to "inner" has incompatible type "int"; expected "str"
494+
inner2(x)
495+
inner2(3)
496+
outer(3)
497+
return x
498+
[out]
499+
main: note: In function "outer":
500+
501+
[case testInnerFunctionMutualRecursionWithTypevarValues]
502+
from typing import TypeVar
503+
T = TypeVar('T', int, str)
504+
def outer(x: T) -> T:
505+
def inner1(y: T) -> T:
506+
return inner2(y)
507+
def inner2(y: T) -> T:
508+
return inner1('a') # E: Argument 1 to "inner1" has incompatible type "str"; expected "int"
509+
return inner1(x)
510+
[out]
511+
main: note: In function "inner2":

0 commit comments

Comments
 (0)