diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index 22e18d2..736f77b 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -344,6 +344,7 @@ def _op( msg = f"Can't update read-only array {x}" raise ValueError(msg) + # Backends without boolean indexing (other than JAX) crash here if in_place_op: # add(), subtract(), ... x[idx] = in_place_op(x[idx], y) else: # set() diff --git a/src/array_api_extra/_lib/_backends.py b/src/array_api_extra/_lib/_backends.py index f044281..3beb676 100644 --- a/src/array_api_extra/_lib/_backends.py +++ b/src/array_api_extra/_lib/_backends.py @@ -1,9 +1,10 @@ """Backends with which array-api-extra interacts in delegation and testing.""" +from __future__ import annotations + from collections.abc import Callable from enum import Enum from types import ModuleType -from typing import cast from ._utils import _compat @@ -23,9 +24,14 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an corresponding to the backend. """ + # Use : to prevent Enum from deduplicating items with the same value ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace + ARRAY_API_STRICTEST = ( + "array_api_strict:strictest", + _compat.is_array_api_strict_namespace, + ) NUMPY = "numpy", _compat.is_numpy_namespace - NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace + NUMPY_READONLY = "numpy:readonly", _compat.is_numpy_namespace CUPY = "cupy", _compat.is_cupy_namespace TORCH = "torch", _compat.is_torch_namespace DASK = "dask.array", _compat.is_dask_namespace @@ -48,4 +54,13 @@ def __init__( def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01 """Pretty-print parameterized test names.""" - return cast(str, self.value) + return self.name.lower() + + @property + def modname(self) -> str: # numpydoc ignore=RT01 + """Module name to be imported.""" + return self.value.split(":")[0] + + def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01 + """Check if this backend uses the same module as others.""" + return any(self.modname == other.modname for other in others) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index efe2f37..e552392 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -8,13 +8,14 @@ from ._at import at from ._utils import _compat, _helpers -from ._utils._compat import ( - array_namespace, - is_dask_namespace, - is_jax_array, - is_jax_namespace, +from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array +from ._utils._helpers import ( + asarrays, + capabilities, + eager_shape, + meta_namespace, + ndindex, ) -from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex from ._utils._typing import Array __all__ = [ @@ -152,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01 ) -> Array: """Helper of `apply_where`. On Dask, this runs on a single chunk.""" - if is_jax_namespace(xp): + if not capabilities(xp)["boolean indexing"]: # jax.jit does not support assignment by boolean mask return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value) @@ -708,14 +709,34 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: # size= is JAX-specific # https://github.com/data-apis/array-api/issues/883 _, counts = xp.unique_counts(x, size=_compat.size(x)) - return xp.astype(counts, xp.bool).sum() - - _, counts = xp.unique_counts(x) - n = _compat.size(counts) - # FIXME https://github.com/data-apis/array-api-compat/pull/231 - if n is None: # e.g. Dask, ndonnx - return xp.astype(counts, xp.bool).sum() - return xp.asarray(n, device=_compat.device(x)) + return (counts > 0).sum() + + # There are 3 general use cases: + # 1. backend has unique_counts and it returns an array with known shape + # 2. backend has unique_counts and it returns a None-sized array; + # e.g. Dask, ndonnx + # 3. backend does not have unique_counts; e.g. wrapped JAX + if capabilities(xp)["data-dependent shapes"]: + # xp has unique_counts; O(n) complexity + _, counts = xp.unique_counts(x) + n = _compat.size(counts) + if n is None: + return xp.sum(xp.ones_like(counts)) + return xp.asarray(n, device=_compat.device(x)) + + # xp does not have unique_counts; O(n*logn) complexity + x = xp.sort(xp.reshape(x, -1)) + mask = x != xp.roll(x, -1) + default_int = xp.__array_namespace_info__().default_dtypes( + device=_compat.device(x) + )["integral"] + return xp.maximum( + # Special cases: + # - array is size 0 + # - array has all elements equal to each other + xp.astype(xp.any(~mask), default_int), + xp.sum(xp.astype(mask, default_int)), + ) def pad( diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 9882d72..6400627 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -12,7 +12,9 @@ array_namespace, is_array_api_obj, is_dask_namespace, + is_jax_namespace, is_numpy_array, + is_pydata_sparse_namespace, ) from ._typing import Array @@ -23,6 +25,7 @@ __all__ = [ "asarrays", + "capabilities", "eager_shape", "in1d", "is_python_scalar", @@ -270,3 +273,36 @@ def meta_namespace( # Quietly skip scalars and None's metas = [cast(Array | None, getattr(a, "_meta", None)) for a in arrays] return array_namespace(*metas) + + +def capabilities(xp: ModuleType) -> dict[str, int]: + """ + Return patched ``xp.__array_namespace_info__().capabilities()``. + + TODO this helper should be eventually removed once all the special cases + it handles are fixed in the respective backends. + + Parameters + ---------- + xp : array_namespace + The standard-compatible namespace. + + Returns + ------- + dict + Capabilities of the namespace. + """ + if is_pydata_sparse_namespace(xp): + # No __array_namespace_info__(); no indexing by sparse arrays + return {"boolean indexing": False, "data-dependent shapes": True} + out = xp.__array_namespace_info__().capabilities() + if is_jax_namespace(xp): + # FIXME https://github.com/jax-ml/jax/issues/27418 + out = out.copy() + out["boolean indexing"] = False + if is_dask_namespace(xp): + # FIXME https://github.com/data-apis/array-api-compat/pull/290 + out = out.copy() + out["boolean indexing"] = True + out["data-dependent shapes"] = True + return out diff --git a/tests/conftest.py b/tests/conftest.py index 54e2a23..4e36885 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ """Pytest fixtures.""" -from collections.abc import Callable +from collections.abc import Callable, Generator from contextlib import suppress from functools import partial, wraps from types import ModuleType @@ -19,6 +19,7 @@ T = TypeVar("T") P = ParamSpec("P") +NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[2]) np_compat = array_namespace(np.empty(0)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] @@ -43,7 +44,7 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01, msg = f"argument of {marker_name} must be a Backend enum" raise TypeError(msg) if library == elem: - reason = library.value + reason = str(library) with suppress(KeyError): reason += ":" + cast(str, marker.kwargs["reason"]) skip_or_xfail(reason=reason) @@ -104,7 +105,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 @pytest.fixture def xp( library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch -) -> ModuleType: # numpydoc ignore=PR01,RT03 +) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT03 """ Parameterized fixture that iterates on all libraries. @@ -113,25 +114,38 @@ def xp( The current array namespace. """ if library == Backend.NUMPY_READONLY: - return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType] - xp = pytest.importorskip(library.value) + yield NumPyReadOnly() # type: ignore[misc] # pyright: ignore[reportReturnType] + return + + if library.like(Backend.ARRAY_API_STRICT) and NUMPY_VERSION < (1, 26): + pytest.skip("array_api_strict is untested on NumPy <1.26") + + xp = pytest.importorskip(library.modname) # Possibly wrap module with array_api_compat xp = array_namespace(xp.empty(0)) + if library == Backend.ARRAY_API_STRICTEST: + with xp.ArrayAPIStrictFlags( + boolean_indexing=False, + data_dependent_shapes=False, + # writeable=False, # TODO implement in array-api-strict + # lazy=True, # TODO implement in array-api-strict + enabled_extensions=(), + ): + yield xp + return + # On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function` # in the global scope of the module containing the test function. patch_lazy_xp_functions(request, monkeypatch, xp=xp) - if library == Backend.ARRAY_API_STRICT and np.__version__ < "1.26": - pytest.skip("array_api_strict is untested on NumPy <1.26") - if library == Backend.JAX: import jax # suppress unused-ignore to run mypy in -e lint as well as -e dev jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore] - return xp + yield xp @pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask` diff --git a/tests/test_at.py b/tests/test_at.py index 218b05b..4bde5ce 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -20,7 +20,8 @@ pytestmark = [ pytest.mark.skip_xp_backend( Backend.SPARSE, reason="read-only backend without .at support" - ) + ), + pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing"), ] @@ -256,7 +257,7 @@ def test_incompatible_dtype( elif library is Backend.DASK: z = at_op(x, idx, op, 1.1, copy=copy) - elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET: + elif library.like(Backend.ARRAY_API_STRICT) and op is not _AtOp.SET: with pytest.raises(Exception, match=r"cast|promote|dtype"): _ = at_op(x, idx, op, 1.1, copy=copy) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 46591ed..48ad7b0 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -32,6 +32,8 @@ from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function +from .conftest import NUMPY_VERSION + # some xp backends are untyped # mypy: disable-error-code=no-untyped-def @@ -48,12 +50,6 @@ lazy_xp_function(sinc, static_argnames="xp") -NUMPY_GE2 = int(np.__version__.split(".")[0]) >= 2 - - -@pytest.mark.skip_xp_backend( - Backend.SPARSE, reason="read-only backend without .at support" -) class TestApplyWhere: @staticmethod def f1(x: Array, y: Array | int = 10) -> Array: @@ -153,6 +149,14 @@ def test_dont_overwrite_fill_value(self, xp: ModuleType): xp_assert_equal(actual, xp.asarray([100, 12])) xp_assert_equal(fill_value, xp.asarray([100, 200])) + @pytest.mark.skip_xp_backend( + Backend.ARRAY_API_STRICTEST, + reason="no boolean indexing -> run everywhere", + ) + @pytest.mark.skip_xp_backend( + Backend.SPARSE, + reason="no indexing by sparse array -> run everywhere", + ) def test_dont_run_on_false(self, xp: ModuleType): x = xp.asarray([1.0, 2.0, 0.0]) y = xp.asarray([0.0, 3.0, 4.0]) @@ -192,6 +196,7 @@ def test_device(self, xp: ModuleType, device: Device): y = apply_where(x % 2 == 0, x, self.f1, fill_value=x) assert get_device(y) == device + @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype") @pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflows, etc. @hypothesis.settings( # The xp and library fixtures are not regenerated between hypothesis iterations @@ -217,8 +222,8 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any] library: Backend, ): if ( - library in (Backend.NUMPY, Backend.NUMPY_READONLY) - and not NUMPY_GE2 + library.like(Backend.NUMPY) + and NUMPY_VERSION < (2, 0) and dtype is np.float32 ): pytest.xfail(reason="NumPy 1.x dtype promotion for scalars") @@ -562,6 +567,9 @@ def test_xp(self, xp: ModuleType): assert y.shape == (1, 1, 1, 3) +@pytest.mark.filterwarnings( # array_api_strictest + "ignore:invalid value encountered:RuntimeWarning:array_api_strict" +) @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype") class TestIsClose: @pytest.mark.parametrize("swap", [False, True]) @@ -680,6 +688,7 @@ def test_bool_dtype(self, xp: ModuleType): isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True]) ) + @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape") def test_none_shape(self, xp: ModuleType): a = xp.asarray([1, 5, 0]) b = xp.asarray([1, 4, 2]) @@ -687,6 +696,7 @@ def test_none_shape(self, xp: ModuleType): a = a[a < 5] xp_assert_equal(isclose(a, b), xp.asarray([True, False])) + @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape") def test_none_shape_bool(self, xp: ModuleType): a = xp.asarray([True, True, False]) b = xp.asarray([True, False, True]) @@ -819,8 +829,27 @@ def test_empty(self, xp: ModuleType): a = xp.asarray([]) xp_assert_equal(nunique(a), xp.asarray(0)) - def test_device(self, xp: ModuleType, device: Device): - a = xp.asarray(0.0, device=device) + def test_size1(self, xp: ModuleType): + a = xp.asarray([123]) + xp_assert_equal(nunique(a), xp.asarray(1)) + + def test_all_equal(self, xp: ModuleType): + a = xp.asarray([123, 123, 123]) + xp_assert_equal(nunique(a), xp.asarray(1)) + + @pytest.mark.xfail_xp_backend(Backend.DASK, reason="No equal_nan kwarg in unique") + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#855") + def test_nan(self, xp: ModuleType, library: Backend): + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): + pytest.xfail("NumPy <1.24 has no equal_nan kwarg in unique") + + # Each NaN is counted separately + a = xp.asarray([xp.nan, 123.0, xp.nan]) + xp_assert_equal(nunique(a), xp.asarray(3)) + + @pytest.mark.parametrize("size", [0, 1, 2]) + def test_device(self, xp: ModuleType, device: Device, size: int): + a = xp.asarray([0.0] * size, device=device) assert get_device(nunique(a)) == device def test_xp(self, xp: ModuleType): @@ -895,6 +924,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType): @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort") +@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_values") class TestSetDiff1D: @pytest.mark.xfail_xp_backend(Backend.DASK, reason="NaN-shaped arrays") @pytest.mark.xfail_xp_backend( diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 81b11d1..c7d271c 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -10,6 +10,7 @@ from array_api_extra._lib._utils._compat import device as get_device from array_api_extra._lib._utils._helpers import ( asarrays, + capabilities, eager_shape, in1d, meta_namespace, @@ -27,6 +28,7 @@ @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse") +@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_inverse") class TestIn1D: # cover both code paths @pytest.mark.parametrize( @@ -161,7 +163,8 @@ def test_ndindex(shape: tuple[int, ...]): assert tuple(ndindex(*shape)) == tuple(np.ndindex(*shape)) -@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") +@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing") def test_eager_shape(xp: ModuleType, library: Backend): a = xp.asarray([1, 2, 3]) # Lazy arrays, like Dask, have an eager shape until you slice them with @@ -194,3 +197,10 @@ def test_dask_metas(self, da: ModuleType): def test_xp(self, xp: ModuleType): args = None, xp.asarray(0), 1 assert meta_namespace(*args, xp=xp) in (xp, np_compat) + + +def test_capabilities(xp: ModuleType): + expect = {"boolean indexing", "data-dependent shapes"} + if xp.__array_api_version__ >= "2024.12": + expect.add("max dimensions") + assert capabilities(xp).keys() == expect diff --git a/tests/test_lazy.py b/tests/test_lazy.py index 7057af8..d360e50 100644 --- a/tests/test_lazy.py +++ b/tests/test_lazy.py @@ -1,3 +1,4 @@ +import contextlib from types import ModuleType from typing import cast @@ -214,24 +215,20 @@ def test_lazy_apply_none_shape_in_args(xp: ModuleType, library: Backend): mxp = np if library is Backend.DASK else xp int_type = xp.asarray(0).dtype + ctx: contextlib.AbstractContextManager[object] if library is Backend.JAX: - # Single output - with pytest.raises(ValueError, match="Output shape must be fully known"): - _ = lazy_apply(mxp.unique_values, x, shape=(None,)) - - # Multi output - with pytest.raises(ValueError, match="Output shape must be fully known"): - _ = lazy_apply( - mxp.unique_counts, - x, - shape=((None,), (None,)), - dtype=(x.dtype, int_type), - ) + ctx = pytest.raises(ValueError, match="Output shape must be fully known") + elif library is Backend.ARRAY_API_STRICTEST: + ctx = pytest.raises(RuntimeError, match="data-dependent shapes") else: - # Single output + ctx = contextlib.nullcontext() + + # Single output + with ctx: values = lazy_apply(mxp.unique_values, x, shape=(None,)) xp_assert_equal(values, xp.asarray([1, 2])) + with ctx: # Multi output values, counts = lazy_apply( mxp.unique_counts, @@ -255,8 +252,9 @@ def f(x: Array) -> Array: lazy_xp_function(check_lazy_apply_none_shape_broadcast) -@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="bool mask") -@pytest.mark.xfail_xp_backend(Backend.JAX, reason="unknown shape") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") +@pytest.mark.skip_xp_backend(Backend.JAX, reason="boolean indexing") +@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing") def test_lazy_apply_none_shape_broadcast(xp: ModuleType): """Broadcast from input array with unknown shape""" x = xp.asarray([1, 2, 2]) diff --git a/tests/test_testing.py b/tests/test_testing.py index 47eaa4d..10ce7ab 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -72,7 +72,8 @@ def test_assert_close_tolerance(xp: ModuleType): @param_assert_equal_close -@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") +@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing") def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any] """On Dask and other lazy backends, test that a shape with NaN's or None's can be compared to a real shape. @@ -222,7 +223,7 @@ def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend): pytest.importorskip("scipy") assert erf is not None x = xp.asarray([6.0, 7.0]) - if library in (Backend.ARRAY_API_STRICT, Backend.JAX): + if library.like(Backend.ARRAY_API_STRICT, Backend.JAX): # array-api-strict arrays are auto-converted to NumPy # which results in an assertion error for mismatched namespaces # eager JAX arrays are auto-converted to NumPy in eager JAX