diff --git a/src/array_api_extra/_lib/_utils/_compat.py b/src/array_api_extra/_lib/_utils/_compat.py index b9997450..c6eec4cd 100644 --- a/src/array_api_extra/_lib/_utils/_compat.py +++ b/src/array_api_extra/_lib/_utils/_compat.py @@ -23,6 +23,7 @@ is_torch_namespace, is_writeable_array, size, + to_device, ) except ImportError: from array_api_compat import ( @@ -45,6 +46,7 @@ is_torch_namespace, is_writeable_array, size, + to_device, ) __all__ = [ @@ -67,4 +69,5 @@ "is_torch_namespace", "is_writeable_array", "size", + "to_device", ] diff --git a/src/array_api_extra/_lib/_utils/_compat.pyi b/src/array_api_extra/_lib/_utils/_compat.pyi index f40d7556..48addda4 100644 --- a/src/array_api_extra/_lib/_utils/_compat.pyi +++ b/src/array_api_extra/_lib/_utils/_compat.pyi @@ -4,6 +4,7 @@ from __future__ import annotations from types import ModuleType +from typing import Any, TypeGuard # TODO import from typing (requires Python >=3.13) from typing_extensions import TypeIs @@ -12,29 +13,33 @@ from ._typing import Array, Device # pylint: disable=missing-class-docstring,unused-argument -class Namespace(ModuleType): - def device(self, x: Array, /) -> Device: ... - def array_namespace( *xs: Array | complex | None, api_version: str | None = None, use_compat: bool | None = None, -) -> Namespace: ... +) -> ModuleType: ... def device(x: Array, /) -> Device: ... def is_array_api_obj(x: object, /) -> TypeIs[Array]: ... -def is_array_api_strict_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_cupy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_dask_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_jax_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_numpy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_pydata_sparse_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_torch_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... -def is_cupy_array(x: object, /) -> TypeIs[Array]: ... -def is_dask_array(x: object, /) -> TypeIs[Array]: ... -def is_jax_array(x: object, /) -> TypeIs[Array]: ... -def is_numpy_array(x: object, /) -> TypeIs[Array]: ... -def is_pydata_sparse_array(x: object, /) -> TypeIs[Array]: ... -def is_torch_array(x: object, /) -> TypeIs[Array]: ... -def is_lazy_array(x: object, /) -> TypeIs[Array]: ... -def is_writeable_array(x: object, /) -> TypeIs[Array]: ... +def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ... +def is_cupy_namespace(xp: ModuleType, /) -> bool: ... +def is_dask_namespace(xp: ModuleType, /) -> bool: ... +def is_jax_namespace(xp: ModuleType, /) -> bool: ... +def is_numpy_namespace(xp: ModuleType, /) -> bool: ... +def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ... +def is_torch_namespace(xp: ModuleType, /) -> bool: ... +def is_cupy_array(x: object, /) -> TypeGuard[Array]: ... +def is_dask_array(x: object, /) -> TypeGuard[Array]: ... +def is_jax_array(x: object, /) -> TypeGuard[Array]: ... +def is_numpy_array(x: object, /) -> TypeGuard[Array]: ... +def is_pydata_sparse_array(x: object, /) -> TypeGuard[Array]: ... +def is_torch_array(x: object, /) -> TypeGuard[Array]: ... +def is_lazy_array(x: object, /) -> TypeGuard[Array]: ... +def is_writeable_array(x: object, /) -> TypeGuard[Array]: ... def size(x: Array, /) -> int | None: ... +def to_device( # type: ignore[explicit-any] + x: Array, + device: Device, # pylint: disable=redefined-outer-name + /, + *, + stream: int | Any | None = None, +) -> Array: ... diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index 4613edc7..374cba11 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -23,11 +23,12 @@ def test_vendor_compat(): is_torch_namespace, is_writeable_array, size, + to_device, ) x = xp.asarray([1, 2, 3]) assert array_namespace(x) is xp - device(x) + to_device(x, device(x)) assert is_array_api_obj(x) assert is_array_api_strict_namespace(xp) assert not is_cupy_array(x)