Skip to content

Commit 1816c16

Browse files
committed
ENH: Dask: sort and argsort
1 parent 8a79994 commit 1816c16

File tree

3 files changed

+157
-23
lines changed

3 files changed

+157
-23
lines changed

Diff for: array_api_compat/dask/array/_aliases.py

+81-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3-
from ...common import _aliases
3+
from typing import Literal
4+
5+
from ...common import _aliases, array_namespace
46

57
from ..._internal import get_xp
68

@@ -228,10 +230,84 @@ def _isscalar(a):
228230

229231
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
230232

231-
# exclude these from all since dask.array has no sorting functions
232-
_da_unsupported = ['sort', 'argsort']
233233

234-
_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
234+
def sort(
235+
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
236+
) -> Array:
237+
"""
238+
Array API compatibility layer around the lack of sort() in Dask.
239+
240+
Warnings
241+
--------
242+
This function temporarily rechunks the array along `axis` to a single chunk.
243+
This can be extremely inefficient and can lead to out-of-memory errors.
244+
245+
See the corresponding documentation in the array library and/or the array API
246+
specification for more details.
247+
"""
248+
return _sort_argsort("sort", x, axis=axis, descending=descending, stable=stable)
249+
250+
251+
def argsort(
252+
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
253+
) -> Array:
254+
"""
255+
Array API compatibility layer around the lack of argsort() in Dask.
256+
257+
See the corresponding documentation in the array library and/or the array API
258+
specification for more details.
259+
260+
Warnings
261+
--------
262+
This function temporarily rechunks the array along `axis` into a single chunk.
263+
This can be extremely inefficient and can lead to out-of-memory errors.
264+
"""
265+
return _sort_argsort("argsort", x, axis=axis, descending=descending, stable=stable)
266+
267+
268+
def _sort_argsort(
269+
func: Literal["sort", "argsort"],
270+
x: Array,
271+
/,
272+
*,
273+
axis: int,
274+
descending: bool,
275+
stable: bool,
276+
) -> Array:
277+
"""
278+
Implementation of sort() and argsort()
279+
280+
TODO Implement sort and argsort properly in Dask on top of the shuffle subsystem.
281+
"""
282+
if axis < 0:
283+
axis += x.ndim
284+
rechunk = False
285+
if x.numblocks[axis] > 1:
286+
rechunk = True
287+
# Break chunks on other axes in an attempt to keep chunk size low
288+
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})
289+
meta_xp = array_namespace(x._meta)
290+
x = da.map_blocks(
291+
getattr(meta_xp, func),
292+
x,
293+
axis=axis,
294+
descending=descending,
295+
stable=stable,
296+
dtype=x.dtype,
297+
meta=x._meta,
298+
)
299+
if rechunk:
300+
# rather than reconstructing the original chunks, which can be a
301+
# very expensive affair, just break down oversized chunks without
302+
# incurring in any transfers over the network.
303+
# This has the downside of a risk of overchunking if the array is
304+
# then used in operations against other arrays that match the
305+
# original chunking pattern.
306+
x = x.rechunk()
307+
return x
308+
309+
310+
_common_aliases = _aliases.__all__
235311

236312
__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
237313
'acosh', 'asin', 'asinh', 'atan', 'atan2',
@@ -242,4 +318,4 @@ def _isscalar(a):
242318
'complex64', 'complex128', 'iinfo', 'finfo',
243319
'can_cast', 'result_type']
244320

245-
_all_ignore = ["get_xp", "da", "np"]
321+
_all_ignore = ["Literal", "array_namespace", "get_xp", "da", "np"]

Diff for: dask-xfails.txt

+5-17
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,13 @@ array_api_tests/test_array_object.py::test_setitem_masking
2323
# Various indexing errors
2424
array_api_tests/test_array_object.py::test_getitem_masking
2525

26-
# asarray(copy=False) is not yet implemented
27-
# copied from numpy xfails, TODO: should this pass with dask?
28-
array_api_tests/test_creation_functions.py::test_asarray_arrays
29-
3026
# zero division error, and typeerror: tuple indices must be integers or slices not tuple
3127
array_api_tests/test_creation_functions.py::test_eye
3228

3329
# finfo(float32).eps returns float32 but should return float
3430
array_api_tests/test_data_type_functions.py::test_finfo[float32]
3531

36-
# out[-1]=dask.aray<getitem ...> but should be some floating number
32+
# out[-1]=dask.array<getitem ...> but should be some floating number
3733
# (I think the test is not forcing the op to be computed?)
3834
array_api_tests/test_creation_functions.py::test_linspace
3935

@@ -48,15 +44,7 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0
4844
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
4945
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
5046

51-
# No sorting in dask
52-
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
53-
array_api_tests/test_has_names.py::test_has_names[sorting-sort]
54-
array_api_tests/test_sorting_functions.py::test_argsort
55-
array_api_tests/test_sorting_functions.py::test_sort
56-
array_api_tests/test_signatures.py::test_func_signature[argsort]
57-
array_api_tests/test_signatures.py::test_func_signature[sort]
58-
59-
# Array methods and attributes not already on np.ndarray cannot be wrapped
47+
# Array methods and attributes not already on da.Array cannot be wrapped
6048
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
6149
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
6250
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
@@ -76,6 +64,7 @@ array_api_tests/test_set_functions.py::test_unique_values
7664
# fails for ndim > 2
7765
array_api_tests/test_linalg.py::test_svdvals
7866
array_api_tests/test_linalg.py::test_cholesky
67+
7968
# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :(
8069
array_api_tests/test_linalg.py::test_tensordot
8170

@@ -105,6 +94,8 @@ array_api_tests/test_linalg.py::test_cross
10594
array_api_tests/test_linalg.py::test_det
10695
array_api_tests/test_linalg.py::test_eigh
10796
array_api_tests/test_linalg.py::test_eigvalsh
97+
array_api_tests/test_linalg.py::test_matrix_norm
98+
array_api_tests/test_linalg.py::test_matrix_rank
10899
array_api_tests/test_linalg.py::test_pinv
109100
array_api_tests/test_linalg.py::test_slogdet
110101
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
@@ -115,9 +106,6 @@ array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power]
115106
array_api_tests/test_has_names.py::test_has_names[linalg-pinv]
116107
array_api_tests/test_has_names.py::test_has_names[linalg-slogdet]
117108

118-
array_api_tests/test_linalg.py::test_matrix_norm
119-
array_api_tests/test_linalg.py::test_matrix_rank
120-
121109
# missing mode kw
122110
# https://github.com/dask/dask/issues/10388
123111
array_api_tests/test_linalg.py::test_qr

Diff for: tests/test_dask.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from contextlib import contextmanager
22

3+
import array_api_strict
34
import dask
45
import numpy as np
56
import pytest
@@ -20,9 +21,10 @@ def assert_no_compute():
2021
Context manager that raises if at any point inside it anything calls compute()
2122
or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
2223
"""
24+
2325
def get(dsk, *args, **kwargs):
2426
raise AssertionError("Called compute() or persist()")
25-
27+
2628
with dask.config.set(scheduler=get):
2729
yield
2830

@@ -40,6 +42,7 @@ def test_assert_no_compute():
4042

4143
# Test no_compute for functions that use generic _aliases with xp=np
4244

45+
4346
def test_unary_ops_no_compute(xp):
4447
with assert_no_compute():
4548
a = xp.asarray([1.5, -1.5])
@@ -59,6 +62,7 @@ def test_matmul_tensordot_no_compute(xp):
5962

6063
# Test no_compute for functions that are fully bespoke for dask
6164

65+
6266
def test_asarray_no_compute(xp):
6367
with assert_no_compute():
6468
a = xp.arange(10)
@@ -88,6 +92,14 @@ def test_clip_no_compute(xp):
8892
xp.clip(a, 1, 8)
8993

9094

95+
@pytest.mark.parametrize("chunks", (5, 10))
96+
def test_sort_argsort_nocompute(xp, chunks):
97+
with assert_no_compute():
98+
a = xp.arange(10, chunks=chunks)
99+
xp.sort(a)
100+
xp.argsort(a)
101+
102+
91103
def test_generators_are_lazy(xp):
92104
"""
93105
Test that generator functions are fully lazy, e.g. that
@@ -106,3 +118,61 @@ def test_generators_are_lazy(xp):
106118
xp.ones_like(a)
107119
xp.empty_like(a)
108120
xp.full_like(a, fill_value=123)
121+
122+
123+
@pytest.mark.parametrize("axis", [0, 1])
124+
@pytest.mark.parametrize("func", ["sort", "argsort"])
125+
def test_sort_argsort_chunks(xp, func, axis):
126+
"""Test that sort and argsort are functionally correct when
127+
the array is chunked along the sort axis, e.g. the sort is
128+
not just local to each chunk.
129+
"""
130+
a = da.random.random((10, 10), chunks=(5, 5))
131+
actual = getattr(xp, func)(a, axis=axis)
132+
expect = getattr(np, func)(a.compute(), axis=axis)
133+
np.testing.assert_array_equal(actual, expect)
134+
135+
136+
@pytest.mark.parametrize(
137+
"shape,chunks",
138+
[
139+
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
140+
# Sort chunks can be 128 MiB each; no need for final rechunk.
141+
((20_000, 20_000), "auto"),
142+
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
143+
# Must sort on two 1.5 GiB chunks; benefits from final rechunk.
144+
((2, 2**30 * 3 // 16), "auto"),
145+
# 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting.
146+
# Surely the user must know what they're doing, so don't
147+
# perform the final rechunk.
148+
((2, 2**30 * 3 // 16), (1, -1)),
149+
],
150+
)
151+
@pytest.mark.parametrize("func", ["sort", "argsort"])
152+
def test_sort_argsort_chunk_size(xp, func, shape, chunks):
153+
"""
154+
Test that sort and argsort produce reasonably-sized chunks
155+
in the output array, even if they had to go through a singular
156+
huge one to perform the operation.
157+
"""
158+
a = da.random.random(shape, chunks=chunks)
159+
b = getattr(xp, func)(a)
160+
max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize
161+
assert (
162+
max_chunk_size <= 128 * 1024 * 1024 # 128 MiB
163+
or b.chunks == a.chunks
164+
)
165+
166+
167+
@pytest.mark.parametrize("func", ["sort", "argsort"])
168+
def test_sort_argsort_meta(xp, func):
169+
"""Test meta-namespace other than numpy"""
170+
typ = type(array_api_strict.asarray(0))
171+
a = da.random.random(10)
172+
b = a.map_blocks(array_api_strict.asarray)
173+
assert isinstance(b._meta, typ)
174+
c = getattr(xp, func)(b)
175+
assert isinstance(c._meta, typ)
176+
d = c.compute()
177+
assert isinstance(d, typ)
178+
np.testing.assert_array_equal(d, getattr(np, func)(a.compute()))

0 commit comments

Comments
 (0)