Skip to content

Commit 9acc411

Browse files
max-sixtypre-commit-ci[bot]dcherian
authored
Add Cumulative aggregation (#8512)
* Add Cumulative aggregation Closes #5215 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * whatsnew * Update xarray/core/dataarray.py Co-authored-by: Deepak Cherian <[email protected]> * Update xarray/core/dataset.py * min_periods defaults to 1 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent da08288 commit 9acc411

File tree

5 files changed

+174
-2
lines changed

5 files changed

+174
-2
lines changed

doc/api.rst

+2
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ Computation
182182
Dataset.groupby_bins
183183
Dataset.rolling
184184
Dataset.rolling_exp
185+
Dataset.cumulative
185186
Dataset.weighted
186187
Dataset.coarsen
187188
Dataset.resample
@@ -379,6 +380,7 @@ Computation
379380
DataArray.groupby_bins
380381
DataArray.rolling
381382
DataArray.rolling_exp
383+
DataArray.cumulative
382384
DataArray.weighted
383385
DataArray.coarsen
384386
DataArray.resample

doc/whats-new.rst

+6
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ New Features
7171
example a 1D array — it's about the same speed as bottleneck, and 2-5x faster
7272
than pandas' default functions. (:pull:`8493`). numbagg is an optional
7373
dependency, so requires installing separately.
74+
- Add :py:meth:`DataArray.cumulative` & :py:meth:`Dataset.cumulative` to compute
75+
cumulative aggregations, such as ``sum``, along a dimension — for example
76+
``da.cumulative('time').sum()``. This is similar to pandas' ``.expanding``,
77+
and mostly equivalent to ``.cumsum`` methods, or to
78+
:py:meth:`DataArray.rolling` with a window length equal to the dimension size.
79+
(:pull:`8512`).
7480
By `Maximilian Roos <https://github.com/max-sixty>`_.
7581
- Use a concise format when plotting datetime arrays. (:pull:`8449`).
7682
By `Jimmy Westling <https://github.com/illviljan>`_.

xarray/core/dataarray.py

+77-1
Original file line numberDiff line numberDiff line change
@@ -6923,14 +6923,90 @@ def rolling(
69236923
69246924
See Also
69256925
--------
6926-
core.rolling.DataArrayRolling
6926+
DataArray.cumulative
69276927
Dataset.rolling
6928+
core.rolling.DataArrayRolling
69286929
"""
69296930
from xarray.core.rolling import DataArrayRolling
69306931

69316932
dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
69326933
return DataArrayRolling(self, dim, min_periods=min_periods, center=center)
69336934

6935+
def cumulative(
6936+
self,
6937+
dim: str | Iterable[Hashable],
6938+
min_periods: int = 1,
6939+
) -> DataArrayRolling:
6940+
"""
6941+
Accumulating object for DataArrays.
6942+
6943+
Parameters
6944+
----------
6945+
dims : iterable of hashable
6946+
The name(s) of the dimensions to create the cumulative window along
6947+
min_periods : int, default: 1
6948+
Minimum number of observations in window required to have a value
6949+
(otherwise result is NA). The default is 1 (note this is different
6950+
from ``Rolling``, whose default is the size of the window).
6951+
6952+
Returns
6953+
-------
6954+
core.rolling.DataArrayRolling
6955+
6956+
Examples
6957+
--------
6958+
Create rolling seasonal average of monthly data e.g. DJF, JFM, ..., SON:
6959+
6960+
>>> da = xr.DataArray(
6961+
... np.linspace(0, 11, num=12),
6962+
... coords=[
6963+
... pd.date_range(
6964+
... "1999-12-15",
6965+
... periods=12,
6966+
... freq=pd.DateOffset(months=1),
6967+
... )
6968+
... ],
6969+
... dims="time",
6970+
... )
6971+
6972+
>>> da
6973+
<xarray.DataArray (time: 12)>
6974+
array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.])
6975+
Coordinates:
6976+
* time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15
6977+
6978+
>>> da.cumulative("time").sum()
6979+
<xarray.DataArray (time: 12)>
6980+
array([ 0., 1., 3., 6., 10., 15., 21., 28., 36., 45., 55., 66.])
6981+
Coordinates:
6982+
* time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15
6983+
6984+
See Also
6985+
--------
6986+
DataArray.rolling
6987+
Dataset.cumulative
6988+
core.rolling.DataArrayRolling
6989+
"""
6990+
from xarray.core.rolling import DataArrayRolling
6991+
6992+
# Could we abstract this "normalize and check 'dim'" logic? It's currently shared
6993+
# with the same method in Dataset.
6994+
if isinstance(dim, str):
6995+
if dim not in self.dims:
6996+
raise ValueError(
6997+
f"Dimension {dim} not found in data dimensions: {self.dims}"
6998+
)
6999+
dim = {dim: self.sizes[dim]}
7000+
else:
7001+
missing_dims = set(dim) - set(self.dims)
7002+
if missing_dims:
7003+
raise ValueError(
7004+
f"Dimensions {missing_dims} not found in data dimensions: {self.dims}"
7005+
)
7006+
dim = {d: self.sizes[d] for d in dim}
7007+
7008+
return DataArrayRolling(self, dim, min_periods=min_periods, center=False)
7009+
69347010
def coarsen(
69357011
self,
69367012
dim: Mapping[Any, int] | None = None,

xarray/core/dataset.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -10369,14 +10369,60 @@ def rolling(
1036910369
1037010370
See Also
1037110371
--------
10372-
core.rolling.DatasetRolling
10372+
Dataset.cumulative
1037310373
DataArray.rolling
10374+
core.rolling.DatasetRolling
1037410375
"""
1037510376
from xarray.core.rolling import DatasetRolling
1037610377

1037710378
dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
1037810379
return DatasetRolling(self, dim, min_periods=min_periods, center=center)
1037910380

10381+
def cumulative(
10382+
self,
10383+
dim: str | Iterable[Hashable],
10384+
min_periods: int = 1,
10385+
) -> DatasetRolling:
10386+
"""
10387+
Accumulating object for Datasets
10388+
10389+
Parameters
10390+
----------
10391+
dims : iterable of hashable
10392+
The name(s) of the dimensions to create the cumulative window along
10393+
min_periods : int, default: 1
10394+
Minimum number of observations in window required to have a value
10395+
(otherwise result is NA). The default is 1 (note this is different
10396+
from ``Rolling``, whose default is the size of the window).
10397+
10398+
Returns
10399+
-------
10400+
core.rolling.DatasetRolling
10401+
10402+
See Also
10403+
--------
10404+
Dataset.rolling
10405+
DataArray.cumulative
10406+
core.rolling.DatasetRolling
10407+
"""
10408+
from xarray.core.rolling import DatasetRolling
10409+
10410+
if isinstance(dim, str):
10411+
if dim not in self.dims:
10412+
raise ValueError(
10413+
f"Dimension {dim} not found in data dimensions: {self.dims}"
10414+
)
10415+
dim = {dim: self.sizes[dim]}
10416+
else:
10417+
missing_dims = set(dim) - set(self.dims)
10418+
if missing_dims:
10419+
raise ValueError(
10420+
f"Dimensions {missing_dims} not found in data dimensions: {self.dims}"
10421+
)
10422+
dim = {d: self.sizes[d] for d in dim}
10423+
10424+
return DatasetRolling(self, dim, min_periods=min_periods, center=False)
10425+
1038010426
def coarsen(
1038110427
self,
1038210428
dim: Mapping[Any, int] | None = None,

xarray/tests/test_rolling.py

+42
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,29 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None:
485485
):
486486
da.rolling_exp(time=10, keep_attrs=True)
487487

488+
@pytest.mark.parametrize("func", ["mean", "sum"])
489+
@pytest.mark.parametrize("min_periods", [1, 20])
490+
def test_cumulative(self, da, func, min_periods) -> None:
491+
# One dim
492+
result = getattr(da.cumulative("time", min_periods=min_periods), func)()
493+
expected = getattr(
494+
da.rolling(time=da.time.size, min_periods=min_periods), func
495+
)()
496+
assert_identical(result, expected)
497+
498+
# Multiple dim
499+
result = getattr(da.cumulative(["time", "a"], min_periods=min_periods), func)()
500+
expected = getattr(
501+
da.rolling(time=da.time.size, a=da.a.size, min_periods=min_periods),
502+
func,
503+
)()
504+
assert_identical(result, expected)
505+
506+
def test_cumulative_vs_cum(self, da) -> None:
507+
result = da.cumulative("time").sum()
508+
expected = da.cumsum("time")
509+
assert_identical(result, expected)
510+
488511

489512
class TestDatasetRolling:
490513
@pytest.mark.parametrize(
@@ -809,6 +832,25 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:
809832
expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)()
810833
assert_allclose(actual, expected)
811834

835+
@pytest.mark.parametrize("func", ["mean", "sum"])
836+
@pytest.mark.parametrize("ds", (2,), indirect=True)
837+
@pytest.mark.parametrize("min_periods", [1, 10])
838+
def test_cumulative(self, ds, func, min_periods) -> None:
839+
# One dim
840+
result = getattr(ds.cumulative("time", min_periods=min_periods), func)()
841+
expected = getattr(
842+
ds.rolling(time=ds.time.size, min_periods=min_periods), func
843+
)()
844+
assert_identical(result, expected)
845+
846+
# Multiple dim
847+
result = getattr(ds.cumulative(["time", "x"], min_periods=min_periods), func)()
848+
expected = getattr(
849+
ds.rolling(time=ds.time.size, x=ds.x.size, min_periods=min_periods),
850+
func,
851+
)()
852+
assert_identical(result, expected)
853+
812854

813855
@requires_numbagg
814856
class TestDatasetRollingExp:

0 commit comments

Comments
 (0)