Skip to content

Commit 8068283

Browse files
authored
add include_print config function (#57)
1 parent 9a18f65 commit 8068283

File tree

5 files changed

+50
-10
lines changed

5 files changed

+50
-10
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ include = ["/README.md", "/Makefile", "/pytest_examples", "/tests"]
99

1010
[project]
1111
name = "pytest-examples"
12-
version = "0.0.15"
12+
version = "0.0.16"
1313
description = "Pytest plugin for testing examples in docstrings and markdown files."
1414
authors = [
1515
{name = "Samuel Colvin", email = "[email protected]"},

pytest_examples/eval_example.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .config import DEFAULT_LINE_LENGTH, ExamplesConfig
1111
from .lint import FormatError, black_check, black_format, ruff_check, ruff_format
12-
from .run_code import InsertPrintStatements, run_code
12+
from .run_code import IncludePrint, InsertPrintStatements, run_code
1313

1414
if TYPE_CHECKING:
1515
from typing import Literal
@@ -29,6 +29,7 @@ def __init__(self, *, tmp_path: Path, pytest_request: pytest.FixtureRequest):
2929
self.to_update: list[CodeExample] = []
3030
self.config: ExamplesConfig = ExamplesConfig()
3131
self.print_callback: Callable[[str], str] | None = None
32+
self.include_print: IncludePrint | None = None
3233

3334
def set_config(
3435
self,
@@ -172,6 +173,7 @@ def _run(
172173
config=self.config,
173174
enable_print_mock=enable_print_mock,
174175
print_callback=self.print_callback,
176+
include_print=self.include_print,
175177
module_globals=module_globals,
176178
call=call,
177179
)

pytest_examples/run_code.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import inspect
88
import re
99
import sys
10+
from collections.abc import Sequence
1011
from dataclasses import dataclass
1112
from importlib.abc import Loader
1213
from pathlib import Path
@@ -24,9 +25,10 @@
2425
from .config import ExamplesConfig
2526
from .find_examples import CodeExample
2627

27-
__all__ = 'run_code', 'InsertPrintStatements'
28+
__all__ = 'run_code', 'InsertPrintStatements', 'IncludePrint'
2829

2930
parent_frame_id = 4 if sys.version_info >= (3, 8) else 3
31+
IncludePrint = Callable[[Path, inspect.FrameInfo, Sequence[Any]], bool]
3032

3133

3234
def run_code(
@@ -37,6 +39,7 @@ def run_code(
3739
config: ExamplesConfig,
3840
enable_print_mock: bool,
3941
print_callback: Callable[[str], str] | None,
42+
include_print: IncludePrint | None,
4043
module_globals: dict[str, Any] | None,
4144
call: str | None,
4245
) -> tuple[InsertPrintStatements, dict[str, Any]]:
@@ -49,6 +52,7 @@ def run_code(
4952
config: The `ExamplesConfig` to use.
5053
enable_print_mock: If True, mock the `print` function.
5154
print_callback: If not None, a callback to call on `print`.
55+
include_print: If not None, a function to call to determine if the print statement should be included.
5256
module_globals: The extra globals to add before calling the module.
5357
call: If not None, a (coroutine) function to call in the module.
5458
@@ -63,7 +67,7 @@ def run_code(
6367
module = importlib.util.module_from_spec(spec)
6468

6569
# does nothing if insert_print_statements is False
66-
insert_print = InsertPrintStatements(python_file, config, enable_print_mock, print_callback)
70+
insert_print = InsertPrintStatements(python_file, config, enable_print_mock, print_callback, include_print)
6771

6872
if module_globals:
6973
module.__dict__.update(module_globals)
@@ -141,26 +145,40 @@ def not_print(*args):
141145

142146

143147
class MockPrintFunction:
144-
def __init__(self, file: Path) -> None:
148+
__slots__ = 'file', 'statements', 'include_print'
149+
150+
def __init__(self, file: Path, include_print: IncludePrint | None) -> None:
145151
self.file = file
146152
self.statements: list[PrintStatement] = []
153+
self.include_print = include_print
147154

148155
def __call__(self, *args: Any, sep: str = ' ', **kwargs: Any) -> None:
149156
frame = inspect.stack()[parent_frame_id]
150157

151-
if self.file.samefile(frame.filename):
158+
if self._include_file(frame, args):
152159
# -1 to account for the line number being 1-indexed
153160
s = PrintStatement(frame.lineno, sep, [Arg(arg) for arg in args])
154161
self.statements.append(s)
155162

163+
def _include_file(self, frame: inspect.FrameInfo, args: Sequence[Any]) -> bool:
164+
if self.include_print:
165+
return self.include_print(self.file, frame, args)
166+
else:
167+
return self.file.samefile(frame.filename)
168+
156169

157170
class InsertPrintStatements:
158171
def __init__(
159-
self, python_path: Path, config: ExamplesConfig, enable: bool, print_callback: Callable[[str], str] | None
172+
self,
173+
python_path: Path,
174+
config: ExamplesConfig,
175+
enable: bool,
176+
print_callback: Callable[[str], str] | None,
177+
include_print: IncludePrint | None,
160178
):
161179
self.file = python_path
162180
self.config = config
163-
self.print_func = MockPrintFunction(python_path) if enable else None
181+
self.print_func = MockPrintFunction(python_path, include_print) if enable else None
164182
self.print_callback = print_callback
165183
self.patch = None
166184

tests/test_insert_print.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,22 @@ async def main():
432432

433433
module_dict = eval_example.run_print_check(example, call='main')
434434
assert module_dict['main_called']
435+
436+
437+
def test_custom_include_print(tmp_path, eval_example):
438+
# note this file is no written here as it's not required
439+
md_file = tmp_path / 'test.md'
440+
python_code = """
441+
print('yes')
442+
#> yes
443+
print('no')
444+
"""
445+
example = CodeExample.create(python_code, path=md_file)
446+
eval_example.set_config(line_length=30)
447+
448+
def custom_include_print(path, frame, args):
449+
return 'yes' in args
450+
451+
eval_example.include_print = custom_include_print
452+
453+
eval_example.run_print_check(example, call='main')

uv.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)