Skip to content

TST: Run all tests on read-only numpy arrays #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions src/array_api_extra/_lib/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,23 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
Parameters
----------
value : str
String describing the backend.
Name of the backend's module.
is_namespace : Callable[[ModuleType], bool]
Function to check whether an input module is the array namespace
corresponding to the backend.
module_name : str
Name of the backend's module.
"""

ARRAY_API_STRICT = (
"array_api_strict",
_compat.is_array_api_strict_namespace,
"array_api_strict",
)
NUMPY = "numpy", _compat.is_numpy_namespace, "numpy"
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace, "numpy"
CUPY = "cupy", _compat.is_cupy_namespace, "cupy"
TORCH = "torch", _compat.is_torch_namespace, "torch"
DASK_ARRAY = "dask.array", _compat.is_dask_namespace, "dask.array"
SPARSE = "sparse", _compat.is_pydata_sparse_namespace, "sparse"
JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace, "jax.numpy"
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
NUMPY = "numpy", _compat.is_numpy_namespace
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
CUPY = "cupy", _compat.is_cupy_namespace
TORCH = "torch", _compat.is_torch_namespace
DASK_ARRAY = "dask.array", _compat.is_dask_namespace
SPARSE = "sparse", _compat.is_pydata_sparse_namespace
JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace

def __new__(
cls, value: str, _is_namespace: Callable[[ModuleType], bool], _module_name: str
cls, value: str, _is_namespace: Callable[[ModuleType], bool]
): # numpydoc ignore=GL08
obj = object.__new__(cls)
obj._value_ = value
Expand All @@ -49,10 +43,8 @@ def __init__(
self,
value: str, # noqa: ARG002 # pylint: disable=unused-argument
is_namespace: Callable[[ModuleType], bool],
module_name: str,
): # numpydoc ignore=GL08
self.is_namespace = is_namespace
self.module_name = module_name

def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
"""Pretty-print parameterized test names."""
Expand Down
64 changes: 62 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
"""Pytest fixtures."""

from collections.abc import Callable
from functools import wraps
from types import ModuleType
from typing import cast
from typing import ParamSpec, TypeVar, cast

import numpy as np
import pytest

from array_api_extra._lib import Backend
from array_api_extra._lib._utils._compat import array_namespace
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._typing import Device

T = TypeVar("T")
P = ParamSpec("P")

np_compat = array_namespace(np.empty(0))


@pytest.fixture(params=tuple(Backend))
def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,RT03
Expand All @@ -34,6 +42,56 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,
return elem


class NumPyReadOnly:
"""
Variant of array_api_compat.numpy producing read-only arrays.

Note that this is not a full read-only Array API library. Notably,
array_namespace(x) returns array_api_compat.numpy, and as a consequence array
creation functions invoked internally by the tested functions will return
writeable arrays, as long as you don't explicitly pass xp=xp.
For this reason, tests that do pass xp=xp may misbehave and should be skipped
for NUMPY_READONLY.
"""

def __getattr__(self, name: str) -> object: # numpydoc ignore=PR01,RT01
"""Wrap all functions that return arrays to make their output read-only."""
func = getattr(np_compat, name)
if not callable(func) or isinstance(func, type):
return func
return self._wrap(func)

@staticmethod
def _wrap(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
"""Wrap func to make all np.ndarrays it returns read-only."""

def as_readonly(o: T, seen: set[int]) -> T: # numpydoc ignore=PR01,RT01
"""Unset the writeable flag in o."""
if id(o) in seen:
return o
seen.add(id(o))

try:
# Don't use is_numpy_array(o), as it includes np.generic
if isinstance(o, np.ndarray):
o.flags.writeable = False
except TypeError:
# Cannot interpret as a data type
return o

# This works with namedtuples too
if isinstance(o, tuple | list):
return type(o)(*(as_readonly(i, seen) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType]

return o

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
return as_readonly(func(*args, **kwargs), seen=set())

return wrapper


@pytest.fixture
def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
"""
Expand All @@ -43,7 +101,9 @@ def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
-------
The current array namespace.
"""
xp = pytest.importorskip(library.module_name)
if library == Backend.NUMPY_READONLY:
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
xp = pytest.importorskip(library.value)
if library == Backend.JAX_NUMPY:
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]

Expand Down
16 changes: 4 additions & 12 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest
from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
array_namespace,
is_pydata_sparse_array,
is_writeable_array,
)

Expand All @@ -18,14 +17,6 @@
from array_api_extra._lib._utils._typing import Array


@pytest.fixture
def array(library: Backend, xp: ModuleType) -> Array:
x = xp.asarray([10.0, 20.0, 30.0])
if library == Backend.NUMPY_READONLY:
x.flags.writeable = False
return x


@contextmanager
def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
if copy is False and not is_writeable_array(array):
Expand All @@ -42,6 +33,9 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy))


@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
@pytest.mark.parametrize(
("kwargs", "expect_copy"),
[
Expand All @@ -66,15 +60,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
)
def test_update_ops(
xp: ModuleType,
array: Array,
kwargs: dict[str, bool | None],
expect_copy: bool | None,
op: _AtOp,
arg: float,
expect: list[float],
):
if is_pydata_sparse_array(array):
pytest.skip("at() does not support updates on sparse arrays")
array = xp.asarray([10.0, 20.0, 30.0])

with assert_copy(array, expect_copy):
func = cast(Callable[..., Array], getattr(at(array)[1:], op.value)) # type: ignore[no-any-explicit]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def test_device(self, xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3], device=device)
assert get_device(cov(x)) == device

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notably, read-only numpy disallows __iadd__ etc., whereas jax and sparse simply don't define these methods, which causes them to fall back to a = a.__add__(b)

def test_xp(self, xp: ModuleType):
xp_assert_close(
cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp),
Expand Down Expand Up @@ -366,6 +367,7 @@ def test_device(self, xp: ModuleType, device: Device):
x2 = xp.asarray([2, 3, 4], device=device)
assert get_device(setdiff1d(x1, x2)) == device

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY)
def test_xp(self, xp: ModuleType):
x1 = xp.asarray([3, 8, 20])
x2 = xp.asarray([2, 3, 4])
Expand Down
10 changes: 5 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from types import ModuleType

import numpy as np
import pytest

from array_api_extra._lib import Backend
from array_api_extra._lib._testing import xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._helpers import in1d
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra._lib._utils._typing import Device

# mypy: disable-error-code=no-untyped-usage

Expand All @@ -16,10 +15,10 @@ class TestIn1D:
@pytest.mark.skip_xp_backend(Backend.DASK_ARRAY, reason="no argsort")
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse, no device")
# cover both code paths
@pytest.mark.parametrize("x2", [np.arange(9), np.arange(15)])
def test_no_invert_assume_unique(self, xp: ModuleType, x2: Array):
@pytest.mark.parametrize("n", [9, 15])
def test_no_invert_assume_unique(self, xp: ModuleType, n: int):
x1 = xp.asarray([3, 8, 20])
x2 = xp.asarray(x2)
x2 = xp.arange(n)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NumPyReadOnly.asarray(x2, copy=None) was tainting the globally defined x2, ruining it for the following iterations of the test.

expected = xp.asarray([True, True, False])
actual = in1d(x1, x2)
xp_assert_equal(actual, expected)
Expand All @@ -30,6 +29,7 @@ def test_device(self, xp: ModuleType, device: Device):
x2 = xp.asarray([2, 3, 4], device=device)
assert get_device(in1d(x1, x2)) == device

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY)
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device")
def test_xp(self, xp: ModuleType):
x1 = xp.asarray([1, 6])
Expand Down
Loading