Skip to content
forked from pydata/xarray

Commit 418a5a5

Browse files
committed
Switch to sliding_window_kwargs
1 parent 88e56d8 commit 418a5a5

File tree

5 files changed

+70
-34
lines changed

5 files changed

+70
-34
lines changed

xarray/core/dask_array_compat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def sliding_window_view(
55
x, window_shape, axis=None, *, automatic_rechunk=True, **kwargs
66
):
77
# Backcompat for handling `automatic_rechunk`, delete when dask>=2024.11.0
8-
# subok, writeable are unsupported by dask
8+
# Note that subok, writeable are unsupported by dask, so we ignore those in kwargs
99
from dask.array.lib.stride_tricks import sliding_window_view
1010

1111
if module_available("dask", "2024.11.0"):

xarray/core/duck_array_ops.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
transpose,
3131
unravel_index,
3232
)
33-
from numpy.ma import masked_invalid # noqa
3433
from packaging.version import Version
3534
from pandas.api.types import is_extension_array_dtype
3635

@@ -107,7 +106,7 @@ def f(*args, **kwargs):
107106
else:
108107
wrapped = getattr(eager_module, name)
109108
for kwarg in dask_only_kwargs:
110-
kwargs.pop(kwarg)
109+
kwargs.pop(kwarg, None)
111110
return wrapped(*args, **kwargs)
112111

113112
return f
@@ -125,6 +124,12 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
125124
# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
126125
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module="dask.array")
127126

127+
# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
128+
# TODO: replacing breaks iris + dask tests
129+
masked_invalid = _dask_or_eager_func(
130+
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
131+
)
132+
128133
# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk),
129134
# so we need to hand-code this.
130135
sliding_window_view = _dask_or_eager_func(

xarray/core/rolling.py

+50-20
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def method(self, keep_attrs=None, **kwargs):
187187
rolling_agg_func=rolling_agg_func,
188188
keep_attrs=keep_attrs,
189189
fillna=fillna,
190-
automatic_rechunk=automatic_rechunk,
190+
sliding_window_kwargs=dict(automatic_rechunk=automatic_rechunk),
191191
**kwargs,
192192
)
193193

@@ -329,7 +329,7 @@ def construct(
329329
stride: int | Mapping[Any, int] = 1,
330330
fill_value: Any = dtypes.NA,
331331
keep_attrs: bool | None = None,
332-
automatic_rechunk: bool = True,
332+
sliding_window_kwargs: Mapping[Any, Any] | None = None,
333333
**window_dim_kwargs: Hashable,
334334
) -> DataArray:
335335
"""
@@ -348,10 +348,9 @@ def construct(
348348
If True, the attributes (``attrs``) will be copied from the original
349349
object to the new one. If False, the new object will be returned
350350
without attributes. If None uses the global default.
351-
automatic_rechunk: bool, default True
352-
Whether dask should automatically rechunk the output to avoid
353-
exploding chunk sizes. Importantly, each chunk will be a view of the data
354-
so large chunk sizes are only safe if *no* copies are made later.
351+
sliding_window_kwargs : Mapping
352+
Keyword arguments that should be passed to the underlying array type's
353+
``sliding_window_view`` function.
355354
**window_dim_kwargs : Hashable, optional
356355
The keyword arguments form of ``window_dim`` {dim: new_name, ...}.
357356
@@ -365,6 +364,15 @@ def construct(
365364
numpy.lib.stride_tricks.sliding_window_view
366365
dask.array.lib.stride_tricks.sliding_window_view
367366
367+
Notes
368+
-----
369+
With dask arrays, it's possible to pass the ``automatic_rechunk`` kwarg as
370+
``sliding_window_kwargs={"automatic_rechunk": True}``. This controls
371+
whether dask should automatically rechunk the output to avoid
372+
exploding chunk sizes. Automatically rechunking is the default behaviour.
373+
Importantly, each chunk will be a view of the data so large chunk sizes are
374+
only safe if *no* copies are made later.
375+
368376
Examples
369377
--------
370378
>>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b"))
@@ -399,13 +407,15 @@ def construct(
399407
400408
"""
401409

410+
if sliding_window_kwargs is None:
411+
sliding_window_kwargs = {}
402412
return self._construct(
403413
self.obj,
404414
window_dim=window_dim,
405415
stride=stride,
406416
fill_value=fill_value,
407417
keep_attrs=keep_attrs,
408-
automatic_rechunk=automatic_rechunk,
418+
sliding_window_kwargs=sliding_window_kwargs,
409419
**window_dim_kwargs,
410420
)
411421

@@ -417,11 +427,14 @@ def _construct(
417427
stride: int | Mapping[Any, int] = 1,
418428
fill_value: Any = dtypes.NA,
419429
keep_attrs: bool | None = None,
420-
automatic_rechunk: bool = True,
430+
sliding_window_kwargs: Mapping[Any, Any] | None = None,
421431
**window_dim_kwargs: Hashable,
422432
) -> DataArray:
423433
from xarray.core.dataarray import DataArray
424434

435+
if sliding_window_kwargs is None:
436+
sliding_window_kwargs = {}
437+
425438
keep_attrs = self._get_keep_attrs(keep_attrs)
426439

427440
if window_dim is None:
@@ -442,7 +455,7 @@ def _construct(
442455
window_dims,
443456
center=self.center,
444457
fill_value=fill_value,
445-
automatic_rechunk=automatic_rechunk,
458+
**sliding_window_kwargs,
446459
)
447460

448461
attrs = obj.attrs if keep_attrs else {}
@@ -463,7 +476,7 @@ def reduce(
463476
func: Callable,
464477
keep_attrs: bool | None = None,
465478
*,
466-
automatic_rechunk: bool = True,
479+
sliding_window_kwargs: Mapping[Any, Any] | None = None,
467480
**kwargs: Any,
468481
) -> DataArray:
469482
"""Reduce each window by applying `func`.
@@ -480,10 +493,9 @@ def reduce(
480493
If True, the attributes (``attrs``) will be copied from the original
481494
object to the new one. If False, the new object will be returned
482495
without attributes. If None uses the global default.
483-
automatic_rechunk: bool, default True
484-
Whether dask should automatically rechunk the output of ``construct`` to avoid
485-
exploding chunk sizes. Importantly, each chunk will be a view of the data
486-
so large chunk sizes are only safe if *no* copies are made in ``func``.
496+
sliding_window_kwargs
497+
Keyword arguments that should be passed to the underlying array type's
498+
``sliding_window_view`` function.
487499
**kwargs : dict
488500
Additional keyword arguments passed on to `func`.
489501
@@ -492,6 +504,15 @@ def reduce(
492504
reduced : DataArray
493505
Array with summarized data.
494506
507+
Notes
508+
-----
509+
With dask arrays, it's possible to pass the ``automatic_rechunk`` kwarg as
510+
``sliding_window_kwargs={"automatic_rechunk": True}``. This controls
511+
whether dask should automatically rechunk the output to avoid
512+
exploding chunk sizes. Automatically rechunking is the default behaviour.
513+
Importantly, each chunk will be a view of the data so large chunk sizes are
514+
only safe if *no* copies are made later.
515+
495516
Examples
496517
--------
497518
>>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b"))
@@ -541,7 +562,7 @@ def reduce(
541562
window_dim=rolling_dim,
542563
keep_attrs=keep_attrs,
543564
fill_value=fillna,
544-
automatic_rechunk=automatic_rechunk,
565+
sliding_window_kwargs=sliding_window_kwargs,
545566
)
546567

547568
dim = list(rolling_dim.values())
@@ -873,7 +894,7 @@ def construct(
873894
stride: int | Mapping[Any, int] = 1,
874895
fill_value: Any = dtypes.NA,
875896
keep_attrs: bool | None = None,
876-
automatic_rechunk: bool = True,
897+
sliding_window_kwargs: Mapping[Any, Any] | None = None,
877898
**window_dim_kwargs: Hashable,
878899
) -> Dataset:
879900
"""
@@ -889,10 +910,9 @@ def construct(
889910
size of stride for the rolling window.
890911
fill_value : Any, default: dtypes.NA
891912
Filling value to match the dimension size.
892-
automatic_rechunk: bool, default True
893-
Whether dask should automatically rechunk the output to avoid
894-
exploding chunk sizes. Importantly, each chunk will be a view of the data
895-
so large chunk sizes are only safe if *no* copies are made later.
913+
sliding_window_kwargs
914+
Keyword arguments that should be passed to the underlying array type's
915+
``sliding_window_view`` function.
896916
**window_dim_kwargs : {dim: new_name, ...}, optional
897917
The keyword arguments form of ``window_dim``.
898918
@@ -904,6 +924,15 @@ def construct(
904924
--------
905925
numpy.lib.stride_tricks.sliding_window_view
906926
dask.array.lib.stride_tricks.sliding_window_view
927+
928+
Notes
929+
-----
930+
With dask arrays, it's possible to pass the ``automatic_rechunk`` kwarg as
931+
``sliding_window_kwargs={"automatic_rechunk": True}``. This controls
932+
whether dask should automatically rechunk the output to avoid
933+
exploding chunk sizes. Automatically rechunking is the default behaviour.
934+
Importantly, each chunk will be a view of the data so large chunk sizes are
935+
only safe if *no* copies are made later.
907936
"""
908937

909938
from xarray.core.dataset import Dataset
@@ -935,6 +964,7 @@ def construct(
935964
fill_value=fill_value,
936965
stride=st,
937966
keep_attrs=keep_attrs,
967+
sliding_window_kwargs=sliding_window_kwargs,
938968
)
939969
else:
940970
dataset[key] = da.copy()

xarray/core/variable.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -2019,7 +2019,7 @@ def rolling_window(
20192019
*,
20202020
center=False,
20212021
fill_value=dtypes.NA,
2022-
automatic_rechunk: bool = True,
2022+
**kwargs,
20232023
):
20242024
"""
20252025
Make a rolling_window along dim and add a new_dim to the last place.
@@ -2040,10 +2040,9 @@ def rolling_window(
20402040
of the axis.
20412041
fill_value
20422042
value to be filled.
2043-
automatic_rechunk: bool, default True
2044-
Whether dask should automatically rechunk the output to avoid
2045-
exploding chunk sizes. Importantly, each chunk will be a view of the data
2046-
so large chunk sizes are only safe if *no* copies are made later.
2043+
**kwargs
2044+
Keyword arguments that should be passed to the underlying array type's
2045+
``sliding_window_view`` function.
20472046
20482047
Returns
20492048
-------
@@ -2052,6 +2051,11 @@ def rolling_window(
20522051
The return dim: self.dims + (window_dim, )
20532052
The return shape: self.shape + (window, )
20542053
2054+
See Also
2055+
--------
2056+
numpy.lib.stride_tricks.sliding_window_view
2057+
dask.array.lib.stride_tricks.sliding_window_view
2058+
20552059
Examples
20562060
--------
20572061
>>> v = Variable(("a", "b"), np.arange(8).reshape((2, 4)))
@@ -2132,10 +2136,7 @@ def rolling_window(
21322136
return Variable(
21332137
new_dims,
21342138
duck_array_ops.sliding_window_view(
2135-
padded.data,
2136-
window_shape=window,
2137-
axis=axis,
2138-
automatic_rechunk=automatic_rechunk,
2139+
padded.data, window_shape=window, axis=axis, **kwargs
21392140
),
21402141
)
21412142

xarray/tests/test_rolling.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -611,10 +611,10 @@ def test_rolling_construct_automatic_rechunk(self):
611611

612612
# Dataset now has chunks of size (400, 400, 100 100) or 11.92 GiB
613613
rechunked = da.rolling(time=100, center=True).construct(
614-
"window", automatic_rechunk=True
614+
"window", sliding_window_kwargs=dict(automatic_rechunk=True)
615615
)
616616
not_rechunked = da.rolling(time=100, center=True).construct(
617-
"window", automatic_rechunk=False
617+
"window", sliding_window_kwargs=dict(automatic_rechunk=False)
618618
)
619619
assert rechunked.chunks != not_rechunked.chunks
620620

0 commit comments

Comments
 (0)