Skip to content
forked from pydata/xarray

Commit 5402e88

Browse files
committed
sliding_window_view: add new automatic_rechunk kwarg
Closes pydata#9550 xref pydata#4325
1 parent 0384363 commit 5402e88

File tree

6 files changed

+114
-23
lines changed

6 files changed

+114
-23
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ New Features
2929
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
3030
(:issue:`2852`, :issue:`757`).
3131
By `Deepak Cherian <https://github.com/dcherian>`_.
32+
- Add new ``automatic_rechunk`` kwarg to :py:meth:`DataArrayRolling.construct` and
33+
:py:meth:`DatasetRolling.construct`. This is only useful on ``dask>=2024.11.0``
34+
(:issue:`9550`). By `Deepak Cherian <https://github.com/dcherian>`_.
3235

3336
Breaking changes
3437
~~~~~~~~~~~~~~~~

xarray/core/duck_array_ops.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
transpose,
3131
unravel_index,
3232
)
33-
from numpy.lib.stride_tricks import sliding_window_view # noqa
33+
from numpy.ma import masked_invalid # noqa
3434
from packaging.version import Version
3535
from pandas.api.types import is_extension_array_dtype
3636

37-
from xarray.core import dask_array_ops, dtypes, nputils
37+
from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils
3838
from xarray.core.options import OPTIONS
3939
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
4040
from xarray.namedarray import pycompat
@@ -92,11 +92,12 @@ def _dask_or_eager_func(
9292
name,
9393
eager_module=np,
9494
dask_module="dask.array",
95+
dask_only_kwargs=tuple(),
9596
):
9697
"""Create a function that dispatches to dask for dask array inputs."""
9798

9899
def f(*args, **kwargs):
99-
if any(is_duck_dask_array(a) for a in args):
100+
if dask_available and any(is_duck_dask_array(a) for a in args):
100101
mod = (
101102
import_module(dask_module)
102103
if isinstance(dask_module, str)
@@ -105,6 +106,8 @@ def f(*args, **kwargs):
105106
wrapped = getattr(mod, name)
106107
else:
107108
wrapped = getattr(eager_module, name)
109+
for kwarg in dask_only_kwargs:
110+
kwargs.pop(kwarg)
108111
return wrapped(*args, **kwargs)
109112

110113
return f
@@ -122,6 +125,15 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
122125
# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
123126
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module="dask.array")
124127

128+
# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk),
129+
# so we need to hand-code this.
130+
sliding_window_view = _dask_or_eager_func(
131+
"sliding_window_view",
132+
eager_module=np.lib.stride_tricks,
133+
dask_module=dask_array_compat,
134+
dask_only_kwargs=("automatic_rechunk",),
135+
)
136+
125137

126138
def round(array):
127139
xp = get_array_namespace(array)
@@ -170,12 +182,6 @@ def notnull(data):
170182
return ~isnull(data)
171183

172184

173-
# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
174-
masked_invalid = _dask_or_eager_func(
175-
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
176-
)
177-
178-
179185
def trapz(y, x, axis):
180186
if axis < 0:
181187
axis = y.ndim + axis

xarray/core/rolling.py

+57-11
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
module_available,
2121
)
2222
from xarray.namedarray import pycompat
23+
from xarray.util.deprecation_helpers import _deprecate_positional_args
2324

2425
try:
2526
import bottleneck
@@ -147,7 +148,10 @@ def ndim(self) -> int:
147148
return len(self.dim)
148149

149150
def _reduce_method( # type: ignore[misc]
150-
name: str, fillna: Any, rolling_agg_func: Callable | None = None
151+
name: str,
152+
fillna: Any,
153+
rolling_agg_func: Callable | None = None,
154+
automatic_rechunk: bool = False,
151155
) -> Callable[..., T_Xarray]:
152156
"""Constructs reduction methods built on a numpy reduction function (e.g. sum),
153157
a numbagg reduction function (e.g. move_sum), a bottleneck reduction function
@@ -157,6 +161,8 @@ def _reduce_method( # type: ignore[misc]
157161
_array_reduce. Arguably we could refactor this. But one constraint is that we
158162
need context of xarray options, of the functions each library offers, of
159163
the array (e.g. dtype).
164+
165+
Set automatic_rechunk=True when the reduction method makes a memory copy.
160166
"""
161167
if rolling_agg_func:
162168
array_agg_func = None
@@ -181,6 +187,7 @@ def method(self, keep_attrs=None, **kwargs):
181187
rolling_agg_func=rolling_agg_func,
182188
keep_attrs=keep_attrs,
183189
fillna=fillna,
190+
automatic_rechunk=automatic_rechunk,
184191
**kwargs,
185192
)
186193

@@ -198,16 +205,19 @@ def _mean(self, keep_attrs, **kwargs):
198205

199206
_mean.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="mean")
200207

201-
argmax = _reduce_method("argmax", dtypes.NINF)
202-
argmin = _reduce_method("argmin", dtypes.INF)
208+
# automatic_rechunk is set to True for reductions that make a copy.
209+
# std, var could be optimized after which we can set it to False
210+
# See #4325
211+
argmax = _reduce_method("argmax", dtypes.NINF, automatic_rechunk=True)
212+
argmin = _reduce_method("argmin", dtypes.INF, automatic_rechunk=True)
203213
max = _reduce_method("max", dtypes.NINF)
204214
min = _reduce_method("min", dtypes.INF)
205215
prod = _reduce_method("prod", 1)
206216
sum = _reduce_method("sum", 0)
207217
mean = _reduce_method("mean", None, _mean)
208-
std = _reduce_method("std", None)
209-
var = _reduce_method("var", None)
210-
median = _reduce_method("median", None)
218+
std = _reduce_method("std", None, automatic_rechunk=True)
219+
var = _reduce_method("var", None, automatic_rechunk=True)
220+
median = _reduce_method("median", None, automatic_rechunk=True)
211221

212222
def _counts(self, keep_attrs: bool | None) -> T_Xarray:
213223
raise NotImplementedError()
@@ -311,12 +321,15 @@ def __iter__(self) -> Iterator[tuple[DataArray, DataArray]]:
311321

312322
yield (label, window)
313323

324+
@_deprecate_positional_args("v2024.11.0")
314325
def construct(
315326
self,
316327
window_dim: Hashable | Mapping[Any, Hashable] | None = None,
328+
*,
317329
stride: int | Mapping[Any, int] = 1,
318330
fill_value: Any = dtypes.NA,
319331
keep_attrs: bool | None = None,
332+
automatic_rechunk: bool = True,
320333
**window_dim_kwargs: Hashable,
321334
) -> DataArray:
322335
"""
@@ -335,6 +348,10 @@ def construct(
335348
If True, the attributes (``attrs``) will be copied from the original
336349
object to the new one. If False, the new object will be returned
337350
without attributes. If None uses the global default.
351+
automatic_rechunk: bool, default True
352+
Whether dask should automatically rechunk the output to avoid
353+
exploding chunk sizes. Importantly, each chunk will be a view of the data
354+
so large chunk sizes are only safe if *no* copies are made later.
338355
**window_dim_kwargs : Hashable, optional
339356
The keyword arguments form of ``window_dim`` {dim: new_name, ...}.
340357
@@ -383,16 +400,19 @@ def construct(
383400
stride=stride,
384401
fill_value=fill_value,
385402
keep_attrs=keep_attrs,
403+
automatic_rechunk=automatic_rechunk,
386404
**window_dim_kwargs,
387405
)
388406

389407
def _construct(
390408
self,
391409
obj: DataArray,
410+
*,
392411
window_dim: Hashable | Mapping[Any, Hashable] | None = None,
393412
stride: int | Mapping[Any, int] = 1,
394413
fill_value: Any = dtypes.NA,
395414
keep_attrs: bool | None = None,
415+
automatic_rechunk: bool = True,
396416
**window_dim_kwargs: Hashable,
397417
) -> DataArray:
398418
from xarray.core.dataarray import DataArray
@@ -412,7 +432,12 @@ def _construct(
412432
strides = self._mapping_to_list(stride, default=1)
413433

414434
window = obj.variable.rolling_window(
415-
self.dim, self.window, window_dims, self.center, fill_value=fill_value
435+
self.dim,
436+
self.window,
437+
window_dims,
438+
center=self.center,
439+
fill_value=fill_value,
440+
automatic_rechunk=automatic_rechunk,
416441
)
417442

418443
attrs = obj.attrs if keep_attrs else {}
@@ -429,10 +454,16 @@ def _construct(
429454
)
430455

431456
def reduce(
432-
self, func: Callable, keep_attrs: bool | None = None, **kwargs: Any
457+
self,
458+
func: Callable,
459+
keep_attrs: bool | None = None,
460+
*,
461+
automatic_rechunk: bool = True,
462+
**kwargs: Any,
433463
) -> DataArray:
434-
"""Reduce the items in this group by applying `func` along some
435-
dimension(s).
464+
"""Reduce each window by applying `func`.
465+
466+
Equivalent to ``.construct(...).reduce(func, ...)``.
436467
437468
Parameters
438469
----------
@@ -444,6 +475,10 @@ def reduce(
444475
If True, the attributes (``attrs``) will be copied from the original
445476
object to the new one. If False, the new object will be returned
446477
without attributes. If None uses the global default.
478+
automatic_rechunk: bool, default True
479+
Whether dask should automatically rechunk the output of ``construct`` to avoid
480+
exploding chunk sizes. Importantly, each chunk will be a view of the data
481+
so large chunk sizes are only safe if *no* copies are made in ``func``.
447482
**kwargs : dict
448483
Additional keyword arguments passed on to `func`.
449484
@@ -497,7 +532,11 @@ def reduce(
497532
else:
498533
obj = self.obj
499534
windows = self._construct(
500-
obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
535+
obj,
536+
window_dim=rolling_dim,
537+
keep_attrs=keep_attrs,
538+
fill_value=fillna,
539+
automatic_rechunk=automatic_rechunk,
501540
)
502541

503542
dim = list(rolling_dim.values())
@@ -821,12 +860,15 @@ def _array_reduce(
821860
**kwargs,
822861
)
823862

863+
@_deprecate_positional_args("v2024.11.0")
824864
def construct(
825865
self,
826866
window_dim: Hashable | Mapping[Any, Hashable] | None = None,
867+
*,
827868
stride: int | Mapping[Any, int] = 1,
828869
fill_value: Any = dtypes.NA,
829870
keep_attrs: bool | None = None,
871+
automatic_rechunk: bool = True,
830872
**window_dim_kwargs: Hashable,
831873
) -> Dataset:
832874
"""
@@ -842,6 +884,10 @@ def construct(
842884
size of stride for the rolling window.
843885
fill_value : Any, default: dtypes.NA
844886
Filling value to match the dimension size.
887+
automatic_rechunk: bool, default True
888+
Whether dask should automatically rechunk the output to avoid
889+
exploding chunk sizes. Importantly, each chunk will be a view of the data
890+
so large chunk sizes are only safe if *no* copies are made later.
845891
**window_dim_kwargs : {dim: new_name, ...}, optional
846892
The keyword arguments form of ``window_dim``.
847893

xarray/core/variable.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
4848
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
49-
from xarray.util.deprecation_helpers import deprecate_dims
49+
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
5050

5151
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
5252
indexing.ExplicitlyIndexed,
@@ -2010,8 +2010,16 @@ def rank(self, dim, pct=False):
20102010
ranked /= count
20112011
return ranked
20122012

2013+
@_deprecate_positional_args("v2024.11.0")
20132014
def rolling_window(
2014-
self, dim, window, window_dim, center=False, fill_value=dtypes.NA
2015+
self,
2016+
dim,
2017+
window,
2018+
window_dim,
2019+
*,
2020+
center=False,
2021+
fill_value=dtypes.NA,
2022+
automatic_rechunk: bool = True,
20152023
):
20162024
"""
20172025
Make a rolling_window along dim and add a new_dim to the last place.
@@ -2032,6 +2040,10 @@ def rolling_window(
20322040
of the axis.
20332041
fill_value
20342042
value to be filled.
2043+
automatic_rechunk: bool, default True
2044+
Whether dask should automatically rechunk the output to avoid
2045+
exploding chunk sizes. Importantly, each chunk will be a view of the data
2046+
so large chunk sizes are only safe if *no* copies are made later.
20352047
20362048
Returns
20372049
-------
@@ -2120,7 +2132,10 @@ def rolling_window(
21202132
return Variable(
21212133
new_dims,
21222134
duck_array_ops.sliding_window_view(
2123-
padded.data, window_shape=window, axis=axis
2135+
padded.data,
2136+
window_shape=window,
2137+
axis=axis,
2138+
automatic_rechunk=automatic_rechunk,
21242139
),
21252140
)
21262141

xarray/tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _importorskip(
107107
has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf")
108108
has_cftime, requires_cftime = _importorskip("cftime")
109109
has_dask, requires_dask = _importorskip("dask")
110+
has_dask_ge_2024_11_0, requires_dask_ge_2024_11_0 = _importorskip("dask", "2024.11.0")
110111
with warnings.catch_warnings():
111112
warnings.filterwarnings(
112113
"ignore",

xarray/tests/test_rolling.py

+20
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
assert_identical,
1515
has_dask,
1616
requires_dask,
17+
requires_dask_ge_2024_11_0,
1718
requires_numbagg,
1819
)
1920

@@ -598,6 +599,25 @@ def test_rolling_properties(self, ds) -> None:
598599
):
599600
ds.rolling(foo=2)
600601

602+
@requires_dask_ge_2024_11_0
603+
def test_rolling_construct_automatic_rechunk():
604+
import dask
605+
606+
# Construct dataset with chunk size of (400, 400, 1) or 1.22 MiB
607+
da = DataArray(
608+
dims=["latitute", "longitude", "time"],
609+
data=dask.array.random.random((400, 400, 400), chunks=(-1, -1, 1)),
610+
)
611+
612+
# Dataset now has chunks of size (400, 400, 100 100) or 11.92 GiB
613+
rechunked = da.rolling(time=100, center=True).construct(
614+
"window", automatic_rechunk=True
615+
)
616+
not_rechunked = da.rolling(time=100, center=True).construct(
617+
"window", automatic_rechunk=False
618+
)
619+
assert rechunked.chunks != not_rechunked.chunks
620+
601621
@pytest.mark.parametrize(
602622
"name", ("sum", "mean", "std", "var", "min", "max", "median")
603623
)

0 commit comments

Comments
 (0)