Skip to content

Commit b0cabf3

Browse files
authored
Fix direct reductions of Xarray objects (#339)
* Fix direct reductions of Xarray objects Closes pydata/xarray#8819 * Fix doctest
1 parent 41372e0 commit b0cabf3

File tree

4 files changed

+74
-38
lines changed

4 files changed

+74
-38
lines changed

flox/xarray.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,10 @@ def xarray_reduce(
201201
>>> da = da = xr.ones_like(labels)
202202
>>> # Sum all values in da that matches the elements in the group index:
203203
>>> xarray_reduce(da, labels, func="sum")
204-
<xarray.DataArray 'label' (label: 4)>
204+
<xarray.DataArray 'label' (label: 4)> Size: 32B
205205
array([3, 2, 2, 2])
206206
Coordinates:
207-
* label (label) int64 0 1 2 3
207+
* label (label) int64 32B 0 1 2 3
208208
"""
209209

210210
if skipna is not None and isinstance(func, Aggregation):
@@ -303,14 +303,16 @@ def xarray_reduce(
303303
# reducing along a dimension along which groups do not vary
304304
# This is really just a normal reduction.
305305
# This is not right when binning so we exclude.
306-
if isinstance(func, str):
307-
dsfunc = func[3:] if skipna else func
308-
else:
306+
if isinstance(func, str) and func.startswith("nan"):
307+
raise ValueError(f"Specify func={func[3:]}, skipna=True instead of func={func}")
308+
elif isinstance(func, Aggregation):
309309
raise NotImplementedError(
310310
"func must be a string when reducing along a dimension not present in `by`"
311311
)
312-
# TODO: skipna needs test
313-
result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna)
312+
# skipna is not supported for all reductions
313+
# https://github.com/pydata/xarray/issues/8819
314+
kwargs = {"skipna": skipna} if skipna is not None else {}
315+
result = getattr(ds_broad, func)(dim=dim_tuple, **kwargs)
314316
if isinstance(obj, xr.DataArray):
315317
return obj._from_temp_dataset(result)
316318
else:

tests/__init__.py

+32
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,35 @@ def assert_equal_tuple(a, b):
124124
np.testing.assert_array_equal(a_, b_)
125125
else:
126126
assert a_ == b_
127+
128+
129+
SCIPY_STATS_FUNCS = ("mode", "nanmode")
130+
BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS
131+
ALL_FUNCS = (
132+
"sum",
133+
"nansum",
134+
"argmax",
135+
"nanfirst",
136+
"nanargmax",
137+
"prod",
138+
"nanprod",
139+
"mean",
140+
"nanmean",
141+
"var",
142+
"nanvar",
143+
"std",
144+
"nanstd",
145+
"max",
146+
"nanmax",
147+
"min",
148+
"nanmin",
149+
"argmin",
150+
"nanargmin",
151+
"any",
152+
"all",
153+
"nanlast",
154+
"median",
155+
"nanmedian",
156+
"quantile",
157+
"nanquantile",
158+
) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS)

tests/test_core.py

+3-31
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@
3131
)
3232

3333
from . import (
34+
ALL_FUNCS,
35+
BLOCKWISE_FUNCS,
36+
SCIPY_STATS_FUNCS,
3437
assert_equal,
3538
assert_equal_tuple,
3639
has_dask,
3740
raise_if_dask_computes,
3841
requires_dask,
39-
requires_scipy,
4042
)
4143

4244
logger = logging.getLogger("flox")
@@ -60,36 +62,6 @@ def dask_array_ones(*args):
6062

6163

6264
DEFAULT_QUANTILE = 0.9
63-
SCIPY_STATS_FUNCS = ("mode", "nanmode")
64-
BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS
65-
ALL_FUNCS = (
66-
"sum",
67-
"nansum",
68-
"argmax",
69-
"nanfirst",
70-
"nanargmax",
71-
"prod",
72-
"nanprod",
73-
"mean",
74-
"nanmean",
75-
"var",
76-
"nanvar",
77-
"std",
78-
"nanstd",
79-
"max",
80-
"nanmax",
81-
"min",
82-
"nanmin",
83-
"argmin",
84-
"nanargmin",
85-
"any",
86-
"all",
87-
"nanlast",
88-
"median",
89-
"nanmedian",
90-
"quantile",
91-
"nanquantile",
92-
) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS)
9365

9466
if TYPE_CHECKING:
9567
from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method

tests/test_xarray.py

+30
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from flox.xarray import rechunk_for_blockwise, xarray_reduce
1010

1111
from . import (
12+
ALL_FUNCS,
1213
assert_equal,
1314
has_dask,
1415
raise_if_dask_computes,
@@ -710,3 +711,32 @@ def test_multiple_quantiles(q, chunk, by_ndim, skipna):
710711
with xr.set_options(use_flox=False):
711712
expected = da.groupby(by).quantile(q, skipna=skipna)
712713
xr.testing.assert_allclose(expected, actual)
714+
715+
716+
@pytest.mark.parametrize("func", ALL_FUNCS)
717+
def test_direct_reduction(func):
718+
if "arg" in func or "mode" in func:
719+
pytest.skip()
720+
# regression test for https://github.com/pydata/xarray/issues/8819
721+
rand = np.random.choice([True, False], size=(2, 3))
722+
if func not in ["any", "all"]:
723+
rand = rand.astype(float)
724+
725+
if "nan" in func:
726+
func = func[3:]
727+
kwargs = {"skipna": True}
728+
else:
729+
kwargs = {}
730+
731+
if "first" not in func and "last" not in func:
732+
kwargs["dim"] = "y"
733+
734+
if "quantile" in func:
735+
kwargs["q"] = 0.9
736+
737+
data = xr.DataArray(rand, dims=("x", "y"), coords={"x": [10, 20], "y": [0, 1, 2]})
738+
with xr.set_options(use_flox=True):
739+
actual = getattr(data.groupby("x", squeeze=False), func)(**kwargs)
740+
with xr.set_options(use_flox=False):
741+
expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs)
742+
xr.testing.assert_identical(expected, actual)

0 commit comments

Comments
 (0)