Skip to content

Commit 400a06c

Browse files
committed
flox: Properly propagate multiindex
Closes #9648
1 parent b9780e7 commit 400a06c

File tree

5 files changed

+49
-37
lines changed

5 files changed

+49
-37
lines changed

doc/whats-new.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ Deprecations
5050

5151
Bug fixes
5252
~~~~~~~~~
53-
5453
- Make illegal path-like variable names when constructing a DataTree from a Dataset
5554
(:issue:`9339`, :pull:`9378`)
5655
By `Etienne Schalk <https://github.com/etienneschalk>`_.
@@ -63,7 +62,7 @@ Bug fixes
6362
the non-missing times could in theory be encoded with integers
6463
(:issue:`9488`, :pull:`9497`). By `Spencer Clark
6564
<https://github.com/spencerkclark>`_.
66-
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
65+
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`, :issue:`9648`).
6766
By `Deepak Cherian <https://github.com/dcherian>`_.
6867
- Fix the safe_chunks validation option on the to_zarr method
6968
(:issue:`5511`, :pull:`9559`). By `Joseph Nowak

xarray/core/coordinates.py

+11
Original file line numberDiff line numberDiff line change
@@ -1116,3 +1116,14 @@ def create_coords_with_default_indexes(
11161116
new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes)
11171117

11181118
return new_coords
1119+
1120+
1121+
def _coordinates_from_variable(variable: Variable) -> Coordinates:
1122+
from xarray.core.indexes import create_default_index_implicit
1123+
1124+
(name,) = variable.dims
1125+
new_index, index_vars = create_default_index_implicit(variable)
1126+
indexes = {k: new_index for k in index_vars}
1127+
new_vars = new_index.create_variables()
1128+
new_vars[name].attrs = variable.attrs
1129+
return Coordinates(new_vars, indexes)

xarray/core/groupby.py

+17-23
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
2222
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
2323
from xarray.core.concat import concat
24-
from xarray.core.coordinates import Coordinates
24+
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
2525
from xarray.core.formatting import format_array_flat
2626
from xarray.core.indexes import (
27-
PandasIndex,
2827
PandasMultiIndex,
2928
filter_indexes_from_coords,
3029
)
30+
from xarray.core.merge import merge_coords
3131
from xarray.core.options import OPTIONS, _get_keep_attrs
3232
from xarray.core.types import (
3333
Dims,
@@ -851,7 +851,6 @@ def _flox_reduce(
851851
from flox.xarray import xarray_reduce
852852

853853
from xarray.core.dataset import Dataset
854-
from xarray.groupers import BinGrouper
855854

856855
obj = self._original_obj
857856
variables = (
@@ -901,13 +900,6 @@ def _flox_reduce(
901900
# set explicitly to avoid unnecessarily accumulating count
902901
kwargs["min_count"] = 0
903902

904-
unindexed_dims: tuple[Hashable, ...] = tuple(
905-
grouper.name
906-
for grouper in self.groupers
907-
if isinstance(grouper.group, _DummyGroup)
908-
and not isinstance(grouper.grouper, BinGrouper)
909-
)
910-
911903
parsed_dim: tuple[Hashable, ...]
912904
if isinstance(dim, str):
913905
parsed_dim = (dim,)
@@ -963,26 +955,28 @@ def _flox_reduce(
963955
# we did end up reducing over dimension(s) that are
964956
# in the grouped variable
965957
group_dims = set(grouper.group.dims)
966-
new_coords = {}
958+
new_coords = []
967959
if group_dims.issubset(set(parsed_dim)):
968-
new_indexes = {}
969960
for grouper in self.groupers:
970961
output_index = grouper.full_index
971962
if isinstance(output_index, pd.RangeIndex):
963+
# flox always assigns an index so we must drop it here if we don't need it.
964+
result = result.drop_vars(grouper.name)
972965
continue
973-
name = grouper.name
974-
new_coords[name] = IndexVariable(
975-
dims=name, data=np.array(output_index), attrs=grouper.codes.attrs
966+
new_coords.append(
967+
# Using IndexVariable here ensures we reconstruct PandasMultiIndex with
968+
# all associated levels properly.
969+
_coordinates_from_variable(
970+
IndexVariable(
971+
dims=grouper.name,
972+
data=output_index,
973+
attrs=grouper.codes.attrs,
974+
)
975+
)
976976
)
977-
index_cls = (
978-
PandasIndex
979-
if not isinstance(output_index, pd.MultiIndex)
980-
else PandasMultiIndex
981-
)
982-
new_indexes[name] = index_cls(output_index, dim=name)
983977
result = result.assign_coords(
984-
Coordinates(new_coords, new_indexes)
985-
).drop_vars(unindexed_dims)
978+
Coordinates._construct_direct(*merge_coords(new_coords))
979+
)
986980

987981
# broadcast any non-dim coord variables that don't
988982
# share all dimensions with the grouper

xarray/groupers.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
1818
from xarray.core import duck_array_ops
19-
from xarray.core.coordinates import Coordinates
19+
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
2020
from xarray.core.dataarray import DataArray
2121
from xarray.core.groupby import T_Group, _DummyGroup
2222
from xarray.core.indexes import safe_cast_to_index
@@ -42,17 +42,6 @@
4242
RESAMPLE_DIM = "__resample_dim__"
4343

4444

45-
def _coordinates_from_variable(variable: Variable) -> Coordinates:
46-
from xarray.core.indexes import create_default_index_implicit
47-
48-
(name,) = variable.dims
49-
new_index, index_vars = create_default_index_implicit(variable)
50-
indexes = {k: new_index for k in index_vars}
51-
new_vars = new_index.create_variables()
52-
new_vars[name].attrs = variable.attrs
53-
return Coordinates(new_vars, indexes)
54-
55-
5645
@dataclass(init=False)
5746
class EncodedGroups:
5847
"""

xarray/tests/test_groupby.py

+19
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,25 @@ def test_multi_index_groupby_sum() -> None:
145145
assert_equal(expected, actual)
146146

147147

148+
def test_multi_index_propagation():
149+
# regression test for GH9648
150+
times = pd.date_range("2023-01-01", periods=4)
151+
locations = ["A", "B"]
152+
data = [[0.5, 0.7], [0.6, 0.5], [0.4, 0.6], [0.4, 0.9]]
153+
154+
da = xr.DataArray(
155+
data, dims=["time", "location"], coords={"time": times, "location": locations}
156+
)
157+
da = da.stack(multiindex=["time", "location"])
158+
grouped = da.groupby("multiindex")
159+
160+
with xr.set_options(use_flox=True):
161+
actual = grouped.sum()
162+
with xr.set_options(use_flox=False):
163+
expected = grouped.first()
164+
assert_identical(actual, expected)
165+
166+
148167
def test_groupby_da_datetime() -> None:
149168
# test groupby with a DataArray of dtype datetime for GH1132
150169
# create test data

0 commit comments

Comments
 (0)