diff --git a/src/array_api_extra/_lib/_backends.py b/src/array_api_extra/_lib/_backends.py index 93c43272..ee2e051e 100644 --- a/src/array_api_extra/_lib/_backends.py +++ b/src/array_api_extra/_lib/_backends.py @@ -17,29 +17,23 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an Parameters ---------- value : str - String describing the backend. + Name of the backend's module. is_namespace : Callable[[ModuleType], bool] Function to check whether an input module is the array namespace corresponding to the backend. - module_name : str - Name of the backend's module. """ - ARRAY_API_STRICT = ( - "array_api_strict", - _compat.is_array_api_strict_namespace, - "array_api_strict", - ) - NUMPY = "numpy", _compat.is_numpy_namespace, "numpy" - NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace, "numpy" - CUPY = "cupy", _compat.is_cupy_namespace, "cupy" - TORCH = "torch", _compat.is_torch_namespace, "torch" - DASK_ARRAY = "dask.array", _compat.is_dask_namespace, "dask.array" - SPARSE = "sparse", _compat.is_pydata_sparse_namespace, "sparse" - JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace, "jax.numpy" + ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace + NUMPY = "numpy", _compat.is_numpy_namespace + NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace + CUPY = "cupy", _compat.is_cupy_namespace + TORCH = "torch", _compat.is_torch_namespace + DASK_ARRAY = "dask.array", _compat.is_dask_namespace + SPARSE = "sparse", _compat.is_pydata_sparse_namespace + JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace def __new__( - cls, value: str, _is_namespace: Callable[[ModuleType], bool], _module_name: str + cls, value: str, _is_namespace: Callable[[ModuleType], bool] ): # numpydoc ignore=GL08 obj = object.__new__(cls) obj._value_ = value @@ -49,10 +43,8 @@ def __init__( self, value: str, # noqa: ARG002 # pylint: disable=unused-argument is_namespace: Callable[[ModuleType], bool], - module_name: str, ): # numpydoc ignore=GL08 self.is_namespace = is_namespace - self.module_name = module_name def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01 """Pretty-print parameterized test names.""" diff --git a/tests/conftest.py b/tests/conftest.py index 4b07c205..c588d802 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,11 @@ """Pytest fixtures.""" +from collections.abc import Callable +from functools import wraps from types import ModuleType -from typing import cast +from typing import ParamSpec, TypeVar, cast +import numpy as np import pytest from array_api_extra._lib import Backend @@ -10,6 +13,11 @@ from array_api_extra._lib._utils._compat import device as get_device from array_api_extra._lib._utils._typing import Device +T = TypeVar("T") +P = ParamSpec("P") + +np_compat = array_namespace(np.empty(0)) + @pytest.fixture(params=tuple(Backend)) def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,RT03 @@ -34,6 +42,56 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01, return elem +class NumPyReadOnly: + """ + Variant of array_api_compat.numpy producing read-only arrays. + + Read-only numpy arrays fail on `__iadd__` etc., whereas read-only libraries such as + JAX and Sparse simply don't define those methods, which makes calls to `+=` fall + back to `__add__`. + + Note that this is not a full read-only Array API library. Notably, + `array_namespace(x)` returns array_api_compat.numpy. This is actually the desired + behaviour, so that when a tested function internally calls `xp = + array_namespace(*args) or xp`, it will internally create writeable arrays. + For this reason, tests that explicitly pass xp=xp to the tested functions may + misbehave and should be skipped for NUMPY_READONLY. + """ + + def __getattr__(self, name: str) -> object: # numpydoc ignore=PR01,RT01 + """Wrap all functions that return arrays to make their output read-only.""" + func = getattr(np_compat, name) + if not callable(func) or isinstance(func, type): + return func + return self._wrap(func) + + @staticmethod + def _wrap(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR01,RT01 + """Wrap func to make all np.ndarrays it returns read-only.""" + + def as_readonly(o: T) -> T: # numpydoc ignore=PR01,RT01 + """Unset the writeable flag in o.""" + try: + # Don't use is_numpy_array(o), as it includes np.generic + if isinstance(o, np.ndarray): + o.flags.writeable = False + except TypeError: + # Cannot interpret as a data type + return o + + # This works with namedtuples too + if isinstance(o, tuple | list): + return type(o)(*(as_readonly(i) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType] + + return o + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 + return as_readonly(func(*args, **kwargs)) + + return wrapper + + @pytest.fixture def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03 """ @@ -43,7 +101,9 @@ def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03 ------- The current array namespace. """ - xp = pytest.importorskip(library.module_name) + if library == Backend.NUMPY_READONLY: + return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType] + xp = pytest.importorskip(library.value) if library == Backend.JAX_NUMPY: import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports] diff --git a/tests/test_at.py b/tests/test_at.py index e9159712..749c1b55 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -7,7 +7,6 @@ import pytest from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] array_namespace, - is_pydata_sparse_array, is_writeable_array, ) @@ -18,14 +17,6 @@ from array_api_extra._lib._utils._typing import Array -@pytest.fixture -def array(library: Backend, xp: ModuleType) -> Array: - x = xp.asarray([10.0, 20.0, 30.0]) - if library == Backend.NUMPY_READONLY: - x.flags.writeable = False - return x - - @contextmanager def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: if copy is False and not is_writeable_array(array): @@ -42,6 +33,9 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy)) +@pytest.mark.skip_xp_backend( + Backend.SPARSE, reason="read-only backend without .at support" +) @pytest.mark.parametrize( ("kwargs", "expect_copy"), [ @@ -66,15 +60,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: ) def test_update_ops( xp: ModuleType, - array: Array, kwargs: dict[str, bool | None], expect_copy: bool | None, op: _AtOp, arg: float, expect: list[float], ): - if is_pydata_sparse_array(array): - pytest.skip("at() does not support updates on sparse arrays") + array = xp.asarray([10.0, 20.0, 30.0]) with assert_copy(array, expect_copy): func = cast(Callable[..., Array], getattr(at(array)[1:], op.value)) # type: ignore[no-any-explicit] diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 5be4a9ad..897b7811 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -136,6 +136,7 @@ def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3], device=device) assert get_device(cov(x)) == device + @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY) def test_xp(self, xp: ModuleType): xp_assert_close( cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp), @@ -366,6 +367,7 @@ def test_device(self, xp: ModuleType, device: Device): x2 = xp.asarray([2, 3, 4], device=device) assert get_device(setdiff1d(x1, x2)) == device + @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY) def test_xp(self, xp: ModuleType): x1 = xp.asarray([3, 8, 20]) x2 = xp.asarray([2, 3, 4]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1960b3eb..981d5c03 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,13 +1,12 @@ from types import ModuleType -import numpy as np import pytest from array_api_extra._lib import Backend from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils._compat import device as get_device from array_api_extra._lib._utils._helpers import in1d -from array_api_extra._lib._utils._typing import Array, Device +from array_api_extra._lib._utils._typing import Device # mypy: disable-error-code=no-untyped-usage @@ -16,10 +15,10 @@ class TestIn1D: @pytest.mark.skip_xp_backend(Backend.DASK_ARRAY, reason="no argsort") @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse, no device") # cover both code paths - @pytest.mark.parametrize("x2", [np.arange(9), np.arange(15)]) - def test_no_invert_assume_unique(self, xp: ModuleType, x2: Array): + @pytest.mark.parametrize("n", [9, 15]) + def test_no_invert_assume_unique(self, xp: ModuleType, n: int): x1 = xp.asarray([3, 8, 20]) - x2 = xp.asarray(x2) + x2 = xp.arange(n) expected = xp.asarray([True, True, False]) actual = in1d(x1, x2) xp_assert_equal(actual, expected) @@ -30,6 +29,7 @@ def test_device(self, xp: ModuleType, device: Device): x2 = xp.asarray([2, 3, 4], device=device) assert get_device(in1d(x1, x2)) == device + @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY) @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device") def test_xp(self, xp: ModuleType): x1 = xp.asarray([1, 6])