Skip to content

Commit aeb45fd

Browse files
authored
Improve logic for finding print statements (#58)
* fix logic for finding print statement * fix test
1 parent 8068283 commit aeb45fd

File tree

4 files changed

+54
-30
lines changed

4 files changed

+54
-30
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ include = ["/README.md", "/Makefile", "/pytest_examples", "/tests"]
99

1010
[project]
1111
name = "pytest-examples"
12-
version = "0.0.16"
12+
version = "0.0.17"
1313
description = "Pytest plugin for testing examples in docstrings and markdown files."
1414
authors = [
1515
{name = "Samuel Colvin", email = "[email protected]"},

pytest_examples/run_code.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def __call__(self, *args: Any, sep: str = ' ', **kwargs: Any) -> None:
156156
frame = inspect.stack()[parent_frame_id]
157157

158158
if self._include_file(frame, args):
159-
# -1 to account for the line number being 1-indexed
160-
s = PrintStatement(frame.lineno, sep, [Arg(arg) for arg in args])
159+
lineno = self._find_line_number(frame)
160+
s = PrintStatement(lineno, sep, [Arg(arg) for arg in args])
161161
self.statements.append(s)
162162

163163
def _include_file(self, frame: inspect.FrameInfo, args: Sequence[Any]) -> bool:
@@ -166,6 +166,17 @@ def _include_file(self, frame: inspect.FrameInfo, args: Sequence[Any]) -> bool:
166166
else:
167167
return self.file.samefile(frame.filename)
168168

169+
def _find_line_number(self, inspect_frame: inspect.FrameInfo) -> int:
170+
"""Find the line number of the print statement in the file that is being executed."""
171+
frame = inspect_frame.frame
172+
while True:
173+
if self.file.samefile(frame.f_code.co_filename):
174+
return frame.f_lineno
175+
elif frame.f_back:
176+
frame = frame.f_back
177+
else:
178+
raise RuntimeError(f'Could not find line number of print statement at {inspect_frame}')
179+
169180

170181
class InsertPrintStatements:
171182
def __init__(
@@ -256,18 +267,6 @@ def _insert_print_args(
256267
triple_quotes_prefix_re = re.compile('^ *(?:"{3}|\'{3})', re.MULTILINE)
257268

258269

259-
def find_print_line(lines: list[str], line_no: int) -> int:
260-
"""For 3.7 we have to reverse through lines to find the print statement lint."""
261-
return line_no
262-
263-
for back in range(100):
264-
new_line_no = line_no - back
265-
m = re.search(r'^ *print\(', lines[new_line_no - 1])
266-
if m:
267-
return new_line_no
268-
return line_no
269-
270-
271270
def remove_old_print(lines: list[str], line_index: int) -> None:
272271
"""Remove the old print statement."""
273272
try:
@@ -294,12 +293,12 @@ def remove_old_print(lines: list[str], line_index: int) -> None:
294293
def find_print_location(example: CodeExample, line_no: int) -> tuple[int, int]:
295294
"""Find the line and column of the print statement.
296295
297-
:param example: the `CodeExample`
298-
:param line_no: The line number on which the print statement starts (or approx on 3.7)
299-
:return: tuple if `(line, column)` of the print statement
300-
"""
301-
# For 3.7 we have to reverse through lines to find the print statement lint
296+
Args:
297+
example: the `CodeExample`
298+
line_no: The line number on which the print statement starts or approx
302299
300+
Return: tuple if `(line, column)` of the print statement
301+
"""
303302
m = ast.parse(example.source, filename=example.path.name)
304303
return find_print(m, line_no) or (line_no, 0)
305304

tests/test_insert_print.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations as _annotations
22

3+
import sys
4+
35
import pytest
46
from _pytest.outcomes import Failed
57

@@ -397,8 +399,6 @@ def main():
397399

398400

399401
def test_run_main_print(tmp_path, eval_example):
400-
# note this file is no written here as it's not required
401-
md_file = tmp_path / 'test.md'
402402
python_code = """
403403
main_called = False
404404
@@ -408,16 +408,15 @@ def main():
408408
print(1, 2, 3)
409409
#> 1 2 3
410410
"""
411-
example = CodeExample.create(python_code, path=md_file)
411+
# note this file is no written here as it's not required
412+
example = CodeExample.create(python_code, path=tmp_path / 'test.md')
412413
eval_example.set_config(line_length=30)
413414

414415
module_dict = eval_example.run_print_check(example, call='main')
415416
assert module_dict['main_called']
416417

417418

418419
def test_run_main_print_async(tmp_path, eval_example):
419-
# note this file is no written here as it's not required
420-
md_file = tmp_path / 'test.md'
421420
python_code = """
422421
main_called = False
423422
@@ -427,22 +426,22 @@ async def main():
427426
print(1, 2, 3)
428427
#> 1 2 3
429428
"""
430-
example = CodeExample.create(python_code, path=md_file)
429+
# note this file is no written here as it's not required
430+
example = CodeExample.create(python_code, path=tmp_path / 'test.md')
431431
eval_example.set_config(line_length=30)
432432

433433
module_dict = eval_example.run_print_check(example, call='main')
434434
assert module_dict['main_called']
435435

436436

437437
def test_custom_include_print(tmp_path, eval_example):
438-
# note this file is no written here as it's not required
439-
md_file = tmp_path / 'test.md'
440438
python_code = """
441439
print('yes')
442440
#> yes
443441
print('no')
444442
"""
445-
example = CodeExample.create(python_code, path=md_file)
443+
# note this file is no written here as it's not required
444+
example = CodeExample.create(python_code, path=tmp_path / 'test.md')
446445
eval_example.set_config(line_length=30)
447446

448447
def custom_include_print(path, frame, args):
@@ -451,3 +450,29 @@ def custom_include_print(path, frame, args):
451450
eval_example.include_print = custom_include_print
452451

453452
eval_example.run_print_check(example, call='main')
453+
454+
455+
def test_print_different_file(tmp_path, eval_example):
456+
other_file = tmp_path / 'other.py'
457+
other_code = """
458+
def does_print():
459+
print('hello')
460+
"""
461+
other_file.write_text(other_code)
462+
sys.path.append(str(tmp_path))
463+
python_code = """
464+
import other
465+
466+
other.does_print()
467+
#> hello
468+
"""
469+
example = CodeExample.create(python_code, path=tmp_path / 'test.md')
470+
471+
eval_example.include_print = lambda p, f, a: True
472+
473+
eval_example.run_print_check(example, call='main')
474+
475+
del sys.modules['other']
476+
other_file.write_text(('\n' * 30) + other_code)
477+
478+
eval_example.run_print_check(example, call='main')

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)