Skip to content

Commit 2783255

Browse files
dcherianmax-sixty
andauthored
GroupBy(multiple strings) (#9414)
* Group by multiple strings Closes #9396 * Fix typing * some more * fix * cleanup * Update xarray/core/dataarray.py * Update docs * Revert "Update xarray/core/dataarray.py" This reverts commit fafd960. * update docstring * Add docstring examples * Update xarray/core/dataarray.py Co-authored-by: Maximilian Roos <[email protected]> * Update xarray/core/dataset.py * fix assert warning / error * fix assert warning / error * Silence RTD warnings --------- Co-authored-by: Maximilian Roos <[email protected]> Co-authored-by: Maximilian Roos <[email protected]>
1 parent a8c9896 commit 2783255

File tree

6 files changed

+199
-59
lines changed

6 files changed

+199
-59
lines changed

doc/user-guide/groupby.rst

+6
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,12 @@ Use grouper objects to group by multiple dimensions:
305305
306306
from xarray.groupers import UniqueGrouper
307307
308+
da.groupby(["lat", "lon"]).sum()
309+
310+
The above is sugar for using ``UniqueGrouper`` objects directly:
311+
312+
.. ipython:: python
313+
308314
da.groupby(lat=UniqueGrouper(), lon=UniqueGrouper()).sum()
309315
310316

xarray/core/dataarray.py

+51-26
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
Dims,
104104
ErrorOptions,
105105
ErrorOptionsWithWarn,
106+
GroupInput,
106107
InterpOptions,
107108
PadModeOptions,
108109
PadReflectOptions,
@@ -6707,9 +6708,7 @@ def interp_calendar(
67076708
@_deprecate_positional_args("v2024.07.0")
67086709
def groupby(
67096710
self,
6710-
group: (
6711-
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
6712-
) = None,
6711+
group: GroupInput = None,
67136712
*,
67146713
squeeze: Literal[False] = False,
67156714
restore_coord_dims: bool = False,
@@ -6719,7 +6718,7 @@ def groupby(
67196718
67206719
Parameters
67216720
----------
6722-
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
6721+
group : str or DataArray or IndexVariable or sequence of hashable or mapping of hashable to Grouper
67236722
Array whose unique values should be used to group this array. If a
67246723
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
67256724
must map an existing variable name to a :py:class:`Grouper` instance.
@@ -6770,6 +6769,52 @@ def groupby(
67706769
Coordinates:
67716770
* dayofyear (dayofyear) int64 3kB 1 2 3 4 5 6 7 ... 361 362 363 364 365 366
67726771
6772+
>>> da = xr.DataArray(
6773+
... data=np.arange(12).reshape((4, 3)),
6774+
... dims=("x", "y"),
6775+
... coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))},
6776+
... )
6777+
6778+
Grouping by a single variable is easy
6779+
6780+
>>> da.groupby("letters")
6781+
<DataArrayGroupBy, grouped over 1 grouper(s), 2 groups in total:
6782+
'letters': 2 groups with labels 'a', 'b'>
6783+
6784+
Execute a reduction
6785+
6786+
>>> da.groupby("letters").sum()
6787+
<xarray.DataArray (letters: 2, y: 3)> Size: 48B
6788+
array([[ 9., 11., 13.],
6789+
[ 9., 11., 13.]])
6790+
Coordinates:
6791+
* letters (letters) object 16B 'a' 'b'
6792+
Dimensions without coordinates: y
6793+
6794+
Grouping by multiple variables
6795+
6796+
>>> da.groupby(["letters", "x"])
6797+
<DataArrayGroupBy, grouped over 2 grouper(s), 8 groups in total:
6798+
'letters': 2 groups with labels 'a', 'b'
6799+
'x': 4 groups with labels 10, 20, 30, 40>
6800+
6801+
Use Grouper objects to express more complicated GroupBy operations
6802+
6803+
>>> from xarray.groupers import BinGrouper, UniqueGrouper
6804+
>>>
6805+
>>> da.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum()
6806+
<xarray.DataArray (x_bins: 2, letters: 2, y: 3)> Size: 96B
6807+
array([[[ 0., 1., 2.],
6808+
[nan, nan, nan]],
6809+
<BLANKLINE>
6810+
[[nan, nan, nan],
6811+
[ 3., 4., 5.]]])
6812+
Coordinates:
6813+
* x_bins (x_bins) object 16B (5, 15] (15, 25]
6814+
* letters (letters) object 16B 'a' 'b'
6815+
Dimensions without coordinates: y
6816+
6817+
67736818
See Also
67746819
--------
67756820
:ref:`groupby`
@@ -6791,32 +6836,12 @@ def groupby(
67916836
"""
67926837
from xarray.core.groupby import (
67936838
DataArrayGroupBy,
6794-
ResolvedGrouper,
6839+
_parse_group_and_groupers,
67956840
_validate_groupby_squeeze,
67966841
)
6797-
from xarray.groupers import UniqueGrouper
67986842

67996843
_validate_groupby_squeeze(squeeze)
6800-
6801-
if isinstance(group, Mapping):
6802-
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
6803-
group = None
6804-
6805-
rgroupers: tuple[ResolvedGrouper, ...]
6806-
if group is not None:
6807-
if groupers:
6808-
raise ValueError(
6809-
"Providing a combination of `group` and **groupers is not supported."
6810-
)
6811-
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
6812-
else:
6813-
if not groupers:
6814-
raise ValueError("Either `group` or `**groupers` must be provided.")
6815-
rgroupers = tuple(
6816-
ResolvedGrouper(grouper, group, self)
6817-
for group, grouper in groupers.items()
6818-
)
6819-
6844+
rgroupers = _parse_group_and_groupers(self, group, groupers)
68206845
return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)
68216846

68226847
@_deprecate_positional_args("v2024.07.0")

xarray/core/dataset.py

+50-25
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
DsCompatible,
156156
ErrorOptions,
157157
ErrorOptionsWithWarn,
158+
GroupInput,
158159
InterpOptions,
159160
JoinOptions,
160161
PadModeOptions,
@@ -10332,9 +10333,7 @@ def interp_calendar(
1033210333
@_deprecate_positional_args("v2024.07.0")
1033310334
def groupby(
1033410335
self,
10335-
group: (
10336-
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
10337-
) = None,
10336+
group: GroupInput = None,
1033810337
*,
1033910338
squeeze: Literal[False] = False,
1034010339
restore_coord_dims: bool = False,
@@ -10344,7 +10343,7 @@ def groupby(
1034410343
1034510344
Parameters
1034610345
----------
10347-
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
10346+
group : str or DataArray or IndexVariable or sequence of hashable or mapping of hashable to Grouper
1034810347
Array whose unique values should be used to group this array. If a
1034910348
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
1035010349
must map an existing variable name to a :py:class:`Grouper` instance.
@@ -10366,6 +10365,51 @@ def groupby(
1036610365
A `DatasetGroupBy` object patterned after `pandas.GroupBy` that can be
1036710366
iterated over in the form of `(unique_value, grouped_array)` pairs.
1036810367
10368+
Examples
10369+
--------
10370+
>>> ds = xr.Dataset(
10371+
... {"foo": (("x", "y"), np.arange(12).reshape((4, 3)))},
10372+
... coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))},
10373+
... )
10374+
10375+
Grouping by a single variable is easy
10376+
10377+
>>> ds.groupby("letters")
10378+
<DatasetGroupBy, grouped over 1 grouper(s), 2 groups in total:
10379+
'letters': 2 groups with labels 'a', 'b'>
10380+
10381+
Execute a reduction
10382+
10383+
>>> ds.groupby("letters").sum()
10384+
<xarray.Dataset> Size: 64B
10385+
Dimensions: (letters: 2, y: 3)
10386+
Coordinates:
10387+
* letters (letters) object 16B 'a' 'b'
10388+
Dimensions without coordinates: y
10389+
Data variables:
10390+
foo (letters, y) float64 48B 9.0 11.0 13.0 9.0 11.0 13.0
10391+
10392+
Grouping by multiple variables
10393+
10394+
>>> ds.groupby(["letters", "x"])
10395+
<DatasetGroupBy, grouped over 2 grouper(s), 8 groups in total:
10396+
'letters': 2 groups with labels 'a', 'b'
10397+
'x': 4 groups with labels 10, 20, 30, 40>
10398+
10399+
Use Grouper objects to express more complicated GroupBy operations
10400+
10401+
>>> from xarray.groupers import BinGrouper, UniqueGrouper
10402+
>>>
10403+
>>> ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum()
10404+
<xarray.Dataset> Size: 128B
10405+
Dimensions: (y: 3, x_bins: 2, letters: 2)
10406+
Coordinates:
10407+
* x_bins (x_bins) object 16B (5, 15] (15, 25]
10408+
* letters (letters) object 16B 'a' 'b'
10409+
Dimensions without coordinates: y
10410+
Data variables:
10411+
foo (y, x_bins, letters) float64 96B 0.0 nan nan 3.0 ... nan nan 5.0
10412+
1036910413
See Also
1037010414
--------
1037110415
:ref:`groupby`
@@ -10387,31 +10431,12 @@ def groupby(
1038710431
"""
1038810432
from xarray.core.groupby import (
1038910433
DatasetGroupBy,
10390-
ResolvedGrouper,
10434+
_parse_group_and_groupers,
1039110435
_validate_groupby_squeeze,
1039210436
)
10393-
from xarray.groupers import UniqueGrouper
1039410437

1039510438
_validate_groupby_squeeze(squeeze)
10396-
10397-
if isinstance(group, Mapping):
10398-
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
10399-
group = None
10400-
10401-
rgroupers: tuple[ResolvedGrouper, ...]
10402-
if group is not None:
10403-
if groupers:
10404-
raise ValueError(
10405-
"Providing a combination of `group` and **groupers is not supported."
10406-
)
10407-
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
10408-
else:
10409-
if not groupers:
10410-
raise ValueError("Either `group` or `**groupers` must be provided.")
10411-
rgroupers = tuple(
10412-
ResolvedGrouper(grouper, group, self)
10413-
for group, grouper in groupers.items()
10414-
)
10439+
rgroupers = _parse_group_and_groupers(self, group, groupers)
1041510440

1041610441
return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)
1041710442

xarray/core/groupby.py

+49-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import warnings
77
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
88
from dataclasses import dataclass, field
9-
from typing import TYPE_CHECKING, Any, Generic, Literal, Union
9+
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
1010

1111
import numpy as np
1212
import pandas as pd
@@ -54,7 +54,7 @@
5454

5555
from xarray.core.dataarray import DataArray
5656
from xarray.core.dataset import Dataset
57-
from xarray.core.types import GroupIndex, GroupIndices, GroupKey
57+
from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey
5858
from xarray.core.utils import Frozen
5959
from xarray.groupers import EncodedGroups, Grouper
6060

@@ -319,6 +319,51 @@ def __len__(self) -> int:
319319
return len(self.encoded.full_index)
320320

321321

322+
def _parse_group_and_groupers(
323+
obj: T_Xarray, group: GroupInput, groupers: dict[str, Grouper]
324+
) -> tuple[ResolvedGrouper, ...]:
325+
from xarray.core.dataarray import DataArray
326+
from xarray.core.variable import Variable
327+
from xarray.groupers import UniqueGrouper
328+
329+
if group is not None and groupers:
330+
raise ValueError(
331+
"Providing a combination of `group` and **groupers is not supported."
332+
)
333+
334+
if group is None and not groupers:
335+
raise ValueError("Either `group` or `**groupers` must be provided.")
336+
337+
if isinstance(group, np.ndarray | pd.Index):
338+
raise TypeError(
339+
f"`group` must be a DataArray. Received {type(group).__name__!r} instead"
340+
)
341+
342+
if isinstance(group, Mapping):
343+
grouper_mapping = either_dict_or_kwargs(group, groupers, "groupby")
344+
group = None
345+
346+
rgroupers: tuple[ResolvedGrouper, ...]
347+
if isinstance(group, DataArray | Variable):
348+
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, obj),)
349+
else:
350+
if group is not None:
351+
if TYPE_CHECKING:
352+
assert isinstance(group, str | Sequence)
353+
group_iter: Sequence[Hashable] = (
354+
(group,) if isinstance(group, str) else group
355+
)
356+
grouper_mapping = {g: UniqueGrouper() for g in group_iter}
357+
elif groupers:
358+
grouper_mapping = cast("Mapping[Hashable, Grouper]", groupers)
359+
360+
rgroupers = tuple(
361+
ResolvedGrouper(grouper, group, obj)
362+
for group, grouper in grouper_mapping.items()
363+
)
364+
return rgroupers
365+
366+
322367
def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
323368
# While we don't generally check the type of every arg, passing
324369
# multiple dimensions as multiple arguments is common enough, and the
@@ -327,7 +372,7 @@ def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
327372
# A future version could make squeeze kwarg only, but would face
328373
# backward-compat issues.
329374
if squeeze is not False:
330-
raise TypeError(f"`squeeze` must be False, but {squeeze} was supplied.")
375+
raise TypeError(f"`squeeze` must be False, but {squeeze!r} was supplied.")
331376

332377

333378
def _resolve_group(
@@ -626,7 +671,7 @@ def __repr__(self) -> str:
626671
for grouper in self.groupers:
627672
coord = grouper.unique_coord
628673
labels = ", ".join(format_array_flat(coord, 30).split())
629-
text += f"\n\t{grouper.name!r}: {coord.size} groups with labels {labels}"
674+
text += f"\n {grouper.name!r}: {coord.size} groups with labels {labels}"
630675
return text + ">"
631676

632677
def _iter_grouped(self) -> Iterator[T_Xarray]:

xarray/core/types.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,17 @@
4343
from xarray.core.dataset import Dataset
4444
from xarray.core.indexes import Index, Indexes
4545
from xarray.core.utils import Frozen
46-
from xarray.core.variable import Variable
47-
from xarray.groupers import TimeResampler
46+
from xarray.core.variable import IndexVariable, Variable
47+
from xarray.groupers import Grouper, TimeResampler
48+
49+
GroupInput: TypeAlias = (
50+
str
51+
| DataArray
52+
| IndexVariable
53+
| Sequence[Hashable]
54+
| Mapping[Any, Grouper]
55+
| None
56+
)
4857

4958
try:
5059
from dask.array import Array as DaskArray

0 commit comments

Comments
 (0)