Skip to content

Commit 560d189

Browse files
authored
Merge pull request #89 from lithomas1/fix-dask
Fix dask
2 parents 40603a9 + 1f7e47c commit 560d189

File tree

9 files changed

+80
-15
lines changed

9 files changed

+80
-15
lines changed

Diff for: .gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
# macOS specific iles
132+
.DS_Store

Diff for: README.md

+26-2
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ part of the specification but which are useful for using the array API:
126126
[`x.device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.device.html)
127127
in the array API specification. Included because `numpy.ndarray` does not
128128
include the `device` attribute and this library does not wrap or extend the
129-
array object. Note that for NumPy, `device(x)` is always `"cpu"`.
129+
array object. Note that for NumPy and dask, `device(x)` is always `"cpu"`.
130130

131131
- `to_device(x, device, /, *, stream=None)`: Equivalent to
132132
[`x.to_device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.to_device.html).
133-
Included because neither NumPy's, CuPy's, nor PyTorch's array objects
133+
Included because neither NumPy's, CuPy's, Dask's, nor PyTorch's array objects
134134
include this method. For NumPy, this function effectively does nothing since
135135
the only supported device is the CPU, but for CuPy, this method supports
136136
CuPy CUDA
@@ -241,6 +241,30 @@ Unlike the other libraries supported here, JAX array API support is contained
241241
entirely in the JAX library. The JAX array API support is tracked at
242242
https://github.com/google/jax/issues/18353.
243243

244+
## Dask
245+
246+
If you're using dask with numpy, many of the same limitations that apply to numpy
247+
will also apply to dask. Besides those differences, other limitations include missing
248+
sort functionality (no `sort` or `argsort`), and limited support for the optional `linalg`
249+
and `fft` extensions.
250+
251+
In particular, the `fft` namespace is not compliant with the array API spec. Any functions
252+
that you find under the `fft` namespace are the original, unwrapped functions under [`dask.array.fft`](https://docs.dask.org/en/latest/array-api.html#fast-fourier-transforms), which may or may not be Array API compliant. Use at your own risk!
253+
254+
For `linalg`, several methods are missing, for example:
255+
- `cross`
256+
- `det`
257+
- `eigh`
258+
- `eigvalsh`
259+
- `matrix_power`
260+
- `pinv`
261+
- `slogdet`
262+
- `matrix_norm`
263+
- `matrix_rank`
264+
Other methods may only be partially implemented or return incorrect results at times.
265+
266+
The minimum supported Dask version is 2023.12.0.
267+
244268
## Vendoring
245269

246270
This library supports vendoring as an installation method. To vendor the

Diff for: array_api_compat/common/_helpers.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,16 @@ def _check_device(xp, device):
159159
if device not in ["cpu", None]:
160160
raise ValueError(f"Unsupported device for NumPy: {device!r}")
161161

162-
# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
162+
# Placeholder object to represent the dask device
163+
# when the array backend is not the CPU.
164+
# (since it is not easy to tell which device a dask array is on)
165+
class _dask_device:
166+
def __repr__(self):
167+
return "DASK_DEVICE"
168+
169+
_DASK_DEVICE = _dask_device()
170+
171+
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
163172
# or cupy.ndarray. They are not included in array objects of this library
164173
# because this library just reuses the respective ndarray classes without
165174
# wrapping or subclassing them. These helper functions can be used instead of
@@ -181,7 +190,17 @@ def device(x: Array, /) -> Device:
181190
"""
182191
if is_numpy_array(x):
183192
return "cpu"
184-
if is_jax_array(x):
193+
elif is_dask_array(x):
194+
# Peek at the metadata of the jax array to determine type
195+
try:
196+
import numpy as np
197+
if isinstance(x._meta, np.ndarray):
198+
# Must be on CPU since backed by numpy
199+
return "cpu"
200+
except ImportError:
201+
pass
202+
return _DASK_DEVICE
203+
elif is_jax_array(x):
185204
# JAX has .device() as a method, but it is being deprecated so that it
186205
# can become a property, in accordance with the standard. In order for
187206
# this function to not break when JAX makes the flip, we check for

Diff for: array_api_compat/dask/__init__.py

Whitespace-only changes.

Diff for: array_api_compat/dask/array/_aliases.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
from typing import TYPE_CHECKING
3737
if TYPE_CHECKING:
3838
from typing import Optional, Union
39-
from ...common._typing import ndarray, Device, Dtype
39+
40+
from ...common._typing import Device, Dtype, Array
4041

4142
import dask.array as da
4243

@@ -60,7 +61,7 @@ def _dask_arange(
6061
dtype: Optional[Dtype] = None,
6162
device: Optional[Device] = None,
6263
**kwargs,
63-
) -> ndarray:
64+
) -> Array:
6465
_check_device(xp, device)
6566
args = [start]
6667
if stop is not None:

Diff for: array_api_compat/dask/array/linalg.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from dask.array.linalg import svd
43
from ...common import _linalg
54
from ..._internal import get_xp
65

@@ -16,8 +15,7 @@
1615

1716
from typing import TYPE_CHECKING
1817
if TYPE_CHECKING:
19-
from typing import Union, Tuple
20-
from ...common._typing import ndarray
18+
from ...common._typing import Array
2119

2220
# cupy.linalg doesn't have __all__. If it is added, replace this with
2321
#
@@ -39,7 +37,16 @@
3937
matrix_rank = get_xp(da)(_linalg.matrix_rank)
4038
matrix_norm = get_xp(da)(_linalg.matrix_norm)
4139

42-
def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
40+
41+
# Wrap the svd functions to not pass full_matrices to dask
42+
# when full_matrices=False (as that is the default behavior for dask),
43+
# and dask doesn't have the full_matrices keyword
44+
def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
45+
if full_matrices:
46+
raise ValueError("full_matrics=True is not supported by dask.")
47+
return da.linalg.svd(x, **kwargs)
48+
49+
def svdvals(x: Array) -> Array:
4350
# TODO: can't avoid computing U or V for dask
4451
_, s, _ = svd(x)
4552
return s

Diff for: dask-xfails.txt

+15-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,20 @@ array_api_tests/test_linalg.py::test_tensordot
7878
# probably same reason for failing as numpy
7979
array_api_tests/test_linalg.py::test_trace
8080

81+
# AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)]
82+
array_api_tests/test_linalg.py::test_linalg_tensordot
83+
84+
# AssertionError: out.shape=(1,), but should be () [linalg.vector_norm(keepdims=True)]
85+
array_api_tests/test_linalg.py::test_vector_norm
86+
87+
# ZeroDivisionError in dask's normalize_chunks/auto_chunks internals
88+
array_api_tests/test_linalg.py::test_inv
89+
array_api_tests/test_linalg.py::test_matrix_power
90+
91+
# did not raise error for invalid shapes
92+
array_api_tests/test_linalg.py::test_matmul
93+
array_api_tests/test_linalg.py::test_linalg_matmul
94+
8195
# Linalg - these don't exist in dask
8296
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross]
8397
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det]
@@ -88,6 +102,7 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv]
88102
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet]
89103
array_api_tests/test_linalg.py::test_cross
90104
array_api_tests/test_linalg.py::test_det
105+
array_api_tests/test_linalg.py::test_eigh
91106
array_api_tests/test_linalg.py::test_eigvalsh
92107
array_api_tests/test_linalg.py::test_pinv
93108
array_api_tests/test_linalg.py::test_slogdet
@@ -112,7 +127,6 @@ array_api_tests/test_linalg.py::test_solve
112127
# missing full_matrics kw
113128
# https://github.com/dask/dask/issues/10389
114129
# also only supports 2-d inputs
115-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd]
116130
array_api_tests/test_linalg.py::test_svd
117131

118132
# Missing dlpack stuff

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
setup(
99
name='array_api_compat',
1010
version=array_api_compat.__version__,
11-
packages=find_packages(include=['array_api_compat*']),
11+
packages=find_packages(include=["array_api_compat*"]),
1212
author="Consortium for Python Data API Standards",
1313
description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard",
1414
long_description=long_description,

Diff for: tests/test_common.py

-3
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ def test_is_xp_array(library, func):
3131

3232
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
3333
def test_device(library):
34-
if library == "dask.array":
35-
pytest.xfail("device() needs to be fixed for dask")
36-
3734
xp = import_(library, wrapper=True)
3835

3936
# We can't test much for device() and to_device() other than that

0 commit comments

Comments
 (0)