Skip to content

Commit 1c5e1cd

Browse files
authored
Coarsen keep attrs 3376 (pydata#3801)
* Add test of DataWithCoords.coarsen() for pydata#3376 * Add test of Variable.coarsen() for pydata#3376 * Add keep_attrs kwarg to DataWithCoords.coarsen() for pydata#3376 * Style and spelling fixes (pydata#3376) * Fix test_coarsen_keep_attrs by removing self from input * Pass keep_attrs through to _coarsen_cls and _rolling_cls returns (pydata#3376) * Move keyword from coarsen to mean in test_coarsen_keep_attrs * Start handling keep_attrs in rolling class constructors (pydata#3376) * Update Coarsen constructor and DatasetCoarsen class method (GH3376) Assign keep_attrs keyword value to Coarsen objects in constructor Add conditional inside _reduce_method.wrapped_func branching on self.keep_attrs and pass back to returned Dataset * Incorporate code review from @max-sixty * Fix Dataset.coarsen and Variable.coarsen for GH3376 Handle global keep_attrs setting inside Variable._coarsen_reshape Pass attrs through consistently inside DatasetCoarsen._reduce_method Don't pass Variable.coarsen a keyword argument it doesn't expect inside DataArrayCoarsen._reduce_method * Update tests for GH3376 * Incorporate review changes to test_dataset for GH3376 Remove commented-out test from test_coarsen_keep_attrs Add test_rolling_keep_attrs * Change Rolling._dataset_implementation for GH3376 Return a Dataset object that results in test_rolling_keep_attrs Passing * style fixes * Remove duplicate variable assignment and document change (GH3776)
1 parent b155853 commit 1c5e1cd

File tree

6 files changed

+165
-17
lines changed

6 files changed

+165
-17
lines changed

doc/whats-new.rst

+5
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ Bug fixes
6161
- xarray now respects the over, under and bad colors if set on a provided colormap.
6262
(:issue:`3590`, :pull:`3601`)
6363
By `johnomotani <https://github.com/johnomotani>`_.
64+
- :py:func:`coarsen` now respects ``xr.set_options(keep_attrs=True)``
65+
to preserve attributes. :py:meth:`Dataset.coarsen` accepts a keyword
66+
argument ``keep_attrs`` to change this setting. (:issue:`3376`,
67+
:pull:`3801`) By `Andrew Thomas <https://github.com/amcnicho>`_.
68+
6469
- Fix :py:meth:`xarray.core.dataset.Dataset.to_zarr` when using `append_dim` and `group`
6570
simultaneously. (:issue:`3170`). By `Matthias Meyer <https://github.com/niowniow>`_.
6671

xarray/core/common.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ def rolling(
753753
dim: Mapping[Hashable, int] = None,
754754
min_periods: int = None,
755755
center: bool = False,
756+
keep_attrs: bool = None,
756757
**window_kwargs: int,
757758
):
758759
"""
@@ -769,6 +770,10 @@ def rolling(
769770
setting min_periods equal to the size of the window.
770771
center : boolean, default False
771772
Set the labels at the center of the window.
773+
keep_attrs : bool, optional
774+
If True, the object's attributes (`attrs`) will be copied from
775+
the original object to the new one. If False (default), the new
776+
object will be returned without attributes.
772777
**window_kwargs : optional
773778
The keyword arguments form of ``dim``.
774779
One of dim or window_kwargs must be provided.
@@ -810,8 +815,13 @@ def rolling(
810815
core.rolling.DataArrayRolling
811816
core.rolling.DatasetRolling
812817
"""
818+
if keep_attrs is None:
819+
keep_attrs = _get_keep_attrs(default=False)
820+
813821
dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
814-
return self._rolling_cls(self, dim, min_periods=min_periods, center=center)
822+
return self._rolling_cls(
823+
self, dim, min_periods=min_periods, center=center, keep_attrs=keep_attrs
824+
)
815825

816826
def rolling_exp(
817827
self,
@@ -859,6 +869,7 @@ def coarsen(
859869
boundary: str = "exact",
860870
side: Union[str, Mapping[Hashable, str]] = "left",
861871
coord_func: str = "mean",
872+
keep_attrs: bool = None,
862873
**window_kwargs: int,
863874
):
864875
"""
@@ -879,8 +890,12 @@ def coarsen(
879890
multiple of the window size. If 'trim', the excess entries are
880891
dropped. If 'pad', NA will be padded.
881892
side : 'left' or 'right' or mapping from dimension to 'left' or 'right'
882-
coord_func : function (name) that is applied to the coordintes,
893+
coord_func : function (name) that is applied to the coordinates,
883894
or a mapping from coordinate name to function (name).
895+
keep_attrs : bool, optional
896+
If True, the object's attributes (`attrs`) will be copied from
897+
the original object to the new one. If False (default), the new
898+
object will be returned without attributes.
884899
885900
Returns
886901
-------
@@ -915,9 +930,17 @@ def coarsen(
915930
core.rolling.DataArrayCoarsen
916931
core.rolling.DatasetCoarsen
917932
"""
933+
if keep_attrs is None:
934+
keep_attrs = _get_keep_attrs(default=False)
935+
918936
dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen")
919937
return self._coarsen_cls(
920-
self, dim, boundary=boundary, side=side, coord_func=coord_func
938+
self,
939+
dim,
940+
boundary=boundary,
941+
side=side,
942+
coord_func=coord_func,
943+
keep_attrs=keep_attrs,
921944
)
922945

923946
def resample(

xarray/core/rolling.py

+54-13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import dtypes, duck_array_ops, utils
88
from .dask_array_ops import dask_rolling_wrapper
99
from .ops import inject_reduce_methods
10+
from .options import _get_keep_attrs
1011
from .pycompat import dask_array_type
1112

1213
try:
@@ -42,10 +43,10 @@ class Rolling:
4243
DataArray.rolling
4344
"""
4445

45-
__slots__ = ("obj", "window", "min_periods", "center", "dim")
46-
_attributes = ("window", "min_periods", "center", "dim")
46+
__slots__ = ("obj", "window", "min_periods", "center", "dim", "keep_attrs")
47+
_attributes = ("window", "min_periods", "center", "dim", "keep_attrs")
4748

48-
def __init__(self, obj, windows, min_periods=None, center=False):
49+
def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
4950
"""
5051
Moving window object.
5152
@@ -65,6 +66,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
6566
setting min_periods equal to the size of the window.
6667
center : boolean, default False
6768
Set the labels at the center of the window.
69+
keep_attrs : bool, optional
70+
If True, the object's attributes (`attrs`) will be copied from
71+
the original object to the new one. If False (default), the new
72+
object will be returned without attributes.
6873
6974
Returns
7075
-------
@@ -89,6 +94,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
8994
self.center = center
9095
self.dim = dim
9196

97+
if keep_attrs is None:
98+
keep_attrs = _get_keep_attrs(default=False)
99+
self.keep_attrs = keep_attrs
100+
92101
@property
93102
def _min_periods(self):
94103
return self.min_periods if self.min_periods is not None else self.window
@@ -143,7 +152,7 @@ def count(self):
143152
class DataArrayRolling(Rolling):
144153
__slots__ = ("window_labels",)
145154

146-
def __init__(self, obj, windows, min_periods=None, center=False):
155+
def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
147156
"""
148157
Moving window object for DataArray.
149158
You should use DataArray.rolling() method to construct this object
@@ -165,6 +174,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
165174
setting min_periods equal to the size of the window.
166175
center : boolean, default False
167176
Set the labels at the center of the window.
177+
keep_attrs : bool, optional
178+
If True, the object's attributes (`attrs`) will be copied from
179+
the original object to the new one. If False (default), the new
180+
object will be returned without attributes.
168181
169182
Returns
170183
-------
@@ -177,7 +190,11 @@ def __init__(self, obj, windows, min_periods=None, center=False):
177190
Dataset.rolling
178191
Dataset.groupby
179192
"""
180-
super().__init__(obj, windows, min_periods=min_periods, center=center)
193+
if keep_attrs is None:
194+
keep_attrs = _get_keep_attrs(default=False)
195+
super().__init__(
196+
obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs
197+
)
181198

182199
self.window_labels = self.obj[self.dim]
183200

@@ -374,7 +391,7 @@ def _numpy_or_bottleneck_reduce(
374391
class DatasetRolling(Rolling):
375392
__slots__ = ("rollings",)
376393

377-
def __init__(self, obj, windows, min_periods=None, center=False):
394+
def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
378395
"""
379396
Moving window object for Dataset.
380397
You should use Dataset.rolling() method to construct this object
@@ -396,6 +413,10 @@ def __init__(self, obj, windows, min_periods=None, center=False):
396413
setting min_periods equal to the size of the window.
397414
center : boolean, default False
398415
Set the labels at the center of the window.
416+
keep_attrs : bool, optional
417+
If True, the object's attributes (`attrs`) will be copied from
418+
the original object to the new one. If False (default), the new
419+
object will be returned without attributes.
399420
400421
Returns
401422
-------
@@ -408,15 +429,17 @@ def __init__(self, obj, windows, min_periods=None, center=False):
408429
Dataset.groupby
409430
DataArray.groupby
410431
"""
411-
super().__init__(obj, windows, min_periods, center)
432+
super().__init__(obj, windows, min_periods, center, keep_attrs)
412433
if self.dim not in self.obj.dims:
413434
raise KeyError(self.dim)
414435
# Keep each Rolling object as a dictionary
415436
self.rollings = {}
416437
for key, da in self.obj.data_vars.items():
417438
# keeps rollings only for the dataset depending on slf.dim
418439
if self.dim in da.dims:
419-
self.rollings[key] = DataArrayRolling(da, windows, min_periods, center)
440+
self.rollings[key] = DataArrayRolling(
441+
da, windows, min_periods, center, keep_attrs
442+
)
420443

421444
def _dataset_implementation(self, func, **kwargs):
422445
from .dataset import Dataset
@@ -427,7 +450,8 @@ def _dataset_implementation(self, func, **kwargs):
427450
reduced[key] = func(self.rollings[key], **kwargs)
428451
else:
429452
reduced[key] = self.obj[key]
430-
return Dataset(reduced, coords=self.obj.coords)
453+
attrs = self.obj.attrs if self.keep_attrs else {}
454+
return Dataset(reduced, coords=self.obj.coords, attrs=attrs)
431455

432456
def reduce(self, func, **kwargs):
433457
"""Reduce the items in this group by applying `func` along some
@@ -466,7 +490,7 @@ def _numpy_or_bottleneck_reduce(
466490
**kwargs,
467491
)
468492

469-
def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
493+
def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None):
470494
"""
471495
Convert this rolling object to xr.Dataset,
472496
where the window dimension is stacked as a new dimension
@@ -487,6 +511,9 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
487511

488512
from .dataset import Dataset
489513

514+
if keep_attrs is None:
515+
keep_attrs = _get_keep_attrs(default=True)
516+
490517
dataset = {}
491518
for key, da in self.obj.data_vars.items():
492519
if self.dim in da.dims:
@@ -509,10 +536,18 @@ class Coarsen:
509536
DataArray.coarsen
510537
"""
511538

512-
__slots__ = ("obj", "boundary", "coord_func", "windows", "side", "trim_excess")
539+
__slots__ = (
540+
"obj",
541+
"boundary",
542+
"coord_func",
543+
"windows",
544+
"side",
545+
"trim_excess",
546+
"keep_attrs",
547+
)
513548
_attributes = ("windows", "side", "trim_excess")
514549

515-
def __init__(self, obj, windows, boundary, side, coord_func):
550+
def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs):
516551
"""
517552
Moving window object.
518553
@@ -541,6 +576,7 @@ def __init__(self, obj, windows, boundary, side, coord_func):
541576
self.windows = windows
542577
self.side = side
543578
self.boundary = boundary
579+
self.keep_attrs = keep_attrs
544580

545581
absent_dims = [dim for dim in windows.keys() if dim not in self.obj.dims]
546582
if absent_dims:
@@ -626,6 +662,11 @@ def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool
626662
def wrapped_func(self, **kwargs):
627663
from .dataset import Dataset
628664

665+
if self.keep_attrs:
666+
attrs = self.obj.attrs
667+
else:
668+
attrs = {}
669+
629670
reduced = {}
630671
for key, da in self.obj.data_vars.items():
631672
reduced[key] = da.variable.coarsen(
@@ -644,7 +685,7 @@ def wrapped_func(self, **kwargs):
644685
)
645686
else:
646687
coords[c] = v.variable
647-
return Dataset(reduced, coords=coords)
688+
return Dataset(reduced, coords=coords, attrs=attrs)
648689

649690
return wrapped_func
650691

xarray/core/variable.py

+3
Original file line numberDiff line numberDiff line change
@@ -1949,6 +1949,9 @@ def _coarsen_reshape(self, windows, boundary, side):
19491949
else:
19501950
shape.append(variable.shape[i])
19511951

1952+
keep_attrs = _get_keep_attrs(default=False)
1953+
variable.attrs = variable._attrs if keep_attrs else {}
1954+
19521955
return variable.data.reshape(shape), tuple(axes)
19531956

19541957
@property

xarray/tests/test_dataset.py

+56
Original file line numberDiff line numberDiff line change
@@ -5664,6 +5664,62 @@ def test_coarsen_coords_cftime():
56645664
np.testing.assert_array_equal(actual.time, expected_times)
56655665

56665666

5667+
def test_coarsen_keep_attrs():
5668+
_attrs = {"units": "test", "long_name": "testing"}
5669+
5670+
var1 = np.linspace(10, 15, 100)
5671+
var2 = np.linspace(5, 10, 100)
5672+
coords = np.linspace(1, 10, 100)
5673+
5674+
ds = Dataset(
5675+
data_vars={"var1": ("coord", var1), "var2": ("coord", var2)},
5676+
coords={"coord": coords},
5677+
attrs=_attrs,
5678+
)
5679+
5680+
# Test dropped attrs
5681+
dat = ds.coarsen(coord=5).mean()
5682+
assert dat.attrs == {}
5683+
5684+
# Test kept attrs using dataset keyword
5685+
dat = ds.coarsen(coord=5, keep_attrs=True).mean()
5686+
assert dat.attrs == _attrs
5687+
5688+
# Test kept attrs using global option
5689+
with set_options(keep_attrs=True):
5690+
dat = ds.coarsen(coord=5).mean()
5691+
assert dat.attrs == _attrs
5692+
5693+
5694+
def test_rolling_keep_attrs():
5695+
_attrs = {"units": "test", "long_name": "testing"}
5696+
5697+
var1 = np.linspace(10, 15, 100)
5698+
var2 = np.linspace(5, 10, 100)
5699+
coords = np.linspace(1, 10, 100)
5700+
5701+
ds = Dataset(
5702+
data_vars={"var1": ("coord", var1), "var2": ("coord", var2)},
5703+
coords={"coord": coords},
5704+
attrs=_attrs,
5705+
)
5706+
5707+
# Test dropped attrs
5708+
dat = ds.rolling(dim={"coord": 5}, min_periods=None, center=False).mean()
5709+
assert dat.attrs == {}
5710+
5711+
# Test kept attrs using dataset keyword
5712+
dat = ds.rolling(
5713+
dim={"coord": 5}, min_periods=None, center=False, keep_attrs=True
5714+
).mean()
5715+
assert dat.attrs == _attrs
5716+
5717+
# Test kept attrs using global option
5718+
with set_options(keep_attrs=True):
5719+
dat = ds.rolling(dim={"coord": 5}, min_periods=None, center=False).mean()
5720+
assert dat.attrs == _attrs
5721+
5722+
56675723
def test_rolling_properties(ds):
56685724
# catching invalid args
56695725
with pytest.raises(ValueError, match="exactly one dim/window should"):

xarray/tests/test_variable.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytz
1010

1111
from xarray import Coordinate, Dataset, IndexVariable, Variable, set_options
12-
from xarray.core import dtypes, indexing
12+
from xarray.core import dtypes, duck_array_ops, indexing
1313
from xarray.core.common import full_like, ones_like, zeros_like
1414
from xarray.core.indexing import (
1515
BasicIndexer,
@@ -1879,6 +1879,26 @@ def test_coarsen_2d(self):
18791879
expected = self.cls(("x", "y"), [[10, 18], [42, 35]])
18801880
assert_equal(actual, expected)
18811881

1882+
# perhaps @pytest.mark.parametrize("operation", [f for f in duck_array_ops])
1883+
def test_coarsen_keep_attrs(self, operation="mean"):
1884+
_attrs = {"units": "test", "long_name": "testing"}
1885+
1886+
test_func = getattr(duck_array_ops, operation, None)
1887+
1888+
# Test dropped attrs
1889+
with set_options(keep_attrs=False):
1890+
new = Variable(["coord"], np.linspace(1, 10, 100), attrs=_attrs).coarsen(
1891+
windows={"coord": 1}, func=test_func, boundary="exact", side="left"
1892+
)
1893+
assert new.attrs == {}
1894+
1895+
# Test kept attrs
1896+
with set_options(keep_attrs=True):
1897+
new = Variable(["coord"], np.linspace(1, 10, 100), attrs=_attrs).coarsen(
1898+
windows={"coord": 1}, func=test_func, boundary="exact", side="left"
1899+
)
1900+
assert new.attrs == _attrs
1901+
18821902

18831903
@requires_dask
18841904
class TestVariableWithDask(VariableSubclassobjects):

0 commit comments

Comments
 (0)