|
13 | 13 |
|
14 | 14 | from ._utils._compat import (
|
15 | 15 | array_namespace,
|
| 16 | + is_array_api_strict_namespace, |
16 | 17 | is_cupy_namespace,
|
17 | 18 | is_dask_namespace,
|
18 | 19 | is_pydata_sparse_namespace,
|
@@ -105,8 +106,18 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
|
105 | 106 | actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
|
106 | 107 | desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
|
107 | 108 |
|
108 |
| - # JAX uses `np.testing` |
109 |
| - np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType] |
| 109 | + actual_np = None |
| 110 | + desired_np = None |
| 111 | + if is_array_api_strict_namespace(xp): |
| 112 | + # __array__ doesn't work on array-api-strict device arrays |
| 113 | + # We need to convert to the CPU device first |
| 114 | + actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) |
| 115 | + desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE"))) |
| 116 | + |
| 117 | + # JAX/Dask arrays work with `np.testing` |
| 118 | + actual_np = actual if actual_np is None else actual_np |
| 119 | + desired_np = desired if desired_np is None else desired_np |
| 120 | + np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType] |
110 | 121 |
|
111 | 122 |
|
112 | 123 | def xp_assert_close(
|
@@ -169,14 +180,25 @@ def xp_assert_close(
|
169 | 180 | actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
|
170 | 181 | desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
|
171 | 182 |
|
172 |
| - # JAX uses `np.testing` |
| 183 | + actual_np = None |
| 184 | + desired_np = None |
| 185 | + if is_array_api_strict_namespace(xp): |
| 186 | + # __array__ doesn't work on array-api-strict device arrays |
| 187 | + # We need to convert to the CPU device first |
| 188 | + actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) |
| 189 | + desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE"))) |
| 190 | + |
| 191 | + # JAX/Dask arrays work with `np.testing` |
| 192 | + actual_np = actual if actual_np is None else actual_np |
| 193 | + desired_np = desired if desired_np is None else desired_np |
| 194 | + |
173 | 195 | assert isinstance(rtol, float)
|
174 | 196 | np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
|
175 |
| - actual, # pyright: ignore[reportArgumentType] |
176 |
| - desired, # pyright: ignore[reportArgumentType] |
| 197 | + actual_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
| 198 | + desired_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
177 | 199 | rtol=rtol,
|
178 | 200 | atol=atol,
|
179 |
| - err_msg=err_msg, # type: ignore[call-overload] |
| 201 | + err_msg=err_msg, |
180 | 202 | )
|
181 | 203 |
|
182 | 204 |
|
|
0 commit comments