Skip to content

Commit 636a121

Browse files
committed
Test none shapes
1 parent 5f54c92 commit 636a121

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

Diff for: src/array_api_extra/_lib/_testing.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
Note that this is private API; don't expect it to be stable.
55
"""
66

7+
import math
78
from types import ModuleType
89

910
from ._utils._compat import (
1011
array_namespace,
1112
is_cupy_namespace,
13+
is_dask_namespace,
1214
is_pydata_sparse_namespace,
1315
is_torch_namespace,
1416
)
@@ -40,8 +42,16 @@ def _check_ns_shape_dtype(
4042
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
4143
assert actual_xp == desired_xp, msg
4244

43-
msg = f"shapes do not match: {actual.shape} != f{desired.shape}"
44-
assert actual.shape == desired.shape, msg
45+
actual_shape = actual.shape
46+
desired_shape = desired.shape
47+
if is_dask_namespace(desired_xp):
48+
if any(math.isnan(i) for i in actual_shape):
49+
actual_shape = actual.compute().shape
50+
if any(math.isnan(i) for i in desired_shape):
51+
desired_shape = desired.compute().shape
52+
53+
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
54+
assert actual_shape == desired_shape, msg
4555

4656
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
4757
assert actual.dtype == desired.dtype, msg

Diff for: tests/test_funcs.py

+14
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,20 @@ def test_bool_dtype(self, xp: ModuleType):
366366
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
367367
)
368368

369+
def test_none_shape(self, xp: ModuleType):
370+
a = xp.asarray([1, 5, 0])
371+
b = xp.asarray([1, 4, 2])
372+
b = b[a < 5]
373+
a = a[a < 5]
374+
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
375+
376+
def test_none_shape_bool(self, xp: ModuleType):
377+
a = xp.asarray([True, True, False])
378+
b = xp.asarray([True, False, True])
379+
b = b[a]
380+
a = a[a]
381+
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
382+
369383
def test_xp(self, xp: ModuleType):
370384
a = xp.asarray([0.0, 0.0])
371385
b = xp.asarray([1e-9, 1e-4])

Diff for: tests/test_testing.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
# mypy: disable-error-code=no-any-decorated
1111
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
1212

13-
14-
@pytest.mark.parametrize(
13+
param_assert_equal_close = pytest.mark.parametrize(
1514
"func",
1615
[
1716
xp_assert_equal,
@@ -21,6 +20,9 @@
2120
),
2221
],
2322
)
23+
24+
25+
@param_assert_equal_close
2426
def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # type: ignore[no-any-explicit]
2527
func(xp.asarray(0), xp.asarray(0))
2628
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
@@ -68,3 +70,21 @@ def test_assert_close_tolerance(xp: ModuleType):
6870
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=3)
6971
with pytest.raises(AssertionError):
7072
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1)
73+
74+
75+
@param_assert_equal_close
76+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no bool indexing by sparse arrays")
77+
def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]): # type: ignore[no-any-explicit]
78+
"""On dask and other lazy backends, test that a shape with NaN's or None's
79+
can be compared to a real shape.
80+
"""
81+
a = xp.asarray([1, 2])
82+
a = a[a > 1]
83+
84+
func(a, xp.asarray([2]))
85+
with pytest.raises(AssertionError):
86+
func(a, xp.asarray([2, 3]))
87+
with pytest.raises(AssertionError):
88+
func(a, xp.asarray(2))
89+
with pytest.raises(AssertionError):
90+
func(a, xp.asarray([3]))

0 commit comments

Comments
 (0)