From 66964ab1fbe0318852c0656e2d9a7dd431074316 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 25 Mar 2025 09:49:29 -0400 Subject: [PATCH 1/6] BUG: Fix isclose multidevice --- src/array_api_extra/_lib/_backends.py | 1 + src/array_api_extra/_lib/_funcs.py | 2 +- src/array_api_extra/_lib/_testing.py | 6 ++++++ tests/conftest.py | 9 +++++++-- tests/test_funcs.py | 5 +++-- 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/array_api_extra/_lib/_backends.py b/src/array_api_extra/_lib/_backends.py index f044281..ff5b23c 100644 --- a/src/array_api_extra/_lib/_backends.py +++ b/src/array_api_extra/_lib/_backends.py @@ -24,6 +24,7 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an """ ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace + ARRAY_API_STRICT_DEVICE1 = "array_api_strict", _compat.is_array_api_strict_namespace NUMPY = "numpy", _compat.is_numpy_namespace NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace CUPY = "cupy", _compat.is_cupy_namespace diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index d0b6738..efe2f37 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -549,7 +549,7 @@ def isclose( xp=xp, ) if equal_nan: - out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out) + out = xp.where(xp.isnan(a) & xp.isnan(b), True, out) return out if xp.isdtype(a.dtype, "bool") or xp.isdtype(b.dtype, "bool"): diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index f592eb4..f6b891b 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -10,6 +10,7 @@ from typing import cast import pytest +from array_api_compat import is_array_api_strict_namespace from ._utils._compat import ( array_namespace, @@ -106,6 +107,11 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] # JAX uses `np.testing` + if is_array_api_strict_namespace(xp): + # Have to move to CPU for array API strict devices before + # we're allowed to convert into numpy + actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType] diff --git a/tests/conftest.py b/tests/conftest.py index 54e2a23..892f739 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,7 +113,7 @@ def xp( The current array namespace. """ if library == Backend.NUMPY_READONLY: - return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType] + return NumPyReadOnly(), None # type: ignore[return-value] # pyright: ignore[reportReturnType] xp = pytest.importorskip(library.value) # Possibly wrap module with array_api_compat xp = array_namespace(xp.empty(0)) @@ -131,7 +131,12 @@ def xp( # suppress unused-ignore to run mypy in -e lint as well as -e dev jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore] - return xp + device = None + if library == Backend.ARRAY_API_STRICT_DEVICE1: + import array_api_strict + + device = array_api_strict.Device("device1") + return xp, device @pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask` diff --git a/tests/test_funcs.py b/tests/test_funcs.py index b93cc7c..f608662 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -629,8 +629,9 @@ def test_some_inf(self, xp: ModuleType): xp_assert_equal(actual, xp.asarray([True, True, True, False, False])) def test_equal_nan(self, xp: ModuleType): - a = xp.asarray([xp.nan, xp.nan, 1.0]) - b = xp.asarray([xp.nan, 1.0, xp.nan]) + xp, device = xp + a = xp.asarray([xp.nan, xp.nan, 1.0], device=device) + b = xp.asarray([xp.nan, 1.0, xp.nan], device=device) xp_assert_equal(isclose(a, b), xp.asarray([False, False, False])) xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False])) From 2288f7a1f919c2899521d002d46792f1a4e14e96 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 25 Mar 2025 10:30:51 -0400 Subject: [PATCH 2/6] test the right way --- src/array_api_extra/_lib/_backends.py | 1 - src/array_api_extra/_lib/_testing.py | 3 +-- tests/conftest.py | 9 ++------- tests/test_funcs.py | 10 +++++++--- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/array_api_extra/_lib/_backends.py b/src/array_api_extra/_lib/_backends.py index ff5b23c..f044281 100644 --- a/src/array_api_extra/_lib/_backends.py +++ b/src/array_api_extra/_lib/_backends.py @@ -24,7 +24,6 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an """ ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace - ARRAY_API_STRICT_DEVICE1 = "array_api_strict", _compat.is_array_api_strict_namespace NUMPY = "numpy", _compat.is_numpy_namespace NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace CUPY = "cupy", _compat.is_cupy_namespace diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index f6b891b..bb73c54 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -10,7 +10,6 @@ from typing import cast import pytest -from array_api_compat import is_array_api_strict_namespace from ._utils._compat import ( array_namespace, @@ -112,7 +111,7 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: # we're allowed to convert into numpy actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType] + np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] def xp_assert_close( diff --git a/tests/conftest.py b/tests/conftest.py index 892f739..54e2a23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,7 +113,7 @@ def xp( The current array namespace. """ if library == Backend.NUMPY_READONLY: - return NumPyReadOnly(), None # type: ignore[return-value] # pyright: ignore[reportReturnType] + return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType] xp = pytest.importorskip(library.value) # Possibly wrap module with array_api_compat xp = array_namespace(xp.empty(0)) @@ -131,12 +131,7 @@ def xp( # suppress unused-ignore to run mypy in -e lint as well as -e dev jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore] - device = None - if library == Backend.ARRAY_API_STRICT_DEVICE1: - import array_api_strict - - device = array_api_strict.Device("device1") - return xp, device + return xp @pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask` diff --git a/tests/test_funcs.py b/tests/test_funcs.py index f608662..e0421d6 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -629,9 +629,8 @@ def test_some_inf(self, xp: ModuleType): xp_assert_equal(actual, xp.asarray([True, True, True, False, False])) def test_equal_nan(self, xp: ModuleType): - xp, device = xp - a = xp.asarray([xp.nan, xp.nan, 1.0], device=device) - b = xp.asarray([xp.nan, 1.0, xp.nan], device=device) + a = xp.asarray([xp.nan, xp.nan, 1.0]) + b = xp.asarray([xp.nan, 1.0, xp.nan]) xp_assert_equal(isclose(a, b), xp.asarray([False, False, False])) xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False])) @@ -717,6 +716,11 @@ def test_xp(self, xp: ModuleType): b = xp.asarray([1e-9, 1e-4]) xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False])) + @pytest.mark.parametrize("equal_nan", [True, False]) + def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): + a = xp.asarray([0.0, 0.0, xp.nan], device=device) + b = xp.asarray([1e-9, 1e-4, xp.nan], device=device) + xp_assert_equal(isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan])) class TestKron: def test_basic(self, xp: ModuleType): From 380a81cac12f19142297fc631fea21c5283c0cef Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 25 Mar 2025 11:22:35 -0400 Subject: [PATCH 3/6] fix pre-commit --- tests/test_funcs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index e0421d6..3edf733 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -720,7 +720,10 @@ def test_xp(self, xp: ModuleType): def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): a = xp.asarray([0.0, 0.0, xp.nan], device=device) b = xp.asarray([1e-9, 1e-4, xp.nan], device=device) - xp_assert_equal(isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan])) + xp_assert_equal( + isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan]) + ) + class TestKron: def test_basic(self, xp: ModuleType): From 403a6d3aa5c814c1af2882555e1bf789a7078bc8 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 25 Mar 2025 13:03:22 -0400 Subject: [PATCH 4/6] convert to CPU in xp_assert_equal --- src/array_api_extra/_lib/_testing.py | 7 +++++++ tests/test_funcs.py | 2 ++ 2 files changed, 9 insertions(+) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index bb73c54..d4519a9 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -13,6 +13,7 @@ from ._utils._compat import ( array_namespace, + is_array_api_strict_namespace, is_cupy_namespace, is_dask_namespace, is_pydata_sparse_namespace, @@ -105,6 +106,12 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + if is_array_api_strict_namespace(xp): + # __array__ doesn't work on array-api-strict device arrays + # We need to convert to the CPU device first + actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + # JAX uses `np.testing` if is_array_api_strict_namespace(xp): # Have to move to CPU for array API strict devices before diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 3edf733..46591ed 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -720,6 +720,8 @@ def test_xp(self, xp: ModuleType): def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): a = xp.asarray([0.0, 0.0, xp.nan], device=device) b = xp.asarray([1e-9, 1e-4, xp.nan], device=device) + res = isclose(a, b, equal_nan=equal_nan) + assert get_device(res) == device xp_assert_equal( isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan]) ) From 1236e3a1d3eff66ac9a1d3fcea14ea7cb3ae3f21 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Tue, 25 Mar 2025 17:40:18 +0000 Subject: [PATCH 5/6] fixes --- src/array_api_extra/_lib/_testing.py | 36 ++++++++++++++++++---------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index d4519a9..caae491 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -106,19 +106,18 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + actual_np = None + desired_np = None if is_array_api_strict_namespace(xp): # __array__ doesn't work on array-api-strict device arrays # We need to convert to the CPU device first - actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + desired_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - # JAX uses `np.testing` - if is_array_api_strict_namespace(xp): - # Have to move to CPU for array API strict devices before - # we're allowed to convert into numpy - actual = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - desired = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + # JAX/Dask arrays work with `np.testing` + actual_np = actual if actual_np is None else actual_np + desired_np = desired if desired_np is None else desired_np + np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType] def xp_assert_close( @@ -181,14 +180,25 @@ def xp_assert_close( actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - # JAX uses `np.testing` + actual_np = None + desired_np = None + if is_array_api_strict_namespace(xp): + # __array__ doesn't work on array-api-strict device arrays + # We need to convert to the CPU device first + actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + desired_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + + # JAX/Dask arrays work with `np.testing` + actual_np = actual if actual_np is None else actual_np + desired_np = desired if desired_np is None else desired_np + assert isinstance(rtol, float) np.testing.assert_allclose( # pyright: ignore[reportCallIssue] - actual, # pyright: ignore[reportArgumentType] - desired, # pyright: ignore[reportArgumentType] + actual_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + desired_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] rtol=rtol, atol=atol, - err_msg=err_msg, # type: ignore[call-overload] + err_msg=err_msg, ) From 66155f652ba5dc6d5d31b1522753bbd118f07ca5 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Tue, 25 Mar 2025 17:52:06 +0000 Subject: [PATCH 6/6] fix tests --- src/array_api_extra/_lib/_testing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index caae491..e5ec16a 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -112,7 +112,7 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: # __array__ doesn't work on array-api-strict device arrays # We need to convert to the CPU device first actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - desired_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE"))) # JAX/Dask arrays work with `np.testing` actual_np = actual if actual_np is None else actual_np @@ -186,7 +186,7 @@ def xp_assert_close( # __array__ doesn't work on array-api-strict device arrays # We need to convert to the CPU device first actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - desired_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) + desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE"))) # JAX/Dask arrays work with `np.testing` actual_np = actual if actual_np is None else actual_np