From 992412b5216f45c616ceb9e0e1568cc9918a727e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 15 Jan 2025 13:12:38 +0000 Subject: [PATCH 1/4] TST: Run all tests on read-only numpy arrays --- tests/conftest.py | 67 +++++++++++++++++++++++++++++++++++++++++++-- tests/test_at.py | 16 +++-------- tests/test_funcs.py | 2 ++ tests/test_utils.py | 10 +++---- 4 files changed, 75 insertions(+), 20 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0bf3114..a6a5d54 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,24 @@ """Pytest fixtures.""" +from __future__ import annotations + +from collections.abc import Callable from enum import Enum -from typing import cast +from functools import wraps +from typing import ParamSpec, TypeVar, cast +import numpy as np import pytest from array_api_extra._lib._compat import array_namespace from array_api_extra._lib._compat import device as get_device from array_api_extra._lib._typing import Device, ModuleType +T = TypeVar("T") +P = ParamSpec("P") + +np_compat = array_namespace(np.empty(0)) + class Library(Enum): """All array libraries explicitly tested by array-api-extra.""" @@ -50,6 +60,56 @@ def library(request: pytest.FixtureRequest) -> Library: # numpydoc ignore=PR01, return elem +class NumPyReadOnly: + """ + Variant of array_api_compat.numpy producing read-only arrays. + + Note that this is not a full read-only Array API library. Notably, + array_namespace(x) returns array_api_compat.numpy, and as a consequence array + creation functions invoked internally by the tested functions will return + writeable arrays, as long as you don't explicitly pass xp=xp. + For this reason, tests that do pass xp=xp may misbehave and should be skipped + for NUMPY_READONLY. + """ + + def __getattr__(self, name: str) -> object: # numpydoc ignore=PR01,RT01 + """Wrap all functions that return arrays to make their output read-only.""" + func = getattr(np_compat, name) + if not callable(func) or isinstance(func, type): + return func + return self._wrap(func) + + @staticmethod + def _wrap(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR01,RT01 + """Wrap func to make all np.ndarrays it returns read-only.""" + + def as_readonly(o: T, seen: set[int]) -> T: # numpydoc ignore=PR01,RT01 + """Unset the writeable flag in o.""" + if id(o) in seen: + return o + seen.add(id(o)) + + try: + # Don't use is_numpy_array(o), as it includes np.generic + if isinstance(o, np.ndarray): + o.flags.writeable = False + except TypeError: + # Cannot interpret as a data type + return o + + # This works with namedtuples too + if isinstance(o, tuple | list): + return type(o)(*(as_readonly(i, seen) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType] + + return o + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 + return as_readonly(func(*args, **kwargs), seen=set()) + + return wrapper + + @pytest.fixture def xp(library: Library) -> ModuleType: # numpydoc ignore=PR01,RT03 """ @@ -59,8 +119,9 @@ def xp(library: Library) -> ModuleType: # numpydoc ignore=PR01,RT03 ------- The current array namespace. """ - name = "numpy" if library == Library.NUMPY_READONLY else library.value - xp = pytest.importorskip(name) + if library == Library.NUMPY_READONLY: + return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType] + xp = pytest.importorskip(library.value) if library == Library.JAX_NUMPY: import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports] diff --git a/tests/test_at.py b/tests/test_at.py index ed56f61..18bab67 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -6,7 +6,6 @@ import pytest from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] array_namespace, - is_pydata_sparse_array, is_writeable_array, ) @@ -18,14 +17,6 @@ from .conftest import Library -@pytest.fixture -def array(library: Library, xp: ModuleType) -> Array: - x = xp.asarray([10.0, 20.0, 30.0]) - if library == Library.NUMPY_READONLY: - x.flags.writeable = False - return x - - @contextmanager def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: if copy is False and not is_writeable_array(array): @@ -42,6 +33,9 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy)) +@pytest.mark.skip_xp_backend( + Library.SPARSE, reason="read-only library without .at support" +) @pytest.mark.parametrize( ("kwargs", "expect_copy"), [ @@ -66,15 +60,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: ) def test_update_ops( xp: ModuleType, - array: Array, kwargs: dict[str, bool | None], expect_copy: bool | None, op: _AtOp, arg: float, expect: list[float], ): - if is_pydata_sparse_array(array): - pytest.skip("at() does not support updates on sparse arrays") + array = xp.asarray([10.0, 20.0, 30.0]) with assert_copy(array, expect_copy): func = cast(Callable[..., Array], getattr(at(array)[1:], op.value)) # type: ignore[no-any-explicit] diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 201295d..13fb59d 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -136,6 +136,7 @@ def test_device(self, xp: ModuleType, device: Device): x = xp.asarray([1, 2, 3], device=device) assert get_device(cov(x)) == device + @pytest.mark.skip_xp_backend(Library.NUMPY_READONLY) def test_xp(self, xp: ModuleType): xp_assert_close( cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp), @@ -366,6 +367,7 @@ def test_device(self, xp: ModuleType, device: Device): x2 = xp.asarray([2, 3, 4], device=device) assert get_device(setdiff1d(x1, x2)) == device + @pytest.mark.skip_xp_backend(Library.NUMPY_READONLY) def test_xp(self, xp: ModuleType): x1 = xp.asarray([3, 8, 20]) x2 = xp.asarray([2, 3, 4]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8cf49c2..da05127 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,8 @@ -import numpy as np import pytest from array_api_extra._lib._compat import device as get_device from array_api_extra._lib._testing import xp_assert_equal -from array_api_extra._lib._typing import Array, Device, ModuleType +from array_api_extra._lib._typing import Device, ModuleType from array_api_extra._lib._utils import in1d from .conftest import Library @@ -15,10 +14,10 @@ class TestIn1D: @pytest.mark.skip_xp_backend(Library.DASK_ARRAY, reason="no argsort") @pytest.mark.skip_xp_backend(Library.SPARSE, reason="no unique_inverse, no device") # cover both code paths - @pytest.mark.parametrize("x2", [np.arange(9), np.arange(15)]) - def test_no_invert_assume_unique(self, xp: ModuleType, x2: Array): + @pytest.mark.parametrize("n", [9, 15]) + def test_no_invert_assume_unique(self, xp: ModuleType, n: int): x1 = xp.asarray([3, 8, 20]) - x2 = xp.asarray(x2) + x2 = xp.arange(n) expected = xp.asarray([True, True, False]) actual = in1d(x1, x2) xp_assert_equal(actual, expected) @@ -29,6 +28,7 @@ def test_device(self, xp: ModuleType, device: Device): x2 = xp.asarray([2, 3, 4], device=device) assert get_device(in1d(x1, x2)) == device + @pytest.mark.skip_xp_backend(Library.NUMPY_READONLY) @pytest.mark.skip_xp_backend(Library.SPARSE, reason="no arange, no device") def test_xp(self, xp: ModuleType): x1 = xp.asarray([1, 6]) From 9d1a6bcdd8742eb86880b9256a5baae591d456bb Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 15 Jan 2025 13:27:38 +0000 Subject: [PATCH 2/4] simplify Backend --- src/array_api_extra/_lib/_backends.py | 28 ++++++++++----------------- tests/conftest.py | 2 -- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/src/array_api_extra/_lib/_backends.py b/src/array_api_extra/_lib/_backends.py index 93c4327..ee2e051 100644 --- a/src/array_api_extra/_lib/_backends.py +++ b/src/array_api_extra/_lib/_backends.py @@ -17,29 +17,23 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an Parameters ---------- value : str - String describing the backend. + Name of the backend's module. is_namespace : Callable[[ModuleType], bool] Function to check whether an input module is the array namespace corresponding to the backend. - module_name : str - Name of the backend's module. """ - ARRAY_API_STRICT = ( - "array_api_strict", - _compat.is_array_api_strict_namespace, - "array_api_strict", - ) - NUMPY = "numpy", _compat.is_numpy_namespace, "numpy" - NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace, "numpy" - CUPY = "cupy", _compat.is_cupy_namespace, "cupy" - TORCH = "torch", _compat.is_torch_namespace, "torch" - DASK_ARRAY = "dask.array", _compat.is_dask_namespace, "dask.array" - SPARSE = "sparse", _compat.is_pydata_sparse_namespace, "sparse" - JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace, "jax.numpy" + ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace + NUMPY = "numpy", _compat.is_numpy_namespace + NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace + CUPY = "cupy", _compat.is_cupy_namespace + TORCH = "torch", _compat.is_torch_namespace + DASK_ARRAY = "dask.array", _compat.is_dask_namespace + SPARSE = "sparse", _compat.is_pydata_sparse_namespace + JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace def __new__( - cls, value: str, _is_namespace: Callable[[ModuleType], bool], _module_name: str + cls, value: str, _is_namespace: Callable[[ModuleType], bool] ): # numpydoc ignore=GL08 obj = object.__new__(cls) obj._value_ = value @@ -49,10 +43,8 @@ def __init__( self, value: str, # noqa: ARG002 # pylint: disable=unused-argument is_namespace: Callable[[ModuleType], bool], - module_name: str, ): # numpydoc ignore=GL08 self.is_namespace = is_namespace - self.module_name = module_name def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01 """Pretty-print parameterized test names.""" diff --git a/tests/conftest.py b/tests/conftest.py index aa0fb11..9f97318 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,5 @@ """Pytest fixtures.""" -from __future__ import annotations - from collections.abc import Callable from functools import wraps from types import ModuleType From eba9b25d892286dab6144074bd5e07141377d417 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 15 Jan 2025 13:50:31 +0000 Subject: [PATCH 3/4] clarify docs --- tests/conftest.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9f97318..536ce68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,12 +46,16 @@ class NumPyReadOnly: """ Variant of array_api_compat.numpy producing read-only arrays. + Read-only numpy arrays fail on `__iadd__` etc., whereas read-only libraries such as + JAX and Sparse simply don't define those methods, which makes calls to `+=` fall + back to `__add__`. + Note that this is not a full read-only Array API library. Notably, - array_namespace(x) returns array_api_compat.numpy, and as a consequence array - creation functions invoked internally by the tested functions will return - writeable arrays, as long as you don't explicitly pass xp=xp. - For this reason, tests that do pass xp=xp may misbehave and should be skipped - for NUMPY_READONLY. + `array_namespace(x)` returns array_api_compat.numpy. This is actually the desired + behaviour, so that when a tested function internally calls `xp = + array_namespace(*args) or xp`, it will internally create writeable arrays. + For this reason, tests that explicitly pass xp=xp to the tested functions may + misbehave and should be skipped for NUMPY_READONLY. """ def __getattr__(self, name: str) -> object: # numpydoc ignore=PR01,RT01 From 3cf11df21a1ed6fef89aa530650a387021879f2e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 15 Jan 2025 13:54:01 +0000 Subject: [PATCH 4/4] Remove unnecessary recursion guard --- tests/conftest.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 536ce68..c588d80 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -69,12 +69,8 @@ def __getattr__(self, name: str) -> object: # numpydoc ignore=PR01,RT01 def _wrap(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR01,RT01 """Wrap func to make all np.ndarrays it returns read-only.""" - def as_readonly(o: T, seen: set[int]) -> T: # numpydoc ignore=PR01,RT01 + def as_readonly(o: T) -> T: # numpydoc ignore=PR01,RT01 """Unset the writeable flag in o.""" - if id(o) in seen: - return o - seen.add(id(o)) - try: # Don't use is_numpy_array(o), as it includes np.generic if isinstance(o, np.ndarray): @@ -85,13 +81,13 @@ def as_readonly(o: T, seen: set[int]) -> T: # numpydoc ignore=PR01,RT01 # This works with namedtuples too if isinstance(o, tuple | list): - return type(o)(*(as_readonly(i, seen) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType] + return type(o)(*(as_readonly(i) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType] return o @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 - return as_readonly(func(*args, **kwargs), seen=set()) + return as_readonly(func(*args, **kwargs)) return wrapper