20
20
module_available ,
21
21
)
22
22
from xarray .namedarray import pycompat
23
+ from xarray .util .deprecation_helpers import _deprecate_positional_args
23
24
24
25
try :
25
26
import bottleneck
@@ -147,7 +148,10 @@ def ndim(self) -> int:
147
148
return len (self .dim )
148
149
149
150
def _reduce_method ( # type: ignore[misc]
150
- name : str , fillna : Any , rolling_agg_func : Callable | None = None
151
+ name : str ,
152
+ fillna : Any ,
153
+ rolling_agg_func : Callable | None = None ,
154
+ automatic_rechunk : bool = False ,
151
155
) -> Callable [..., T_Xarray ]:
152
156
"""Constructs reduction methods built on a numpy reduction function (e.g. sum),
153
157
a numbagg reduction function (e.g. move_sum), a bottleneck reduction function
@@ -157,6 +161,8 @@ def _reduce_method( # type: ignore[misc]
157
161
_array_reduce. Arguably we could refactor this. But one constraint is that we
158
162
need context of xarray options, of the functions each library offers, of
159
163
the array (e.g. dtype).
164
+
165
+ Set automatic_rechunk=True when the reduction method makes a memory copy.
160
166
"""
161
167
if rolling_agg_func :
162
168
array_agg_func = None
@@ -181,6 +187,7 @@ def method(self, keep_attrs=None, **kwargs):
181
187
rolling_agg_func = rolling_agg_func ,
182
188
keep_attrs = keep_attrs ,
183
189
fillna = fillna ,
190
+ automatic_rechunk = automatic_rechunk ,
184
191
** kwargs ,
185
192
)
186
193
@@ -198,16 +205,19 @@ def _mean(self, keep_attrs, **kwargs):
198
205
199
206
_mean .__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE .format (name = "mean" )
200
207
201
- argmax = _reduce_method ("argmax" , dtypes .NINF )
202
- argmin = _reduce_method ("argmin" , dtypes .INF )
208
+ # automatic_rechunk is set to True for reductions that make a copy.
209
+ # std, var could be optimized after which we can set it to False
210
+ # See #4325
211
+ argmax = _reduce_method ("argmax" , dtypes .NINF , automatic_rechunk = True )
212
+ argmin = _reduce_method ("argmin" , dtypes .INF , automatic_rechunk = True )
203
213
max = _reduce_method ("max" , dtypes .NINF )
204
214
min = _reduce_method ("min" , dtypes .INF )
205
215
prod = _reduce_method ("prod" , 1 )
206
216
sum = _reduce_method ("sum" , 0 )
207
217
mean = _reduce_method ("mean" , None , _mean )
208
- std = _reduce_method ("std" , None )
209
- var = _reduce_method ("var" , None )
210
- median = _reduce_method ("median" , None )
218
+ std = _reduce_method ("std" , None , automatic_rechunk = True )
219
+ var = _reduce_method ("var" , None , automatic_rechunk = True )
220
+ median = _reduce_method ("median" , None , automatic_rechunk = True )
211
221
212
222
def _counts (self , keep_attrs : bool | None ) -> T_Xarray :
213
223
raise NotImplementedError ()
@@ -311,12 +321,15 @@ def __iter__(self) -> Iterator[tuple[DataArray, DataArray]]:
311
321
312
322
yield (label , window )
313
323
324
+ @_deprecate_positional_args ("v2024.11.0" )
314
325
def construct (
315
326
self ,
316
327
window_dim : Hashable | Mapping [Any , Hashable ] | None = None ,
328
+ * ,
317
329
stride : int | Mapping [Any , int ] = 1 ,
318
330
fill_value : Any = dtypes .NA ,
319
331
keep_attrs : bool | None = None ,
332
+ automatic_rechunk : bool = True ,
320
333
** window_dim_kwargs : Hashable ,
321
334
) -> DataArray :
322
335
"""
@@ -335,6 +348,10 @@ def construct(
335
348
If True, the attributes (``attrs``) will be copied from the original
336
349
object to the new one. If False, the new object will be returned
337
350
without attributes. If None uses the global default.
351
+ automatic_rechunk: bool, default True
352
+ Whether dask should automatically rechunk the output to avoid
353
+ exploding chunk sizes. Importantly, each chunk will be a view of the data
354
+ so large chunk sizes are only safe if *no* copies are made later.
338
355
**window_dim_kwargs : Hashable, optional
339
356
The keyword arguments form of ``window_dim`` {dim: new_name, ...}.
340
357
@@ -383,16 +400,19 @@ def construct(
383
400
stride = stride ,
384
401
fill_value = fill_value ,
385
402
keep_attrs = keep_attrs ,
403
+ automatic_rechunk = automatic_rechunk ,
386
404
** window_dim_kwargs ,
387
405
)
388
406
389
407
def _construct (
390
408
self ,
391
409
obj : DataArray ,
410
+ * ,
392
411
window_dim : Hashable | Mapping [Any , Hashable ] | None = None ,
393
412
stride : int | Mapping [Any , int ] = 1 ,
394
413
fill_value : Any = dtypes .NA ,
395
414
keep_attrs : bool | None = None ,
415
+ automatic_rechunk : bool = True ,
396
416
** window_dim_kwargs : Hashable ,
397
417
) -> DataArray :
398
418
from xarray .core .dataarray import DataArray
@@ -412,7 +432,12 @@ def _construct(
412
432
strides = self ._mapping_to_list (stride , default = 1 )
413
433
414
434
window = obj .variable .rolling_window (
415
- self .dim , self .window , window_dims , self .center , fill_value = fill_value
435
+ self .dim ,
436
+ self .window ,
437
+ window_dims ,
438
+ center = self .center ,
439
+ fill_value = fill_value ,
440
+ automatic_rechunk = automatic_rechunk ,
416
441
)
417
442
418
443
attrs = obj .attrs if keep_attrs else {}
@@ -429,10 +454,16 @@ def _construct(
429
454
)
430
455
431
456
def reduce (
432
- self , func : Callable , keep_attrs : bool | None = None , ** kwargs : Any
457
+ self ,
458
+ func : Callable ,
459
+ keep_attrs : bool | None = None ,
460
+ * ,
461
+ automatic_rechunk : bool = True ,
462
+ ** kwargs : Any ,
433
463
) -> DataArray :
434
- """Reduce the items in this group by applying `func` along some
435
- dimension(s).
464
+ """Reduce each window by applying `func`.
465
+
466
+ Equivalent to ``.construct(...).reduce(func, ...)``.
436
467
437
468
Parameters
438
469
----------
@@ -444,6 +475,10 @@ def reduce(
444
475
If True, the attributes (``attrs``) will be copied from the original
445
476
object to the new one. If False, the new object will be returned
446
477
without attributes. If None uses the global default.
478
+ automatic_rechunk: bool, default True
479
+ Whether dask should automatically rechunk the output of ``construct`` to avoid
480
+ exploding chunk sizes. Importantly, each chunk will be a view of the data
481
+ so large chunk sizes are only safe if *no* copies are made in ``func``.
447
482
**kwargs : dict
448
483
Additional keyword arguments passed on to `func`.
449
484
@@ -497,7 +532,11 @@ def reduce(
497
532
else :
498
533
obj = self .obj
499
534
windows = self ._construct (
500
- obj , rolling_dim , keep_attrs = keep_attrs , fill_value = fillna
535
+ obj ,
536
+ window_dim = rolling_dim ,
537
+ keep_attrs = keep_attrs ,
538
+ fill_value = fillna ,
539
+ automatic_rechunk = automatic_rechunk ,
501
540
)
502
541
503
542
dim = list (rolling_dim .values ())
@@ -821,12 +860,15 @@ def _array_reduce(
821
860
** kwargs ,
822
861
)
823
862
863
+ @_deprecate_positional_args ("v2024.11.0" )
824
864
def construct (
825
865
self ,
826
866
window_dim : Hashable | Mapping [Any , Hashable ] | None = None ,
867
+ * ,
827
868
stride : int | Mapping [Any , int ] = 1 ,
828
869
fill_value : Any = dtypes .NA ,
829
870
keep_attrs : bool | None = None ,
871
+ automatic_rechunk : bool = True ,
830
872
** window_dim_kwargs : Hashable ,
831
873
) -> Dataset :
832
874
"""
@@ -842,6 +884,10 @@ def construct(
842
884
size of stride for the rolling window.
843
885
fill_value : Any, default: dtypes.NA
844
886
Filling value to match the dimension size.
887
+ automatic_rechunk: bool, default True
888
+ Whether dask should automatically rechunk the output to avoid
889
+ exploding chunk sizes. Importantly, each chunk will be a view of the data
890
+ so large chunk sizes are only safe if *no* copies are made later.
845
891
**window_dim_kwargs : {dim: new_name, ...}, optional
846
892
The keyword arguments form of ``window_dim``.
847
893
0 commit comments