Skip to content

Commit cd84a8b

Browse files
committed
ENH Test tools for jax.jit and dask
1 parent b5bf75c commit cd84a8b

File tree

10 files changed

+446
-17
lines changed

10 files changed

+446
-17
lines changed

Diff for: docs/api-reference.md

+12
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,15 @@
1818
setdiff1d
1919
sinc
2020
```
21+
22+
## Testing utilities
23+
24+
```{eval-rst}
25+
.. currentmodule:: array_api_extra.testing
26+
.. autosummary::
27+
:nosignatures:
28+
:toctree: generated
29+
30+
lazy_xp_function
31+
patch_lazy_xp_functions
32+
```

Diff for: pixi.lock

+13-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+11-2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ furo = ">=2023.08.17"
9898
myst-parser = ">=0.13"
9999
sphinx-copybutton = "*"
100100
sphinx-autodoc-typehints = "*"
101+
# Needed to import parsed modules with autodoc
102+
pytest = "*"
101103

102104
[tool.pixi.feature.docs.tasks]
103105
docs = { cmd = "sphinx-build . build/", cwd = "docs" }
@@ -180,8 +182,10 @@ markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific b
180182

181183
[tool.coverage]
182184
run.source = ["array_api_extra"]
183-
report.exclude_also = ['\.\.\.']
184-
185+
report.exclude_also = [
186+
'\.\.\.',
187+
'if TYPE_CHECKING:',
188+
]
185189

186190
# mypy
187191

@@ -221,6 +225,8 @@ reportMissingImports = false
221225
reportMissingTypeStubs = false
222226
# false positives for input validation
223227
reportUnreachable = false
228+
# ruff handles this
229+
reportUnusedParameter = false
224230

225231
executionEnvironments = [
226232
{ root = "tests", reportPrivateUsage = false },
@@ -282,7 +288,10 @@ messages_control.disable = [
282288
"design", # ignore heavily opinionated design checks
283289
"fixme", # allow FIXME comments
284290
"line-too-long", # ruff handles this
291+
"unused-argument", # ruff handles this
285292
"missing-function-docstring", # numpydoc handles this
293+
"import-error", # mypy handles this
294+
"import-outside-toplevel", # optional dependencies
286295
]
287296

288297

Diff for: src/array_api_extra/_lib/_testing.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Testing utilities.
33
44
Note that this is private API; don't expect it to be stable.
5+
See also ..testing for public testing utilities.
56
"""
67

78
import math

Diff for: src/array_api_extra/testing.py

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
"""
2+
Public testing utilities.
3+
4+
See also _lib._testing for additional private testing utilities.
5+
"""
6+
7+
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
8+
from __future__ import annotations
9+
10+
from collections.abc import Callable, Iterable, Sequence
11+
from functools import wraps
12+
from types import ModuleType
13+
from typing import TYPE_CHECKING, Any, TypeVar, cast
14+
15+
import pytest
16+
17+
from array_api_extra._lib._utils._compat import is_dask_namespace, is_jax_namespace
18+
19+
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
20+
21+
if TYPE_CHECKING:
22+
# TODO move outside TYPE_CHECKING
23+
# depends on scikit-learn abandoning Python 3.9
24+
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
25+
from typing import ParamSpec
26+
27+
P = ParamSpec("P")
28+
else:
29+
# Sphinx hacks
30+
class P: # pylint: disable=missing-class-docstring
31+
args: tuple
32+
kwargs: dict
33+
34+
35+
T = TypeVar("T")
36+
37+
38+
def lazy_xp_function( # type: ignore[no-any-explicit]
39+
func: Callable[..., Any],
40+
*,
41+
disable_dask_compute: bool = True,
42+
jax_jit: bool = True,
43+
static_argnums: int | Sequence[int] | None = None,
44+
static_argnames: str | Iterable[str] | None = None,
45+
) -> None: # numpydoc ignore=GL07
46+
"""
47+
Tag a function to be tested on lazy backends.
48+
49+
Tag a function, which must be imported in the test module globals, so that when any
50+
tests defined in the same module are executed with `xp=jax.numpy` the function is
51+
replaced with a jitted version of itself, and when it is executed with
52+
`xp=dask.array` the function will raise if it attempts to materialize the graph.
53+
This will be later expanded to provide test coverage for other lazy backends.
54+
55+
In order for the tag to be effective, the test or a fixture must call
56+
:func:`patch_lazy_xp_functions`.
57+
58+
Parameters
59+
----------
60+
func : callable
61+
Function to be tested.
62+
disable_dask_compute : bool, optional
63+
Set to True to raise an error if `func` attempts to call `dask.compute()` or
64+
`dask.persist()` after calling the calling the :func:`patch_lazy_xp_functions`
65+
test helper with `xp=dask.array`. This is typically inadvertently triggered by
66+
`bool()`, `float()`, or `np.asarray()`. Set to False to allow these calls,
67+
knowing that they are going to be extremely detrimental for performance.
68+
Default: True.
69+
jax_jit : bool, optional
70+
Set to True to replace `func` with `jax.jit(func)` after calling the
71+
:func:`patch_lazy_xp_functions` test helper with `xp=jax.numpy`. Set to False if
72+
`func` is only compatible with eager (non-jitted) JAX. Default: True.
73+
static_argnums : int | Sequence[int], optional
74+
Passed to jax.jit.
75+
Positional arguments to treat as static (compile-time constant).
76+
Default: infer from `static_argnames` using `inspect.signature(func)`.
77+
static_argnames : str | Iterable[str], optional
78+
Passed to jax.jit.
79+
Named arguments to treat as static (compile-time constant).
80+
Default: infer from `static_argnums` using `inspect.signature(func)`.
81+
82+
See Also
83+
--------
84+
patch_lazy_xp_functions : Companion function to call from the test or fixture.
85+
jax.jit : JAX function to compile a function for performance.
86+
87+
Examples
88+
--------
89+
In `test_mymodule.py`::
90+
91+
from array_api_extra.testing import lazy_xp_function
92+
from mymodule import myfunc
93+
94+
lazy_xp_function(myfunc)
95+
96+
def test_myfunc(xp):
97+
a = xp.asarray([1, 2])
98+
# When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)`
99+
# When xp=dask.array, crash on compute() or persist()
100+
b = myfunc(a)
101+
102+
Notes
103+
-----
104+
A test function can circumvent this monkey-patching system by calling `func` as an
105+
attribute of the original module. You need to sanitize your code to make sure this
106+
does not happen.
107+
108+
Example::
109+
110+
import mymodule
111+
from mymodule import myfunc
112+
113+
lazy_xp_function(myfunc)
114+
115+
def test_myfunc(xp):
116+
a = xp.asarray([1, 2])
117+
b = myfunc(a) # This is jitted when xp=jax.numpy
118+
c = mymodule.myfunc(a) # This is not
119+
"""
120+
func.disable_dask_compute = disable_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
121+
if jax_jit:
122+
func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
123+
"static_argnums": static_argnums,
124+
"static_argnames": static_argnames,
125+
}
126+
127+
128+
def patch_lazy_xp_functions(
129+
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch, *, xp: ModuleType
130+
) -> None:
131+
"""
132+
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
133+
134+
If `xp==jax.numpy`, search for all functions which have been tagged with
135+
:func:`lazy_xp_function` in the globals of the module that defines the current test
136+
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
137+
138+
If `xp==dask.array`, wrap the functions with a decorator that disables `compute()`
139+
and `persist()`.
140+
141+
This function should be typically called by your library's `xp` fixture that runs
142+
tests on multiple backends::
143+
144+
@pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array])
145+
def xp(request, monkeypatch):
146+
patch_lazy_xp_functions(request, monkeypatch, xp=request.param)
147+
return request.param
148+
149+
but it can be otherwise be called by the test itself too.
150+
151+
Parameters
152+
----------
153+
request : pytest.FixtureRequest
154+
Pytest fixture, as acquired by the test itself or by one of its fixtures.
155+
monkeypatch : pytest.MonkeyPatch
156+
Pytest fixture, as acquired by the test itself or by one of its fixtures.
157+
xp : module
158+
Array namespace to be tested.
159+
160+
See Also
161+
--------
162+
lazy_xp_function : Tag a function to be tested on lazy backends.
163+
pytest.FixtureRequest : `request` test function parameter.
164+
"""
165+
globals_ = cast(dict[str, Any], request.module.__dict__) # type: ignore[no-any-explicit]
166+
167+
if is_dask_namespace(xp):
168+
for name, func in globals_.items():
169+
if getattr(func, "disable_dask_compute", False):
170+
wrapped = _disable_dask_compute(func)
171+
monkeypatch.setitem(globals_, name, wrapped)
172+
173+
elif is_jax_namespace(xp):
174+
import jax
175+
176+
for name, func in globals_.items():
177+
kwargs = cast( # type: ignore[no-any-explicit]
178+
"dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None)
179+
)
180+
181+
# suppress unused-ignore to run mypy in -e lint as well as -e dev
182+
if kwargs is not None: # type: ignore[no-untyped-call,unused-ignore]
183+
wrapped = jax.jit(func, **kwargs) # type: ignore[no-untyped-call,unused-ignore]
184+
monkeypatch.setitem(globals_, name, wrapped) # pyright: ignore[reportUnknownArgumentType]
185+
186+
187+
def _disable_dask_compute(
188+
func: Callable[P, T],
189+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
190+
"""
191+
Wrap a function to raise if it attempts to call dask.compute or dask.persist.
192+
"""
193+
import dask.config
194+
195+
def get(*args: object, **kwargs: object) -> object: # noqa: ARG001 # numpydoc ignore=PR01
196+
"""Dask scheduler which will always raise when invoked."""
197+
msg = "Called `dask.compute()` or `dask.persist()`"
198+
raise AssertionError(msg)
199+
200+
@wraps(func)
201+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
202+
with dask.config.set({"scheduler": get}):
203+
return func(*args, **kwargs)
204+
205+
return wrapper

Diff for: tests/conftest.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from array_api_extra._lib._utils._compat import array_namespace
1414
from array_api_extra._lib._utils._compat import device as get_device
1515
from array_api_extra._lib._utils._typing import Device
16+
from array_api_extra.testing import patch_lazy_xp_functions
1617

1718
T = TypeVar("T")
1819
P = ParamSpec("P")
@@ -96,7 +97,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
9697

9798

9899
@pytest.fixture
99-
def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
100+
def xp(
101+
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
102+
) -> ModuleType: # numpydoc ignore=PR01,RT03
100103
"""
101104
Parameterized fixture that iterates on all libraries.
102105
@@ -107,6 +110,9 @@ def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
107110
if library == Backend.NUMPY_READONLY:
108111
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
109112
xp = pytest.importorskip(library.value)
113+
114+
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
115+
110116
if library == Backend.JAX:
111117
import jax
112118

0 commit comments

Comments
 (0)