From 8dcaba5c33294d4eadabc90e9b7065bd71f6800c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 19 Oct 2024 22:29:20 -0600 Subject: [PATCH 1/4] flox: Properly propagate multiindex Closes #9648 --- doc/whats-new.rst | 2 +- xarray/core/coordinates.py | 11 ++++++++++ xarray/core/groupby.py | 40 +++++++++++++++--------------------- xarray/groupers.py | 13 +----------- xarray/tests/test_groupby.py | 19 +++++++++++++++++ 5 files changed, 49 insertions(+), 36 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f1b6b4fe061..3938d44bcb1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -63,7 +63,7 @@ Bug fixes the non-missing times could in theory be encoded with integers (:issue:`9488`, :pull:`9497`). By `Spencer Clark `_. -- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`). +- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`, :issue:`9648`). By `Deepak Cherian `_. - Fix the safe_chunks validation option on the to_zarr method (:issue:`5511`, :pull:`9559`). By `Joseph Nowak diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 91ef9b6ccad..c4a082f22b7 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1116,3 +1116,14 @@ def create_coords_with_default_indexes( new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes) return new_coords + + +def _coordinates_from_variable(variable: Variable) -> Coordinates: + from xarray.core.indexes import create_default_index_implicit + + (name,) = variable.dims + new_index, index_vars = create_default_index_implicit(variable) + indexes = {k: new_index for k in index_vars} + new_vars = new_index.create_variables() + new_vars[name].attrs = variable.attrs + return Coordinates(new_vars, indexes) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b09d7cf852c..313ceb37429 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -21,13 +21,13 @@ from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce from xarray.core.concat import concat -from xarray.core.coordinates import Coordinates +from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( - PandasIndex, PandasMultiIndex, filter_indexes_from_coords, ) +from xarray.core.merge import merge_coords from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import ( Dims, @@ -851,7 +851,6 @@ def _flox_reduce( from flox.xarray import xarray_reduce from xarray.core.dataset import Dataset - from xarray.groupers import BinGrouper obj = self._original_obj variables = ( @@ -901,13 +900,6 @@ def _flox_reduce( # set explicitly to avoid unnecessarily accumulating count kwargs["min_count"] = 0 - unindexed_dims: tuple[Hashable, ...] = tuple( - grouper.name - for grouper in self.groupers - if isinstance(grouper.group, _DummyGroup) - and not isinstance(grouper.grouper, BinGrouper) - ) - parsed_dim: tuple[Hashable, ...] if isinstance(dim, str): parsed_dim = (dim,) @@ -963,26 +955,28 @@ def _flox_reduce( # we did end up reducing over dimension(s) that are # in the grouped variable group_dims = set(grouper.group.dims) - new_coords = {} + new_coords = [] if group_dims.issubset(set(parsed_dim)): - new_indexes = {} for grouper in self.groupers: output_index = grouper.full_index if isinstance(output_index, pd.RangeIndex): + # flox always assigns an index so we must drop it here if we don't need it. + result = result.drop_vars(grouper.name) continue - name = grouper.name - new_coords[name] = IndexVariable( - dims=name, data=np.array(output_index), attrs=grouper.codes.attrs + new_coords.append( + # Using IndexVariable here ensures we reconstruct PandasMultiIndex with + # all associated levels properly. + _coordinates_from_variable( + IndexVariable( + dims=grouper.name, + data=output_index, + attrs=grouper.codes.attrs, + ) + ) ) - index_cls = ( - PandasIndex - if not isinstance(output_index, pd.MultiIndex) - else PandasMultiIndex - ) - new_indexes[name] = index_cls(output_index, dim=name) result = result.assign_coords( - Coordinates(new_coords, new_indexes) - ).drop_vars(unindexed_dims) + Coordinates._construct_direct(*merge_coords(new_coords)) + ) # broadcast any non-dim coord variables that don't # share all dimensions with the grouper diff --git a/xarray/groupers.py b/xarray/groupers.py index e4cb884e6de..996f86317b9 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -16,7 +16,7 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.core import duck_array_ops -from xarray.core.coordinates import Coordinates +from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.dataarray import DataArray from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index @@ -42,17 +42,6 @@ RESAMPLE_DIM = "__resample_dim__" -def _coordinates_from_variable(variable: Variable) -> Coordinates: - from xarray.core.indexes import create_default_index_implicit - - (name,) = variable.dims - new_index, index_vars = create_default_index_implicit(variable) - indexes = {k: new_index for k in index_vars} - new_vars = new_index.create_variables() - new_vars[name].attrs = variable.attrs - return Coordinates(new_vars, indexes) - - @dataclass(init=False) class EncodedGroups: """ diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index dc869cc3a34..164818db620 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -145,6 +145,25 @@ def test_multi_index_groupby_sum() -> None: assert_equal(expected, actual) +def test_multi_index_propagation(): + # regression test for GH9648 + times = pd.date_range("2023-01-01", periods=4) + locations = ["A", "B"] + data = [[0.5, 0.7], [0.6, 0.5], [0.4, 0.6], [0.4, 0.9]] + + da = xr.DataArray( + data, dims=["time", "location"], coords={"time": times, "location": locations} + ) + da = da.stack(multiindex=["time", "location"]) + grouped = da.groupby("multiindex") + + with xr.set_options(use_flox=True): + actual = grouped.sum() + with xr.set_options(use_flox=False): + expected = grouped.first() + assert_identical(actual, expected) + + def test_groupby_da_datetime() -> None: # test groupby with a DataArray of dtype datetime for GH1132 # create test data From 1c0b39ed9d63ac92600f4e85f1853df9ace01d38 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 21 Oct 2024 07:06:24 -0600 Subject: [PATCH 2/4] skip test on old pandas --- xarray/tests/test_groupby.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 164818db620..3c321166619 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -35,6 +35,7 @@ requires_dask, requires_flox, requires_flox_0_9_12, + requires_pandas_ge_2_2, requires_scipy, ) @@ -145,6 +146,7 @@ def test_multi_index_groupby_sum() -> None: assert_equal(expected, actual) +@requires_pandas_ge_2_2 def test_multi_index_propagation(): # regression test for GH9648 times = pd.date_range("2023-01-01", periods=4) From 667785a81924ebc16106a9468be1f16d3d104dbe Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 21 Oct 2024 07:08:21 -0600 Subject: [PATCH 3/4] small optimization --- xarray/core/groupby.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 313ceb37429..5536c5d2e26 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -956,12 +956,13 @@ def _flox_reduce( # in the grouped variable group_dims = set(grouper.group.dims) new_coords = [] + to_drop = [] if group_dims.issubset(set(parsed_dim)): for grouper in self.groupers: output_index = grouper.full_index if isinstance(output_index, pd.RangeIndex): # flox always assigns an index so we must drop it here if we don't need it. - result = result.drop_vars(grouper.name) + to_drop.append(grouper.name) continue new_coords.append( # Using IndexVariable here ensures we reconstruct PandasMultiIndex with @@ -976,7 +977,7 @@ def _flox_reduce( ) result = result.assign_coords( Coordinates._construct_direct(*merge_coords(new_coords)) - ) + ).drop_vars(to_drop) # broadcast any non-dim coord variables that don't # share all dimensions with the grouper From 42a236f6166ae4c6721c5b23400602e4cabf3111 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 21 Oct 2024 07:12:15 -0600 Subject: [PATCH 4/4] fix --- xarray/tests/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index bd7ec6297b9..a0ac8d51f95 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -135,7 +135,7 @@ def _importorskip( has_pint, requires_pint = _importorskip("pint") has_numexpr, requires_numexpr = _importorskip("numexpr") has_flox, requires_flox = _importorskip("flox") -has_pandas_ge_2_2, __ = _importorskip("pandas", "2.2") +has_pandas_ge_2_2, requires_pandas_ge_2_2 = _importorskip("pandas", "2.2") has_pandas_3, requires_pandas_3 = _importorskip("pandas", "3.0.0.dev0")