Skip to content

Commit 86bdec2

Browse files
committed
Bump array-api-compat to 1.11
1 parent 0db3325 commit 86bdec2

File tree

8 files changed

+548
-957
lines changed

8 files changed

+548
-957
lines changed

Diff for: pixi.lock

+536-946
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Typing :: Typed",
2727
]
2828
dynamic = ["version"]
29-
dependencies = ["array-api-compat>=1.10.0,<2"]
29+
dependencies = ["array-api-compat>=1.11,<2"]
3030

3131
[project.urls]
3232
Homepage = "https://github.com/data-apis/array-api-extra"
@@ -48,7 +48,7 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
4848

4949
[tool.pixi.dependencies]
5050
python = ">=3.10,<3.14"
51-
array-api-compat = ">=1.10.0,<2"
51+
array-api-compat = ">=1.11,<2"
5252

5353
[tool.pixi.pypi-dependencies]
5454
array-api-extra = { path = ".", editable = true }

Diff for: src/array_api_extra/_lib/_funcs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
556556
_, counts = xp.unique_counts(x)
557557
n = _compat.size(counts)
558558
# FIXME https://github.com/data-apis/array-api-compat/pull/231
559-
if n is None or math.isnan(n): # e.g. Dask, ndonnx
559+
if n is None: # e.g. Dask, ndonnx
560560
return xp.astype(counts, xp.bool).sum()
561561
return xp.asarray(n, device=_compat.device(x))
562562

Diff for: src/array_api_extra/_lib/_utils/_compat.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
is_dask_namespace,
1515
is_jax_array,
1616
is_jax_namespace,
17+
is_lazy_array,
1718
is_numpy_array,
1819
is_numpy_namespace,
1920
is_pydata_sparse_array,
@@ -35,6 +36,7 @@
3536
is_dask_namespace,
3637
is_jax_array,
3738
is_jax_namespace,
39+
is_lazy_array,
3840
is_numpy_array,
3941
is_numpy_namespace,
4042
is_pydata_sparse_array,
@@ -56,6 +58,7 @@
5658
"is_dask_namespace",
5759
"is_jax_array",
5860
"is_jax_namespace",
61+
"is_lazy_array",
5962
"is_numpy_array",
6063
"is_numpy_namespace",
6164
"is_pydata_sparse_array",

Diff for: src/array_api_extra/_lib/_utils/_compat.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@ def is_jax_array(x: object, /) -> bool: ...
3232
def is_numpy_array(x: object, /) -> bool: ...
3333
def is_pydata_sparse_array(x: object, /) -> bool: ...
3434
def is_torch_array(x: object, /) -> bool: ...
35+
def is_lazy_array(x: object, /) -> bool: ...
3536
def is_writeable_array(x: object, /) -> bool: ...
3637
def size(x: Array, /) -> int | None: ...

Diff for: tests/test_at.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from array_api_extra import at
1212
from array_api_extra._lib import Backend
1313
from array_api_extra._lib._at import _AtOp
14-
from array_api_extra._lib._testing import xfail, xp_assert_equal
14+
from array_api_extra._lib._testing import xp_assert_equal
1515
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
1616
from array_api_extra._lib._utils._typing import Array, Index
1717
from array_api_extra.testing import lazy_xp_function
@@ -219,7 +219,6 @@ def test_alternate_index_syntax():
219219
def test_incompatible_dtype(
220220
xp: ModuleType,
221221
library: Backend,
222-
request: pytest.FixtureRequest,
223222
op: _AtOp,
224223
copy: bool | None,
225224
bool_mask: bool,
@@ -253,8 +252,6 @@ def test_incompatible_dtype(
253252
z = at_op(x, idx, op, 1.1, copy=copy)
254253

255254
elif library is Backend.DASK:
256-
if op in (_AtOp.MIN, _AtOp.MAX) and bool_mask:
257-
xfail(request, reason="need array-api-compat 1.11")
258255
z = at_op(x, idx, op, 1.1, copy=copy)
259256

260257
elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:

Diff for: tests/test_funcs.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,14 @@
3232

3333
lazy_xp_function(atleast_nd, static_argnames=("ndim", "xp"))
3434
lazy_xp_function(cov, static_argnames="xp")
35-
# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
36-
lazy_xp_function(create_diagonal, jax_jit=False, static_argnames=("offset", "xp"))
35+
lazy_xp_function(create_diagonal, static_argnames=("offset", "xp"))
3736
lazy_xp_function(expand_dims, static_argnames=("axis", "xp"))
3837
lazy_xp_function(kron, static_argnames="xp")
3938
lazy_xp_function(nunique, static_argnames="xp")
4039
lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp"))
4140
# FIXME calls in1d which calls xp.unique_values without size
4241
lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp"))
43-
# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
44-
lazy_xp_function(sinc, jax_jit=False, static_argnames="xp")
42+
lazy_xp_function(sinc, static_argnames="xp")
4543

4644

4745
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")

Diff for: vendor_tests/test_vendor.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def test_vendor_compat():
1414
is_dask_namespace,
1515
is_jax_array,
1616
is_jax_namespace,
17+
is_lazy_array,
1718
is_numpy_array,
1819
is_numpy_namespace,
1920
is_pydata_sparse_array,
@@ -35,6 +36,7 @@ def test_vendor_compat():
3536
assert not is_dask_namespace(xp)
3637
assert not is_jax_array(x)
3738
assert not is_jax_namespace(xp)
39+
assert not is_lazy_array(x)
3840
assert not is_numpy_array(x)
3941
assert not is_numpy_namespace(xp)
4042
assert not is_pydata_sparse_array(x)

0 commit comments

Comments
 (0)