Skip to content
forked from pydata/xarray

Commit 1bdd60b

Browse files
committed
sliding_window_view: add new automatic_rechunk kwarg
Closes pydata#9550 xref pydata#4325
1 parent 0384363 commit 1bdd60b

File tree

7 files changed

+141
-23
lines changed

7 files changed

+141
-23
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ New Features
2929
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
3030
(:issue:`2852`, :issue:`757`).
3131
By `Deepak Cherian <https://github.com/dcherian>`_.
32+
- Add new ``automatic_rechunk`` kwarg to :py:meth:`DataArrayRolling.construct` and
33+
:py:meth:`DatasetRolling.construct`. This is only useful on ``dask>=2024.11.0``
34+
(:issue:`9550`). By `Deepak Cherian <https://github.com/dcherian>`_.
3235

3336
Breaking changes
3437
~~~~~~~~~~~~~~~~

xarray/core/dask_array_compat.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from xarray.namedarray.utils import module_available
2+
3+
4+
def sliding_window_view(
5+
x, window_shape, axis=None, *, automatic_rechunk=True, **kwargs
6+
):
7+
# Backcompat for handling `automatic_rechunk`, delete when dask>=2024.11.0
8+
# subok, writeable are unsupported by dask
9+
from dask.array.lib.stride_tricks import sliding_window_view
10+
11+
if module_available("dask", "2024.11.0"):
12+
return sliding_window_view(
13+
x, window_shape=window_shape, axis=axis, automatic_rechunk=automatic_rechunk
14+
)
15+
else:
16+
# automatic_rechunk is not supported
17+
return sliding_window_view(x, window_shape=window_shape, axis=axis)

xarray/core/duck_array_ops.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
transpose,
3131
unravel_index,
3232
)
33-
from numpy.lib.stride_tricks import sliding_window_view # noqa
33+
from numpy.ma import masked_invalid # noqa
3434
from packaging.version import Version
3535
from pandas.api.types import is_extension_array_dtype
3636

37-
from xarray.core import dask_array_ops, dtypes, nputils
37+
from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils
3838
from xarray.core.options import OPTIONS
3939
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
4040
from xarray.namedarray import pycompat
@@ -92,11 +92,12 @@ def _dask_or_eager_func(
9292
name,
9393
eager_module=np,
9494
dask_module="dask.array",
95+
dask_only_kwargs=tuple(),
9596
):
9697
"""Create a function that dispatches to dask for dask array inputs."""
9798

9899
def f(*args, **kwargs):
99-
if any(is_duck_dask_array(a) for a in args):
100+
if dask_available and any(is_duck_dask_array(a) for a in args):
100101
mod = (
101102
import_module(dask_module)
102103
if isinstance(dask_module, str)
@@ -105,6 +106,8 @@ def f(*args, **kwargs):
105106
wrapped = getattr(mod, name)
106107
else:
107108
wrapped = getattr(eager_module, name)
109+
for kwarg in dask_only_kwargs:
110+
kwargs.pop(kwarg)
108111
return wrapped(*args, **kwargs)
109112

110113
return f
@@ -122,6 +125,15 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
122125
# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
123126
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module="dask.array")
124127

128+
# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk),
129+
# so we need to hand-code this.
130+
sliding_window_view = _dask_or_eager_func(
131+
"sliding_window_view",
132+
eager_module=np.lib.stride_tricks,
133+
dask_module=dask_array_compat,
134+
dask_only_kwargs=("automatic_rechunk",),
135+
)
136+
125137

126138
def round(array):
127139
xp = get_array_namespace(array)
@@ -170,12 +182,6 @@ def notnull(data):
170182
return ~isnull(data)
171183

172184

173-
# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
174-
masked_invalid = _dask_or_eager_func(
175-
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
176-
)
177-
178-
179185
def trapz(y, x, axis):
180186
if axis < 0:
181187
axis = y.ndim + axis

xarray/core/rolling.py

+67-11
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
module_available,
2121
)
2222
from xarray.namedarray import pycompat
23+
from xarray.util.deprecation_helpers import _deprecate_positional_args
2324

2425
try:
2526
import bottleneck
@@ -147,7 +148,10 @@ def ndim(self) -> int:
147148
return len(self.dim)
148149

149150
def _reduce_method( # type: ignore[misc]
150-
name: str, fillna: Any, rolling_agg_func: Callable | None = None
151+
name: str,
152+
fillna: Any,
153+
rolling_agg_func: Callable | None = None,
154+
automatic_rechunk: bool = False,
151155
) -> Callable[..., T_Xarray]:
152156
"""Constructs reduction methods built on a numpy reduction function (e.g. sum),
153157
a numbagg reduction function (e.g. move_sum), a bottleneck reduction function
@@ -157,6 +161,8 @@ def _reduce_method( # type: ignore[misc]
157161
_array_reduce. Arguably we could refactor this. But one constraint is that we
158162
need context of xarray options, of the functions each library offers, of
159163
the array (e.g. dtype).
164+
165+
Set automatic_rechunk=True when the reduction method makes a memory copy.
160166
"""
161167
if rolling_agg_func:
162168
array_agg_func = None
@@ -181,6 +187,7 @@ def method(self, keep_attrs=None, **kwargs):
181187
rolling_agg_func=rolling_agg_func,
182188
keep_attrs=keep_attrs,
183189
fillna=fillna,
190+
automatic_rechunk=automatic_rechunk,
184191
**kwargs,
185192
)
186193

@@ -198,16 +205,19 @@ def _mean(self, keep_attrs, **kwargs):
198205

199206
_mean.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="mean")
200207

201-
argmax = _reduce_method("argmax", dtypes.NINF)
202-
argmin = _reduce_method("argmin", dtypes.INF)
208+
# automatic_rechunk is set to True for reductions that make a copy.
209+
# std, var could be optimized after which we can set it to False
210+
# See #4325
211+
argmax = _reduce_method("argmax", dtypes.NINF, automatic_rechunk=True)
212+
argmin = _reduce_method("argmin", dtypes.INF, automatic_rechunk=True)
203213
max = _reduce_method("max", dtypes.NINF)
204214
min = _reduce_method("min", dtypes.INF)
205215
prod = _reduce_method("prod", 1)
206216
sum = _reduce_method("sum", 0)
207217
mean = _reduce_method("mean", None, _mean)
208-
std = _reduce_method("std", None)
209-
var = _reduce_method("var", None)
210-
median = _reduce_method("median", None)
218+
std = _reduce_method("std", None, automatic_rechunk=True)
219+
var = _reduce_method("var", None, automatic_rechunk=True)
220+
median = _reduce_method("median", None, automatic_rechunk=True)
211221

212222
def _counts(self, keep_attrs: bool | None) -> T_Xarray:
213223
raise NotImplementedError()
@@ -311,12 +321,15 @@ def __iter__(self) -> Iterator[tuple[DataArray, DataArray]]:
311321

312322
yield (label, window)
313323

324+
@_deprecate_positional_args("v2024.11.0")
314325
def construct(
315326
self,
316327
window_dim: Hashable | Mapping[Any, Hashable] | None = None,
328+
*,
317329
stride: int | Mapping[Any, int] = 1,
318330
fill_value: Any = dtypes.NA,
319331
keep_attrs: bool | None = None,
332+
automatic_rechunk: bool = True,
320333
**window_dim_kwargs: Hashable,
321334
) -> DataArray:
322335
"""
@@ -335,6 +348,10 @@ def construct(
335348
If True, the attributes (``attrs``) will be copied from the original
336349
object to the new one. If False, the new object will be returned
337350
without attributes. If None uses the global default.
351+
automatic_rechunk: bool, default True
352+
Whether dask should automatically rechunk the output to avoid
353+
exploding chunk sizes. Importantly, each chunk will be a view of the data
354+
so large chunk sizes are only safe if *no* copies are made later.
338355
**window_dim_kwargs : Hashable, optional
339356
The keyword arguments form of ``window_dim`` {dim: new_name, ...}.
340357
@@ -343,6 +360,11 @@ def construct(
343360
DataArray that is a view of the original array. The returned array is
344361
not writeable.
345362
363+
See Also
364+
--------
365+
numpy.lib.stride_tricks.sliding_window_view
366+
dask.array.lib.stride_tricks.sliding_window_view
367+
346368
Examples
347369
--------
348370
>>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b"))
@@ -383,16 +405,19 @@ def construct(
383405
stride=stride,
384406
fill_value=fill_value,
385407
keep_attrs=keep_attrs,
408+
automatic_rechunk=automatic_rechunk,
386409
**window_dim_kwargs,
387410
)
388411

389412
def _construct(
390413
self,
391414
obj: DataArray,
415+
*,
392416
window_dim: Hashable | Mapping[Any, Hashable] | None = None,
393417
stride: int | Mapping[Any, int] = 1,
394418
fill_value: Any = dtypes.NA,
395419
keep_attrs: bool | None = None,
420+
automatic_rechunk: bool = True,
396421
**window_dim_kwargs: Hashable,
397422
) -> DataArray:
398423
from xarray.core.dataarray import DataArray
@@ -412,7 +437,12 @@ def _construct(
412437
strides = self._mapping_to_list(stride, default=1)
413438

414439
window = obj.variable.rolling_window(
415-
self.dim, self.window, window_dims, self.center, fill_value=fill_value
440+
self.dim,
441+
self.window,
442+
window_dims,
443+
center=self.center,
444+
fill_value=fill_value,
445+
automatic_rechunk=automatic_rechunk,
416446
)
417447

418448
attrs = obj.attrs if keep_attrs else {}
@@ -429,10 +459,16 @@ def _construct(
429459
)
430460

431461
def reduce(
432-
self, func: Callable, keep_attrs: bool | None = None, **kwargs: Any
462+
self,
463+
func: Callable,
464+
keep_attrs: bool | None = None,
465+
*,
466+
automatic_rechunk: bool = True,
467+
**kwargs: Any,
433468
) -> DataArray:
434-
"""Reduce the items in this group by applying `func` along some
435-
dimension(s).
469+
"""Reduce each window by applying `func`.
470+
471+
Equivalent to ``.construct(...).reduce(func, ...)``.
436472
437473
Parameters
438474
----------
@@ -444,6 +480,10 @@ def reduce(
444480
If True, the attributes (``attrs``) will be copied from the original
445481
object to the new one. If False, the new object will be returned
446482
without attributes. If None uses the global default.
483+
automatic_rechunk: bool, default True
484+
Whether dask should automatically rechunk the output of ``construct`` to avoid
485+
exploding chunk sizes. Importantly, each chunk will be a view of the data
486+
so large chunk sizes are only safe if *no* copies are made in ``func``.
447487
**kwargs : dict
448488
Additional keyword arguments passed on to `func`.
449489
@@ -497,7 +537,11 @@ def reduce(
497537
else:
498538
obj = self.obj
499539
windows = self._construct(
500-
obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
540+
obj,
541+
window_dim=rolling_dim,
542+
keep_attrs=keep_attrs,
543+
fill_value=fillna,
544+
automatic_rechunk=automatic_rechunk,
501545
)
502546

503547
dim = list(rolling_dim.values())
@@ -821,12 +865,15 @@ def _array_reduce(
821865
**kwargs,
822866
)
823867

868+
@_deprecate_positional_args("v2024.11.0")
824869
def construct(
825870
self,
826871
window_dim: Hashable | Mapping[Any, Hashable] | None = None,
872+
*,
827873
stride: int | Mapping[Any, int] = 1,
828874
fill_value: Any = dtypes.NA,
829875
keep_attrs: bool | None = None,
876+
automatic_rechunk: bool = True,
830877
**window_dim_kwargs: Hashable,
831878
) -> Dataset:
832879
"""
@@ -842,12 +889,21 @@ def construct(
842889
size of stride for the rolling window.
843890
fill_value : Any, default: dtypes.NA
844891
Filling value to match the dimension size.
892+
automatic_rechunk: bool, default True
893+
Whether dask should automatically rechunk the output to avoid
894+
exploding chunk sizes. Importantly, each chunk will be a view of the data
895+
so large chunk sizes are only safe if *no* copies are made later.
845896
**window_dim_kwargs : {dim: new_name, ...}, optional
846897
The keyword arguments form of ``window_dim``.
847898
848899
Returns
849900
-------
850901
Dataset with variables converted from rolling object.
902+
903+
See Also
904+
--------
905+
numpy.lib.stride_tricks.sliding_window_view
906+
dask.array.lib.stride_tricks.sliding_window_view
851907
"""
852908

853909
from xarray.core.dataset import Dataset

xarray/core/variable.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
4848
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
49-
from xarray.util.deprecation_helpers import deprecate_dims
49+
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
5050

5151
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
5252
indexing.ExplicitlyIndexed,
@@ -2010,8 +2010,16 @@ def rank(self, dim, pct=False):
20102010
ranked /= count
20112011
return ranked
20122012

2013+
@_deprecate_positional_args("v2024.11.0")
20132014
def rolling_window(
2014-
self, dim, window, window_dim, center=False, fill_value=dtypes.NA
2015+
self,
2016+
dim,
2017+
window,
2018+
window_dim,
2019+
*,
2020+
center=False,
2021+
fill_value=dtypes.NA,
2022+
automatic_rechunk: bool = True,
20152023
):
20162024
"""
20172025
Make a rolling_window along dim and add a new_dim to the last place.
@@ -2032,6 +2040,10 @@ def rolling_window(
20322040
of the axis.
20332041
fill_value
20342042
value to be filled.
2043+
automatic_rechunk: bool, default True
2044+
Whether dask should automatically rechunk the output to avoid
2045+
exploding chunk sizes. Importantly, each chunk will be a view of the data
2046+
so large chunk sizes are only safe if *no* copies are made later.
20352047
20362048
Returns
20372049
-------
@@ -2120,7 +2132,10 @@ def rolling_window(
21202132
return Variable(
21212133
new_dims,
21222134
duck_array_ops.sliding_window_view(
2123-
padded.data, window_shape=window, axis=axis
2135+
padded.data,
2136+
window_shape=window,
2137+
axis=axis,
2138+
automatic_rechunk=automatic_rechunk,
21242139
),
21252140
)
21262141

xarray/tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _importorskip(
107107
has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf")
108108
has_cftime, requires_cftime = _importorskip("cftime")
109109
has_dask, requires_dask = _importorskip("dask")
110+
has_dask_ge_2024_11_0, requires_dask_ge_2024_11_0 = _importorskip("dask", "2024.11.0")
110111
with warnings.catch_warnings():
111112
warnings.filterwarnings(
112113
"ignore",

0 commit comments

Comments
 (0)