Skip to content

Commit 15ffa81

Browse files
committed
Fix non-dim coord grouping by multiple variables.
xref pydata/xarray#9372
1 parent 4dbadae commit 15ffa81

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

flox/xarray.py

+9
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,14 @@ def xarray_reduce(
250250
else:
251251
ds = obj._to_temp_dataset()
252252

253+
# These will need to be broadcast/reduced as data_vars
254+
reset_non_dim_coords = [
255+
name
256+
for name in ds._coord_names
257+
if any(dim in ds._variables[name].dims for dim in grouper_dims)
258+
]
259+
ds = ds.reset_coords(reset_non_dim_coords)
260+
253261
try:
254262
from xarray.indexes import PandasMultiIndex
255263
except ImportError:
@@ -475,6 +483,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
475483
if all(d not in ds_broad[var].dims for d in dim_tuple):
476484
actual[var] = ds_broad[var]
477485

486+
actual = actual.set_coords(reset_non_dim_coords)
478487
for newdim in newdims:
479488
actual.coords[newdim.name] = newdim.values if newdim.is_scalar else np.array(newdim.values)
480489

tests/test_xarray.py

+17
Original file line numberDiff line numberDiff line change
@@ -749,3 +749,20 @@ def test_direct_reduction(func):
749749
with xr.set_options(use_flox=False):
750750
expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs)
751751
xr.testing.assert_identical(expected, actual)
752+
753+
754+
def test_non_dim_coords_with_core_dim():
755+
coords = {"a": ("x", [0, 0, 1, 1]), "b": ("y", [0, 0, 1, 1])}
756+
square = xr.DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=["x", "y"])
757+
actual = xarray_reduce(square, "a", "b", func="mean")
758+
expected = xr.DataArray(
759+
np.array([[2.5, 4.5], [10.5, 12.5]]),
760+
dims=("a", "b"),
761+
coords={"a": [0, 1], "b": [0, 1]},
762+
)
763+
xr.testing.assert_identical(actual, expected)
764+
765+
actual = xarray_reduce(square, "x", "y", func="mean")
766+
expected = square.astype(np.float64).copy()
767+
expected["a"], expected["b"] = xr.broadcast(square.a, square.b)
768+
xr.testing.assert_identical(actual, expected)

0 commit comments

Comments
 (0)