Skip to content

Commit 370eee0

Browse files
lithomas1lucascolley
authored andcommitted
BUG: isclose: fix multidevice for equal_nan=True (data-apis#177)
* BUG: Fix isclose multidevice * test the right way * fix pre-commit * convert to CPU in xp_assert_equal * fixes * fix tests --------- Co-authored-by: Lucas Colley <[email protected]>
1 parent fc2c7a7 commit 370eee0

File tree

3 files changed

+39
-7
lines changed

3 files changed

+39
-7
lines changed

Diff for: src/array_api_extra/_lib/_funcs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def isclose(
549549
xp=xp,
550550
)
551551
if equal_nan:
552-
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
552+
out = xp.where(xp.isnan(a) & xp.isnan(b), True, out)
553553
return out
554554

555555
if xp.isdtype(a.dtype, "bool") or xp.isdtype(b.dtype, "bool"):

Diff for: src/array_api_extra/_lib/_testing.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from ._utils._compat import (
1515
array_namespace,
16+
is_array_api_strict_namespace,
1617
is_cupy_namespace,
1718
is_dask_namespace,
1819
is_pydata_sparse_namespace,
@@ -105,8 +106,18 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
105106
actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
106107
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
107108

108-
# JAX uses `np.testing`
109-
np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType]
109+
actual_np = None
110+
desired_np = None
111+
if is_array_api_strict_namespace(xp):
112+
# __array__ doesn't work on array-api-strict device arrays
113+
# We need to convert to the CPU device first
114+
actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
115+
desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE")))
116+
117+
# JAX/Dask arrays work with `np.testing`
118+
actual_np = actual if actual_np is None else actual_np
119+
desired_np = desired if desired_np is None else desired_np
120+
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType]
110121

111122

112123
def xp_assert_close(
@@ -169,14 +180,25 @@ def xp_assert_close(
169180
actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
170181
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
171182

172-
# JAX uses `np.testing`
183+
actual_np = None
184+
desired_np = None
185+
if is_array_api_strict_namespace(xp):
186+
# __array__ doesn't work on array-api-strict device arrays
187+
# We need to convert to the CPU device first
188+
actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
189+
desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE")))
190+
191+
# JAX/Dask arrays work with `np.testing`
192+
actual_np = actual if actual_np is None else actual_np
193+
desired_np = desired if desired_np is None else desired_np
194+
173195
assert isinstance(rtol, float)
174196
np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
175-
actual, # pyright: ignore[reportArgumentType]
176-
desired, # pyright: ignore[reportArgumentType]
197+
actual_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
198+
desired_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
177199
rtol=rtol,
178200
atol=atol,
179-
err_msg=err_msg, # type: ignore[call-overload]
201+
err_msg=err_msg,
180202
)
181203

182204

Diff for: tests/test_funcs.py

+10
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,16 @@ def test_xp(self, xp: ModuleType):
716716
b = xp.asarray([1e-9, 1e-4])
717717
xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False]))
718718

719+
@pytest.mark.parametrize("equal_nan", [True, False])
720+
def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
721+
a = xp.asarray([0.0, 0.0, xp.nan], device=device)
722+
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
723+
res = isclose(a, b, equal_nan=equal_nan)
724+
assert get_device(res) == device
725+
xp_assert_equal(
726+
isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan])
727+
)
728+
719729

720730
class TestKron:
721731
def test_basic(self, xp: ModuleType):

0 commit comments

Comments
 (0)