Skip to content

Commit fa558f2

Browse files
authored
Merge pull request #239 from crusaderky/dask_sort
2 parents 0d66f80 + e6dbf3a commit fa558f2

File tree

3 files changed

+187
-26
lines changed

3 files changed

+187
-26
lines changed

Diff for: array_api_compat/dask/array/_aliases.py

+110-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3-
from ...common import _aliases
3+
from typing import Callable
4+
5+
from ...common import _aliases, array_namespace
46

57
from ..._internal import get_xp
68

@@ -29,24 +31,32 @@
2931
)
3032

3133
from typing import TYPE_CHECKING
34+
3235
if TYPE_CHECKING:
3336
from typing import Optional, Union
3437

35-
from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol
38+
from ...common._typing import (
39+
Device,
40+
Dtype,
41+
Array,
42+
NestedSequence,
43+
SupportsBufferProtocol,
44+
)
3645

3746
import dask.array as da
3847

3948
isdtype = get_xp(np)(_aliases.isdtype)
4049
unstack = get_xp(da)(_aliases.unstack)
4150

51+
4252
# da.astype doesn't respect copy=True
4353
def astype(
4454
x: Array,
4555
dtype: Dtype,
4656
/,
4757
*,
4858
copy: bool = True,
49-
device: Optional[Device] = None
59+
device: Optional[Device] = None,
5060
) -> Array:
5161
"""
5262
Array API compatibility wrapper for astype().
@@ -61,8 +71,10 @@ def astype(
6171
x = x.astype(dtype)
6272
return x.copy() if copy else x
6373

74+
6475
# Common aliases
6576

77+
6678
# This arange func is modified from the common one to
6779
# not pass stop/step as keyword arguments, which will cause
6880
# an error with dask
@@ -189,6 +201,7 @@ def asarray(
189201
concatenate as concat,
190202
)
191203

204+
192205
# dask.array.clip does not work unless all three arguments are provided.
193206
# Furthermore, the masking workaround in common._aliases.clip cannot work with
194207
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
@@ -205,8 +218,10 @@ def clip(
205218
See the corresponding documentation in the array library and/or the array API
206219
specification for more details.
207220
"""
221+
208222
def _isscalar(a):
209223
return isinstance(a, (int, float, type(None)))
224+
210225
min_shape = () if _isscalar(min) else min.shape
211226
max_shape = () if _isscalar(max) else max.shape
212227

@@ -228,12 +243,99 @@ def _isscalar(a):
228243

229244
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
230245

231-
# exclude these from all since dask.array has no sorting functions
232-
_da_unsupported = ['sort', 'argsort']
233246

234-
_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
247+
def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]:
248+
"""
249+
Make sure that Array is not broken into multiple chunks along axis.
250+
251+
Returns
252+
-------
253+
x : Array
254+
The input Array with a single chunk along axis.
255+
restore : Callable[Array, Array]
256+
function to apply to the output to rechunk it back into reasonable chunks
257+
"""
258+
if axis < 0:
259+
axis += x.ndim
260+
if x.numblocks[axis] < 2:
261+
return x, lambda x: x
262+
263+
# Break chunks on other axes in an attempt to keep chunk size low
264+
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
265+
266+
# Rather than reconstructing the original chunks, which can be a
267+
# very expensive affair, just break down oversized chunks without
268+
# incurring in any transfers over the network.
269+
# This has the downside of a risk of overchunking if the array is
270+
# then used in operations against other arrays that match the
271+
# original chunking pattern.
272+
return x, lambda x: x.rechunk()
273+
274+
275+
def sort(
276+
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
277+
) -> Array:
278+
"""
279+
Array API compatibility layer around the lack of sort() in Dask.
280+
281+
Warnings
282+
--------
283+
This function temporarily rechunks the array along `axis` to a single chunk.
284+
This can be extremely inefficient and can lead to out-of-memory errors.
285+
286+
See the corresponding documentation in the array library and/or the array API
287+
specification for more details.
288+
"""
289+
x, restore = _ensure_single_chunk(x, axis)
290+
291+
meta_xp = array_namespace(x._meta)
292+
x = da.map_blocks(
293+
meta_xp.sort,
294+
x,
295+
axis=axis,
296+
meta=x._meta,
297+
dtype=x.dtype,
298+
descending=descending,
299+
stable=stable,
300+
)
301+
302+
return restore(x)
235303

236-
__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'astype', 'acos',
304+
305+
def argsort(
306+
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
307+
) -> Array:
308+
"""
309+
Array API compatibility layer around the lack of argsort() in Dask.
310+
311+
See the corresponding documentation in the array library and/or the array API
312+
specification for more details.
313+
314+
Warnings
315+
--------
316+
This function temporarily rechunks the array along `axis` into a single chunk.
317+
This can be extremely inefficient and can lead to out-of-memory errors.
318+
"""
319+
x, restore = _ensure_single_chunk(x, axis)
320+
321+
meta_xp = array_namespace(x._meta)
322+
dtype = meta_xp.argsort(x._meta).dtype
323+
meta = meta_xp.astype(x._meta, dtype)
324+
x = da.map_blocks(
325+
meta_xp.argsort,
326+
x,
327+
axis=axis,
328+
meta=meta,
329+
dtype=dtype,
330+
descending=descending,
331+
stable=stable,
332+
)
333+
334+
return restore(x)
335+
336+
337+
__all__ = _aliases.__all__ + [
338+
'__array_namespace_info__', 'asarray', 'astype', 'acos',
237339
'acosh', 'asin', 'asinh', 'atan', 'atan2',
238340
'atanh', 'bitwise_left_shift', 'bitwise_invert',
239341
'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',
@@ -242,4 +344,4 @@ def _isscalar(a):
242344
'complex64', 'complex128', 'iinfo', 'finfo',
243345
'can_cast', 'result_type']
244346

245-
_all_ignore = ["get_xp", "da", "np"]
347+
_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]

Diff for: dask-xfails.txt

+5-17
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,13 @@ array_api_tests/test_array_object.py::test_setitem_masking
2323
# Various indexing errors
2424
array_api_tests/test_array_object.py::test_getitem_masking
2525

26-
# asarray(copy=False) is not yet implemented
27-
# copied from numpy xfails, TODO: should this pass with dask?
28-
array_api_tests/test_creation_functions.py::test_asarray_arrays
29-
3026
# zero division error, and typeerror: tuple indices must be integers or slices not tuple
3127
array_api_tests/test_creation_functions.py::test_eye
3228

3329
# finfo(float32).eps returns float32 but should return float
3430
array_api_tests/test_data_type_functions.py::test_finfo[float32]
3531

36-
# out[-1]=dask.aray<getitem ...> but should be some floating number
32+
# out[-1]=dask.array<getitem ...> but should be some floating number
3733
# (I think the test is not forcing the op to be computed?)
3834
array_api_tests/test_creation_functions.py::test_linspace
3935

@@ -48,15 +44,7 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0
4844
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
4945
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
5046

51-
# No sorting in dask
52-
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
53-
array_api_tests/test_has_names.py::test_has_names[sorting-sort]
54-
array_api_tests/test_sorting_functions.py::test_argsort
55-
array_api_tests/test_sorting_functions.py::test_sort
56-
array_api_tests/test_signatures.py::test_func_signature[argsort]
57-
array_api_tests/test_signatures.py::test_func_signature[sort]
58-
59-
# Array methods and attributes not already on np.ndarray cannot be wrapped
47+
# Array methods and attributes not already on da.Array cannot be wrapped
6048
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
6149
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
6250
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
@@ -76,6 +64,7 @@ array_api_tests/test_set_functions.py::test_unique_values
7664
# fails for ndim > 2
7765
array_api_tests/test_linalg.py::test_svdvals
7866
array_api_tests/test_linalg.py::test_cholesky
67+
7968
# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :(
8069
array_api_tests/test_linalg.py::test_tensordot
8170

@@ -105,6 +94,8 @@ array_api_tests/test_linalg.py::test_cross
10594
array_api_tests/test_linalg.py::test_det
10695
array_api_tests/test_linalg.py::test_eigh
10796
array_api_tests/test_linalg.py::test_eigvalsh
97+
array_api_tests/test_linalg.py::test_matrix_norm
98+
array_api_tests/test_linalg.py::test_matrix_rank
10899
array_api_tests/test_linalg.py::test_pinv
109100
array_api_tests/test_linalg.py::test_slogdet
110101
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
@@ -115,9 +106,6 @@ array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power]
115106
array_api_tests/test_has_names.py::test_has_names[linalg-pinv]
116107
array_api_tests/test_has_names.py::test_has_names[linalg-slogdet]
117108

118-
array_api_tests/test_linalg.py::test_matrix_norm
119-
array_api_tests/test_linalg.py::test_matrix_rank
120-
121109
# missing mode kw
122110
# https://github.com/dask/dask/issues/10388
123111
array_api_tests/test_linalg.py::test_qr

Diff for: tests/test_dask.py

+72-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from contextlib import contextmanager
22

3+
import array_api_strict
34
import dask
45
import numpy as np
56
import pytest
@@ -20,9 +21,10 @@ def assert_no_compute():
2021
Context manager that raises if at any point inside it anything calls compute()
2122
or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
2223
"""
24+
2325
def get(dsk, *args, **kwargs):
2426
raise AssertionError("Called compute() or persist()")
25-
27+
2628
with dask.config.set(scheduler=get):
2729
yield
2830

@@ -40,6 +42,7 @@ def test_assert_no_compute():
4042

4143
# Test no_compute for functions that use generic _aliases with xp=np
4244

45+
4346
def test_unary_ops_no_compute(xp):
4447
with assert_no_compute():
4548
a = xp.asarray([1.5, -1.5])
@@ -59,6 +62,7 @@ def test_matmul_tensordot_no_compute(xp):
5962

6063
# Test no_compute for functions that are fully bespoke for dask
6164

65+
6266
def test_asarray_no_compute(xp):
6367
with assert_no_compute():
6468
a = xp.arange(10)
@@ -88,6 +92,14 @@ def test_clip_no_compute(xp):
8892
xp.clip(a, 1, 8)
8993

9094

95+
@pytest.mark.parametrize("chunks", (5, 10))
96+
def test_sort_argsort_nocompute(xp, chunks):
97+
with assert_no_compute():
98+
a = xp.arange(10, chunks=chunks)
99+
xp.sort(a)
100+
xp.argsort(a)
101+
102+
91103
def test_generators_are_lazy(xp):
92104
"""
93105
Test that generator functions are fully lazy, e.g. that
@@ -106,3 +118,62 @@ def test_generators_are_lazy(xp):
106118
xp.ones_like(a)
107119
xp.empty_like(a)
108120
xp.full_like(a, fill_value=123)
121+
122+
123+
@pytest.mark.parametrize("axis", [0, 1])
124+
@pytest.mark.parametrize("func", ["sort", "argsort"])
125+
def test_sort_argsort_chunks(xp, func, axis):
126+
"""Test that sort and argsort are functionally correct when
127+
the array is chunked along the sort axis, e.g. the sort is
128+
not just local to each chunk.
129+
"""
130+
a = da.random.random((10, 10), chunks=(5, 5))
131+
actual = getattr(xp, func)(a, axis=axis)
132+
expect = getattr(np, func)(a.compute(), axis=axis)
133+
np.testing.assert_array_equal(actual, expect)
134+
135+
136+
@pytest.mark.parametrize(
137+
"shape,chunks",
138+
[
139+
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
140+
# Sort chunks can be 128 MiB each; no need for final rechunk.
141+
((20_000, 20_000), "auto"),
142+
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
143+
# Must sort on two 1.5 GiB chunks; benefits from final rechunk.
144+
((2, 2**30 * 3 // 16), "auto"),
145+
# 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting.
146+
# Surely the user must know what they're doing, so don't
147+
# perform the final rechunk.
148+
((2, 2**30 * 3 // 16), (1, -1)),
149+
],
150+
)
151+
@pytest.mark.parametrize("func", ["sort", "argsort"])
152+
def test_sort_argsort_chunk_size(xp, func, shape, chunks):
153+
"""
154+
Test that sort and argsort produce reasonably-sized chunks
155+
in the output array, even if they had to go through a singular
156+
huge one to perform the operation.
157+
"""
158+
a = da.random.random(shape, chunks=chunks)
159+
b = getattr(xp, func)(a)
160+
max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize
161+
assert (
162+
max_chunk_size <= 128 * 1024 * 1024 # 128 MiB
163+
or b.chunks == a.chunks
164+
)
165+
166+
167+
@pytest.mark.parametrize("func", ["sort", "argsort"])
168+
def test_sort_argsort_meta(xp, func):
169+
"""Test meta-namespace other than numpy"""
170+
typ = type(array_api_strict.asarray(0))
171+
a = da.random.random(10)
172+
b = a.map_blocks(array_api_strict.asarray)
173+
assert isinstance(b._meta, typ)
174+
c = getattr(xp, func)(b)
175+
assert isinstance(c._meta, typ)
176+
d = c.compute()
177+
# Note: np.sort(array_api_strict.asarray(0)) would return a numpy array
178+
assert isinstance(d, typ)
179+
np.testing.assert_array_equal(d, getattr(np, func)(a.compute()))

0 commit comments

Comments
 (0)