Skip to content

Commit 1236e3a

Browse files
committed
fixes
1 parent 403a6d3 commit 1236e3a

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

Diff for: src/array_api_extra/_lib/_testing.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,18 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
106106
actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
107107
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
108108

109+
actual_np = None
110+
desired_np = None
109111
if is_array_api_strict_namespace(xp):
110112
# __array__ doesn't work on array-api-strict device arrays
111113
# We need to convert to the CPU device first
112-
actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
113-
desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
114+
actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
115+
desired_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
114116

115-
# JAX uses `np.testing`
116-
if is_array_api_strict_namespace(xp):
117-
# Have to move to CPU for array API strict devices before
118-
# we're allowed to convert into numpy
119-
actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
120-
desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
121-
np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
117+
# JAX/Dask arrays work with `np.testing`
118+
actual_np = actual if actual_np is None else actual_np
119+
desired_np = desired if desired_np is None else desired_np
120+
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType]
122121

123122

124123
def xp_assert_close(
@@ -181,14 +180,25 @@ def xp_assert_close(
181180
actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
182181
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
183182

184-
# JAX uses `np.testing`
183+
actual_np = None
184+
desired_np = None
185+
if is_array_api_strict_namespace(xp):
186+
# __array__ doesn't work on array-api-strict device arrays
187+
# We need to convert to the CPU device first
188+
actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
189+
desired_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
190+
191+
# JAX/Dask arrays work with `np.testing`
192+
actual_np = actual if actual_np is None else actual_np
193+
desired_np = desired if desired_np is None else desired_np
194+
185195
assert isinstance(rtol, float)
186196
np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
187-
actual, # pyright: ignore[reportArgumentType]
188-
desired, # pyright: ignore[reportArgumentType]
197+
actual_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
198+
desired_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
189199
rtol=rtol,
190200
atol=atol,
191-
err_msg=err_msg, # type: ignore[call-overload]
201+
err_msg=err_msg,
192202
)
193203

194204

0 commit comments

Comments
 (0)