From 35e556cb2dad647fe6ce61ace12a411f91ba19c0 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Fri, 9 Feb 2024 02:40:19 -0500 Subject: [PATCH 01/11] Fix dask --- .gitignore | 3 +++ README.md | 28 +++++++++++++++++++++++-- array_api_compat/common/_helpers.py | 4 +++- array_api_compat/dask/array/__init__.py | 6 +++--- array_api_compat/dask/array/linalg.py | 2 +- tests/test_common.py | 3 --- 6 files changed, 36 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index b6e47617..4b61f865 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# macOS specific iles +.DS_Store diff --git a/README.md b/README.md index 5be86271..0bc1141d 100644 --- a/README.md +++ b/README.md @@ -125,11 +125,11 @@ part of the specification but which are useful for using the array API: [`x.device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.device.html) in the array API specification. Included because `numpy.ndarray` does not include the `device` attribute and this library does not wrap or extend the - array object. Note that for NumPy, `device(x)` is always `"cpu"`. + array object. Note that for NumPy and dask, `device(x)` is always `"cpu"`. - `to_device(x, device, /, *, stream=None)`: Equivalent to [`x.to_device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.to_device.html). - Included because neither NumPy's, CuPy's, nor PyTorch's array objects + Included because neither NumPy's, CuPy's, Dask's, nor PyTorch's array objects include this method. For NumPy, this function effectively does nothing since the only supported device is the CPU, but for CuPy, this method supports CuPy CUDA @@ -240,6 +240,30 @@ Unlike the other libraries supported here, JAX array API support is contained entirely in the JAX library. The JAX array API support is tracked at https://github.com/google/jax/issues/18353. +## Dask + +If you're using dask with numpy, many of the same limitations that apply to numpy +will also apply to dask. Besides those differences, other limitations include missing +sort functionality (no `sort` or `argsort`), and limited support for the optional `linalg` +and `fft` extensions. + +In particular, the `fft` namespace is not compliant with the array API spec. Any functions +that you find under the `fft` namespace are the original, unwrapped functions under [`dask.array.fft`](https://docs.dask.org/en/latest/array-api.html#fast-fourier-transforms), which may or may not be Array API compliant. Use at your own risk! + +For `linalg`, several methods are missing, for example: +- `cross` +- `det` +- `eigh` +- `eigvalsh` +- `matrix_power` +- `pinv` +- `slogdet` +- `matrix_norm` +- `matrix_rank` +Other methods may only be partially implemented or return incorrect results at times. + +The minimum supported Dask version is 2023.12.0. + ## Vendoring This library supports vendoring as an installation method. To vendor the diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 5e59c7ea..7de739d4 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -179,7 +179,9 @@ def device(x: Array, /) -> Device: out: device a ``device`` object (see the "Device Support" section of the array API specification). """ - if is_numpy_array(x): + if is_numpy_array(x) or is_dask_array(x): + # TODO: dask technically can support GPU arrays + # Detecting the array backend isn't easy for dask, though, so just return CPU for now return "cpu" if is_jax_array(x): # JAX has .device() as a method, but it is being deprecated so that it diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index d6b5e94e..9df8b1bc 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -22,7 +22,7 @@ from dask.array import ( arctanh as atanh, ) -from dask.array import ( +from numpy import ( bool_ as bool, ) from dask.array import ( @@ -67,7 +67,7 @@ uint64, ) -from ..common._helpers import ( +from ...common._helpers import ( array_namespace, device, get_namespace, @@ -75,7 +75,7 @@ size, to_device, ) -from ..internal import _get_all_public_members +from ..._internal import _get_all_public_members from ._aliases import ( UniqueAllResult, UniqueCountsResult, diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index cc9ac880..67470a02 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -7,7 +7,7 @@ ) from dask.array.linalg import * # noqa: F401, F403 -from .._internal import _get_all_public_members +from ..._internal import _get_all_public_members from ._aliases import ( EighResult, QRResult, diff --git a/tests/test_common.py b/tests/test_common.py index b84dfdde..66076bfe 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -31,9 +31,6 @@ def test_is_xp_array(library, func): @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) def test_device(library): - if library == "dask.array": - pytest.xfail("device() needs to be fixed for dask") - xp = import_(library, wrapper=True) # We can't test much for device() and to_device() other than that From 9afed290d751e5d41e07948853702cee1a583629 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Fri, 9 Feb 2024 09:52:41 -0500 Subject: [PATCH 02/11] force more recent jax --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2877bf06..879121f1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest numpy torch dask[array] jax[cpu] + python -m pip install pytest numpy torch dask[array] jax[cpu]>=0.4.24 - name: Run Tests run: | From 327a1e2e0b1b1ac498aedd043c3c1a0323c6cf1f Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Fri, 9 Feb 2024 14:05:26 -0500 Subject: [PATCH 03/11] fix ci? --- .github/workflows/tests.yml | 2 +- tests/_helpers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 879121f1..2877bf06 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest numpy torch dask[array] jax[cpu]>=0.4.24 + python -m pip install pytest numpy torch dask[array] jax[cpu] - name: Run Tests run: | diff --git a/tests/_helpers.py b/tests/_helpers.py index e05ae86c..23cb5db9 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -8,7 +8,7 @@ def import_(library, wrapper=False): if library == 'cupy': return pytest.importorskip(library) - if 'jax' in library and sys.version_info <= (3, 8): + if 'jax' in library and sys.version_info < (3, 9): pytest.skip('JAX array API support does not support Python 3.8') if wrapper: From 6bcc4a9ec0212a77af4e0c166e60b184e317eb7d Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sat, 17 Feb 2024 15:39:55 -0500 Subject: [PATCH 04/11] update --- array_api_compat/common/_helpers.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 7de739d4..6a4a715f 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -159,7 +159,16 @@ def _check_device(xp, device): if device not in ["cpu", None]: raise ValueError(f"Unsupported device for NumPy: {device!r}") -# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray +# Placeholder object to represent the dask device +# when the array backend is not the CPU. +# (since it is not easy to tell which device a dask array is on) +class _dask_device: + def __repr__(self): + return "DASK_DEVICE" + +DASK_DEVICE = _dask_device() + +# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of @@ -179,11 +188,19 @@ def device(x: Array, /) -> Device: out: device a ``device`` object (see the "Device Support" section of the array API specification). """ - if is_numpy_array(x) or is_dask_array(x): - # TODO: dask technically can support GPU arrays - # Detecting the array backend isn't easy for dask, though, so just return CPU for now + if is_numpy_array(x): return "cpu" - if is_jax_array(x): + elif is_dask_array(x): + # Peek at the metadata of the jax array to determine type + try: + import numpy as np + if isinstance(x._meta, np.ndarray): + # Must be on CPU since backed by numpy + return "cpu" + except ImportError: + pass + return DASK_DEVICE + elif is_jax_array(x): # JAX has .device() as a method, but it is being deprecated so that it # can become a property, in accordance with the standard. In order for # this function to not break when JAX makes the flip, we check for From e9e740fa320bd7cb133beeabfafcdf81bd09fe64 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sun, 18 Feb 2024 20:04:18 -0500 Subject: [PATCH 05/11] fix missing dask? --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 25f6e54c..cbc8bd89 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import setup, find_packages, find_namespace_packages with open("README.md", "r") as fh: long_description = fh.read() @@ -8,7 +8,7 @@ setup( name='array_api_compat', version=array_api_compat.__version__, - packages=find_packages(include=['array_api_compat*']), + packages=find_packages(include=['array_api_compat*'] + find_namespace_packages(include=["array_api_compat*"])), author="Consortium for Python Data API Standards", description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard", long_description=long_description, From a4f1b2c3dee193f41aa7ac53a3eee8aadac21ab9 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sun, 18 Feb 2024 20:08:01 -0500 Subject: [PATCH 06/11] try something else --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index cbc8bd89..e511ce0a 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages, find_namespace_packages +from setuptools import setup, find_namespace_packages with open("README.md", "r") as fh: long_description = fh.read() @@ -8,7 +8,7 @@ setup( name='array_api_compat', version=array_api_compat.__version__, - packages=find_packages(include=['array_api_compat*'] + find_namespace_packages(include=["array_api_compat*"])), + packages=find_namespace_packages(include=["array_api_compat*"]), author="Consortium for Python Data API Standards", description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard", long_description=long_description, From 3f0683705dc46d63b382e8057cea82570203eaff Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sun, 25 Feb 2024 19:56:59 -0500 Subject: [PATCH 07/11] revert namespace package change --- array_api_compat/dask/__init__.py | 0 setup.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 array_api_compat/dask/__init__.py diff --git a/array_api_compat/dask/__init__.py b/array_api_compat/dask/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/setup.py b/setup.py index e511ce0a..e917cb88 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_namespace_packages +from setuptools import setup, find_packages with open("README.md", "r") as fh: long_description = fh.read() @@ -8,7 +8,7 @@ setup( name='array_api_compat', version=array_api_compat.__version__, - packages=find_namespace_packages(include=["array_api_compat*"]), + packages=find_packages(include=["array_api_compat*"]), author="Consortium for Python Data API Standards", description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard", long_description=long_description, From 924d29728432a125b754468b7616b71fd1ea1af6 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Mon, 26 Feb 2024 12:02:17 -0500 Subject: [PATCH 08/11] wrap svd --- array_api_compat/dask/array/_aliases.py | 16 +++++++++++++--- dask-xfails.txt | 1 - 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 14b27070..80b1e6cf 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from typing import Optional, Tuple, Union - from ...common._typing import Device, Dtype, ndarray + from ...common._typing import Device, Dtype, Array import dask.array as da @@ -37,7 +37,7 @@ def dask_arange( dtype: Optional[Dtype] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) args = [start] if stop is not None: @@ -99,8 +99,18 @@ def dask_arange( matrix_rank = get_xp(da)(_linalg.matrix_rank) matrix_norm = get_xp(da)(_linalg.matrix_norm) +# Wrap the svd functions to not pass full_matrices to dask +# when full_matrices=False (as that is the defualt behavior for dask), +# and dask doesn't have the full_matrices keyword +_svd = get_xp(da)(_linalg.svd) -def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: +def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: + if full_matrices: + return _svd(x, full_matrices=full_matrices, **kwargs) + return _svd(x, **kwargs) + + +def svdvals(x: Array) -> Array: # TODO: can't avoid computing U or V for dask _, s, _ = da.linalg.svd(x) return s diff --git a/dask-xfails.txt b/dask-xfails.txt index 39a1dd8a..340962e6 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -112,7 +112,6 @@ array_api_tests/test_linalg.py::test_solve # missing full_matrics kw # https://github.com/dask/dask/issues/10389 # also only supports 2-d inputs -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd] array_api_tests/test_linalg.py::test_svd # Missing dlpack stuff From 2e4c7967674a85af3269ee59cb294853e7a4a17c Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Mon, 26 Feb 2024 16:08:53 -0500 Subject: [PATCH 09/11] actually wrap svd properly --- array_api_compat/dask/array/_aliases.py | 5 ++--- array_api_compat/dask/array/linalg.py | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 80b1e6cf..9e3adacf 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -102,12 +102,11 @@ def dask_arange( # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the defualt behavior for dask), # and dask doesn't have the full_matrices keyword -_svd = get_xp(da)(_linalg.svd) def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: if full_matrices: - return _svd(x, full_matrices=full_matrices, **kwargs) - return _svd(x, **kwargs) + raise ValueError("full_matrics=True is not supported by dask.") + return da.linalg.svd(x, **kwargs) def svdvals(x: Array) -> Array: diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 67470a02..4cf7c622 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -19,6 +19,7 @@ matrix_rank, matrix_transpose, qr, + svd, svdvals, vecdot, vector_norm, From 23eb7649cb5ac08cb231256c7a2491af434769f7 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 26 Feb 2024 15:44:52 -0700 Subject: [PATCH 10/11] Rename DASK_DEVICE to _DASK_DEVICE --- array_api_compat/common/_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 5cb9800f..bd2372b5 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -166,7 +166,7 @@ class _dask_device: def __repr__(self): return "DASK_DEVICE" -DASK_DEVICE = _dask_device() +_DASK_DEVICE = _dask_device() # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library @@ -199,7 +199,7 @@ def device(x: Array, /) -> Device: return "cpu" except ImportError: pass - return DASK_DEVICE + return _DASK_DEVICE elif is_jax_array(x): # JAX has .device() as a method, but it is being deprecated so that it # can become a property, in accordance with the standard. In order for From 1f7e47c7750abf97972610a0f27c4af1a3e0a336 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Mon, 26 Feb 2024 21:44:05 -0500 Subject: [PATCH 11/11] try to get green --- dask-xfails.txt | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dask-xfails.txt b/dask-xfails.txt index 340962e6..ecde5420 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -78,6 +78,20 @@ array_api_tests/test_linalg.py::test_tensordot # probably same reason for failing as numpy array_api_tests/test_linalg.py::test_trace +# AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)] +array_api_tests/test_linalg.py::test_linalg_tensordot + +# AssertionError: out.shape=(1,), but should be () [linalg.vector_norm(keepdims=True)] +array_api_tests/test_linalg.py::test_vector_norm + +# ZeroDivisionError in dask's normalize_chunks/auto_chunks internals +array_api_tests/test_linalg.py::test_inv +array_api_tests/test_linalg.py::test_matrix_power + +# did not raise error for invalid shapes +array_api_tests/test_linalg.py::test_matmul +array_api_tests/test_linalg.py::test_linalg_matmul + # Linalg - these don't exist in dask array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det] @@ -88,6 +102,7 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet] array_api_tests/test_linalg.py::test_cross array_api_tests/test_linalg.py::test_det +array_api_tests/test_linalg.py::test_eigh array_api_tests/test_linalg.py::test_eigvalsh array_api_tests/test_linalg.py::test_pinv array_api_tests/test_linalg.py::test_slogdet