|
21 | 21 | from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
|
22 | 22 | from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
|
23 | 23 | from xarray.core.concat import concat
|
24 |
| -from xarray.core.coordinates import Coordinates |
| 24 | +from xarray.core.coordinates import Coordinates, _coordinates_from_variable |
25 | 25 | from xarray.core.formatting import format_array_flat
|
26 | 26 | from xarray.core.indexes import (
|
27 |
| - PandasIndex, |
28 | 27 | PandasMultiIndex,
|
29 | 28 | filter_indexes_from_coords,
|
30 | 29 | )
|
| 30 | +from xarray.core.merge import merge_coords |
31 | 31 | from xarray.core.options import OPTIONS, _get_keep_attrs
|
32 | 32 | from xarray.core.types import (
|
33 | 33 | Dims,
|
@@ -851,7 +851,6 @@ def _flox_reduce(
|
851 | 851 | from flox.xarray import xarray_reduce
|
852 | 852 |
|
853 | 853 | from xarray.core.dataset import Dataset
|
854 |
| - from xarray.groupers import BinGrouper |
855 | 854 |
|
856 | 855 | obj = self._original_obj
|
857 | 856 | variables = (
|
@@ -901,13 +900,6 @@ def _flox_reduce(
|
901 | 900 | # set explicitly to avoid unnecessarily accumulating count
|
902 | 901 | kwargs["min_count"] = 0
|
903 | 902 |
|
904 |
| - unindexed_dims: tuple[Hashable, ...] = tuple( |
905 |
| - grouper.name |
906 |
| - for grouper in self.groupers |
907 |
| - if isinstance(grouper.group, _DummyGroup) |
908 |
| - and not isinstance(grouper.grouper, BinGrouper) |
909 |
| - ) |
910 |
| - |
911 | 903 | parsed_dim: tuple[Hashable, ...]
|
912 | 904 | if isinstance(dim, str):
|
913 | 905 | parsed_dim = (dim,)
|
@@ -963,26 +955,28 @@ def _flox_reduce(
|
963 | 955 | # we did end up reducing over dimension(s) that are
|
964 | 956 | # in the grouped variable
|
965 | 957 | group_dims = set(grouper.group.dims)
|
966 |
| - new_coords = {} |
| 958 | + new_coords = [] |
967 | 959 | if group_dims.issubset(set(parsed_dim)):
|
968 |
| - new_indexes = {} |
969 | 960 | for grouper in self.groupers:
|
970 | 961 | output_index = grouper.full_index
|
971 | 962 | if isinstance(output_index, pd.RangeIndex):
|
| 963 | + # flox always assigns an index so we must drop it here if we don't need it. |
| 964 | + result = result.drop_vars(grouper.name) |
972 | 965 | continue
|
973 |
| - name = grouper.name |
974 |
| - new_coords[name] = IndexVariable( |
975 |
| - dims=name, data=np.array(output_index), attrs=grouper.codes.attrs |
| 966 | + new_coords.append( |
| 967 | + # Using IndexVariable here ensures we reconstruct PandasMultiIndex with |
| 968 | + # all associated levels properly. |
| 969 | + _coordinates_from_variable( |
| 970 | + IndexVariable( |
| 971 | + dims=grouper.name, |
| 972 | + data=output_index, |
| 973 | + attrs=grouper.codes.attrs, |
| 974 | + ) |
| 975 | + ) |
976 | 976 | )
|
977 |
| - index_cls = ( |
978 |
| - PandasIndex |
979 |
| - if not isinstance(output_index, pd.MultiIndex) |
980 |
| - else PandasMultiIndex |
981 |
| - ) |
982 |
| - new_indexes[name] = index_cls(output_index, dim=name) |
983 | 977 | result = result.assign_coords(
|
984 |
| - Coordinates(new_coords, new_indexes) |
985 |
| - ).drop_vars(unindexed_dims) |
| 978 | + Coordinates._construct_direct(*merge_coords(new_coords)) |
| 979 | + ) |
986 | 980 |
|
987 | 981 | # broadcast any non-dim coord variables that don't
|
988 | 982 | # share all dimensions with the grouper
|
|
0 commit comments