diff --git a/.gitignore b/.gitignore index b6e47617..4b61f865 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# macOS specific iles +.DS_Store diff --git a/README.md b/README.md index 51957644..784197dc 100644 --- a/README.md +++ b/README.md @@ -126,11 +126,11 @@ part of the specification but which are useful for using the array API: [`x.device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.device.html) in the array API specification. Included because `numpy.ndarray` does not include the `device` attribute and this library does not wrap or extend the - array object. Note that for NumPy, `device(x)` is always `"cpu"`. + array object. Note that for NumPy and dask, `device(x)` is always `"cpu"`. - `to_device(x, device, /, *, stream=None)`: Equivalent to [`x.to_device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.to_device.html). - Included because neither NumPy's, CuPy's, nor PyTorch's array objects + Included because neither NumPy's, CuPy's, Dask's, nor PyTorch's array objects include this method. For NumPy, this function effectively does nothing since the only supported device is the CPU, but for CuPy, this method supports CuPy CUDA @@ -241,6 +241,30 @@ Unlike the other libraries supported here, JAX array API support is contained entirely in the JAX library. The JAX array API support is tracked at https://github.com/google/jax/issues/18353. +## Dask + +If you're using dask with numpy, many of the same limitations that apply to numpy +will also apply to dask. Besides those differences, other limitations include missing +sort functionality (no `sort` or `argsort`), and limited support for the optional `linalg` +and `fft` extensions. + +In particular, the `fft` namespace is not compliant with the array API spec. Any functions +that you find under the `fft` namespace are the original, unwrapped functions under [`dask.array.fft`](https://docs.dask.org/en/latest/array-api.html#fast-fourier-transforms), which may or may not be Array API compliant. Use at your own risk! + +For `linalg`, several methods are missing, for example: +- `cross` +- `det` +- `eigh` +- `eigvalsh` +- `matrix_power` +- `pinv` +- `slogdet` +- `matrix_norm` +- `matrix_rank` +Other methods may only be partially implemented or return incorrect results at times. + +The minimum supported Dask version is 2023.12.0. + ## Vendoring This library supports vendoring as an installation method. To vendor the diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 2fa963f8..bd2372b5 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -159,7 +159,16 @@ def _check_device(xp, device): if device not in ["cpu", None]: raise ValueError(f"Unsupported device for NumPy: {device!r}") -# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray +# Placeholder object to represent the dask device +# when the array backend is not the CPU. +# (since it is not easy to tell which device a dask array is on) +class _dask_device: + def __repr__(self): + return "DASK_DEVICE" + +_DASK_DEVICE = _dask_device() + +# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of @@ -181,7 +190,17 @@ def device(x: Array, /) -> Device: """ if is_numpy_array(x): return "cpu" - if is_jax_array(x): + elif is_dask_array(x): + # Peek at the metadata of the jax array to determine type + try: + import numpy as np + if isinstance(x._meta, np.ndarray): + # Must be on CPU since backed by numpy + return "cpu" + except ImportError: + pass + return _DASK_DEVICE + elif is_jax_array(x): # JAX has .device() as a method, but it is being deprecated so that it # can become a property, in accordance with the standard. In order for # this function to not break when JAX makes the flip, we check for diff --git a/array_api_compat/dask/__init__.py b/array_api_compat/dask/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 844cdf91..94d938a4 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -36,7 +36,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union - from ...common._typing import ndarray, Device, Dtype + + from ...common._typing import Device, Dtype, Array import dask.array as da @@ -60,7 +61,7 @@ def _dask_arange( dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) args = [start] if stop is not None: diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index f2dd80cd..03f16e89 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,6 +1,5 @@ from __future__ import annotations -from dask.array.linalg import svd from ...common import _linalg from ..._internal import get_xp @@ -16,8 +15,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Union, Tuple - from ...common._typing import ndarray + from ...common._typing import Array # cupy.linalg doesn't have __all__. If it is added, replace this with # @@ -39,7 +37,16 @@ matrix_rank = get_xp(da)(_linalg.matrix_rank) matrix_norm = get_xp(da)(_linalg.matrix_norm) -def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: + +# Wrap the svd functions to not pass full_matrices to dask +# when full_matrices=False (as that is the default behavior for dask), +# and dask doesn't have the full_matrices keyword +def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: + if full_matrices: + raise ValueError("full_matrics=True is not supported by dask.") + return da.linalg.svd(x, **kwargs) + +def svdvals(x: Array) -> Array: # TODO: can't avoid computing U or V for dask _, s, _ = svd(x) return s diff --git a/dask-xfails.txt b/dask-xfails.txt index 39a1dd8a..ecde5420 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -78,6 +78,20 @@ array_api_tests/test_linalg.py::test_tensordot # probably same reason for failing as numpy array_api_tests/test_linalg.py::test_trace +# AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)] +array_api_tests/test_linalg.py::test_linalg_tensordot + +# AssertionError: out.shape=(1,), but should be () [linalg.vector_norm(keepdims=True)] +array_api_tests/test_linalg.py::test_vector_norm + +# ZeroDivisionError in dask's normalize_chunks/auto_chunks internals +array_api_tests/test_linalg.py::test_inv +array_api_tests/test_linalg.py::test_matrix_power + +# did not raise error for invalid shapes +array_api_tests/test_linalg.py::test_matmul +array_api_tests/test_linalg.py::test_linalg_matmul + # Linalg - these don't exist in dask array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det] @@ -88,6 +102,7 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet] array_api_tests/test_linalg.py::test_cross array_api_tests/test_linalg.py::test_det +array_api_tests/test_linalg.py::test_eigh array_api_tests/test_linalg.py::test_eigvalsh array_api_tests/test_linalg.py::test_pinv array_api_tests/test_linalg.py::test_slogdet @@ -112,7 +127,6 @@ array_api_tests/test_linalg.py::test_solve # missing full_matrics kw # https://github.com/dask/dask/issues/10389 # also only supports 2-d inputs -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd] array_api_tests/test_linalg.py::test_svd # Missing dlpack stuff diff --git a/setup.py b/setup.py index 25f6e54c..e917cb88 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='array_api_compat', version=array_api_compat.__version__, - packages=find_packages(include=['array_api_compat*']), + packages=find_packages(include=["array_api_compat*"]), author="Consortium for Python Data API Standards", description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard", long_description=long_description, diff --git a/tests/test_common.py b/tests/test_common.py index b84dfdde..66076bfe 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -31,9 +31,6 @@ def test_is_xp_array(library, func): @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) def test_device(library): - if library == "dask.array": - pytest.xfail("device() needs to be fixed for dask") - xp = import_(library, wrapper=True) # We can't test much for device() and to_device() other than that