Skip to content

Commit e4989ff

Browse files
crusaderkyNeilGirdhar
authored andcommitted
TST: xfail_xp_backend (#132)
* TST: `xfail_xp_backend` * nit * nit * nit * pixi update
1 parent e38480d commit e4989ff

File tree

8 files changed

+1792
-1865
lines changed

8 files changed

+1792
-1865
lines changed

Diff for: pixi.lock

+1,652-1,802
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ xfail_strict = true
181181
filterwarnings = ["error"]
182182
log_cli_level = "INFO"
183183
testpaths = ["tests"]
184-
markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific backend"]
184+
markers = [
185+
"skip_xp_backend(library, *, reason=None): Skip test for a specific backend",
186+
"xfail_xp_backend(library, *, reason=None): Xfail test for a specific backend",
187+
]
185188

186189

187190
# Coverage

Diff for: src/array_api_extra/_lib/_testing.py

+20
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import math
99
from types import ModuleType
1010

11+
import pytest
12+
1113
from ._utils._compat import (
1214
array_namespace,
1315
is_cupy_namespace,
@@ -170,3 +172,21 @@ def xp_assert_close(
170172
np.testing.assert_allclose(
171173
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
172174
)
175+
176+
177+
def xfail(request: pytest.FixtureRequest, reason: str) -> None:
178+
"""
179+
XFAIL the currently running test.
180+
181+
Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately
182+
halting it, so that it may result in a XPASS.
183+
xref https://github.com/pandas-dev/pandas/issues/38902
184+
185+
Parameters
186+
----------
187+
request : pytest.FixtureRequest
188+
``request`` argument of the test function.
189+
reason : str
190+
Reason for the expected failure.
191+
"""
192+
request.node.add_marker(pytest.mark.xfail(reason=reason))

Diff for: tests/conftest.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
from collections.abc import Callable
44
from contextlib import suppress
5-
from functools import wraps
5+
from functools import partial, wraps
66
from types import ModuleType
77
from typing import ParamSpec, TypeVar, cast
88

99
import numpy as np
1010
import pytest
1111

1212
from array_api_extra._lib import Backend
13+
from array_api_extra._lib._testing import xfail
1314
from array_api_extra._lib._utils._compat import array_namespace
1415
from array_api_extra._lib._utils._compat import device as get_device
1516
from array_api_extra._lib._utils._typing import Device
@@ -32,16 +33,20 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,
3233
"""
3334
elem = cast(Backend, request.param)
3435

35-
for marker in request.node.iter_markers("skip_xp_backend"):
36-
skip_library = marker.kwargs.get("library") or marker.args[0] # type: ignore[no-untyped-usage]
37-
if not isinstance(skip_library, Backend):
38-
msg = "argument of skip_xp_backend must be a Backend enum"
39-
raise TypeError(msg)
40-
if skip_library == elem:
41-
reason = skip_library.value
42-
with suppress(KeyError):
43-
reason += ":" + cast(str, marker.kwargs["reason"])
44-
pytest.skip(reason=reason)
36+
for marker_name, skip_or_xfail in (
37+
("skip_xp_backend", pytest.skip),
38+
("xfail_xp_backend", partial(xfail, request)),
39+
):
40+
for marker in request.node.iter_markers(marker_name):
41+
library = marker.kwargs.get("library") or marker.args[0] # type: ignore[no-untyped-usage]
42+
if not isinstance(library, Backend):
43+
msg = f"argument of {marker_name} must be a Backend enum"
44+
raise TypeError(msg)
45+
if library == elem:
46+
reason = library.value
47+
with suppress(KeyError):
48+
reason += ":" + cast(str, marker.kwargs["reason"])
49+
skip_or_xfail(reason=reason)
4550

4651
return elem
4752

Diff for: tests/test_at.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from array_api_extra import at
1212
from array_api_extra._lib import Backend
1313
from array_api_extra._lib._at import _AtOp
14-
from array_api_extra._lib._testing import xp_assert_equal
14+
from array_api_extra._lib._testing import xfail, xp_assert_equal
1515
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
1616
from array_api_extra._lib._utils._typing import Array, Index
1717
from array_api_extra.testing import lazy_xp_function
@@ -80,10 +80,12 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
8080
@pytest.mark.parametrize(
8181
("kwargs", "expect_copy"),
8282
[
83-
({"copy": True}, True),
84-
({"copy": False}, False),
85-
({"copy": None}, None), # Behavior is backend-specific
86-
({}, None), # Test that the copy parameter defaults to None
83+
pytest.param({"copy": True}, True, id="copy=True"),
84+
pytest.param({"copy": False}, False, id="copy=False"),
85+
# Behavior is backend-specific
86+
pytest.param({"copy": None}, None, id="copy=None"),
87+
# Test that the copy parameter defaults to None
88+
pytest.param({}, None, id="no copy kwarg"),
8789
],
8890
)
8991
@pytest.mark.parametrize(
@@ -109,10 +111,10 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
109111
True,
110112
True,
111113
marks=(
112-
pytest.mark.skip_xp_backend(
114+
pytest.mark.skip_xp_backend( # test passes when copy=False
113115
Backend.JAX, reason="bool mask update with shaped rhs"
114116
),
115-
pytest.mark.skip_xp_backend(
117+
pytest.mark.xfail_xp_backend(
116118
Backend.DASK, reason="bool mask update with shaped rhs"
117119
),
118120
),
@@ -177,7 +179,12 @@ def test_alternate_index_syntax():
177179
@pytest.mark.parametrize("bool_mask", [False, True])
178180
@pytest.mark.parametrize("op", list(_AtOp))
179181
def test_incompatible_dtype(
180-
xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None, bool_mask: bool
182+
xp: ModuleType,
183+
library: Backend,
184+
request: pytest.FixtureRequest,
185+
op: _AtOp,
186+
copy: bool | None,
187+
bool_mask: bool,
181188
):
182189
"""Test that at() replicates the backend's behaviour for
183190
in-place operations with incompatible dtypes.
@@ -208,8 +215,8 @@ def test_incompatible_dtype(
208215
z = at_op(x, idx, op, 1.1, copy=copy)
209216

210217
elif library is Backend.DASK:
211-
if op in (_AtOp.MIN, _AtOp.MAX):
212-
pytest.xfail(reason="need array-api-compat 1.11")
218+
if op in (_AtOp.MIN, _AtOp.MAX) and bool_mask:
219+
xfail(request, reason="need array-api-compat 1.11")
213220
z = at_op(x, idx, op, 1.1, copy=copy)
214221

215222
elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
@@ -234,8 +241,18 @@ def test_bool_mask_nd(xp: ModuleType):
234241
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))
235242

236243

237-
@pytest.mark.skip_xp_backend(Backend.DASK, reason="FIXME need scipy's lazywhere")
238-
@pytest.mark.parametrize("bool_mask", [False, True])
244+
@pytest.mark.parametrize(
245+
"bool_mask",
246+
[
247+
False,
248+
pytest.param(
249+
True,
250+
marks=pytest.mark.xfail_xp_backend(
251+
Backend.DASK, reason="FIXME need scipy's lazywhere"
252+
),
253+
),
254+
],
255+
)
239256
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
240257
x = xp.asarray([math.inf, 1.0, 2.0])
241258
idx = ~xp.isinf(x) if bool_mask else slice(1, None)

Diff for: tests/test_funcs.py

+45-24
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
lazy_xp_function(sinc, jax_jit=False, static_argnames="xp")
4242

4343

44-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
44+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
4545
class TestAtLeastND:
4646
def test_0D(self, xp: ModuleType):
4747
x = xp.asarray(1.0)
@@ -108,12 +108,12 @@ def test_device(self, xp: ModuleType, device: Device):
108108
assert get_device(atleast_nd(x, ndim=2)) == device
109109

110110
def test_xp(self, xp: ModuleType):
111-
x = xp.asarray(1)
112-
y = atleast_nd(x, ndim=0, xp=xp)
113-
xp_assert_equal(y, x)
111+
x = xp.asarray(1.0)
112+
y = atleast_nd(x, ndim=1, xp=xp)
113+
xp_assert_equal(y, xp.ones((1,)))
114114

115115

116-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
116+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
117117
class TestCov:
118118
def test_basic(self, xp: ModuleType):
119119
xp_assert_close(
@@ -152,16 +152,16 @@ def test_device(self, xp: ModuleType, device: Device):
152152
x = xp.asarray([1, 2, 3], device=device)
153153
assert get_device(cov(x)) == device
154154

155-
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="explicit xp")
155+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
156156
def test_xp(self, xp: ModuleType):
157157
xp_assert_close(
158158
cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp),
159159
xp.asarray([[1.0, -1.0], [-1.0, 1.0]], dtype=xp.float64),
160160
)
161161

162162

163-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device")
164163
class TestCreateDiagonal:
164+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
165165
def test_1d(self, xp: ModuleType):
166166
# from np.diag tests
167167
vals = 100 * xp.arange(5, dtype=xp.float64)
@@ -177,6 +177,7 @@ def test_1d(self, xp: ModuleType):
177177
xp_assert_equal(create_diagonal(vals, offset=2), b)
178178
xp_assert_equal(create_diagonal(vals, offset=-2), c)
179179

180+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
180181
@pytest.mark.parametrize("n", range(1, 10))
181182
@pytest.mark.parametrize("offset", range(1, 10))
182183
def test_create_diagonal(self, xp: ModuleType, n: int, offset: int):
@@ -196,20 +197,22 @@ def test_2d(self, xp: ModuleType):
196197
with pytest.raises(ValueError, match="1-dimensional"):
197198
create_diagonal(xp.asarray([[1]]))
198199

200+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
199201
def test_device(self, xp: ModuleType, device: Device):
200202
x = xp.asarray([1, 2, 3], device=device)
201203
assert get_device(create_diagonal(x)) == device
202204

205+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
203206
def test_xp(self, xp: ModuleType):
204207
x = xp.asarray([1, 2])
205208
y = create_diagonal(x, xp=xp)
206209
xp_assert_equal(y, xp.asarray([[1, 0], [0, 2]]))
207210

208211

209-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
210212
class TestExpandDims:
211-
@pytest.mark.skip_xp_backend(Backend.DASK, reason="tuple index out of range")
212-
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="tuple index out of range")
213+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
214+
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="tuple index out of range")
215+
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="tuple index out of range")
213216
def test_functionality(self, xp: ModuleType):
214217
def _squeeze_all(b: Array) -> Array:
215218
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
@@ -225,6 +228,7 @@ def _squeeze_all(b: Array) -> Array:
225228
assert b.shape[axis] == 1
226229
assert _squeeze_all(b).shape == s
227230

231+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
228232
def test_axis_tuple(self, xp: ModuleType):
229233
a = xp.empty((3, 3, 3))
230234
assert expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3)
@@ -257,17 +261,19 @@ def test_positive_negative_repeated(self, xp: ModuleType):
257261
with pytest.raises(ValueError, match="Duplicate dimensions"):
258262
expand_dims(a, axis=(3, -3))
259263

264+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
260265
def test_device(self, xp: ModuleType, device: Device):
261266
x = xp.asarray([1, 2, 3], device=device)
262267
assert get_device(expand_dims(x, axis=0)) == device
263268

269+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
264270
def test_xp(self, xp: ModuleType):
265271
x = xp.asarray([1, 2, 3])
266272
y = expand_dims(x, axis=(0, 1, 2), xp=xp)
267273
assert y.shape == (1, 1, 1, 3)
268274

269275

270-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
276+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
271277
class TestIsClose:
272278
# FIXME use lazywhere to avoid warnings on inf
273279
@pytest.mark.filterwarnings("ignore:invalid value encountered")
@@ -402,7 +408,7 @@ def test_none_shape_bool(self, xp: ModuleType):
402408
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
403409

404410
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
405-
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="Array API 2024.12 support")
411+
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="Array API 2024.12 support")
406412
def test_python_scalar(self, xp: ModuleType):
407413
a = xp.asarray([0.0, 0.1], dtype=xp.float32)
408414
xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False]))
@@ -425,7 +431,7 @@ def test_xp(self, xp: ModuleType):
425431
xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False]))
426432

427433

428-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
434+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
429435
class TestKron:
430436
def test_basic(self, xp: ModuleType):
431437
# Using 0-dimensional array
@@ -526,7 +532,7 @@ def test_xp(self, xp: ModuleType):
526532
xp_assert_equal(nunique(a, xp=xp), xp.asarray(3))
527533

528534

529-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device")
535+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange, no device")
530536
class TestPad:
531537
def test_simple(self, xp: ModuleType):
532538
a = xp.arange(1, 4)
@@ -576,10 +582,24 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
576582
assert padded.shape == (4, 4)
577583

578584

579-
@pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort")
580-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray")
585+
assume_unique = pytest.mark.parametrize(
586+
"assume_unique",
587+
[
588+
True,
589+
pytest.param(
590+
False,
591+
marks=pytest.mark.xfail_xp_backend(
592+
Backend.DASK, reason="NaN-shaped arrays"
593+
),
594+
),
595+
],
596+
)
597+
598+
599+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray()")
581600
class TestSetDiff1D:
582-
@pytest.mark.skip_xp_backend(
601+
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="NaN-shaped arrays")
602+
@pytest.mark.xfail_xp_backend(
583603
Backend.TORCH, reason="index_select not implemented for uint32"
584604
)
585605
def test_setdiff1d(self, xp: ModuleType):
@@ -608,7 +628,7 @@ def test_assume_unique(self, xp: ModuleType):
608628
actual = setdiff1d(x1, x2, assume_unique=True)
609629
xp_assert_equal(actual, expected)
610630

611-
@pytest.mark.parametrize("assume_unique", [True, False])
631+
@assume_unique
612632
@pytest.mark.parametrize("shape1", [(), (1,), (1, 1)])
613633
@pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
614634
def test_shapes(
@@ -623,8 +643,8 @@ def test_shapes(
623643
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
624644
xp_assert_equal(actual, xp.empty((0,)))
625645

646+
@assume_unique
626647
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
627-
@pytest.mark.parametrize("assume_unique", [True, False])
628648
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
629649
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
630650
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
@@ -645,21 +665,22 @@ def test_all_python_scalars(self, assume_unique: bool):
645665
with pytest.raises(TypeError, match="Unrecognized"):
646666
setdiff1d(0, 0, assume_unique=assume_unique)
647667

648-
def test_device(self, xp: ModuleType, device: Device):
668+
@assume_unique
669+
def test_device(self, xp: ModuleType, device: Device, assume_unique: bool):
649670
x1 = xp.asarray([3, 8, 20], device=device)
650671
x2 = xp.asarray([2, 3, 4], device=device)
651-
assert get_device(setdiff1d(x1, x2)) == device
672+
assert get_device(setdiff1d(x1, x2, assume_unique=assume_unique)) == device
652673

653-
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="explicit xp")
674+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
654675
def test_xp(self, xp: ModuleType):
655676
x1 = xp.asarray([3, 8, 20])
656677
x2 = xp.asarray([2, 3, 4])
657678
expected = xp.asarray([8, 20])
658-
actual = setdiff1d(x1, x2, xp=xp)
679+
actual = setdiff1d(x1, x2, assume_unique=True, xp=xp)
659680
xp_assert_equal(actual, expected)
660681

661682

662-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
683+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
663684
class TestSinc:
664685
def test_simple(self, xp: ModuleType):
665686
xp_assert_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0))

0 commit comments

Comments
 (0)