Skip to content

add include_print config function #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ include = ["/README.md", "/Makefile", "/pytest_examples", "/tests"]

[project]
name = "pytest-examples"
version = "0.0.15"
version = "0.0.16"
description = "Pytest plugin for testing examples in docstrings and markdown files."
authors = [
{name = "Samuel Colvin", email = "[email protected]"},
Expand Down
4 changes: 3 additions & 1 deletion pytest_examples/eval_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .config import DEFAULT_LINE_LENGTH, ExamplesConfig
from .lint import FormatError, black_check, black_format, ruff_check, ruff_format
from .run_code import InsertPrintStatements, run_code
from .run_code import IncludePrint, InsertPrintStatements, run_code

if TYPE_CHECKING:
from typing import Literal
Expand All @@ -29,6 +29,7 @@ def __init__(self, *, tmp_path: Path, pytest_request: pytest.FixtureRequest):
self.to_update: list[CodeExample] = []
self.config: ExamplesConfig = ExamplesConfig()
self.print_callback: Callable[[str], str] | None = None
self.include_print: IncludePrint | None = None

def set_config(
self,
Expand Down Expand Up @@ -172,6 +173,7 @@ def _run(
config=self.config,
enable_print_mock=enable_print_mock,
print_callback=self.print_callback,
include_print=self.include_print,
module_globals=module_globals,
call=call,
)
Expand Down
30 changes: 24 additions & 6 deletions pytest_examples/run_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import inspect
import re
import sys
from collections.abc import Sequence
from dataclasses import dataclass
from importlib.abc import Loader
from pathlib import Path
Expand All @@ -24,9 +25,10 @@
from .config import ExamplesConfig
from .find_examples import CodeExample

__all__ = 'run_code', 'InsertPrintStatements'
__all__ = 'run_code', 'InsertPrintStatements', 'IncludePrint'

parent_frame_id = 4 if sys.version_info >= (3, 8) else 3
IncludePrint = Callable[[Path, inspect.FrameInfo, Sequence[Any]], bool]


def run_code(
Expand All @@ -37,6 +39,7 @@ def run_code(
config: ExamplesConfig,
enable_print_mock: bool,
print_callback: Callable[[str], str] | None,
include_print: IncludePrint | None,
module_globals: dict[str, Any] | None,
call: str | None,
) -> tuple[InsertPrintStatements, dict[str, Any]]:
Expand All @@ -49,6 +52,7 @@ def run_code(
config: The `ExamplesConfig` to use.
enable_print_mock: If True, mock the `print` function.
print_callback: If not None, a callback to call on `print`.
include_print: If not None, a function to call to determine if the print statement should be included.
module_globals: The extra globals to add before calling the module.
call: If not None, a (coroutine) function to call in the module.

Expand All @@ -63,7 +67,7 @@ def run_code(
module = importlib.util.module_from_spec(spec)

# does nothing if insert_print_statements is False
insert_print = InsertPrintStatements(python_file, config, enable_print_mock, print_callback)
insert_print = InsertPrintStatements(python_file, config, enable_print_mock, print_callback, include_print)

if module_globals:
module.__dict__.update(module_globals)
Expand Down Expand Up @@ -141,26 +145,40 @@ def not_print(*args):


class MockPrintFunction:
def __init__(self, file: Path) -> None:
__slots__ = 'file', 'statements', 'include_print'

def __init__(self, file: Path, include_print: IncludePrint | None) -> None:
self.file = file
self.statements: list[PrintStatement] = []
self.include_print = include_print

def __call__(self, *args: Any, sep: str = ' ', **kwargs: Any) -> None:
frame = inspect.stack()[parent_frame_id]

if self.file.samefile(frame.filename):
if self._include_file(frame, args):
# -1 to account for the line number being 1-indexed
s = PrintStatement(frame.lineno, sep, [Arg(arg) for arg in args])
self.statements.append(s)

def _include_file(self, frame: inspect.FrameInfo, args: Sequence[Any]) -> bool:
if self.include_print:
return self.include_print(self.file, frame, args)
else:
return self.file.samefile(frame.filename)


class InsertPrintStatements:
def __init__(
self, python_path: Path, config: ExamplesConfig, enable: bool, print_callback: Callable[[str], str] | None
self,
python_path: Path,
config: ExamplesConfig,
enable: bool,
print_callback: Callable[[str], str] | None,
include_print: IncludePrint | None,
):
self.file = python_path
self.config = config
self.print_func = MockPrintFunction(python_path) if enable else None
self.print_func = MockPrintFunction(python_path, include_print) if enable else None
self.print_callback = print_callback
self.patch = None

Expand Down
19 changes: 19 additions & 0 deletions tests/test_insert_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,22 @@ async def main():

module_dict = eval_example.run_print_check(example, call='main')
assert module_dict['main_called']


def test_custom_include_print(tmp_path, eval_example):
# note this file is no written here as it's not required
md_file = tmp_path / 'test.md'
python_code = """
print('yes')
#> yes
print('no')
"""
example = CodeExample.create(python_code, path=md_file)
eval_example.set_config(line_length=30)

def custom_include_print(path, frame, args):
return 'yes' in args

eval_example.include_print = custom_include_print

eval_example.run_print_check(example, call='main')
5 changes: 3 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.