Skip to content

Commit bc28eda

Browse files
authored
Add tests for groupby math (#6137)
1 parent 18703ba commit bc28eda

File tree

1 file changed

+69
-14
lines changed

1 file changed

+69
-14
lines changed

xarray/tests/test_groupby.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -663,30 +663,33 @@ def test_groupby_dataset_reduce() -> None:
663663
assert_allclose(expected, actual)
664664

665665

666-
def test_groupby_dataset_math() -> None:
666+
@pytest.mark.parametrize("squeeze", [True, False])
667+
def test_groupby_dataset_math(squeeze) -> None:
667668
def reorder_dims(x):
668669
return x.transpose("dim1", "dim2", "dim3", "time")
669670

670671
ds = create_test_data()
671672
ds["dim1"] = ds["dim1"]
672-
for squeeze in [True, False]:
673-
grouped = ds.groupby("dim1", squeeze=squeeze)
673+
grouped = ds.groupby("dim1", squeeze=squeeze)
674674

675-
expected = reorder_dims(ds + ds.coords["dim1"])
676-
actual = grouped + ds.coords["dim1"]
677-
assert_identical(expected, reorder_dims(actual))
675+
expected = reorder_dims(ds + ds.coords["dim1"])
676+
actual = grouped + ds.coords["dim1"]
677+
assert_identical(expected, reorder_dims(actual))
678678

679-
actual = ds.coords["dim1"] + grouped
680-
assert_identical(expected, reorder_dims(actual))
679+
actual = ds.coords["dim1"] + grouped
680+
assert_identical(expected, reorder_dims(actual))
681681

682-
ds2 = 2 * ds
683-
expected = reorder_dims(ds + ds2)
684-
actual = grouped + ds2
685-
assert_identical(expected, reorder_dims(actual))
682+
ds2 = 2 * ds
683+
expected = reorder_dims(ds + ds2)
684+
actual = grouped + ds2
685+
assert_identical(expected, reorder_dims(actual))
686686

687-
actual = ds2 + grouped
688-
assert_identical(expected, reorder_dims(actual))
687+
actual = ds2 + grouped
688+
assert_identical(expected, reorder_dims(actual))
689689

690+
691+
def test_groupby_math_more() -> None:
692+
ds = create_test_data()
690693
grouped = ds.groupby("numbers")
691694
zeros = DataArray([0, 0, 0, 0], [("numbers", range(4))])
692695
expected = (ds + Variable("dim3", np.zeros(10))).transpose(
@@ -719,6 +722,58 @@ def reorder_dims(x):
719722
ds + ds.groupby("time.month")
720723

721724

725+
@pytest.mark.parametrize("indexed_coord", [True, False])
726+
def test_groupby_bins_math(indexed_coord) -> None:
727+
N = 7
728+
da = DataArray(np.random.random((N, N)), dims=("x", "y"))
729+
if indexed_coord:
730+
da["x"] = np.arange(N)
731+
da["y"] = np.arange(N)
732+
g = da.groupby_bins("x", np.arange(0, N + 1, 3))
733+
mean = g.mean()
734+
expected = da.isel(x=slice(1, None)) - mean.isel(x_bins=("x", [0, 0, 0, 1, 1, 1]))
735+
actual = g - mean
736+
assert_identical(expected, actual)
737+
738+
739+
def test_groupby_math_nD_group() -> None:
740+
N = 40
741+
da = DataArray(
742+
np.random.random((N, N)),
743+
dims=("x", "y"),
744+
coords={
745+
"labels": (
746+
"x",
747+
np.repeat(["a", "b", "c", "d", "e", "f", "g", "h"], repeats=N // 8),
748+
),
749+
},
750+
)
751+
da["labels2d"] = xr.broadcast(da.labels, da)[0]
752+
753+
g = da.groupby("labels2d")
754+
mean = g.mean()
755+
expected = da - mean.sel(labels2d=da.labels2d)
756+
expected["labels"] = expected.labels.broadcast_like(expected.labels2d)
757+
actual = g - mean
758+
assert_identical(expected, actual)
759+
760+
da["num"] = (
761+
"x",
762+
np.repeat([1, 2, 3, 4, 5, 6, 7, 8], repeats=N // 8),
763+
)
764+
da["num2d"] = xr.broadcast(da.num, da)[0]
765+
g = da.groupby_bins("num2d", bins=[0, 4, 6])
766+
mean = g.mean()
767+
idxr = np.digitize(da.num2d, bins=(0, 4, 6), right=True)[:30, :] - 1
768+
expanded_mean = mean.drop("num2d_bins").isel(num2d_bins=(("x", "y"), idxr))
769+
expected = da.isel(x=slice(30)) - expanded_mean
770+
expected["labels"] = expected.labels.broadcast_like(expected.labels2d)
771+
expected["num"] = expected.num.broadcast_like(expected.num2d)
772+
expected["num2d_bins"] = (("x", "y"), mean.num2d_bins.data[idxr])
773+
actual = g - mean
774+
assert_identical(expected, actual)
775+
776+
722777
def test_groupby_dataset_math_virtual() -> None:
723778
ds = Dataset({"x": ("t", [1, 2, 3])}, {"t": pd.date_range("20100101", periods=3)})
724779
grouped = ds.groupby("t.day")

0 commit comments

Comments
 (0)