Skip to content

Commit 8fa3fd2

Browse files
authored
Merge pull request #115 from crusaderky/test_jit
2 parents 9064806 + 9a6b7b5 commit 8fa3fd2

File tree

11 files changed

+590
-17
lines changed

11 files changed

+590
-17
lines changed

Diff for: docs/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
:hidden:
66
self
77
api-reference.md
8+
testing-utils.md
89
contributing.md
910
contributors.md
1011
```

Diff for: docs/testing-utils.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Testing Utilities
2+
3+
These additional functions are meant to be used while unit testing Array API
4+
compliant packages:
5+
6+
```{eval-rst}
7+
.. currentmodule:: array_api_extra.testing
8+
.. autosummary::
9+
:nosignatures:
10+
:toctree: generated
11+
12+
lazy_xp_function
13+
patch_lazy_xp_functions
14+
```

Diff for: pixi.lock

+73-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+17-2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ array-api-compat = ">=1.10.0,<2"
5454
array-api-extra = { path = ".", editable = true }
5555

5656
[tool.pixi.feature.lint.dependencies]
57+
typing-extensions = "*"
5758
pre-commit = "*"
5859
pylint = "*"
5960
basedmypy = "*"
@@ -63,6 +64,9 @@ numpydoc = ">=1.8.0,<2"
6364
array-api-strict = "*"
6465
numpy = "*"
6566
pytest = "*"
67+
dask-core = "*" # No distributed, tornado, etc.
68+
# NOTE: don't add jax, pytorch, sparse, cupy here
69+
# as they slow down mypy and are not portable across target OSs
6670

6771
[tool.pixi.feature.lint.tasks]
6872
pre-commit-install = "pre-commit install"
@@ -98,6 +102,10 @@ furo = ">=2023.08.17"
98102
myst-parser = ">=0.13"
99103
sphinx-copybutton = "*"
100104
sphinx-autodoc-typehints = "*"
105+
# Needed to import parsed modules with autodoc
106+
dask-core = "*"
107+
pytest = "*"
108+
typing-extensions = "*"
101109

102110
[tool.pixi.feature.docs.tasks]
103111
docs = { cmd = "sphinx-build . build/", cwd = "docs" }
@@ -180,8 +188,10 @@ markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific b
180188

181189
[tool.coverage]
182190
run.source = ["array_api_extra"]
183-
report.exclude_also = ['\.\.\.']
184-
191+
report.exclude_also = [
192+
'\.\.\.',
193+
'if TYPE_CHECKING:',
194+
]
185195

186196
# mypy
187197

@@ -221,6 +231,8 @@ reportMissingImports = false
221231
reportMissingTypeStubs = false
222232
# false positives for input validation
223233
reportUnreachable = false
234+
# ruff handles this
235+
reportUnusedParameter = false
224236

225237
executionEnvironments = [
226238
{ root = "tests", reportPrivateUsage = false },
@@ -282,7 +294,10 @@ messages_control.disable = [
282294
"design", # ignore heavily opinionated design checks
283295
"fixme", # allow FIXME comments
284296
"line-too-long", # ruff handles this
297+
"unused-argument", # ruff handles this
285298
"missing-function-docstring", # numpydoc handles this
299+
"import-error", # mypy handles this
300+
"import-outside-toplevel", # optional dependencies
286301
]
287302

288303

Diff for: src/array_api_extra/_lib/_testing.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Testing utilities.
33
44
Note that this is private API; don't expect it to be stable.
5+
See also ..testing for public testing utilities.
56
"""
67

78
import math

Diff for: src/array_api_extra/testing.py

+262
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
"""
2+
Public testing utilities.
3+
4+
See also _lib._testing for additional private testing utilities.
5+
"""
6+
7+
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
8+
from __future__ import annotations
9+
10+
from collections.abc import Callable, Iterable, Sequence
11+
from functools import wraps
12+
from types import ModuleType
13+
from typing import TYPE_CHECKING, Any, TypeVar, cast
14+
15+
import pytest
16+
17+
from array_api_extra._lib._utils._compat import is_dask_namespace, is_jax_namespace
18+
19+
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
20+
21+
if TYPE_CHECKING:
22+
# TODO move ParamSpec outside TYPE_CHECKING
23+
# depends on scikit-learn abandoning Python 3.9
24+
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
25+
from typing import ParamSpec
26+
27+
from dask.typing import Graph, Key, SchedulerGetCallable
28+
from typing_extensions import override
29+
30+
P = ParamSpec("P")
31+
else:
32+
SchedulerGetCallable = object
33+
34+
# Sphinx hacks
35+
class P: # pylint: disable=missing-class-docstring
36+
args: tuple
37+
kwargs: dict
38+
39+
def override(func: Callable[P, T]) -> Callable[P, T]:
40+
return func
41+
42+
43+
T = TypeVar("T")
44+
45+
46+
def lazy_xp_function( # type: ignore[no-any-explicit]
47+
func: Callable[..., Any],
48+
*,
49+
allow_dask_compute: int = 0,
50+
jax_jit: bool = True,
51+
static_argnums: int | Sequence[int] | None = None,
52+
static_argnames: str | Iterable[str] | None = None,
53+
) -> None: # numpydoc ignore=GL07
54+
"""
55+
Tag a function to be tested on lazy backends.
56+
57+
Tag a function, which must be imported in the test module globals, so that when any
58+
tests defined in the same module are executed with ``xp=jax.numpy`` the function is
59+
replaced with a jitted version of itself, and when it is executed with
60+
``xp=dask.array`` the function will raise if it attempts to materialize the graph.
61+
This will be later expanded to provide test coverage for other lazy backends.
62+
63+
In order for the tag to be effective, the test or a fixture must call
64+
:func:`patch_lazy_xp_functions`.
65+
66+
Parameters
67+
----------
68+
func : callable
69+
Function to be tested.
70+
allow_dask_compute : int, optional
71+
Number of times `func` is allowed to internally materialize the Dask graph. This
72+
is typically triggered by ``bool()``, ``float()``, or ``np.asarray()``.
73+
74+
Set to 1 if you are aware that `func` converts the input parameters to numpy and
75+
want to let it do so at least for the time being, knowing that it is going to be
76+
extremely detrimental for performance.
77+
78+
If a test needs values higher than 1 to pass, it is a canary that the conversion
79+
to numpy/bool/float is happening multiple times, which translates to multiple
80+
computations of the whole graph. Short of making the function fully lazy, you
81+
should at least add explicit calls to ``np.asarray()`` early in the function.
82+
*Note:* the counter of `allow_dask_compute` resets after each call to `func`, so
83+
a test function that invokes `func` multiple times should still work with this
84+
parameter set to 1.
85+
86+
Default: 0, meaning that `func` must be fully lazy and never materialize the
87+
graph.
88+
jax_jit : bool, optional
89+
Set to True to replace `func` with ``jax.jit(func)`` after calling the
90+
:func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False
91+
if `func` is only compatible with eager (non-jitted) JAX. Default: True.
92+
static_argnums : int | Sequence[int], optional
93+
Passed to jax.jit. Positional arguments to treat as static (compile-time
94+
constant). Default: infer from `static_argnames` using
95+
`inspect.signature(func)`.
96+
static_argnames : str | Iterable[str], optional
97+
Passed to jax.jit. Named arguments to treat as static (compile-time constant).
98+
Default: infer from `static_argnums` using `inspect.signature(func)`.
99+
100+
See Also
101+
--------
102+
patch_lazy_xp_functions : Companion function to call from the test or fixture.
103+
jax.jit : JAX function to compile a function for performance.
104+
105+
Examples
106+
--------
107+
In ``test_mymodule.py``::
108+
109+
from array_api_extra.testing import lazy_xp_function from mymodule import myfunc
110+
111+
lazy_xp_function(myfunc)
112+
113+
def test_myfunc(xp):
114+
a = xp.asarray([1, 2])
115+
# When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)`
116+
# When xp=dask.array, crash on compute() or persist()
117+
b = myfunc(a)
118+
119+
Notes
120+
-----
121+
A test function can circumvent this monkey-patching system by calling `func` as an
122+
attribute of the original module. You need to sanitize your code to make sure this
123+
does not happen.
124+
125+
Example::
126+
127+
import mymodule from mymodule import myfunc
128+
129+
lazy_xp_function(myfunc)
130+
131+
def test_myfunc(xp):
132+
a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
133+
mymodule.myfunc(a) # This is not
134+
"""
135+
func.allow_dask_compute = allow_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
136+
if jax_jit:
137+
func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
138+
"static_argnums": static_argnums,
139+
"static_argnames": static_argnames,
140+
}
141+
142+
143+
def patch_lazy_xp_functions(
144+
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch, *, xp: ModuleType
145+
) -> None:
146+
"""
147+
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
148+
149+
If ``xp==jax.numpy``, search for all functions which have been tagged with
150+
:func:`lazy_xp_function` in the globals of the module that defines the current test
151+
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
152+
153+
If ``xp==dask.array``, wrap the functions with a decorator that disables
154+
``compute()`` and ``persist()``.
155+
156+
This function should be typically called by your library's `xp` fixture that runs
157+
tests on multiple backends::
158+
159+
@pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array])
160+
def xp(request, monkeypatch):
161+
patch_lazy_xp_functions(request, monkeypatch, xp=request.param)
162+
return request.param
163+
164+
but it can be otherwise be called by the test itself too.
165+
166+
Parameters
167+
----------
168+
request : pytest.FixtureRequest
169+
Pytest fixture, as acquired by the test itself or by one of its fixtures.
170+
monkeypatch : pytest.MonkeyPatch
171+
Pytest fixture, as acquired by the test itself or by one of its fixtures.
172+
xp : module
173+
Array namespace to be tested.
174+
175+
See Also
176+
--------
177+
lazy_xp_function : Tag a function to be tested on lazy backends.
178+
pytest.FixtureRequest : `request` test function parameter.
179+
"""
180+
globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit]
181+
182+
if is_dask_namespace(xp):
183+
for name, func in globals_.items():
184+
n = getattr(func, "allow_dask_compute", None)
185+
if n is not None:
186+
assert isinstance(n, int)
187+
wrapped = _allow_dask_compute(func, n)
188+
monkeypatch.setitem(globals_, name, wrapped)
189+
190+
elif is_jax_namespace(xp):
191+
import jax
192+
193+
for name, func in globals_.items():
194+
kwargs = cast( # type: ignore[no-any-explicit]
195+
"dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None)
196+
)
197+
if kwargs is not None:
198+
# suppress unused-ignore to run mypy in -e lint as well as -e dev
199+
wrapped = cast(Callable[..., Any], jax.jit(func, **kwargs)) # type: ignore[no-any-explicit,no-untyped-call,unused-ignore]
200+
monkeypatch.setitem(globals_, name, wrapped)
201+
202+
203+
class CountingDaskScheduler(SchedulerGetCallable):
204+
"""
205+
Dask scheduler that counts how many times `dask.compute` is called.
206+
207+
If the number of times exceeds 'max_count', it raises an error.
208+
This is a wrapper around Dask's own 'synchronous' scheduler.
209+
210+
Parameters
211+
----------
212+
max_count : int
213+
Maximum number of allowed calls to `dask.compute`.
214+
msg : str
215+
Assertion to raise when the count exceeds `max_count`.
216+
"""
217+
218+
count: int
219+
max_count: int
220+
msg: str
221+
222+
def __init__(self, max_count: int, msg: str): # numpydoc ignore=GL08
223+
self.count = 0
224+
self.max_count = max_count
225+
self.msg = msg
226+
227+
@override
228+
def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any: # type: ignore[no-any-decorated,no-any-explicit] # numpydoc ignore=GL08
229+
import dask
230+
231+
self.count += 1
232+
# This should yield a nice traceback to the
233+
# offending line in the user's code
234+
assert self.count <= self.max_count, self.msg
235+
236+
return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
237+
238+
239+
def _allow_dask_compute(
240+
func: Callable[P, T], n: int
241+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
242+
"""
243+
Wrap `func` to raise if it attempts to call `dask.compute` more than `n` times.
244+
"""
245+
import dask.config
246+
247+
func_name = getattr(func, "__name__", str(func))
248+
n_str = f"only up to {n}" if n else "no"
249+
msg = (
250+
f"Called `dask.compute()` or `dask.persist()` {n + 1} times, "
251+
f"but {n_str} calls are allowed. Set "
252+
f"`lazy_xp_function({func_name}, allow_dask_compute={n + 1})` "
253+
"to allow for more (but note that this will harm performance). "
254+
)
255+
256+
@wraps(func)
257+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
258+
scheduler = CountingDaskScheduler(n, msg)
259+
with dask.config.set({"scheduler": scheduler}):
260+
return func(*args, **kwargs)
261+
262+
return wrapper

Diff for: tests/conftest.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from array_api_extra._lib._utils._compat import array_namespace
1414
from array_api_extra._lib._utils._compat import device as get_device
1515
from array_api_extra._lib._utils._typing import Device
16+
from array_api_extra.testing import patch_lazy_xp_functions
1617

1718
T = TypeVar("T")
1819
P = ParamSpec("P")
@@ -96,7 +97,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
9697

9798

9899
@pytest.fixture
99-
def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
100+
def xp(
101+
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
102+
) -> ModuleType: # numpydoc ignore=PR01,RT03
100103
"""
101104
Parameterized fixture that iterates on all libraries.
102105
@@ -107,6 +110,9 @@ def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
107110
if library == Backend.NUMPY_READONLY:
108111
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
109112
xp = pytest.importorskip(library.value)
113+
114+
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
115+
110116
if library == Backend.JAX:
111117
import jax
112118

0 commit comments

Comments
 (0)