@@ -106,19 +106,18 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
106
106
actual = actual .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
107
107
desired = desired .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
108
108
109
+ actual_np = None
110
+ desired_np = None
109
111
if is_array_api_strict_namespace (xp ):
110
112
# __array__ doesn't work on array-api-strict device arrays
111
113
# We need to convert to the CPU device first
112
- actual = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
113
- desired = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
114
+ actual_np = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
115
+ desired_np = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
114
116
115
- # JAX uses `np.testing`
116
- if is_array_api_strict_namespace (xp ):
117
- # Have to move to CPU for array API strict devices before
118
- # we're allowed to convert into numpy
119
- actual = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
120
- desired = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
121
- np .testing .assert_array_equal (actual , desired , err_msg = err_msg ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
117
+ # JAX/Dask arrays work with `np.testing`
118
+ actual_np = actual if actual_np is None else actual_np
119
+ desired_np = desired if desired_np is None else desired_np
120
+ np .testing .assert_array_equal (actual_np , desired_np , err_msg = err_msg ) # pyright: ignore[reportUnknownArgumentType]
122
121
123
122
124
123
def xp_assert_close (
@@ -181,14 +180,25 @@ def xp_assert_close(
181
180
actual = actual .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
182
181
desired = desired .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
183
182
184
- # JAX uses `np.testing`
183
+ actual_np = None
184
+ desired_np = None
185
+ if is_array_api_strict_namespace (xp ):
186
+ # __array__ doesn't work on array-api-strict device arrays
187
+ # We need to convert to the CPU device first
188
+ actual_np = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
189
+ desired_np = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
190
+
191
+ # JAX/Dask arrays work with `np.testing`
192
+ actual_np = actual if actual_np is None else actual_np
193
+ desired_np = desired if desired_np is None else desired_np
194
+
185
195
assert isinstance (rtol , float )
186
196
np .testing .assert_allclose ( # pyright: ignore[reportCallIssue]
187
- actual , # pyright: ignore[reportArgumentType]
188
- desired , # pyright: ignore[reportArgumentType]
197
+ actual_np , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
198
+ desired_np , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
189
199
rtol = rtol ,
190
200
atol = atol ,
191
- err_msg = err_msg , # type: ignore[call-overload]
201
+ err_msg = err_msg ,
192
202
)
193
203
194
204
0 commit comments