20
20
module_available ,
21
21
)
22
22
from xarray .namedarray import pycompat
23
+ from xarray .util .deprecation_helpers import _deprecate_positional_args
23
24
24
25
try :
25
26
import bottleneck
@@ -147,7 +148,10 @@ def ndim(self) -> int:
147
148
return len (self .dim )
148
149
149
150
def _reduce_method ( # type: ignore[misc]
150
- name : str , fillna : Any , rolling_agg_func : Callable | None = None
151
+ name : str ,
152
+ fillna : Any ,
153
+ rolling_agg_func : Callable | None = None ,
154
+ automatic_rechunk : bool = False ,
151
155
) -> Callable [..., T_Xarray ]:
152
156
"""Constructs reduction methods built on a numpy reduction function (e.g. sum),
153
157
a numbagg reduction function (e.g. move_sum), a bottleneck reduction function
@@ -157,6 +161,8 @@ def _reduce_method( # type: ignore[misc]
157
161
_array_reduce. Arguably we could refactor this. But one constraint is that we
158
162
need context of xarray options, of the functions each library offers, of
159
163
the array (e.g. dtype).
164
+
165
+ Set automatic_rechunk=True when the reduction method makes a memory copy.
160
166
"""
161
167
if rolling_agg_func :
162
168
array_agg_func = None
@@ -181,6 +187,7 @@ def method(self, keep_attrs=None, **kwargs):
181
187
rolling_agg_func = rolling_agg_func ,
182
188
keep_attrs = keep_attrs ,
183
189
fillna = fillna ,
190
+ automatic_rechunk = automatic_rechunk ,
184
191
** kwargs ,
185
192
)
186
193
@@ -198,16 +205,19 @@ def _mean(self, keep_attrs, **kwargs):
198
205
199
206
_mean .__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE .format (name = "mean" )
200
207
201
- argmax = _reduce_method ("argmax" , dtypes .NINF )
202
- argmin = _reduce_method ("argmin" , dtypes .INF )
208
+ # automatic_rechunk is set to True for reductions that make a copy.
209
+ # std, var could be optimized after which we can set it to False
210
+ # See #4325
211
+ argmax = _reduce_method ("argmax" , dtypes .NINF , automatic_rechunk = True )
212
+ argmin = _reduce_method ("argmin" , dtypes .INF , automatic_rechunk = True )
203
213
max = _reduce_method ("max" , dtypes .NINF )
204
214
min = _reduce_method ("min" , dtypes .INF )
205
215
prod = _reduce_method ("prod" , 1 )
206
216
sum = _reduce_method ("sum" , 0 )
207
217
mean = _reduce_method ("mean" , None , _mean )
208
- std = _reduce_method ("std" , None )
209
- var = _reduce_method ("var" , None )
210
- median = _reduce_method ("median" , None )
218
+ std = _reduce_method ("std" , None , automatic_rechunk = True )
219
+ var = _reduce_method ("var" , None , automatic_rechunk = True )
220
+ median = _reduce_method ("median" , None , automatic_rechunk = True )
211
221
212
222
def _counts (self , keep_attrs : bool | None ) -> T_Xarray :
213
223
raise NotImplementedError ()
@@ -311,12 +321,15 @@ def __iter__(self) -> Iterator[tuple[DataArray, DataArray]]:
311
321
312
322
yield (label , window )
313
323
324
+ @_deprecate_positional_args ("v2024.11.0" )
314
325
def construct (
315
326
self ,
316
327
window_dim : Hashable | Mapping [Any , Hashable ] | None = None ,
328
+ * ,
317
329
stride : int | Mapping [Any , int ] = 1 ,
318
330
fill_value : Any = dtypes .NA ,
319
331
keep_attrs : bool | None = None ,
332
+ automatic_rechunk : bool = True ,
320
333
** window_dim_kwargs : Hashable ,
321
334
) -> DataArray :
322
335
"""
@@ -335,6 +348,10 @@ def construct(
335
348
If True, the attributes (``attrs``) will be copied from the original
336
349
object to the new one. If False, the new object will be returned
337
350
without attributes. If None uses the global default.
351
+ automatic_rechunk: bool, default True
352
+ Whether dask should automatically rechunk the output to avoid
353
+ exploding chunk sizes. Importantly, each chunk will be a view of the data
354
+ so large chunk sizes are only safe if *no* copies are made later.
338
355
**window_dim_kwargs : Hashable, optional
339
356
The keyword arguments form of ``window_dim`` {dim: new_name, ...}.
340
357
@@ -343,6 +360,11 @@ def construct(
343
360
DataArray that is a view of the original array. The returned array is
344
361
not writeable.
345
362
363
+ See Also
364
+ --------
365
+ numpy.lib.stride_tricks.sliding_window_view
366
+ dask.array.lib.stride_tricks.sliding_window_view
367
+
346
368
Examples
347
369
--------
348
370
>>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b"))
@@ -383,16 +405,19 @@ def construct(
383
405
stride = stride ,
384
406
fill_value = fill_value ,
385
407
keep_attrs = keep_attrs ,
408
+ automatic_rechunk = automatic_rechunk ,
386
409
** window_dim_kwargs ,
387
410
)
388
411
389
412
def _construct (
390
413
self ,
391
414
obj : DataArray ,
415
+ * ,
392
416
window_dim : Hashable | Mapping [Any , Hashable ] | None = None ,
393
417
stride : int | Mapping [Any , int ] = 1 ,
394
418
fill_value : Any = dtypes .NA ,
395
419
keep_attrs : bool | None = None ,
420
+ automatic_rechunk : bool = True ,
396
421
** window_dim_kwargs : Hashable ,
397
422
) -> DataArray :
398
423
from xarray .core .dataarray import DataArray
@@ -412,7 +437,12 @@ def _construct(
412
437
strides = self ._mapping_to_list (stride , default = 1 )
413
438
414
439
window = obj .variable .rolling_window (
415
- self .dim , self .window , window_dims , self .center , fill_value = fill_value
440
+ self .dim ,
441
+ self .window ,
442
+ window_dims ,
443
+ center = self .center ,
444
+ fill_value = fill_value ,
445
+ automatic_rechunk = automatic_rechunk ,
416
446
)
417
447
418
448
attrs = obj .attrs if keep_attrs else {}
@@ -429,10 +459,16 @@ def _construct(
429
459
)
430
460
431
461
def reduce (
432
- self , func : Callable , keep_attrs : bool | None = None , ** kwargs : Any
462
+ self ,
463
+ func : Callable ,
464
+ keep_attrs : bool | None = None ,
465
+ * ,
466
+ automatic_rechunk : bool = True ,
467
+ ** kwargs : Any ,
433
468
) -> DataArray :
434
- """Reduce the items in this group by applying `func` along some
435
- dimension(s).
469
+ """Reduce each window by applying `func`.
470
+
471
+ Equivalent to ``.construct(...).reduce(func, ...)``.
436
472
437
473
Parameters
438
474
----------
@@ -444,6 +480,10 @@ def reduce(
444
480
If True, the attributes (``attrs``) will be copied from the original
445
481
object to the new one. If False, the new object will be returned
446
482
without attributes. If None uses the global default.
483
+ automatic_rechunk: bool, default True
484
+ Whether dask should automatically rechunk the output of ``construct`` to avoid
485
+ exploding chunk sizes. Importantly, each chunk will be a view of the data
486
+ so large chunk sizes are only safe if *no* copies are made in ``func``.
447
487
**kwargs : dict
448
488
Additional keyword arguments passed on to `func`.
449
489
@@ -497,7 +537,11 @@ def reduce(
497
537
else :
498
538
obj = self .obj
499
539
windows = self ._construct (
500
- obj , rolling_dim , keep_attrs = keep_attrs , fill_value = fillna
540
+ obj ,
541
+ window_dim = rolling_dim ,
542
+ keep_attrs = keep_attrs ,
543
+ fill_value = fillna ,
544
+ automatic_rechunk = automatic_rechunk ,
501
545
)
502
546
503
547
dim = list (rolling_dim .values ())
@@ -821,12 +865,15 @@ def _array_reduce(
821
865
** kwargs ,
822
866
)
823
867
868
+ @_deprecate_positional_args ("v2024.11.0" )
824
869
def construct (
825
870
self ,
826
871
window_dim : Hashable | Mapping [Any , Hashable ] | None = None ,
872
+ * ,
827
873
stride : int | Mapping [Any , int ] = 1 ,
828
874
fill_value : Any = dtypes .NA ,
829
875
keep_attrs : bool | None = None ,
876
+ automatic_rechunk : bool = True ,
830
877
** window_dim_kwargs : Hashable ,
831
878
) -> Dataset :
832
879
"""
@@ -842,12 +889,21 @@ def construct(
842
889
size of stride for the rolling window.
843
890
fill_value : Any, default: dtypes.NA
844
891
Filling value to match the dimension size.
892
+ automatic_rechunk: bool, default True
893
+ Whether dask should automatically rechunk the output to avoid
894
+ exploding chunk sizes. Importantly, each chunk will be a view of the data
895
+ so large chunk sizes are only safe if *no* copies are made later.
845
896
**window_dim_kwargs : {dim: new_name, ...}, optional
846
897
The keyword arguments form of ``window_dim``.
847
898
848
899
Returns
849
900
-------
850
901
Dataset with variables converted from rolling object.
902
+
903
+ See Also
904
+ --------
905
+ numpy.lib.stride_tricks.sliding_window_view
906
+ dask.array.lib.stride_tricks.sliding_window_view
851
907
"""
852
908
853
909
from xarray .core .dataset import Dataset
0 commit comments