7
7
from . import dtypes , duck_array_ops , utils
8
8
from .dask_array_ops import dask_rolling_wrapper
9
9
from .ops import inject_reduce_methods
10
+ from .options import _get_keep_attrs
10
11
from .pycompat import dask_array_type
11
12
12
13
try :
@@ -42,10 +43,10 @@ class Rolling:
42
43
DataArray.rolling
43
44
"""
44
45
45
- __slots__ = ("obj" , "window" , "min_periods" , "center" , "dim" )
46
- _attributes = ("window" , "min_periods" , "center" , "dim" )
46
+ __slots__ = ("obj" , "window" , "min_periods" , "center" , "dim" , "keep_attrs" )
47
+ _attributes = ("window" , "min_periods" , "center" , "dim" , "keep_attrs" )
47
48
48
- def __init__ (self , obj , windows , min_periods = None , center = False ):
49
+ def __init__ (self , obj , windows , min_periods = None , center = False , keep_attrs = None ):
49
50
"""
50
51
Moving window object.
51
52
@@ -65,6 +66,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
65
66
setting min_periods equal to the size of the window.
66
67
center : boolean, default False
67
68
Set the labels at the center of the window.
69
+ keep_attrs : bool, optional
70
+ If True, the object's attributes (`attrs`) will be copied from
71
+ the original object to the new one. If False (default), the new
72
+ object will be returned without attributes.
68
73
69
74
Returns
70
75
-------
@@ -89,6 +94,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
89
94
self .center = center
90
95
self .dim = dim
91
96
97
+ if keep_attrs is None :
98
+ keep_attrs = _get_keep_attrs (default = False )
99
+ self .keep_attrs = keep_attrs
100
+
92
101
@property
93
102
def _min_periods (self ):
94
103
return self .min_periods if self .min_periods is not None else self .window
@@ -143,7 +152,7 @@ def count(self):
143
152
class DataArrayRolling (Rolling ):
144
153
__slots__ = ("window_labels" ,)
145
154
146
- def __init__ (self , obj , windows , min_periods = None , center = False ):
155
+ def __init__ (self , obj , windows , min_periods = None , center = False , keep_attrs = None ):
147
156
"""
148
157
Moving window object for DataArray.
149
158
You should use DataArray.rolling() method to construct this object
@@ -165,6 +174,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
165
174
setting min_periods equal to the size of the window.
166
175
center : boolean, default False
167
176
Set the labels at the center of the window.
177
+ keep_attrs : bool, optional
178
+ If True, the object's attributes (`attrs`) will be copied from
179
+ the original object to the new one. If False (default), the new
180
+ object will be returned without attributes.
168
181
169
182
Returns
170
183
-------
@@ -177,7 +190,11 @@ def __init__(self, obj, windows, min_periods=None, center=False):
177
190
Dataset.rolling
178
191
Dataset.groupby
179
192
"""
180
- super ().__init__ (obj , windows , min_periods = min_periods , center = center )
193
+ if keep_attrs is None :
194
+ keep_attrs = _get_keep_attrs (default = False )
195
+ super ().__init__ (
196
+ obj , windows , min_periods = min_periods , center = center , keep_attrs = keep_attrs
197
+ )
181
198
182
199
self .window_labels = self .obj [self .dim ]
183
200
@@ -374,7 +391,7 @@ def _numpy_or_bottleneck_reduce(
374
391
class DatasetRolling (Rolling ):
375
392
__slots__ = ("rollings" ,)
376
393
377
- def __init__ (self , obj , windows , min_periods = None , center = False ):
394
+ def __init__ (self , obj , windows , min_periods = None , center = False , keep_attrs = None ):
378
395
"""
379
396
Moving window object for Dataset.
380
397
You should use Dataset.rolling() method to construct this object
@@ -396,6 +413,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
396
413
setting min_periods equal to the size of the window.
397
414
center : boolean, default False
398
415
Set the labels at the center of the window.
416
+ keep_attrs : bool, optional
417
+ If True, the object's attributes (`attrs`) will be copied from
418
+ the original object to the new one. If False (default), the new
419
+ object will be returned without attributes.
399
420
400
421
Returns
401
422
-------
@@ -408,15 +429,17 @@ def __init__(self, obj, windows, min_periods=None, center=False):
408
429
Dataset.groupby
409
430
DataArray.groupby
410
431
"""
411
- super ().__init__ (obj , windows , min_periods , center )
432
+ super ().__init__ (obj , windows , min_periods , center , keep_attrs )
412
433
if self .dim not in self .obj .dims :
413
434
raise KeyError (self .dim )
414
435
# Keep each Rolling object as a dictionary
415
436
self .rollings = {}
416
437
for key , da in self .obj .data_vars .items ():
417
438
# keeps rollings only for the dataset depending on slf.dim
418
439
if self .dim in da .dims :
419
- self .rollings [key ] = DataArrayRolling (da , windows , min_periods , center )
440
+ self .rollings [key ] = DataArrayRolling (
441
+ da , windows , min_periods , center , keep_attrs
442
+ )
420
443
421
444
def _dataset_implementation (self , func , ** kwargs ):
422
445
from .dataset import Dataset
@@ -427,7 +450,8 @@ def _dataset_implementation(self, func, **kwargs):
427
450
reduced [key ] = func (self .rollings [key ], ** kwargs )
428
451
else :
429
452
reduced [key ] = self .obj [key ]
430
- return Dataset (reduced , coords = self .obj .coords )
453
+ attrs = self .obj .attrs if self .keep_attrs else {}
454
+ return Dataset (reduced , coords = self .obj .coords , attrs = attrs )
431
455
432
456
def reduce (self , func , ** kwargs ):
433
457
"""Reduce the items in this group by applying `func` along some
@@ -466,7 +490,7 @@ def _numpy_or_bottleneck_reduce(
466
490
** kwargs ,
467
491
)
468
492
469
- def construct (self , window_dim , stride = 1 , fill_value = dtypes .NA ):
493
+ def construct (self , window_dim , stride = 1 , fill_value = dtypes .NA , keep_attrs = None ):
470
494
"""
471
495
Convert this rolling object to xr.Dataset,
472
496
where the window dimension is stacked as a new dimension
@@ -487,6 +511,9 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
487
511
488
512
from .dataset import Dataset
489
513
514
+ if keep_attrs is None :
515
+ keep_attrs = _get_keep_attrs (default = True )
516
+
490
517
dataset = {}
491
518
for key , da in self .obj .data_vars .items ():
492
519
if self .dim in da .dims :
@@ -509,10 +536,18 @@ class Coarsen:
509
536
DataArray.coarsen
510
537
"""
511
538
512
- __slots__ = ("obj" , "boundary" , "coord_func" , "windows" , "side" , "trim_excess" )
539
+ __slots__ = (
540
+ "obj" ,
541
+ "boundary" ,
542
+ "coord_func" ,
543
+ "windows" ,
544
+ "side" ,
545
+ "trim_excess" ,
546
+ "keep_attrs" ,
547
+ )
513
548
_attributes = ("windows" , "side" , "trim_excess" )
514
549
515
- def __init__ (self , obj , windows , boundary , side , coord_func ):
550
+ def __init__ (self , obj , windows , boundary , side , coord_func , keep_attrs ):
516
551
"""
517
552
Moving window object.
518
553
@@ -541,6 +576,7 @@ def __init__(self, obj, windows, boundary, side, coord_func):
541
576
self .windows = windows
542
577
self .side = side
543
578
self .boundary = boundary
579
+ self .keep_attrs = keep_attrs
544
580
545
581
absent_dims = [dim for dim in windows .keys () if dim not in self .obj .dims ]
546
582
if absent_dims :
@@ -626,6 +662,11 @@ def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool
626
662
def wrapped_func (self , ** kwargs ):
627
663
from .dataset import Dataset
628
664
665
+ if self .keep_attrs :
666
+ attrs = self .obj .attrs
667
+ else :
668
+ attrs = {}
669
+
629
670
reduced = {}
630
671
for key , da in self .obj .data_vars .items ():
631
672
reduced [key ] = da .variable .coarsen (
@@ -644,7 +685,7 @@ def wrapped_func(self, **kwargs):
644
685
)
645
686
else :
646
687
coords [c ] = v .variable
647
- return Dataset (reduced , coords = coords )
688
+ return Dataset (reduced , coords = coords , attrs = attrs )
648
689
649
690
return wrapped_func
650
691
0 commit comments