From 6997c91fa3558f5746d87999b0771dc3bd695d43 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jan 2025 13:12:49 +0000 Subject: [PATCH] MAINT: import everything from array-api-compat --- src/array_api_extra/_lib/_utils/_compat.py | 15 +++++++++++++++ src/array_api_extra/_lib/_utils/_compat.pyi | 5 +++++ vendor_tests/test_vendor.py | 14 ++++++++++++++ 3 files changed, 34 insertions(+) diff --git a/src/array_api_extra/_lib/_utils/_compat.py b/src/array_api_extra/_lib/_utils/_compat.py index f4def9f3..34958149 100644 --- a/src/array_api_extra/_lib/_utils/_compat.py +++ b/src/array_api_extra/_lib/_utils/_compat.py @@ -6,14 +6,19 @@ from ...._array_api_compat_vendor import ( array_namespace, device, + is_array_api_obj, is_array_api_strict_namespace, + is_cupy_array, is_cupy_namespace, is_dask_array, is_dask_namespace, is_jax_array, is_jax_namespace, + is_numpy_array, is_numpy_namespace, + is_pydata_sparse_array, is_pydata_sparse_namespace, + is_torch_array, is_torch_namespace, is_writeable_array, size, @@ -22,14 +27,19 @@ from array_api_compat import ( array_namespace, device, + is_array_api_obj, is_array_api_strict_namespace, + is_cupy_array, is_cupy_namespace, is_dask_array, is_dask_namespace, is_jax_array, is_jax_namespace, + is_numpy_array, is_numpy_namespace, + is_pydata_sparse_array, is_pydata_sparse_namespace, + is_torch_array, is_torch_namespace, is_writeable_array, size, @@ -38,14 +48,19 @@ __all__ = [ "array_namespace", "device", + "is_array_api_obj", "is_array_api_strict_namespace", + "is_cupy_array", "is_cupy_namespace", "is_dask_array", "is_dask_namespace", "is_jax_array", "is_jax_namespace", + "is_numpy_array", "is_numpy_namespace", + "is_pydata_sparse_array", "is_pydata_sparse_namespace", + "is_torch_array", "is_torch_namespace", "is_writeable_array", "size", diff --git a/src/array_api_extra/_lib/_utils/_compat.pyi b/src/array_api_extra/_lib/_utils/_compat.pyi index e409091e..5c8b6260 100644 --- a/src/array_api_extra/_lib/_utils/_compat.pyi +++ b/src/array_api_extra/_lib/_utils/_compat.pyi @@ -18,6 +18,7 @@ def array_namespace( use_compat: bool | None = None, ) -> ArrayModule: ... def device(x: Array, /) -> Device: ... +def is_array_api_obj(x: object, /) -> bool: ... def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ... def is_cupy_namespace(xp: ModuleType, /) -> bool: ... def is_dask_namespace(xp: ModuleType, /) -> bool: ... @@ -25,7 +26,11 @@ def is_jax_namespace(xp: ModuleType, /) -> bool: ... def is_numpy_namespace(xp: ModuleType, /) -> bool: ... def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ... def is_torch_namespace(xp: ModuleType, /) -> bool: ... +def is_cupy_array(x: object, /) -> bool: ... def is_dask_array(x: object, /) -> bool: ... def is_jax_array(x: object, /) -> bool: ... +def is_numpy_array(x: object, /) -> bool: ... +def is_pydata_sparse_array(x: object, /) -> bool: ... +def is_torch_array(x: object, /) -> bool: ... def is_writeable_array(x: object, /) -> bool: ... def size(x: Array, /) -> int | None: ... diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index 9402217b..7aaa9eba 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -6,12 +6,19 @@ def test_vendor_compat(): from ._array_api_compat_vendor import ( # type: ignore[attr-defined] array_namespace, device, + is_array_api_obj, + is_array_api_strict_namespace, + is_cupy_array, is_cupy_namespace, is_dask_array, is_dask_namespace, is_jax_array, is_jax_namespace, + is_numpy_array, + is_numpy_namespace, + is_pydata_sparse_array, is_pydata_sparse_namespace, + is_torch_array, is_torch_namespace, is_writeable_array, size, @@ -20,12 +27,19 @@ def test_vendor_compat(): x = xp.asarray([1, 2, 3]) assert array_namespace(x) is xp device(x) + assert is_array_api_obj(x) + assert is_array_api_strict_namespace(xp) + assert not is_cupy_array(x) assert not is_cupy_namespace(xp) assert not is_dask_array(x) assert not is_dask_namespace(xp) assert not is_jax_array(x) assert not is_jax_namespace(xp) + assert not is_numpy_array(x) + assert not is_numpy_namespace(xp) + assert not is_pydata_sparse_array(x) assert not is_pydata_sparse_namespace(xp) + assert not is_torch_array(x) assert not is_torch_namespace(xp) assert is_writeable_array(x) assert size(x) == 3