Skip to content

ENH/TST: tougher restrictions on array_api_strict #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def _op(
msg = f"Can't update read-only array {x}"
raise ValueError(msg)

# Backends without boolean indexing (other than JAX) crash here
if in_place_op: # add(), subtract(), ...
x[idx] = in_place_op(x[idx], y)
else: # set()
Expand Down
21 changes: 18 additions & 3 deletions src/array_api_extra/_lib/_backends.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Backends with which array-api-extra interacts in delegation and testing."""

from __future__ import annotations

from collections.abc import Callable
from enum import Enum
from types import ModuleType
from typing import cast

from ._utils import _compat

Expand All @@ -23,9 +24,14 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
corresponding to the backend.
"""

# Use :<tag> to prevent Enum from deduplicating items with the same value
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
ARRAY_API_STRICTEST = (
"array_api_strict:strictest",
_compat.is_array_api_strict_namespace,
)
NUMPY = "numpy", _compat.is_numpy_namespace
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
NUMPY_READONLY = "numpy:readonly", _compat.is_numpy_namespace
CUPY = "cupy", _compat.is_cupy_namespace
TORCH = "torch", _compat.is_torch_namespace
DASK = "dask.array", _compat.is_dask_namespace
Expand All @@ -48,4 +54,13 @@ def __init__(

def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
"""Pretty-print parameterized test names."""
return cast(str, self.value)
return self.name.lower()

@property
def modname(self) -> str: # numpydoc ignore=RT01
"""Module name to be imported."""
return self.value.split(":")[0]

def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
"""Check if this backend uses the same module as others."""
return any(self.modname == other.modname for other in others)
51 changes: 36 additions & 15 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import (
array_namespace,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
from ._utils._helpers import (
asarrays,
capabilities,
eager_shape,
meta_namespace,
ndindex,
)
from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex
from ._utils._typing import Array

__all__ = [
Expand Down Expand Up @@ -152,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
) -> Array:
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""

if is_jax_namespace(xp):
if not capabilities(xp)["boolean indexing"]:
# jax.jit does not support assignment by boolean mask
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)

Expand Down Expand Up @@ -708,14 +709,34 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
# size= is JAX-specific
# https://github.com/data-apis/array-api/issues/883
_, counts = xp.unique_counts(x, size=_compat.size(x))
return xp.astype(counts, xp.bool).sum()

_, counts = xp.unique_counts(x)
n = _compat.size(counts)
# FIXME https://github.com/data-apis/array-api-compat/pull/231
if n is None: # e.g. Dask, ndonnx
return xp.astype(counts, xp.bool).sum()
return xp.asarray(n, device=_compat.device(x))
return (counts > 0).sum()

# There are 3 general use cases:
# 1. backend has unique_counts and it returns an array with known shape
# 2. backend has unique_counts and it returns a None-sized array;
# e.g. Dask, ndonnx
# 3. backend does not have unique_counts; e.g. wrapped JAX
if capabilities(xp)["data-dependent shapes"]:
# xp has unique_counts; O(n) complexity
_, counts = xp.unique_counts(x)
n = _compat.size(counts)
if n is None:
return xp.sum(xp.ones_like(counts))
return xp.asarray(n, device=_compat.device(x))

# xp does not have unique_counts; O(n*logn) complexity
x = xp.sort(xp.reshape(x, -1))
mask = x != xp.roll(x, -1)
default_int = xp.__array_namespace_info__().default_dtypes(
device=_compat.device(x)
)["integral"]
return xp.maximum(
# Special cases:
# - array is size 0
# - array has all elements equal to each other
xp.astype(xp.any(~mask), default_int),
xp.sum(xp.astype(mask, default_int)),
)


def pad(
Expand Down
36 changes: 36 additions & 0 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
array_namespace,
is_array_api_obj,
is_dask_namespace,
is_jax_namespace,
is_numpy_array,
is_pydata_sparse_namespace,
)
from ._typing import Array

Expand All @@ -23,6 +25,7 @@

__all__ = [
"asarrays",
"capabilities",
"eager_shape",
"in1d",
"is_python_scalar",
Expand Down Expand Up @@ -270,3 +273,36 @@ def meta_namespace(
# Quietly skip scalars and None's
metas = [cast(Array | None, getattr(a, "_meta", None)) for a in arrays]
return array_namespace(*metas)


def capabilities(xp: ModuleType) -> dict[str, int]:
"""
Return patched ``xp.__array_namespace_info__().capabilities()``.

TODO this helper should be eventually removed once all the special cases
it handles are fixed in the respective backends.

Parameters
----------
xp : array_namespace
The standard-compatible namespace.

Returns
-------
dict
Capabilities of the namespace.
"""
if is_pydata_sparse_namespace(xp):
# No __array_namespace_info__(); no indexing by sparse arrays
return {"boolean indexing": False, "data-dependent shapes": True}
out = xp.__array_namespace_info__().capabilities()
if is_jax_namespace(xp):
# FIXME https://github.com/jax-ml/jax/issues/27418
out = out.copy()
out["boolean indexing"] = False
if is_dask_namespace(xp):
# FIXME https://github.com/data-apis/array-api-compat/pull/290
out = out.copy()
out["boolean indexing"] = True
out["data-dependent shapes"] = True
return out
32 changes: 23 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Pytest fixtures."""

from collections.abc import Callable
from collections.abc import Callable, Generator
from contextlib import suppress
from functools import partial, wraps
from types import ModuleType
Expand All @@ -19,6 +19,7 @@
T = TypeVar("T")
P = ParamSpec("P")

NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[2])
np_compat = array_namespace(np.empty(0)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]


Expand All @@ -43,7 +44,7 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,
msg = f"argument of {marker_name} must be a Backend enum"
raise TypeError(msg)
if library == elem:
reason = library.value
reason = str(library)
with suppress(KeyError):
reason += ":" + cast(str, marker.kwargs["reason"])
skip_or_xfail(reason=reason)
Expand Down Expand Up @@ -104,7 +105,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
@pytest.fixture
def xp(
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
) -> ModuleType: # numpydoc ignore=PR01,RT03
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT03
"""
Parameterized fixture that iterates on all libraries.

Expand All @@ -113,25 +114,38 @@ def xp(
The current array namespace.
"""
if library == Backend.NUMPY_READONLY:
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
xp = pytest.importorskip(library.value)
yield NumPyReadOnly() # type: ignore[misc] # pyright: ignore[reportReturnType]
return

if library.like(Backend.ARRAY_API_STRICT) and NUMPY_VERSION < (1, 26):
pytest.skip("array_api_strict is untested on NumPy <1.26")

xp = pytest.importorskip(library.modname)
# Possibly wrap module with array_api_compat
xp = array_namespace(xp.empty(0))

if library == Backend.ARRAY_API_STRICTEST:
with xp.ArrayAPIStrictFlags(
boolean_indexing=False,
data_dependent_shapes=False,
# writeable=False, # TODO implement in array-api-strict
# lazy=True, # TODO implement in array-api-strict
enabled_extensions=(),
):
yield xp
return

# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
# in the global scope of the module containing the test function.
patch_lazy_xp_functions(request, monkeypatch, xp=xp)

if library == Backend.ARRAY_API_STRICT and np.__version__ < "1.26":
pytest.skip("array_api_strict is untested on NumPy <1.26")

if library == Backend.JAX:
import jax

# suppress unused-ignore to run mypy in -e lint as well as -e dev
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]

return xp
yield xp


@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
Expand Down
5 changes: 3 additions & 2 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
pytestmark = [
pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
),
pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing"),
]


Expand Down Expand Up @@ -256,7 +257,7 @@ def test_incompatible_dtype(
elif library is Backend.DASK:
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
elif library.like(Backend.ARRAY_API_STRICT) and op is not _AtOp.SET:
with pytest.raises(Exception, match=r"cast|promote|dtype"):
_ = at_op(x, idx, op, 1.1, copy=copy)

Expand Down
50 changes: 40 additions & 10 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function

from .conftest import NUMPY_VERSION

# some xp backends are untyped
# mypy: disable-error-code=no-untyped-def

Expand All @@ -48,12 +50,6 @@
lazy_xp_function(sinc, static_argnames="xp")


NUMPY_GE2 = int(np.__version__.split(".")[0]) >= 2


@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
class TestApplyWhere:
@staticmethod
def f1(x: Array, y: Array | int = 10) -> Array:
Expand Down Expand Up @@ -153,6 +149,14 @@ def test_dont_overwrite_fill_value(self, xp: ModuleType):
xp_assert_equal(actual, xp.asarray([100, 12]))
xp_assert_equal(fill_value, xp.asarray([100, 200]))

@pytest.mark.skip_xp_backend(
Backend.ARRAY_API_STRICTEST,
reason="no boolean indexing -> run everywhere",
)
@pytest.mark.skip_xp_backend(
Backend.SPARSE,
reason="no indexing by sparse array -> run everywhere",
)
def test_dont_run_on_false(self, xp: ModuleType):
x = xp.asarray([1.0, 2.0, 0.0])
y = xp.asarray([0.0, 3.0, 4.0])
Expand Down Expand Up @@ -192,6 +196,7 @@ def test_device(self, xp: ModuleType, device: Device):
y = apply_where(x % 2 == 0, x, self.f1, fill_value=x)
assert get_device(y) == device

@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflows, etc.
@hypothesis.settings(
# The xp and library fixtures are not regenerated between hypothesis iterations
Expand All @@ -217,8 +222,8 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any]
library: Backend,
):
if (
library in (Backend.NUMPY, Backend.NUMPY_READONLY)
and not NUMPY_GE2
library.like(Backend.NUMPY)
and NUMPY_VERSION < (2, 0)
and dtype is np.float32
):
pytest.xfail(reason="NumPy 1.x dtype promotion for scalars")
Expand Down Expand Up @@ -562,6 +567,9 @@ def test_xp(self, xp: ModuleType):
assert y.shape == (1, 1, 1, 3)


@pytest.mark.filterwarnings( # array_api_strictest
"ignore:invalid value encountered:RuntimeWarning:array_api_strict"
)
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
class TestIsClose:
@pytest.mark.parametrize("swap", [False, True])
Expand Down Expand Up @@ -680,13 +688,15 @@ def test_bool_dtype(self, xp: ModuleType):
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
)

@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
def test_none_shape(self, xp: ModuleType):
a = xp.asarray([1, 5, 0])
b = xp.asarray([1, 4, 2])
b = b[a < 5]
a = a[a < 5]
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))

@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
def test_none_shape_bool(self, xp: ModuleType):
a = xp.asarray([True, True, False])
b = xp.asarray([True, False, True])
Expand Down Expand Up @@ -819,8 +829,27 @@ def test_empty(self, xp: ModuleType):
a = xp.asarray([])
xp_assert_equal(nunique(a), xp.asarray(0))

def test_device(self, xp: ModuleType, device: Device):
a = xp.asarray(0.0, device=device)
def test_size1(self, xp: ModuleType):
a = xp.asarray([123])
xp_assert_equal(nunique(a), xp.asarray(1))

def test_all_equal(self, xp: ModuleType):
a = xp.asarray([123, 123, 123])
xp_assert_equal(nunique(a), xp.asarray(1))

@pytest.mark.xfail_xp_backend(Backend.DASK, reason="No equal_nan kwarg in unique")
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#855")
def test_nan(self, xp: ModuleType, library: Backend):
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
pytest.xfail("NumPy <1.24 has no equal_nan kwarg in unique")

# Each NaN is counted separately
a = xp.asarray([xp.nan, 123.0, xp.nan])
xp_assert_equal(nunique(a), xp.asarray(3))

@pytest.mark.parametrize("size", [0, 1, 2])
def test_device(self, xp: ModuleType, device: Device, size: int):
a = xp.asarray([0.0] * size, device=device)
assert get_device(nunique(a)) == device

def test_xp(self, xp: ModuleType):
Expand Down Expand Up @@ -895,6 +924,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_values")
class TestSetDiff1D:
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="NaN-shaped arrays")
@pytest.mark.xfail_xp_backend(
Expand Down
Loading