Skip to content

ENH: Dask: sort and argsort #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 110 additions & 8 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from ...common import _aliases
from typing import Callable

from ...common import _aliases, array_namespace

from ..._internal import get_xp

@@ -29,24 +31,32 @@
)

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Union

from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol
from ...common._typing import (
Device,
Dtype,
Array,
NestedSequence,
SupportsBufferProtocol,
)

import dask.array as da

isdtype = get_xp(np)(_aliases.isdtype)
unstack = get_xp(da)(_aliases.unstack)


# da.astype doesn't respect copy=True
def astype(
x: Array,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Optional[Device] = None
device: Optional[Device] = None,
) -> Array:
"""
Array API compatibility wrapper for astype().
@@ -61,8 +71,10 @@ def astype(
x = x.astype(dtype)
return x.copy() if copy else x


# Common aliases


# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask
@@ -189,6 +201,7 @@ def asarray(
concatenate as concat,
)


# dask.array.clip does not work unless all three arguments are provided.
# Furthermore, the masking workaround in common._aliases.clip cannot work with
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
@@ -205,8 +218,10 @@ def clip(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""

def _isscalar(a):
return isinstance(a, (int, float, type(None)))

min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape

@@ -228,12 +243,99 @@ def _isscalar(a):

return astype(da.minimum(da.maximum(x, min), max), x.dtype)

# exclude these from all since dask.array has no sorting functions
_da_unsupported = ['sort', 'argsort']

_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]:
"""
Make sure that Array is not broken into multiple chunks along axis.

Returns
-------
x : Array
The input Array with a single chunk along axis.
restore : Callable[Array, Array]
function to apply to the output to rechunk it back into reasonable chunks
"""
if axis < 0:
axis += x.ndim
if x.numblocks[axis] < 2:
return x, lambda x: x

# Break chunks on other axes in an attempt to keep chunk size low
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})

# Rather than reconstructing the original chunks, which can be a
# very expensive affair, just break down oversized chunks without
# incurring in any transfers over the network.
# This has the downside of a risk of overchunking if the array is
# then used in operations against other arrays that match the
# original chunking pattern.
return x, lambda x: x.rechunk()


def sort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
"""
Array API compatibility layer around the lack of sort() in Dask.

Warnings
--------
This function temporarily rechunks the array along `axis` to a single chunk.
This can be extremely inefficient and can lead to out-of-memory errors.

See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
x, restore = _ensure_single_chunk(x, axis)

meta_xp = array_namespace(x._meta)
x = da.map_blocks(
meta_xp.sort,
x,
axis=axis,
meta=x._meta,
dtype=x.dtype,
descending=descending,
stable=stable,
)

return restore(x)

__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'astype', 'acos',

def argsort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
"""
Array API compatibility layer around the lack of argsort() in Dask.

See the corresponding documentation in the array library and/or the array API
specification for more details.

Warnings
--------
This function temporarily rechunks the array along `axis` into a single chunk.
This can be extremely inefficient and can lead to out-of-memory errors.
"""
x, restore = _ensure_single_chunk(x, axis)

meta_xp = array_namespace(x._meta)
dtype = meta_xp.argsort(x._meta).dtype
meta = meta_xp.astype(x._meta, dtype)
x = da.map_blocks(
meta_xp.argsort,
x,
axis=axis,
meta=meta,
dtype=dtype,
descending=descending,
stable=stable,
)

return restore(x)


__all__ = _aliases.__all__ + [
'__array_namespace_info__', 'asarray', 'astype', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',
@@ -242,4 +344,4 @@ def _isscalar(a):
'complex64', 'complex128', 'iinfo', 'finfo',
'can_cast', 'result_type']

_all_ignore = ["get_xp", "da", "np"]
_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]
22 changes: 5 additions & 17 deletions dask-xfails.txt
Original file line number Diff line number Diff line change
@@ -23,17 +23,13 @@ array_api_tests/test_array_object.py::test_setitem_masking
# Various indexing errors
array_api_tests/test_array_object.py::test_getitem_masking

# asarray(copy=False) is not yet implemented
# copied from numpy xfails, TODO: should this pass with dask?
array_api_tests/test_creation_functions.py::test_asarray_arrays

# zero division error, and typeerror: tuple indices must be integers or slices not tuple
array_api_tests/test_creation_functions.py::test_eye

# finfo(float32).eps returns float32 but should return float
array_api_tests/test_data_type_functions.py::test_finfo[float32]

# out[-1]=dask.aray<getitem ...> but should be some floating number
# out[-1]=dask.array<getitem ...> but should be some floating number
# (I think the test is not forcing the op to be computed?)
array_api_tests/test_creation_functions.py::test_linspace

@@ -48,15 +44,7 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]

# No sorting in dask
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
array_api_tests/test_has_names.py::test_has_names[sorting-sort]
array_api_tests/test_sorting_functions.py::test_argsort
array_api_tests/test_sorting_functions.py::test_sort
array_api_tests/test_signatures.py::test_func_signature[argsort]
array_api_tests/test_signatures.py::test_func_signature[sort]

# Array methods and attributes not already on np.ndarray cannot be wrapped
# Array methods and attributes not already on da.Array cannot be wrapped
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
@@ -76,6 +64,7 @@ array_api_tests/test_set_functions.py::test_unique_values
# fails for ndim > 2
array_api_tests/test_linalg.py::test_svdvals
array_api_tests/test_linalg.py::test_cholesky

# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :(
array_api_tests/test_linalg.py::test_tensordot

@@ -105,6 +94,8 @@ array_api_tests/test_linalg.py::test_cross
array_api_tests/test_linalg.py::test_det
array_api_tests/test_linalg.py::test_eigh
array_api_tests/test_linalg.py::test_eigvalsh
array_api_tests/test_linalg.py::test_matrix_norm
array_api_tests/test_linalg.py::test_matrix_rank
array_api_tests/test_linalg.py::test_pinv
array_api_tests/test_linalg.py::test_slogdet
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
@@ -115,9 +106,6 @@ array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power]
array_api_tests/test_has_names.py::test_has_names[linalg-pinv]
array_api_tests/test_has_names.py::test_has_names[linalg-slogdet]

array_api_tests/test_linalg.py::test_matrix_norm
array_api_tests/test_linalg.py::test_matrix_rank

# missing mode kw
# https://github.com/dask/dask/issues/10388
array_api_tests/test_linalg.py::test_qr
73 changes: 72 additions & 1 deletion tests/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager

import array_api_strict
import dask
import numpy as np
import pytest
@@ -20,9 +21,10 @@ def assert_no_compute():
Context manager that raises if at any point inside it anything calls compute()
or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
"""

def get(dsk, *args, **kwargs):
raise AssertionError("Called compute() or persist()")

with dask.config.set(scheduler=get):
yield

@@ -40,6 +42,7 @@ def test_assert_no_compute():

# Test no_compute for functions that use generic _aliases with xp=np


def test_unary_ops_no_compute(xp):
with assert_no_compute():
a = xp.asarray([1.5, -1.5])
@@ -59,6 +62,7 @@ def test_matmul_tensordot_no_compute(xp):

# Test no_compute for functions that are fully bespoke for dask


def test_asarray_no_compute(xp):
with assert_no_compute():
a = xp.arange(10)
@@ -88,6 +92,14 @@ def test_clip_no_compute(xp):
xp.clip(a, 1, 8)


@pytest.mark.parametrize("chunks", (5, 10))
def test_sort_argsort_nocompute(xp, chunks):
with assert_no_compute():
a = xp.arange(10, chunks=chunks)
xp.sort(a)
xp.argsort(a)


def test_generators_are_lazy(xp):
"""
Test that generator functions are fully lazy, e.g. that
@@ -106,3 +118,62 @@ def test_generators_are_lazy(xp):
xp.ones_like(a)
xp.empty_like(a)
xp.full_like(a, fill_value=123)


@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_chunks(xp, func, axis):
"""Test that sort and argsort are functionally correct when
the array is chunked along the sort axis, e.g. the sort is
not just local to each chunk.
"""
a = da.random.random((10, 10), chunks=(5, 5))
actual = getattr(xp, func)(a, axis=axis)
expect = getattr(np, func)(a.compute(), axis=axis)
np.testing.assert_array_equal(actual, expect)


@pytest.mark.parametrize(
"shape,chunks",
[
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
# Sort chunks can be 128 MiB each; no need for final rechunk.
((20_000, 20_000), "auto"),
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
# Must sort on two 1.5 GiB chunks; benefits from final rechunk.
((2, 2**30 * 3 // 16), "auto"),
# 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting.
# Surely the user must know what they're doing, so don't
# perform the final rechunk.
((2, 2**30 * 3 // 16), (1, -1)),
],
)
@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_chunk_size(xp, func, shape, chunks):
"""
Test that sort and argsort produce reasonably-sized chunks
in the output array, even if they had to go through a singular
huge one to perform the operation.
"""
a = da.random.random(shape, chunks=chunks)
b = getattr(xp, func)(a)
max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize
assert (
max_chunk_size <= 128 * 1024 * 1024 # 128 MiB
or b.chunks == a.chunks
)


@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_meta(xp, func):
"""Test meta-namespace other than numpy"""
typ = type(array_api_strict.asarray(0))
a = da.random.random(10)
b = a.map_blocks(array_api_strict.asarray)
assert isinstance(b._meta, typ)
c = getattr(xp, func)(b)
assert isinstance(c._meta, typ)
d = c.compute()
# Note: np.sort(array_api_strict.asarray(0)) would return a numpy array
assert isinstance(d, typ)
np.testing.assert_array_equal(d, getattr(np, func)(a.compute()))