Skip to content

Commit 8ee1613

Browse files
committed
Add dask to array-api-compat
1 parent 874c2ff commit 8ee1613

12 files changed

+208
-6
lines changed

Diff for: .github/workflows/array-api-tests-dask.yml

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
name: Array API Tests (Dask)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-dask:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: dask

Diff for: array_api_compat/common/_aliases.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,8 @@ def _asarray(
303303
import numpy as xp
304304
elif namespace == 'cupy':
305305
import cupy as xp
306+
elif namespace == 'dask':
307+
import dask.array as xp
306308
else:
307309
raise ValueError("Unrecognized namespace argument to asarray()")
308310

@@ -322,7 +324,9 @@ def _asarray(
322324
if copy in COPY_FALSE:
323325
# copy=False is not yet implemented in xp.asarray
324326
raise NotImplementedError("copy=False is not yet implemented")
325-
if isinstance(obj, xp.ndarray):
327+
# TODO: This feels wrong (__array__ is not in the standard)
328+
# Dask doesn't support DLPack, though, so, this'll do
329+
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
326330
if dtype is not None and obj.dtype != dtype:
327331
copy = True
328332
if copy in COPY_TRUE:

Diff for: array_api_compat/common/_helpers.py

+24
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ def _is_torch_array(x):
4040
# TODO: Should we reject ndarray subclasses?
4141
return isinstance(x, torch.Tensor)
4242

43+
def _is_dask_array(x):
44+
# Avoid importing dask if it isn't already
45+
if 'dask.array' not in sys.modules:
46+
return False
47+
48+
import dask.array
49+
50+
# TODO: Should we reject ndarray subclasses?
51+
return isinstance(x, dask.array.Array)
52+
4353
def is_array_api_obj(x):
4454
"""
4555
Check if x is an array API compatible array object.
@@ -97,6 +107,13 @@ def your_function(x, y):
97107
else:
98108
import torch
99109
namespaces.add(torch)
110+
elif _is_dask_array(x):
111+
_check_api_version(api_version)
112+
if _use_compat:
113+
from .. import dask as dask_namespace
114+
namespaces.add(dask_namespace)
115+
else:
116+
raise TypeError("_use_compat cannot be False if input array is a dask array!")
100117
else:
101118
# TODO: Support Python scalars?
102119
raise TypeError("The input is not a supported array type")
@@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
219236
return _cupy_to_device(x, device, stream=stream)
220237
elif _is_torch_array(x):
221238
return _torch_to_device(x, device, stream=stream)
239+
elif _is_dask_array(x):
240+
if stream is not None:
241+
raise ValueError("The stream argument to to_device() is not supported")
242+
# TODO: What if our array is on the GPU already?
243+
if device == 'cpu':
244+
return x
245+
raise ValueError(f"Unsupported device {device!r}")
222246
return x.to_device(device, stream=stream)
223247

224248
def size(x):

Diff for: array_api_compat/dask/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from dask.array import *
2+
3+
# These imports may overwrite names from the import * above.
4+
from ._aliases import *
5+
6+
__array_api_version__ = '2022.12'

Diff for: array_api_compat/dask/_aliases.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from ..common import _aliases
2+
3+
from .._internal import get_xp
4+
5+
import numpy as np
6+
from numpy import (
7+
# Constants
8+
e,
9+
inf,
10+
nan,
11+
pi,
12+
newaxis,
13+
# Dtypes
14+
bool_ as bool,
15+
float32,
16+
float64,
17+
int8,
18+
int16,
19+
int32,
20+
int64,
21+
uint8,
22+
uint16,
23+
uint32,
24+
uint64,
25+
complex64,
26+
complex128,
27+
iinfo,
28+
finfo,
29+
can_cast,
30+
result_type,
31+
)
32+
33+
import dask.array as da
34+
35+
isdtype = get_xp(np)(_aliases.isdtype)
36+
astype = _aliases.astype
37+
38+
# Common aliases
39+
arange = get_xp(da)(_aliases.arange)
40+
41+
from functools import partial
42+
asarray = partial(_aliases._asarray, namespace='dask')
43+
asarray.__doc__ = _aliases._asarray.__doc__
44+
45+
linspace = get_xp(da)(_aliases.linspace)
46+
eye = get_xp(da)(_aliases.eye)
47+
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
48+
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
49+
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
50+
unique_all = get_xp(da)(_aliases.unique_all)
51+
unique_counts = get_xp(da)(_aliases.unique_counts)
52+
unique_inverse = get_xp(da)(_aliases.unique_inverse)
53+
unique_values = get_xp(da)(_aliases.unique_values)
54+
permute_dims = get_xp(da)(_aliases.permute_dims)
55+
std = get_xp(da)(_aliases.std)
56+
var = get_xp(da)(_aliases.var)
57+
empty = get_xp(da)(_aliases.empty)
58+
empty_like = get_xp(da)(_aliases.empty_like)
59+
full = get_xp(da)(_aliases.full)
60+
full_like = get_xp(da)(_aliases.full_like)
61+
ones = get_xp(da)(_aliases.ones)
62+
ones_like = get_xp(da)(_aliases.ones_like)
63+
zeros = get_xp(da)(_aliases.zeros)
64+
zeros_like = get_xp(da)(_aliases.zeros_like)
65+
reshape = get_xp(da)(_aliases.reshape)
66+
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
67+
vecdot = get_xp(da)(_aliases.vecdot)
68+
69+
70+
71+
from dask.array import (
72+
# Element wise aliases
73+
arccos as acos,
74+
arccosh as acosh,
75+
arcsin as asin,
76+
arcsinh as asinh,
77+
arctan as atan,
78+
arctan2 as atan2,
79+
arctanh as atanh,
80+
left_shift as bitwise_left_shift,
81+
right_shift as bitwise_right_shift,
82+
invert as bitwise_invert,
83+
power as pow,
84+
# Other
85+
concatenate as concat,
86+
87+
)
88+

Diff for: dask-skips.txt

Whitespace-only changes.

Diff for: dask-xfails.txt

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# finfo(float32).eps returns float32 but should return float
2+
array_api_tests/test_data_type_functions.py::test_finfo[float32]
3+
4+
# No sorting in dask
5+
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
6+
array_api_tests/test_has_names.py::test_has_names[sorting-sort]
7+
8+
# Array methods and attributes not already on np.ndarray cannot be wrapped
9+
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
10+
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
11+
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
12+
array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]
13+
14+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] - AssertionError: out[0]=0, but should be (x1 + x2[0])=65536 [__add__()]
15+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_bitwise_and(ctx=BinaryParamContext(<__and__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the ...
16+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_bitwise_right_shift(ctx=BinaryParamContext(<__rshift__(x1, x2)>), data=data(...)) produces unreliable results: Falsif...
17+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 994.44ms, which exceeds the deadline of 800.00ms
18+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] - ValueError: Inferred dtype from function 'xor' was 'uint64' but got 'int16', which can't be cast using casting='same_kind'
19+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_ceil - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions)
20+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_divide(ctx=BinaryParamContext(<divide(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the first ...
21+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions)
22+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1015.43ms, which exceeds the deadline of 800.00ms
23+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 900.99ms, which exceeds the deadline of 800.00ms
24+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1106.49ms, which exceeds the deadline of 800.00ms
25+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater(ctx=BinaryParamContext(<__gt__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the first...
26+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater_equal(ctx=BinaryParamContext(<greater_equal(x1, x2)>), data=data(...)) produces unreliable results: Falsified...
27+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater_equal(ctx=BinaryParamContext(<__ge__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the...
28+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 961.84ms, which exceeds the deadline of 800.00ms
29+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 980.68ms, which exceeds the deadline of 800.00ms
30+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1043.47ms, which exceeds the deadline of 800.00ms
31+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1011.76ms, which exceeds the deadline of 800.00ms
32+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] - AssertionError: out[0]=0, but should be (x1 * x2[0])=256 [multiply()]
33+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] - AssertionError: out[0]=2, but should be (x1 * x2[0])=258 [__mul__()]
34+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions)
35+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1034.64ms, which exceeds the deadline of 800.00ms
36+
#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_trunc - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions)
37+
38+
#FAILED array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error - Failed: DID NOT RAISE <class 'Exception'>
39+
40+
# Fails because shape is NaN since we don't materialize it yet
41+
#FAILED array_api_tests/test_searching_functions.py::test_nonzero - AssertionError: prod(out[0].shape)=nan, but should be prod(out[0].shape)=nan
42+
#FAILED array_api_tests/test_set_functions.py::test_unique_all - AssertionError: out.indices.shape=(nan,), but should be out.values.shape=(nan,)
43+
#FAILED array_api_tests/test_set_functions.py::test_unique_counts - AssertionError: out.counts.shape=(nan,), but should be out.values.shape=(nan,)
44+
45+
# Needs investigation
46+
#FAILED array_api_tests/test_set_functions.py::test_unique_inverse - TypeError: 'float' object cannot be interpreted as an integer
47+
#FAILED array_api_tests/test_set_functions.py::test_unique_values - TypeError: 'float' object cannot be interpreted as an integer

Diff for: tests/test_array_namespace.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
import pytest
77

8-
9-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
8+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
109
@pytest.mark.parametrize("api_version", [None, '2021.12'])
1110
def test_array_namespace(library, api_version):
1211
lib = import_(library)
@@ -17,6 +16,8 @@ def test_array_namespace(library, api_version):
1716
if 'array_api' in library:
1817
assert namespace == lib
1918
else:
19+
if library == "dask.array":
20+
library = "dask"
2021
assert namespace == getattr(array_api_compat, library)
2122

2223

Diff for: tests/test_common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from numpy.testing import assert_allclose
77

8-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
8+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"])
99
def test_to_device_host(library):
1010
# different libraries have different semantics
1111
# for DtoH transfers; ensure that we support a portable

Diff for: tests/test_isdtype.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def isdtype_(dtype_, kind):
6464
assert type(res) is bool
6565
return res
6666

67-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
67+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"])
6868
def test_isdtype_spec_dtypes(library):
6969
xp = import_('array_api_compat.' + library)
7070

@@ -98,7 +98,7 @@ def test_isdtype_spec_dtypes(library):
9898
'bfloat16',
9999
]
100100

101-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
101+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"])
102102
@pytest.mark.parametrize("dtype_", additional_dtypes)
103103
def test_isdtype_additional_dtypes(library, dtype_):
104104
xp = import_('array_api_compat.' + library)

Diff for: tests/test_vendoring.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ def test_vendoring_cupy():
1717
def test_vendoring_torch():
1818
from vendor_test import uses_torch
1919
uses_torch._test_torch()
20+
21+
def test_vendoring_torch():
22+
from vendor_test import uses_torch
23+
uses_torch._test_torch()

Diff for: vendor_test/uses_dask.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Basic test that vendoring works
2+
3+
from .vendored._compat import dask as dask_compat
4+
5+
import dask.array as da
6+
import numpy as np
7+
8+
def _test_numpy():
9+
a = dask_compat.asarray([1., 2., 3.])
10+
b = dask_compat.arange(3, dtype=dask_compat.float32)
11+
12+
# np.pow does not exist. Update this to use something else if it is added
13+
res = dask_compat.pow(a, b)
14+
assert res.dtype == dask_compat.float64 == np.float64
15+
assert isinstance(a, da.array)
16+
assert isinstance(b, da.array)
17+
assert isinstance(res, da.array)
18+
19+
np.testing.assert_allclose(res, [1., 2., 9.])

0 commit comments

Comments
 (0)