Skip to content

Commit a8487a7

Browse files
committed
ENH: tougher restrictions on array_api_strict
1 parent 5326cbd commit a8487a7

File tree

10 files changed

+171
-49
lines changed

10 files changed

+171
-49
lines changed

Diff for: src/array_api_extra/_lib/_at.py

+1
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def _op(
344344
msg = f"Can't update read-only array {x}"
345345
raise ValueError(msg)
346346

347+
# Backends without boolean indexing (other than JAX) crash here
347348
if in_place_op: # add(), subtract(), ...
348349
x[idx] = in_place_op(x[idx], y)
349350
else: # set()

Diff for: src/array_api_extra/_lib/_backends.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
2424
"""
2525

2626
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
27+
ARRAY_API_STRICTEST = "array_api_strictest", _compat.is_array_api_strict_namespace
2728
NUMPY = "numpy", _compat.is_numpy_namespace
2829
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
2930
CUPY = "cupy", _compat.is_cupy_namespace

Diff for: src/array_api_extra/_lib/_funcs.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88

99
from ._at import at
1010
from ._utils import _compat, _helpers
11-
from ._utils._compat import (
12-
array_namespace,
13-
is_dask_namespace,
14-
is_jax_array,
15-
is_jax_namespace,
11+
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
12+
from ._utils._helpers import (
13+
asarrays,
14+
capabilities,
15+
eager_shape,
16+
meta_namespace,
17+
ndindex,
1618
)
17-
from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex
1819
from ._utils._typing import Array
1920

2021
__all__ = [
@@ -152,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
152153
) -> Array:
153154
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
154155

155-
if is_jax_namespace(xp):
156+
if not capabilities(xp)["boolean indexing"]:
156157
# jax.jit does not support assignment by boolean mask
157158
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
158159

@@ -708,14 +709,34 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
708709
# size= is JAX-specific
709710
# https://github.com/data-apis/array-api/issues/883
710711
_, counts = xp.unique_counts(x, size=_compat.size(x))
711-
return xp.astype(counts, xp.bool).sum()
712-
713-
_, counts = xp.unique_counts(x)
714-
n = _compat.size(counts)
715-
# FIXME https://github.com/data-apis/array-api-compat/pull/231
716-
if n is None: # e.g. Dask, ndonnx
717-
return xp.astype(counts, xp.bool).sum()
718-
return xp.asarray(n, device=_compat.device(x))
712+
return (counts > 0).sum()
713+
714+
# There are 3 general use cases:
715+
# 1. backend has unique_counts and it returns an array with known shape
716+
# 2. backend has unique_counts and it returns a None-sized array;
717+
# e.g. Dask, ndonnx
718+
# 3. backend does not have unique_counts; e.g. wrapped JAX
719+
if capabilities(xp)["data-dependent shapes"]:
720+
# xp has unique_counts; O(n) complexity
721+
_, counts = xp.unique_counts(x)
722+
n = _compat.size(counts)
723+
if n is None:
724+
return xp.sum(xp.ones_like(counts))
725+
return xp.asarray(n, device=_compat.device(x))
726+
727+
# xp does not have unique_counts; O(n*logn) complexity
728+
x = xp.sort(xp.reshape(x, -1))
729+
mask = x != xp.roll(x, -1)
730+
default_int = xp.__array_namespace_info__().default_dtypes(
731+
device=_compat.device(x)
732+
)["integral"]
733+
return xp.maximum(
734+
# Special cases:
735+
# - array is size 0
736+
# - array has all elements equal to each other
737+
xp.astype(xp.any(~mask), default_int),
738+
xp.sum(xp.astype(mask, default_int)),
739+
)
719740

720741

721742
def pad(

Diff for: src/array_api_extra/_lib/_utils/_helpers.py

+33
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
array_namespace,
1313
is_array_api_obj,
1414
is_dask_namespace,
15+
is_jax_namespace,
1516
is_numpy_array,
17+
is_pydata_sparse_namespace,
1618
)
1719
from ._typing import Array
1820

@@ -23,6 +25,7 @@
2325

2426
__all__ = [
2527
"asarrays",
28+
"capabilities",
2629
"eager_shape",
2730
"in1d",
2831
"is_python_scalar",
@@ -270,3 +273,33 @@ def meta_namespace(
270273
# Quietly skip scalars and None's
271274
metas = [cast(Array | None, getattr(a, "_meta", None)) for a in arrays]
272275
return array_namespace(*metas)
276+
277+
278+
def capabilities(xp: ModuleType) -> dict[str, int]:
279+
"""
280+
Return patched ``xp.__array_namespace_info__().capabilities()``.
281+
282+
Parameters
283+
----------
284+
xp : array_namespace
285+
The standard-compatible namespace.
286+
287+
Returns
288+
-------
289+
dict
290+
Capabilities of the namespace.
291+
"""
292+
if is_pydata_sparse_namespace(xp):
293+
# No __array_namespace_info__(); no indexing by sparse arrays
294+
return {"boolean indexing": False, "data-dependent shapes": True}
295+
out = xp.__array_namespace_info__().capabilities()
296+
if is_jax_namespace(xp):
297+
# FIXME https://github.com/jax-ml/jax/issues/27418
298+
out = out.copy()
299+
out["boolean indexing"] = False
300+
if is_dask_namespace(xp):
301+
# FIXME https://github.com/data-apis/array-api-compat/pull/290
302+
out = out.copy()
303+
out["boolean indexing"] = True
304+
out["data-dependent shapes"] = True
305+
return out

Diff for: tests/conftest.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Pytest fixtures."""
22

3-
from collections.abc import Callable
3+
from collections.abc import Callable, Generator
44
from contextlib import suppress
55
from functools import partial, wraps
66
from types import ModuleType
@@ -104,7 +104,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
104104
@pytest.fixture
105105
def xp(
106106
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
107-
) -> ModuleType: # numpydoc ignore=PR01,RT03
107+
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT03
108108
"""
109109
Parameterized fixture that iterates on all libraries.
110110
@@ -113,7 +113,27 @@ def xp(
113113
The current array namespace.
114114
"""
115115
if library == Backend.NUMPY_READONLY:
116-
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
116+
yield NumPyReadOnly() # type: ignore[misc] # pyright: ignore[reportReturnType]
117+
return
118+
119+
if (
120+
library in (Backend.ARRAY_API_STRICT, Backend.ARRAY_API_STRICTEST)
121+
and np.__version__ < "1.26"
122+
):
123+
pytest.skip("array_api_strict is untested on NumPy <1.26")
124+
125+
if library == Backend.ARRAY_API_STRICTEST:
126+
xp = pytest.importorskip("array_api_strict")
127+
with xp.ArrayAPIStrictFlags(
128+
boolean_indexing=False,
129+
data_dependent_shapes=False,
130+
# writeable=False, # TODO implement in array-api-strict
131+
# lazy=True, # TODO implement in array-api-strict
132+
enabled_extensions=(),
133+
):
134+
yield xp
135+
return
136+
117137
xp = pytest.importorskip(library.value)
118138
# Possibly wrap module with array_api_compat
119139
xp = array_namespace(xp.empty(0))
@@ -122,16 +142,15 @@ def xp(
122142
# in the global scope of the module containing the test function.
123143
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
124144

125-
if library == Backend.ARRAY_API_STRICT and np.__version__ < "1.26":
126-
pytest.skip("array_api_strict is untested on NumPy <1.26")
127-
128145
if library == Backend.JAX:
129146
import jax
130147

131148
# suppress unused-ignore to run mypy in -e lint as well as -e dev
132149
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
150+
yield xp
151+
return
133152

134-
return xp
153+
yield xp
135154

136155

137156
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`

Diff for: tests/test_at.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
pytestmark = [
2121
pytest.mark.skip_xp_backend(
2222
Backend.SPARSE, reason="read-only backend without .at support"
23-
)
23+
),
24+
pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing"),
2425
]
2526

2627

@@ -256,7 +257,10 @@ def test_incompatible_dtype(
256257
elif library is Backend.DASK:
257258
z = at_op(x, idx, op, 1.1, copy=copy)
258259

259-
elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
260+
elif (
261+
library in (Backend.ARRAY_API_STRICT, Backend.ARRAY_API_STRICTEST)
262+
and op is not _AtOp.SET
263+
):
260264
with pytest.raises(Exception, match=r"cast|promote|dtype"):
261265
_ = at_op(x, idx, op, 1.1, copy=copy)
262266

Diff for: tests/test_funcs.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,9 @@
4848
lazy_xp_function(sinc, static_argnames="xp")
4949

5050

51-
NUMPY_GE2 = int(np.__version__.split(".")[0]) >= 2
51+
NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[2])
5252

5353

54-
@pytest.mark.skip_xp_backend(
55-
Backend.SPARSE, reason="read-only backend without .at support"
56-
)
5754
class TestApplyWhere:
5855
@staticmethod
5956
def f1(x: Array, y: Array | int = 10) -> Array:
@@ -153,6 +150,14 @@ def test_dont_overwrite_fill_value(self, xp: ModuleType):
153150
xp_assert_equal(actual, xp.asarray([100, 12]))
154151
xp_assert_equal(fill_value, xp.asarray([100, 200]))
155152

153+
@pytest.mark.skip_xp_backend(
154+
Backend.ARRAY_API_STRICTEST,
155+
reason="no boolean indexing -> run everywhere",
156+
)
157+
@pytest.mark.skip_xp_backend(
158+
Backend.SPARSE,
159+
reason="no indexing by sparse array -> run everywhere",
160+
)
156161
def test_dont_run_on_false(self, xp: ModuleType):
157162
x = xp.asarray([1.0, 2.0, 0.0])
158163
y = xp.asarray([0.0, 3.0, 4.0])
@@ -192,6 +197,7 @@ def test_device(self, xp: ModuleType, device: Device):
192197
y = apply_where(x % 2 == 0, x, self.f1, fill_value=x)
193198
assert get_device(y) == device
194199

200+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
195201
@pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflows, etc.
196202
@hypothesis.settings(
197203
# The xp and library fixtures are not regenerated between hypothesis iterations
@@ -218,7 +224,7 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any]
218224
):
219225
if (
220226
library in (Backend.NUMPY, Backend.NUMPY_READONLY)
221-
and not NUMPY_GE2
227+
and NUMPY_VERSION < (2, 0)
222228
and dtype is np.float32
223229
):
224230
pytest.xfail(reason="NumPy 1.x dtype promotion for scalars")
@@ -562,6 +568,9 @@ def test_xp(self, xp: ModuleType):
562568
assert y.shape == (1, 1, 1, 3)
563569

564570

571+
@pytest.mark.filterwarnings( # array_api_strictest
572+
"ignore:invalid value encountered:RuntimeWarning:array_api_strict"
573+
)
565574
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
566575
class TestIsClose:
567576
@pytest.mark.parametrize("swap", [False, True])
@@ -680,13 +689,15 @@ def test_bool_dtype(self, xp: ModuleType):
680689
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
681690
)
682691

692+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
683693
def test_none_shape(self, xp: ModuleType):
684694
a = xp.asarray([1, 5, 0])
685695
b = xp.asarray([1, 4, 2])
686696
b = b[a < 5]
687697
a = a[a < 5]
688698
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
689699

700+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
690701
def test_none_shape_bool(self, xp: ModuleType):
691702
a = xp.asarray([True, True, False])
692703
b = xp.asarray([True, False, True])
@@ -819,8 +830,30 @@ def test_empty(self, xp: ModuleType):
819830
a = xp.asarray([])
820831
xp_assert_equal(nunique(a), xp.asarray(0))
821832

822-
def test_device(self, xp: ModuleType, device: Device):
823-
a = xp.asarray(0.0, device=device)
833+
def test_size1(self, xp: ModuleType):
834+
a = xp.asarray([123])
835+
xp_assert_equal(nunique(a), xp.asarray(1))
836+
837+
def test_all_equal(self, xp: ModuleType):
838+
a = xp.asarray([123, 123, 123])
839+
xp_assert_equal(nunique(a), xp.asarray(1))
840+
841+
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="No equal_nan kwarg in unique")
842+
@pytest.mark.xfail_xp_backend(
843+
Backend.SPARSE, reason="Non-compliant equal_nan=True behaviour"
844+
)
845+
def test_nan(self, xp: ModuleType, library: Backend):
846+
is_numpy = library in (Backend.NUMPY, Backend.NUMPY_READONLY)
847+
if is_numpy and NUMPY_VERSION < (1, 24):
848+
pytest.xfail("NumPy <1.24 has no equal_nan kwarg in unique")
849+
850+
# Each NaN is counted separately
851+
a = xp.asarray([xp.nan, 123.0, xp.nan])
852+
xp_assert_equal(nunique(a), xp.asarray(3))
853+
854+
@pytest.mark.parametrize("size", [0, 1, 2])
855+
def test_device(self, xp: ModuleType, device: Device, size: int):
856+
a = xp.asarray([0.0] * size, device=device)
824857
assert get_device(nunique(a)) == device
825858

826859
def test_xp(self, xp: ModuleType):
@@ -895,6 +928,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
895928

896929

897930
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")
931+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_values")
898932
class TestSetDiff1D:
899933
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="NaN-shaped arrays")
900934
@pytest.mark.xfail_xp_backend(

Diff for: tests/test_helpers.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from array_api_extra._lib._utils._compat import device as get_device
1111
from array_api_extra._lib._utils._helpers import (
1212
asarrays,
13+
capabilities,
1314
eager_shape,
1415
in1d,
1516
meta_namespace,
@@ -27,6 +28,7 @@
2728

2829

2930
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse")
31+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_inverse")
3032
class TestIn1D:
3133
# cover both code paths
3234
@pytest.mark.parametrize(
@@ -161,7 +163,8 @@ def test_ndindex(shape: tuple[int, ...]):
161163
assert tuple(ndindex(*shape)) == tuple(np.ndindex(*shape))
162164

163165

164-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
166+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
167+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
165168
def test_eager_shape(xp: ModuleType, library: Backend):
166169
a = xp.asarray([1, 2, 3])
167170
# Lazy arrays, like Dask, have an eager shape until you slice them with
@@ -194,3 +197,10 @@ def test_dask_metas(self, da: ModuleType):
194197
def test_xp(self, xp: ModuleType):
195198
args = None, xp.asarray(0), 1
196199
assert meta_namespace(*args, xp=xp) in (xp, np_compat)
200+
201+
202+
def test_capabilities(xp: ModuleType):
203+
expect = {"boolean indexing", "data-dependent shapes"}
204+
if xp.__array_api_version__ >= "2024.12":
205+
expect.add("max dimensions")
206+
assert capabilities(xp).keys() == expect

0 commit comments

Comments
 (0)