13
13
import sys
14
14
import tokenize
15
15
import types
16
+ from collections import defaultdict
16
17
from pathlib import Path
17
18
from pathlib import PurePath
18
19
from typing import Callable
56
57
astNum = ast .Num
57
58
58
59
60
+ class Sentinel :
61
+ pass
62
+
63
+
59
64
assertstate_key = StashKey ["AssertionState" ]()
60
65
61
66
# pytest caches rewritten pycs in pycache dirs
62
67
PYTEST_TAG = f"{ sys .implementation .cache_tag } -pytest-{ version } "
63
68
PYC_EXT = ".py" + (__debug__ and "c" or "o" )
64
69
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
65
70
71
+ # Special marker that denotes we have just left a scope definition
72
+ _SCOPE_END_MARKER = Sentinel ()
73
+
66
74
67
75
class AssertionRewritingHook (importlib .abc .MetaPathFinder , importlib .abc .Loader ):
68
76
"""PEP302/PEP451 import hook which rewrites asserts."""
@@ -645,6 +653,8 @@ class AssertionRewriter(ast.NodeVisitor):
645
653
.push_format_context() and .pop_format_context() which allows
646
654
to build another %-formatted string while already building one.
647
655
656
+ :scope: A tuple containing the current scope used for variables_overwrite.
657
+
648
658
:variables_overwrite: A dict filled with references to variables
649
659
that change value within an assert. This happens when a variable is
650
660
reassigned with the walrus operator
@@ -666,7 +676,10 @@ def __init__(
666
676
else :
667
677
self .enable_assertion_pass_hook = False
668
678
self .source = source
669
- self .variables_overwrite : Dict [str , str ] = {}
679
+ self .scope : tuple [ast .AST , ...] = ()
680
+ self .variables_overwrite : defaultdict [
681
+ tuple [ast .AST , ...], Dict [str , str ]
682
+ ] = defaultdict (dict )
670
683
671
684
def run (self , mod : ast .Module ) -> None :
672
685
"""Find all assert statements in *mod* and rewrite them."""
@@ -732,9 +745,17 @@ def run(self, mod: ast.Module) -> None:
732
745
mod .body [pos :pos ] = imports
733
746
734
747
# Collect asserts.
735
- nodes : List [ast .AST ] = [mod ]
748
+ self .scope = (mod ,)
749
+ nodes : List [Union [ast .AST , Sentinel ]] = [mod ]
736
750
while nodes :
737
751
node = nodes .pop ()
752
+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef )):
753
+ self .scope = tuple ((* self .scope , node ))
754
+ nodes .append (_SCOPE_END_MARKER )
755
+ if node == _SCOPE_END_MARKER :
756
+ self .scope = self .scope [:- 1 ]
757
+ continue
758
+ assert isinstance (node , ast .AST )
738
759
for name , field in ast .iter_fields (node ):
739
760
if isinstance (field , list ):
740
761
new : List [ast .AST ] = []
@@ -1005,7 +1026,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
1005
1026
]
1006
1027
):
1007
1028
pytest_temp = self .variable ()
1008
- self .variables_overwrite [
1029
+ self .variables_overwrite [self . scope ][
1009
1030
v .left .target .id
1010
1031
] = v .left # type:ignore[assignment]
1011
1032
v .left .target .id = pytest_temp
@@ -1048,17 +1069,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
1048
1069
new_args = []
1049
1070
new_kwargs = []
1050
1071
for arg in call .args :
1051
- if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite :
1052
- arg = self .variables_overwrite [arg .id ] # type:ignore[assignment]
1072
+ if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite .get (
1073
+ self .scope , {}
1074
+ ):
1075
+ arg = self .variables_overwrite [self .scope ][
1076
+ arg .id
1077
+ ] # type:ignore[assignment]
1053
1078
res , expl = self .visit (arg )
1054
1079
arg_expls .append (expl )
1055
1080
new_args .append (res )
1056
1081
for keyword in call .keywords :
1057
- if (
1058
- isinstance (keyword .value , ast .Name )
1059
- and keyword .value .id in self .variables_overwrite
1060
- ):
1061
- keyword .value = self .variables_overwrite [
1082
+ if isinstance (
1083
+ keyword .value , ast .Name
1084
+ ) and keyword .value .id in self .variables_overwrite .get (self .scope , {}):
1085
+ keyword .value = self .variables_overwrite [self .scope ][
1062
1086
keyword .value .id
1063
1087
] # type:ignore[assignment]
1064
1088
res , expl = self .visit (keyword .value )
@@ -1094,12 +1118,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
1094
1118
def visit_Compare (self , comp : ast .Compare ) -> Tuple [ast .expr , str ]:
1095
1119
self .push_format_context ()
1096
1120
# We first check if we have overwritten a variable in the previous assert
1097
- if isinstance (comp .left , ast .Name ) and comp .left .id in self .variables_overwrite :
1098
- comp .left = self .variables_overwrite [
1121
+ if isinstance (
1122
+ comp .left , ast .Name
1123
+ ) and comp .left .id in self .variables_overwrite .get (self .scope , {}):
1124
+ comp .left = self .variables_overwrite [self .scope ][
1099
1125
comp .left .id
1100
1126
] # type:ignore[assignment]
1101
1127
if isinstance (comp .left , namedExpr ):
1102
- self .variables_overwrite [
1128
+ self .variables_overwrite [self . scope ][
1103
1129
comp .left .target .id
1104
1130
] = comp .left # type:ignore[assignment]
1105
1131
left_res , left_expl = self .visit (comp .left )
@@ -1119,7 +1145,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
1119
1145
and next_operand .target .id == left_res .id
1120
1146
):
1121
1147
next_operand .target .id = self .variable ()
1122
- self .variables_overwrite [
1148
+ self .variables_overwrite [self . scope ][
1123
1149
left_res .id
1124
1150
] = next_operand # type:ignore[assignment]
1125
1151
next_res , next_expl = self .visit (next_operand )
0 commit comments