Skip to content

Commit c11dd95

Browse files
committed
ENH: lazy_xp_function namespaces support
1 parent 75e5166 commit c11dd95

File tree

2 files changed

+110
-28
lines changed

2 files changed

+110
-28
lines changed

Diff for: src/array_api_extra/testing.py

+73-28
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,8 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
5757
"""
5858
Tag a function to be tested on lazy backends.
5959
60-
Tag a function, which must be imported in the test module globals, so that when any
61-
tests defined in the same module are executed with ``xp=jax.numpy`` the function is
62-
replaced with a jitted version of itself, and when it is executed with
60+
Tag a function so that when any tests are executed with ``xp=jax.numpy`` the
61+
function is replaced with a jitted version of itself, and when it is executed with
6362
``xp=dask.array`` the function will raise if it attempts to materialize the graph.
6463
This will be later expanded to provide test coverage for other lazy backends.
6564
@@ -121,19 +120,59 @@ def test_myfunc(xp):
121120
122121
Notes
123122
-----
124-
A test function can circumvent this monkey-patching system by calling `func` as an
125-
attribute of the original module. You need to sanitize your code to make sure this
126-
does not happen.
123+
In order for this tag to be effective, the test function must be imported into the
124+
test module globals without namespace; alternatively its namespace must be declared
125+
in a ``lazy_xp_modules`` list in the test module globals.
127126
128-
Example::
127+
Example 1::
129128
130-
import mymodule from mymodule import myfunc
129+
from mymodule import myfunc
131130
132131
lazy_xp_function(myfunc)
133132
134133
def test_myfunc(xp):
135-
a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
136-
mymodule.myfunc(a) # This is not
134+
x = myfunc(xp.asarray([1, 2]))
135+
136+
Example 2::
137+
138+
import mymodule
139+
140+
lazy_xp_modules = [mymodule]
141+
lazy_xp_function(mymodule.myfunc)
142+
143+
def test_myfunc(xp):
144+
x = mymodule.myfunc(xp.asarray([1, 2]))
145+
146+
A test function can circumvent this monkey-patching system by using a namespace
147+
outside of the two above patterns. You need to sanitize your code to make sure this
148+
only happens intentionally.
149+
150+
Example 1::
151+
152+
import mymodule
153+
from mymodule import myfunc
154+
155+
lazy_xp_function(myfunc)
156+
157+
def test_myfunc(xp):
158+
a = xp.asarray([1, 2])
159+
b = myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
160+
c = mymodule.myfunc(a) # This is not
161+
162+
Example 2::
163+
164+
import mymodule
165+
166+
class naked:
167+
myfunc = mymodule.myfunc
168+
169+
lazy_xp_modules = [mymodule]
170+
lazy_xp_function(mymodule.myfunc)
171+
172+
def test_myfunc(xp):
173+
a = xp.asarray([1, 2])
174+
b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
175+
c = naked.myfunc(a) # This is not
137176
"""
138177
tags = {
139178
"allow_dask_compute": allow_dask_compute,
@@ -154,11 +193,13 @@ def patch_lazy_xp_functions(
154193
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
155194
156195
If ``xp==jax.numpy``, search for all functions which have been tagged with
157-
:func:`lazy_xp_function` in the globals of the module that defines the current test
196+
:func:`lazy_xp_function` in the globals of the module that defines the current test,
197+
as well as in the ``lazy_xp_modules`` list in the globals of the same module,
158198
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
159199
160200
If ``xp==dask.array``, wrap the functions with a decorator that disables
161-
``compute()`` and ``persist()``.
201+
``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
202+
eagerly.
162203
163204
This function should be typically called by your library's `xp` fixture that runs
164205
tests on multiple backends::
@@ -184,29 +225,33 @@ def xp(request, monkeypatch):
184225
lazy_xp_function : Tag a function to be tested on lazy backends.
185226
pytest.FixtureRequest : `request` test function parameter.
186227
"""
187-
globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit]
188-
189-
def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: # type: ignore[no-any-explicit]
190-
for name, func in globals_.items():
191-
tags: dict[str, Any] | None = None # type: ignore[no-any-explicit]
192-
with contextlib.suppress(AttributeError):
193-
tags = func._lazy_xp_function # pylint: disable=protected-access
194-
if tags is None:
195-
with contextlib.suppress(KeyError, TypeError):
196-
tags = _ufuncs_tags[func]
197-
if tags is not None:
198-
yield name, func, tags
228+
mod = cast(ModuleType, request.module)
229+
mods = [mod, *cast(list[ModuleType], getattr(mod, "lazy_xp_modules", []))]
230+
231+
def iter_tagged() -> ( # type: ignore[no-any-explicit]
232+
Iterator[tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]]
233+
):
234+
for mod in mods:
235+
for name, func in mod.__dict__.items():
236+
tags: dict[str, Any] | None = None # type: ignore[no-any-explicit]
237+
with contextlib.suppress(AttributeError):
238+
tags = func._lazy_xp_function # pylint: disable=protected-access
239+
if tags is None:
240+
with contextlib.suppress(KeyError, TypeError):
241+
tags = _ufuncs_tags[func]
242+
if tags is not None:
243+
yield mod, name, func, tags
199244

200245
if is_dask_namespace(xp):
201-
for name, func, tags in iter_tagged():
246+
for mod, name, func, tags in iter_tagged():
202247
n = tags["allow_dask_compute"]
203248
wrapped = _dask_wrap(func, n)
204-
monkeypatch.setitem(globals_, name, wrapped)
249+
monkeypatch.setattr(mod, name, wrapped)
205250

206251
elif is_jax_namespace(xp):
207252
import jax
208253

209-
for name, func, tags in iter_tagged():
254+
for mod, name, func, tags in iter_tagged():
210255
if tags["jax_jit"]:
211256
# suppress unused-ignore to run mypy in -e lint as well as -e dev
212257
wrapped = cast( # type: ignore[no-any-explicit]
@@ -217,7 +262,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
217262
static_argnames=tags["static_argnames"],
218263
),
219264
)
220-
monkeypatch.setitem(globals_, name, wrapped)
265+
monkeypatch.setattr(mod, name, wrapped)
221266

222267

223268
class CountingDaskScheduler(SchedulerGetCallable):

Diff for: tests/test_testing.py

+37
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def non_materializable(x: Array) -> Array:
108108
and it will trigger an expensive computation in dask.
109109
"""
110110
xp = array_namespace(x)
111+
# Crashes inside jax.jit
111112
# On dask, this triggers two computations of the whole graph
112113
if xp.any(x < 0.0) or xp.any(x > 10.0):
113114
msg = "Values must be in the [0, 10] range"
@@ -261,3 +262,39 @@ def test_lazy_xp_function_eagerly_raises(da: ModuleType):
261262
x = da.arange(3)
262263
with pytest.raises(ValueError, match="Hello world"):
263264
dask_raises(x)
265+
266+
267+
class Wrapped:
268+
def f(x: Array) -> Array: # noqa: N805 # pyright: ignore[reportSelfClsParameterName]
269+
xp = array_namespace(x)
270+
# Crash in jax.jit and trigger compute() on dask
271+
if not xp.all(x):
272+
msg = "Values must be non-zero"
273+
raise ValueError(msg)
274+
return x
275+
276+
277+
class Naked:
278+
f = Wrapped.f # pyright: ignore[reportUnannotatedClassAttribute]
279+
280+
281+
lazy_xp_function(Wrapped.f)
282+
lazy_xp_modules = [Wrapped]
283+
284+
285+
def test_lazy_xp_modules(xp: ModuleType, library: Backend):
286+
x = xp.asarray([1.0, 2.0])
287+
y = Naked.f(x)
288+
xp_assert_equal(y, x)
289+
290+
if library is Backend.JAX:
291+
with pytest.raises(
292+
TypeError, match="Attempted boolean conversion of traced array"
293+
):
294+
Wrapped.f(x)
295+
elif library is Backend.DASK:
296+
with pytest.raises(AssertionError, match=r"dask\.compute"):
297+
Wrapped.f(x)
298+
else:
299+
y = Wrapped.f(x)
300+
xp_assert_equal(y, x)

0 commit comments

Comments
 (0)