Skip to content

Commit 4da732f

Browse files
committed
Squashed commit of the following:
commit 583a3d2 Author: Deepak Cherian <[email protected]> Date: Wed Mar 19 12:55:54 2025 -0600 fix mypy commit 699c3b8 Author: Deepak Cherian <[email protected]> Date: Wed Mar 19 09:30:38 2025 -0600 Preserve label ordering for multi-variable GroupBy
1 parent b67118a commit 4da732f

File tree

3 files changed

+61
-15
lines changed

3 files changed

+61
-15
lines changed

xarray/core/groupby.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,11 @@ def factorize(self) -> EncodedGroups:
536536
list(grouper.full_index.values for grouper in groupers),
537537
names=tuple(grouper.name for grouper in groupers),
538538
)
539+
if not full_index.is_unique:
540+
raise ValueError(
541+
"The output index for the GroupBy is non-unique. "
542+
"This is a bug in the Grouper provided."
543+
)
539544
# This will be unused when grouping by dask arrays, so skip..
540545
if not is_chunked_array(_flatcodes):
541546
# Constructing an index from the product is wrong when there are missing groups
@@ -947,17 +952,29 @@ def _binary_op(self, other, f, reflexive=False):
947952
def _restore_dim_order(self, stacked):
948953
raise NotImplementedError
949954

950-
def _maybe_restore_empty_groups(self, combined):
951-
"""Our index contained empty groups (e.g., from a resampling or binning). If we
955+
def _maybe_reindex(self, combined):
956+
"""Reindexing is needed in two cases:
957+
1. Our index contained empty groups (e.g., from a resampling or binning). If we
952958
reduced on that dimension, we want to restore the full index.
959+
960+
2. We use a MultiIndex for multi-variable GroupBy.
961+
The MultiIndex stores each level's labels in sorted order
962+
which are then assigned on unstacking. So we need to restore
963+
the correct order here.
953964
"""
954965
has_missing_groups = (
955966
self.encoded.unique_coord.size != self.encoded.full_index.size
956967
)
957968
indexers = {}
958969
for grouper in self.groupers:
959-
if has_missing_groups and grouper.name in combined._indexes:
970+
index = combined._indexes.get(grouper.name, None)
971+
if has_missing_groups and index is not None:
960972
indexers[grouper.name] = grouper.full_index
973+
elif len(self.groupers) > 1:
974+
if not isinstance(
975+
grouper.full_index, pd.RangeIndex
976+
) and not index.index.equals(grouper.full_index):
977+
indexers[grouper.name] = grouper.full_index
961978
if indexers:
962979
combined = combined.reindex(**indexers)
963980
return combined
@@ -1597,7 +1614,7 @@ def _combine(self, applied, shortcut=False):
15971614
if dim not in applied_example.dims:
15981615
combined = combined.assign_coords(self.encoded.coords)
15991616
combined = self._maybe_unstack(combined)
1600-
combined = self._maybe_restore_empty_groups(combined)
1617+
combined = self._maybe_reindex(combined)
16011618
return combined
16021619

16031620
def reduce(
@@ -1753,7 +1770,7 @@ def _combine(self, applied):
17531770
if dim not in applied_example.dims:
17541771
combined = combined.assign_coords(self.encoded.coords)
17551772
combined = self._maybe_unstack(combined)
1756-
combined = self._maybe_restore_empty_groups(combined)
1773+
combined = self._maybe_reindex(combined)
17571774
return combined
17581775

17591776
def reduce(

xarray/groupers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
536536
counts = grouped.count()
537537
# This way we generate codes for the final output index: full_index.
538538
# So for _flox_reduce we avoid one reindex and copy by avoiding
539-
# _maybe_restore_empty_groups
539+
# _maybe_reindex
540540
codes = np.repeat(np.arange(len(first_items)), counts)
541541
return first_items, codes
542542

xarray/tests/test_groupby.py

+38-9
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_multi_index_groupby_sum() -> None:
158158

159159

160160
@requires_pandas_ge_2_2
161-
def test_multi_index_propagation():
161+
def test_multi_index_propagation() -> None:
162162
# regression test for GH9648
163163
times = pd.date_range("2023-01-01", periods=4)
164164
locations = ["A", "B"]
@@ -2295,7 +2295,7 @@ def test_resample_origin(self) -> None:
22952295
times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10)
22962296
array = DataArray(np.arange(10), [("time", times)])
22972297

2298-
origin = "start"
2298+
origin: Literal["start"] = "start"
22992299
actual = array.resample(time="24h", origin=origin).mean()
23002300
expected = DataArray(array.to_series().resample("24h", origin=origin).mean())
23012301
assert_identical(expected, actual)
@@ -2700,7 +2700,7 @@ def test_default_flox_method() -> None:
27002700

27012701
@requires_cftime
27022702
@pytest.mark.filterwarnings("ignore")
2703-
def test_cftime_resample_gh_9108():
2703+
def test_cftime_resample_gh_9108() -> None:
27042704
import cftime
27052705

27062706
ds = Dataset(
@@ -3050,7 +3050,7 @@ def test_gappy_resample_reductions(reduction):
30503050
assert_identical(expected, actual)
30513051

30523052

3053-
def test_groupby_transpose():
3053+
def test_groupby_transpose() -> None:
30543054
# GH5361
30553055
data = xr.DataArray(
30563056
np.random.randn(4, 2),
@@ -3110,7 +3110,7 @@ def test_lazy_grouping(grouper, expect_index):
31103110

31113111

31123112
@requires_dask
3113-
def test_lazy_grouping_errors():
3113+
def test_lazy_grouping_errors() -> None:
31143114
import dask.array
31153115

31163116
data = DataArray(
@@ -3136,15 +3136,15 @@ def test_lazy_grouping_errors():
31363136

31373137

31383138
@requires_dask
3139-
def test_lazy_int_bins_error():
3139+
def test_lazy_int_bins_error() -> None:
31403140
import dask.array
31413141

31423142
with pytest.raises(ValueError, match="Bin edges must be provided"):
31433143
with raise_if_dask_computes():
31443144
_ = BinGrouper(bins=4).factorize(DataArray(dask.array.arange(3)))
31453145

31463146

3147-
def test_time_grouping_seasons_specified():
3147+
def test_time_grouping_seasons_specified() -> None:
31483148
time = xr.date_range("2001-01-01", "2002-01-01", freq="D")
31493149
ds = xr.Dataset({"foo": np.arange(time.size)}, coords={"time": ("time", time)})
31503150
labels = ["DJF", "MAM", "JJA", "SON"]
@@ -3153,7 +3153,36 @@ def test_time_grouping_seasons_specified():
31533153
assert_identical(actual, expected.reindex(season=labels))
31543154

31553155

3156-
def test_groupby_multiple_bin_grouper_missing_groups():
3156+
def test_multiple_grouper_unsorted_order() -> None:
3157+
time = xr.date_range("2001-01-01", "2003-01-01", freq="MS")
3158+
ds = xr.Dataset({"foo": np.arange(time.size)}, coords={"time": ("time", time)})
3159+
labels = ["DJF", "MAM", "JJA", "SON"]
3160+
actual = ds.groupby(
3161+
{
3162+
"time.season": UniqueGrouper(labels=labels),
3163+
"time.year": UniqueGrouper(labels=[2002, 2001]),
3164+
}
3165+
).sum()
3166+
expected = (
3167+
ds.groupby({"time.season": UniqueGrouper(), "time.year": UniqueGrouper()})
3168+
.sum()
3169+
.reindex(season=labels, year=[2002, 2001])
3170+
)
3171+
assert_identical(actual, expected.reindex(season=labels))
3172+
3173+
b = xr.DataArray(
3174+
np.random.default_rng(0).random((2, 3, 4)),
3175+
coords={"x": [0, 1], "y": [0, 1, 2]},
3176+
dims=["x", "y", "z"],
3177+
)
3178+
actual2 = b.groupby(
3179+
x=UniqueGrouper(labels=[1, 0]), y=UniqueGrouper(labels=[2, 0, 1])
3180+
).sum()
3181+
expected2 = b.reindex(x=[1, 0], y=[2, 0, 1]).transpose("z", ...)
3182+
assert_identical(actual2, expected2)
3183+
3184+
3185+
def test_groupby_multiple_bin_grouper_missing_groups() -> None:
31573186
from numpy import nan
31583187

31593188
ds = xr.Dataset(
@@ -3230,7 +3259,7 @@ def test_shuffle_by(chunks, expected_chunks):
32303259

32313260

32323261
@requires_dask
3233-
def test_groupby_dask_eager_load_warnings():
3262+
def test_groupby_dask_eager_load_warnings() -> None:
32343263
ds = xr.Dataset(
32353264
{"foo": (("z"), np.arange(12))},
32363265
coords={"x": ("z", np.arange(12)), "y": ("z", np.arange(12))},

0 commit comments

Comments
 (0)