Skip to content

Commit 946634c

Browse files
authored
Merge pull request #11419 from nicoddemus/backport-11414-to-7.4.x
[7.4.x] Fix assert rewriting with assignment expressions (#11414)
2 parents d849a3e + 721a088 commit 946634c

File tree

4 files changed

+64
-14
lines changed

4 files changed

+64
-14
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ Maho
231231
Maik Figura
232232
Mandeep Bhutani
233233
Manuel Krebber
234+
Marc Mueller
234235
Marc Schlaich
235236
Marcelo Duarte Trevisani
236237
Marcin Bachry

changelog/11239.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed ``:=`` in asserts impacting unrelated test cases.

src/_pytest/assertion/rewrite.py

+40-14
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414
import tokenize
1515
import types
16+
from collections import defaultdict
1617
from pathlib import Path
1718
from pathlib import PurePath
1819
from typing import Callable
@@ -56,13 +57,20 @@
5657
astNum = ast.Num
5758

5859

60+
class Sentinel:
61+
pass
62+
63+
5964
assertstate_key = StashKey["AssertionState"]()
6065

6166
# pytest caches rewritten pycs in pycache dirs
6267
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
6368
PYC_EXT = ".py" + (__debug__ and "c" or "o")
6469
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
6570

71+
# Special marker that denotes we have just left a scope definition
72+
_SCOPE_END_MARKER = Sentinel()
73+
6674

6775
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
6876
"""PEP302/PEP451 import hook which rewrites asserts."""
@@ -645,6 +653,8 @@ class AssertionRewriter(ast.NodeVisitor):
645653
.push_format_context() and .pop_format_context() which allows
646654
to build another %-formatted string while already building one.
647655
656+
:scope: A tuple containing the current scope used for variables_overwrite.
657+
648658
:variables_overwrite: A dict filled with references to variables
649659
that change value within an assert. This happens when a variable is
650660
reassigned with the walrus operator
@@ -666,7 +676,10 @@ def __init__(
666676
else:
667677
self.enable_assertion_pass_hook = False
668678
self.source = source
669-
self.variables_overwrite: Dict[str, str] = {}
679+
self.scope: tuple[ast.AST, ...] = ()
680+
self.variables_overwrite: defaultdict[
681+
tuple[ast.AST, ...], Dict[str, str]
682+
] = defaultdict(dict)
670683

671684
def run(self, mod: ast.Module) -> None:
672685
"""Find all assert statements in *mod* and rewrite them."""
@@ -732,9 +745,17 @@ def run(self, mod: ast.Module) -> None:
732745
mod.body[pos:pos] = imports
733746

734747
# Collect asserts.
735-
nodes: List[ast.AST] = [mod]
748+
self.scope = (mod,)
749+
nodes: List[Union[ast.AST, Sentinel]] = [mod]
736750
while nodes:
737751
node = nodes.pop()
752+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
753+
self.scope = tuple((*self.scope, node))
754+
nodes.append(_SCOPE_END_MARKER)
755+
if node == _SCOPE_END_MARKER:
756+
self.scope = self.scope[:-1]
757+
continue
758+
assert isinstance(node, ast.AST)
738759
for name, field in ast.iter_fields(node):
739760
if isinstance(field, list):
740761
new: List[ast.AST] = []
@@ -1005,7 +1026,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
10051026
]
10061027
):
10071028
pytest_temp = self.variable()
1008-
self.variables_overwrite[
1029+
self.variables_overwrite[self.scope][
10091030
v.left.target.id
10101031
] = v.left # type:ignore[assignment]
10111032
v.left.target.id = pytest_temp
@@ -1048,17 +1069,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
10481069
new_args = []
10491070
new_kwargs = []
10501071
for arg in call.args:
1051-
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
1052-
arg = self.variables_overwrite[arg.id] # type:ignore[assignment]
1072+
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
1073+
self.scope, {}
1074+
):
1075+
arg = self.variables_overwrite[self.scope][
1076+
arg.id
1077+
] # type:ignore[assignment]
10531078
res, expl = self.visit(arg)
10541079
arg_expls.append(expl)
10551080
new_args.append(res)
10561081
for keyword in call.keywords:
1057-
if (
1058-
isinstance(keyword.value, ast.Name)
1059-
and keyword.value.id in self.variables_overwrite
1060-
):
1061-
keyword.value = self.variables_overwrite[
1082+
if isinstance(
1083+
keyword.value, ast.Name
1084+
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
1085+
keyword.value = self.variables_overwrite[self.scope][
10621086
keyword.value.id
10631087
] # type:ignore[assignment]
10641088
res, expl = self.visit(keyword.value)
@@ -1094,12 +1118,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
10941118
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
10951119
self.push_format_context()
10961120
# We first check if we have overwritten a variable in the previous assert
1097-
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
1098-
comp.left = self.variables_overwrite[
1121+
if isinstance(
1122+
comp.left, ast.Name
1123+
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
1124+
comp.left = self.variables_overwrite[self.scope][
10991125
comp.left.id
11001126
] # type:ignore[assignment]
11011127
if isinstance(comp.left, namedExpr):
1102-
self.variables_overwrite[
1128+
self.variables_overwrite[self.scope][
11031129
comp.left.target.id
11041130
] = comp.left # type:ignore[assignment]
11051131
left_res, left_expl = self.visit(comp.left)
@@ -1119,7 +1145,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
11191145
and next_operand.target.id == left_res.id
11201146
):
11211147
next_operand.target.id = self.variable()
1122-
self.variables_overwrite[
1148+
self.variables_overwrite[self.scope][
11231149
left_res.id
11241150
] = next_operand # type:ignore[assignment]
11251151
next_res, next_expl = self.visit(next_operand)

testing/test_assertrewrite.py

+22
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,28 @@ def test_gt():
15311531
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])
15321532

15331533

1534+
class TestIssue11239:
1535+
@pytest.mark.skipif(sys.version_info[:2] <= (3, 7), reason="Only Python 3.8+")
1536+
def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None:
1537+
"""Regression for (#11239)
1538+
1539+
Walrus operator rewriting would leak to separate test cases if they used the same variables.
1540+
"""
1541+
pytester.makepyfile(
1542+
"""
1543+
def test_1():
1544+
state = {"x": 2}.get("x")
1545+
assert state is not None
1546+
1547+
def test_2():
1548+
db = {"x": 2}
1549+
assert (state := db.get("x")) is not None
1550+
"""
1551+
)
1552+
result = pytester.runpytest()
1553+
assert result.ret == 0
1554+
1555+
15341556
@pytest.mark.skipif(
15351557
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
15361558
)

0 commit comments

Comments
 (0)