Skip to content

Commit fc26a74

Browse files
committed
convert to CPU in xp_assert_equal
1 parent 6ba1f9b commit fc26a74

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

Diff for: src/array_api_extra/_lib/_testing.py

+7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from ._utils._compat import (
1515
array_namespace,
16+
is_array_api_strict_namespace,
1617
is_cupy_namespace,
1718
is_dask_namespace,
1819
is_pydata_sparse_namespace,
@@ -105,6 +106,12 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
105106
actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
106107
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
107108

109+
if is_array_api_strict_namespace(xp):
110+
# __array__ doesn't work on array-api-strict
111+
# We need to convert to the CPU device first
112+
actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
113+
desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
114+
108115
# JAX uses `np.testing`
109116
np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
110117

Diff for: tests/test_funcs.py

+2
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,8 @@ def test_xp(self, xp: ModuleType):
724724
def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
725725
a = xp.asarray([0.0, 0.0, xp.nan], device=device)
726726
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
727+
res = isclose(a, b, equal_nan=equal_nan)
728+
assert get_device(res) == device
727729
xp_assert_equal(
728730
isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan])
729731
)

0 commit comments

Comments
 (0)