7
7
import inspect
8
8
import re
9
9
import sys
10
+ from collections .abc import Sequence
10
11
from dataclasses import dataclass
11
12
from importlib .abc import Loader
12
13
from pathlib import Path
24
25
from .config import ExamplesConfig
25
26
from .find_examples import CodeExample
26
27
27
- __all__ = 'run_code' , 'InsertPrintStatements'
28
+ __all__ = 'run_code' , 'InsertPrintStatements' , 'IncludePrint'
28
29
29
30
parent_frame_id = 4 if sys .version_info >= (3 , 8 ) else 3
31
+ IncludePrint = Callable [[Path , inspect .FrameInfo , Sequence [Any ]], bool ]
30
32
31
33
32
34
def run_code (
@@ -37,6 +39,7 @@ def run_code(
37
39
config : ExamplesConfig ,
38
40
enable_print_mock : bool ,
39
41
print_callback : Callable [[str ], str ] | None ,
42
+ include_print : IncludePrint | None ,
40
43
module_globals : dict [str , Any ] | None ,
41
44
call : str | None ,
42
45
) -> tuple [InsertPrintStatements , dict [str , Any ]]:
@@ -49,6 +52,7 @@ def run_code(
49
52
config: The `ExamplesConfig` to use.
50
53
enable_print_mock: If True, mock the `print` function.
51
54
print_callback: If not None, a callback to call on `print`.
55
+ include_print: If not None, a function to call to determine if the print statement should be included.
52
56
module_globals: The extra globals to add before calling the module.
53
57
call: If not None, a (coroutine) function to call in the module.
54
58
@@ -63,7 +67,7 @@ def run_code(
63
67
module = importlib .util .module_from_spec (spec )
64
68
65
69
# does nothing if insert_print_statements is False
66
- insert_print = InsertPrintStatements (python_file , config , enable_print_mock , print_callback )
70
+ insert_print = InsertPrintStatements (python_file , config , enable_print_mock , print_callback , include_print )
67
71
68
72
if module_globals :
69
73
module .__dict__ .update (module_globals )
@@ -141,26 +145,40 @@ def not_print(*args):
141
145
142
146
143
147
class MockPrintFunction :
144
- def __init__ (self , file : Path ) -> None :
148
+ __slots__ = 'file' , 'statements' , 'include_print'
149
+
150
+ def __init__ (self , file : Path , include_print : IncludePrint | None ) -> None :
145
151
self .file = file
146
152
self .statements : list [PrintStatement ] = []
153
+ self .include_print = include_print
147
154
148
155
def __call__ (self , * args : Any , sep : str = ' ' , ** kwargs : Any ) -> None :
149
156
frame = inspect .stack ()[parent_frame_id ]
150
157
151
- if self .file . samefile (frame . filename ):
158
+ if self ._include_file (frame , args ):
152
159
# -1 to account for the line number being 1-indexed
153
160
s = PrintStatement (frame .lineno , sep , [Arg (arg ) for arg in args ])
154
161
self .statements .append (s )
155
162
163
+ def _include_file (self , frame : inspect .FrameInfo , args : Sequence [Any ]) -> bool :
164
+ if self .include_print :
165
+ return self .include_print (self .file , frame , args )
166
+ else :
167
+ return self .file .samefile (frame .filename )
168
+
156
169
157
170
class InsertPrintStatements :
158
171
def __init__ (
159
- self , python_path : Path , config : ExamplesConfig , enable : bool , print_callback : Callable [[str ], str ] | None
172
+ self ,
173
+ python_path : Path ,
174
+ config : ExamplesConfig ,
175
+ enable : bool ,
176
+ print_callback : Callable [[str ], str ] | None ,
177
+ include_print : IncludePrint | None ,
160
178
):
161
179
self .file = python_path
162
180
self .config = config
163
- self .print_func = MockPrintFunction (python_path ) if enable else None
181
+ self .print_func = MockPrintFunction (python_path , include_print ) if enable else None
164
182
self .print_callback = print_callback
165
183
self .patch = None
166
184
0 commit comments