@@ -57,9 +57,8 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
57
57
"""
58
58
Tag a function to be tested on lazy backends.
59
59
60
- Tag a function, which must be imported in the test module globals, so that when any
61
- tests defined in the same module are executed with ``xp=jax.numpy`` the function is
62
- replaced with a jitted version of itself, and when it is executed with
60
+ Tag a function so that when any tests are executed with ``xp=jax.numpy`` the
61
+ function is replaced with a jitted version of itself, and when it is executed with
63
62
``xp=dask.array`` the function will raise if it attempts to materialize the graph.
64
63
This will be later expanded to provide test coverage for other lazy backends.
65
64
@@ -121,19 +120,59 @@ def test_myfunc(xp):
121
120
122
121
Notes
123
122
-----
124
- A test function can circumvent this monkey-patching system by calling `func` as an
125
- attribute of the original module. You need to sanitize your code to make sure this
126
- does not happen .
123
+ In order for this tag to be effective, the test function must be imported into the
124
+ test module globals without namespace; alternatively its namespace must be declared
125
+ in a ``lazy_xp_modules`` list in the test module globals .
127
126
128
- Example::
127
+ Example 1 ::
129
128
130
- import mymodule from mymodule import myfunc
129
+ from mymodule import myfunc
131
130
132
131
lazy_xp_function(myfunc)
133
132
134
133
def test_myfunc(xp):
135
- a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
136
- mymodule.myfunc(a) # This is not
134
+ x = myfunc(xp.asarray([1, 2]))
135
+
136
+ Example 2::
137
+
138
+ import mymodule
139
+
140
+ lazy_xp_modules = [mymodule]
141
+ lazy_xp_function(mymodule.myfunc)
142
+
143
+ def test_myfunc(xp):
144
+ x = mymodule.myfunc(xp.asarray([1, 2]))
145
+
146
+ A test function can circumvent this monkey-patching system by using a namespace
147
+ outside of the two above patterns. You need to sanitize your code to make sure this
148
+ only happens intentionally.
149
+
150
+ Example 1::
151
+
152
+ import mymodule
153
+ from mymodule import myfunc
154
+
155
+ lazy_xp_function(myfunc)
156
+
157
+ def test_myfunc(xp):
158
+ a = xp.asarray([1, 2])
159
+ b = myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
160
+ c = mymodule.myfunc(a) # This is not
161
+
162
+ Example 2::
163
+
164
+ import mymodule
165
+
166
+ class naked:
167
+ myfunc = mymodule.myfunc
168
+
169
+ lazy_xp_modules = [mymodule]
170
+ lazy_xp_function(mymodule.myfunc)
171
+
172
+ def test_myfunc(xp):
173
+ a = xp.asarray([1, 2])
174
+ b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
175
+ c = naked.myfunc(a) # This is not
137
176
"""
138
177
tags = {
139
178
"allow_dask_compute" : allow_dask_compute ,
@@ -154,11 +193,13 @@ def patch_lazy_xp_functions(
154
193
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
155
194
156
195
If ``xp==jax.numpy``, search for all functions which have been tagged with
157
- :func:`lazy_xp_function` in the globals of the module that defines the current test
196
+ :func:`lazy_xp_function` in the globals of the module that defines the current test,
197
+ as well as in the ``lazy_xp_modules`` list in the globals of the same module,
158
198
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
159
199
160
200
If ``xp==dask.array``, wrap the functions with a decorator that disables
161
- ``compute()`` and ``persist()``.
201
+ ``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
202
+ eagerly.
162
203
163
204
This function should be typically called by your library's `xp` fixture that runs
164
205
tests on multiple backends::
@@ -184,29 +225,33 @@ def xp(request, monkeypatch):
184
225
lazy_xp_function : Tag a function to be tested on lazy backends.
185
226
pytest.FixtureRequest : `request` test function parameter.
186
227
"""
187
- globals_ = cast ("dict[str, Any]" , request .module .__dict__ ) # type: ignore[no-any-explicit]
188
-
189
- def iter_tagged () -> Iterator [tuple [str , Callable [..., Any ], dict [str , Any ]]]: # type: ignore[no-any-explicit]
190
- for name , func in globals_ .items ():
191
- tags : dict [str , Any ] | None = None # type: ignore[no-any-explicit]
192
- with contextlib .suppress (AttributeError ):
193
- tags = func ._lazy_xp_function # pylint: disable=protected-access
194
- if tags is None :
195
- with contextlib .suppress (KeyError , TypeError ):
196
- tags = _ufuncs_tags [func ]
197
- if tags is not None :
198
- yield name , func , tags
228
+ mod = cast (ModuleType , request .module )
229
+ mods = [mod , * cast (list [ModuleType ], getattr (mod , "lazy_xp_modules" , []))]
230
+
231
+ def iter_tagged () -> ( # type: ignore[no-any-explicit]
232
+ Iterator [tuple [ModuleType , str , Callable [..., Any ], dict [str , Any ]]]
233
+ ):
234
+ for mod in mods :
235
+ for name , func in mod .__dict__ .items ():
236
+ tags : dict [str , Any ] | None = None # type: ignore[no-any-explicit]
237
+ with contextlib .suppress (AttributeError ):
238
+ tags = func ._lazy_xp_function # pylint: disable=protected-access
239
+ if tags is None :
240
+ with contextlib .suppress (KeyError , TypeError ):
241
+ tags = _ufuncs_tags [func ]
242
+ if tags is not None :
243
+ yield mod , name , func , tags
199
244
200
245
if is_dask_namespace (xp ):
201
- for name , func , tags in iter_tagged ():
246
+ for mod , name , func , tags in iter_tagged ():
202
247
n = tags ["allow_dask_compute" ]
203
248
wrapped = _dask_wrap (func , n )
204
- monkeypatch .setitem ( globals_ , name , wrapped )
249
+ monkeypatch .setattr ( mod , name , wrapped )
205
250
206
251
elif is_jax_namespace (xp ):
207
252
import jax
208
253
209
- for name , func , tags in iter_tagged ():
254
+ for mod , name , func , tags in iter_tagged ():
210
255
if tags ["jax_jit" ]:
211
256
# suppress unused-ignore to run mypy in -e lint as well as -e dev
212
257
wrapped = cast ( # type: ignore[no-any-explicit]
@@ -217,7 +262,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
217
262
static_argnames = tags ["static_argnames" ],
218
263
),
219
264
)
220
- monkeypatch .setitem ( globals_ , name , wrapped )
265
+ monkeypatch .setattr ( mod , name , wrapped )
221
266
222
267
223
268
class CountingDaskScheduler (SchedulerGetCallable ):
0 commit comments