Skip to content

ENH: lazy_xp_function to materialize exceptions on Dask #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/array_api_extra/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
if is_dask_namespace(xp):
for name, func, tags in iter_tagged():
n = tags["allow_dask_compute"]
wrapped = _allow_dask_compute(func, n)
wrapped = _dask_wrap(func, n)
monkeypatch.setitem(globals_, name, wrapped)

elif is_jax_namespace(xp):
Expand Down Expand Up @@ -256,13 +256,15 @@ def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any:
return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]


def _allow_dask_compute(
def _dask_wrap(
func: Callable[P, T], n: int
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
"""
Wrap `func` to raise if it attempts to call `dask.compute` more than `n` times.

After the function returns, materialize the graph in order to re-raise exceptions.
"""
import dask.config
import dask

func_name = getattr(func, "__name__", str(func))
n_str = f"only up to {n}" if n else "no"
Expand All @@ -276,7 +278,12 @@ def _allow_dask_compute(
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
scheduler = CountingDaskScheduler(n, msg)
with dask.config.set({"scheduler": scheduler}):
return func(*args, **kwargs)
with dask.config.set({"scheduler": scheduler}): # pyright: ignore[reportPrivateImportUsage]
out = func(*args, **kwargs)

# Block until the graph materializes and reraise exceptions. This allows
# `pytest.raises` and `pytest.warns` to work as expected. Note that this would
# not work on scheduler='distributed', as it would not block.
return dask.persist(out, scheduler="threads")[0] # type: ignore[no-any-return,attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]

return wrapper
29 changes: 29 additions & 0 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,32 @@ def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
# note that when sparse reduces to scalar it returns a np.generic, which
# would make xp_assert_equal fail.
xp_assert_equal(erf(x), xp.asarray([1.0, 1.0]))


def dask_raises(x: Array) -> Array:
def _raises(x: Array) -> Array:
# Test that map_blocks doesn't eagerly call the function;
# dtype and meta should be sufficient to skip the trial run.
assert x.shape == (3,)
msg = "Hello world"
raise ValueError(msg)

return x.map_blocks(_raises, dtype=x.dtype, meta=x._meta)


lazy_xp_function(dask_raises)


def test_lazy_xp_function_eagerly_raises(da: ModuleType):
"""Test that the pattern::

with pytest.raises(Exception):
func(x)

works with Dask, even though it normally wouldn't as we're disregarding the func
output so the graph would not be ordinarily materialized.
lazy_xp_function contains ad-hoc code to materialize and reraise exceptions.
"""
x = da.arange(3)
with pytest.raises(ValueError, match="Hello world"):
dask_raises(x)