Skip to content

Commit c3861a7

Browse files
committed
Test none shapes
1 parent 5f54c92 commit c3861a7

File tree

3 files changed

+58
-14
lines changed

3 files changed

+58
-14
lines changed

Diff for: src/array_api_extra/_lib/_testing.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
Note that this is private API; don't expect it to be stable.
55
"""
66

7+
import math
78
from types import ModuleType
89

910
from ._utils._compat import (
1011
array_namespace,
1112
is_cupy_namespace,
13+
is_dask_namespace,
1214
is_pydata_sparse_namespace,
1315
is_torch_namespace,
1416
)
@@ -40,8 +42,16 @@ def _check_ns_shape_dtype(
4042
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
4143
assert actual_xp == desired_xp, msg
4244

43-
msg = f"shapes do not match: {actual.shape} != f{desired.shape}"
44-
assert actual.shape == desired.shape, msg
45+
actual_shape = actual.shape
46+
desired_shape = desired.shape
47+
if is_dask_namespace(desired_xp):
48+
if any(math.isnan(i) for i in actual_shape):
49+
actual_shape = actual.compute().shape
50+
if any(math.isnan(i) for i in desired_shape):
51+
desired_shape = desired.compute().shape
52+
53+
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
54+
assert actual_shape == desired_shape, msg
4555

4656
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
4757
assert actual.dtype == desired.dtype, msg

Diff for: tests/test_funcs.py

+14
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,20 @@ def test_bool_dtype(self, xp: ModuleType):
366366
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
367367
)
368368

369+
def test_none_shape(self, xp: ModuleType):
370+
a = xp.asarray([1, 5, 0])
371+
b = xp.asarray([1, 4, 2])
372+
b = b[a < 5]
373+
a = a[a < 5]
374+
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
375+
376+
def test_none_shape_bool(self, xp: ModuleType):
377+
a = xp.asarray([True, True, False])
378+
b = xp.asarray([True, False, True])
379+
b = b[a]
380+
a = a[a]
381+
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
382+
369383
def test_xp(self, xp: ModuleType):
370384
a = xp.asarray([0.0, 0.0])
371385
b = xp.asarray([1e-9, 1e-4])

Diff for: tests/test_testing.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
# mypy: disable-error-code=no-any-decorated
1111
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
1212

13-
14-
@pytest.mark.parametrize(
13+
param_assert_equal_close = pytest.mark.parametrize(
1514
"func",
1615
[
1716
xp_assert_equal,
@@ -21,6 +20,9 @@
2120
),
2221
],
2322
)
23+
24+
25+
@param_assert_equal_close
2426
def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # type: ignore[no-any-explicit]
2527
func(xp.asarray(0), xp.asarray(0))
2628
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
@@ -40,16 +42,7 @@ def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): #
4042

4143
@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
4244
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="test other ns vs. numpy")
43-
@pytest.mark.parametrize(
44-
"func",
45-
[
46-
xp_assert_equal,
47-
pytest.param(
48-
xp_assert_close,
49-
marks=pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype"),
50-
),
51-
],
52-
)
45+
@param_assert_equal_close
5346
def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None]): # type: ignore[no-any-explicit]
5447
with pytest.raises(AssertionError):
5548
func(xp.asarray(0), np.asarray(0))
@@ -68,3 +61,30 @@ def test_assert_close_tolerance(xp: ModuleType):
6861
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=3)
6962
with pytest.raises(AssertionError):
7063
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1)
64+
65+
66+
@param_assert_equal_close
67+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no bool indexing by sparse arrays")
68+
def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]): # type: ignore[no-any-explicit]
69+
"""On dask and other lazy backends, test that a shape with NaN's or None's
70+
can be compared to a real shape.
71+
"""
72+
a = xp.asarray([1, 2])
73+
a = a[a > 1]
74+
75+
func(a, xp.asarray([2]))
76+
with pytest.raises(AssertionError):
77+
func(a, xp.asarray([2, 3]))
78+
with pytest.raises(AssertionError):
79+
func(a, xp.asarray(2))
80+
with pytest.raises(AssertionError):
81+
func(a, xp.asarray([3]))
82+
83+
# Swap actual and desired
84+
func(xp.asarray([2]), a)
85+
with pytest.raises(AssertionError):
86+
func(xp.asarray([2, 3]), a)
87+
with pytest.raises(AssertionError):
88+
func(xp.asarray(2), a)
89+
with pytest.raises(AssertionError):
90+
func(xp.asarray([3]), a)

0 commit comments

Comments
 (0)