Skip to content

Commit 830ee6d

Browse files
authored
Support first, last with dask arrays (#7562)
* Support first, last with dask arrays Use dask.array.reduction. For this we need to add support for the `keepdims` kwarg to `nanfirst` and `nanlast`. Even though the final result is always keepdims=False, dask runs the intermediate steps with keepdims=True. * Don't provide meta. It would need to account for shape change.
1 parent 43ba095 commit 830ee6d

File tree

6 files changed

+101
-21
lines changed

6 files changed

+101
-21
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ New Features
2525

2626
- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` now support complex valued arrays (:issue:`7340`, :pull:`7392`).
2727
By `Michael Niklas <https://github.com/headtr1ck>`_.
28+
- Support dask arrays in ``first`` and ``last`` reductions.
29+
By `Deepak Cherian <https://github.com/dcherian>`_.
2830

2931
Breaking changes
3032
~~~~~~~~~~~~~~~~

xarray/core/dask_array_ops.py

+37
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from __future__ import annotations
22

3+
from functools import partial
4+
5+
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
6+
37
from xarray.core import dtypes, nputils
48

59

@@ -92,3 +96,36 @@ def _fill_with_last_one(a, b):
9296
axis=axis,
9397
dtype=array.dtype,
9498
)
99+
100+
101+
def _first_last_wrapper(array, *, axis, op, keepdims):
102+
return op(array, axis, keepdims=keepdims)
103+
104+
105+
def _first_or_last(darray, axis, op):
106+
import dask.array
107+
108+
# This will raise the same error message seen for numpy
109+
axis = normalize_axis_index(axis, darray.ndim)
110+
111+
wrapped_op = partial(_first_last_wrapper, op=op)
112+
return dask.array.reduction(
113+
darray,
114+
chunk=wrapped_op,
115+
aggregate=wrapped_op,
116+
axis=axis,
117+
dtype=darray.dtype,
118+
keepdims=False, # match numpy version
119+
)
120+
121+
122+
def nanfirst(darray, axis):
123+
from xarray.core.duck_array_ops import nanfirst
124+
125+
return _first_or_last(darray, axis, op=nanfirst)
126+
127+
128+
def nanlast(darray, axis):
129+
from xarray.core.duck_array_ops import nanlast
130+
131+
return _first_or_last(darray, axis, op=nanlast)

xarray/core/duck_array_ops.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import datetime
1010
import inspect
1111
import warnings
12-
from functools import partial
1312
from importlib import import_module
1413

1514
import numpy as np
@@ -637,27 +636,25 @@ def cumsum(array, axis=None, **kwargs):
637636
return _nd_cum_func(cumsum_1d, array, axis, **kwargs)
638637

639638

640-
_fail_on_dask_array_input_skipna = partial(
641-
fail_on_dask_array_input,
642-
msg="%r with skipna=True is not yet implemented on dask arrays",
643-
)
644-
645-
646639
def first(values, axis, skipna=None):
647640
"""Return the first non-NA elements in this array along the given axis"""
648641
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
649642
# only bother for dtypes that can hold NaN
650-
_fail_on_dask_array_input_skipna(values)
651-
return nanfirst(values, axis)
643+
if is_duck_dask_array(values):
644+
return dask_array_ops.nanfirst(values, axis)
645+
else:
646+
return nanfirst(values, axis)
652647
return take(values, 0, axis=axis)
653648

654649

655650
def last(values, axis, skipna=None):
656651
"""Return the last non-NA elements in this array along the given axis"""
657652
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
658653
# only bother for dtypes that can hold NaN
659-
_fail_on_dask_array_input_skipna(values)
660-
return nanlast(values, axis)
654+
if is_duck_dask_array(values):
655+
return dask_array_ops.nanlast(values, axis)
656+
else:
657+
return nanlast(values, axis)
661658
return take(values, -1, axis=axis)
662659

663660

xarray/core/nputils.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,29 @@ def _select_along_axis(values, idx, axis):
2424
return values[sl]
2525

2626

27-
def nanfirst(values, axis):
27+
def nanfirst(values, axis, keepdims=False):
28+
if isinstance(axis, tuple):
29+
(axis,) = axis
2830
axis = normalize_axis_index(axis, values.ndim)
2931
idx_first = np.argmax(~pd.isnull(values), axis=axis)
30-
return _select_along_axis(values, idx_first, axis)
32+
result = _select_along_axis(values, idx_first, axis)
33+
if keepdims:
34+
return np.expand_dims(result, axis=axis)
35+
else:
36+
return result
3137

3238

33-
def nanlast(values, axis):
39+
def nanlast(values, axis, keepdims=False):
40+
if isinstance(axis, tuple):
41+
(axis,) = axis
3442
axis = normalize_axis_index(axis, values.ndim)
3543
rev = (slice(None),) * axis + (slice(None, None, -1),)
3644
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
37-
return _select_along_axis(values, idx_last, axis)
45+
result = _select_along_axis(values, idx_last, axis)
46+
if keepdims:
47+
return np.expand_dims(result, axis=axis)
48+
else:
49+
return result
3850

3951

4052
def inverse_permutation(indices):

xarray/tests/test_dask.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -549,17 +549,22 @@ def test_rolling(self):
549549
actual = v.rolling(x=2).mean()
550550
self.assertLazyAndAllClose(expected, actual)
551551

552-
def test_groupby_first(self):
552+
@pytest.mark.parametrize("func", ["first", "last"])
553+
def test_groupby_first_last(self, func):
554+
method = operator.methodcaller(func)
553555
u = self.eager_array
554556
v = self.lazy_array
555557

556558
for coords in [u.coords, v.coords]:
557559
coords["ab"] = ("x", ["a", "a", "b", "b"])
558-
with pytest.raises(NotImplementedError, match=r"dask"):
559-
v.groupby("ab").first()
560-
expected = u.groupby("ab").first()
560+
expected = method(u.groupby("ab"))
561+
562+
with raise_if_dask_computes():
563+
actual = method(v.groupby("ab"))
564+
self.assertLazyAndAllClose(expected, actual)
565+
561566
with raise_if_dask_computes():
562-
actual = v.groupby("ab").first(skipna=False)
567+
actual = method(v.groupby("ab"))
563568
self.assertLazyAndAllClose(expected, actual)
564569

565570
def test_reindex(self):

xarray/tests/test_duck_array_ops.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ class TestOps:
4848
def setUp(self):
4949
self.x = array(
5050
[
51-
[[nan, nan, 2.0, nan], [nan, 5.0, 6.0, nan], [8.0, 9.0, 10.0, nan]],
51+
[
52+
[nan, nan, 2.0, nan],
53+
[nan, 5.0, 6.0, nan],
54+
[8.0, 9.0, 10.0, nan],
55+
],
5256
[
5357
[nan, 13.0, 14.0, 15.0],
5458
[nan, 17.0, 18.0, nan],
@@ -128,6 +132,29 @@ def test_all_nan_arrays(self):
128132
assert np.isnan(mean([np.nan, np.nan]))
129133

130134

135+
@requires_dask
136+
class TestDaskOps(TestOps):
137+
@pytest.fixture(autouse=True)
138+
def setUp(self):
139+
import dask.array
140+
141+
self.x = dask.array.from_array(
142+
[
143+
[
144+
[nan, nan, 2.0, nan],
145+
[nan, 5.0, 6.0, nan],
146+
[8.0, 9.0, 10.0, nan],
147+
],
148+
[
149+
[nan, 13.0, 14.0, 15.0],
150+
[nan, 17.0, 18.0, nan],
151+
[nan, 21.0, nan, nan],
152+
],
153+
],
154+
chunks=(2, 1, 2),
155+
)
156+
157+
131158
def test_cumsum_1d():
132159
inputs = np.array([0, 1, 2, 3])
133160
expected = np.array([0, 1, 3, 6])

0 commit comments

Comments
 (0)