Skip to content

Commit 4f6164f

Browse files
authored
Support first, last with datetime, timedelta (#402)
1 parent 672be8c commit 4f6164f

File tree

5 files changed

+73
-27
lines changed

5 files changed

+73
-27
lines changed

flox/aggregate_numbagg.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
"nanmean": {np.int_: np.float64},
3131
"nanvar": {np.int_: np.float64},
3232
"nanstd": {np.int_: np.float64},
33+
"nanfirst": {np.datetime64: np.int64, np.timedelta64: np.int64},
34+
"nanlast": {np.datetime64: np.int64, np.timedelta64: np.int64},
3335
}
3436

3537

@@ -51,7 +53,7 @@ def _numbagg_wrapper(
5153
if cast_to:
5254
for from_, to_ in cast_to.items():
5355
if np.issubdtype(array.dtype, from_):
54-
array = array.astype(to_)
56+
array = array.astype(to_, copy=False)
5557

5658
func_ = getattr(numbagg.grouped, f"group_{func}")
5759

flox/core.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
)
4646
from .cache import memoize
4747
from .xrutils import (
48+
_contains_cftime_datetimes,
49+
_datetime_nanmin,
50+
_to_pytimedelta,
51+
datetime_to_numeric,
4852
is_chunked_array,
4953
is_duck_array,
5054
is_duck_cubed_array,
@@ -2473,7 +2477,8 @@ def groupby_reduce(
24732477
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
24742478
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
24752479

2476-
if _is_first_last_reduction(func):
2480+
is_first_last = _is_first_last_reduction(func)
2481+
if is_first_last:
24772482
if has_dask and nax != 1:
24782483
raise ValueError(
24792484
"For dask arrays: first, last, nanfirst, nanlast reductions are "
@@ -2486,6 +2491,24 @@ def groupby_reduce(
24862491
"along a single axis or when reducing across all dimensions of `by`."
24872492
)
24882493

2494+
# Flox's count works with non-numeric and its faster than converting.
2495+
is_npdatetime = array.dtype.kind in "Mm"
2496+
is_cftime = _contains_cftime_datetimes(array)
2497+
requires_numeric = (
2498+
(func not in ["count", "any", "all"] and not is_first_last)
2499+
or (func == "count" and engine != "flox")
2500+
or (is_first_last and is_cftime)
2501+
)
2502+
if requires_numeric:
2503+
if is_npdatetime:
2504+
offset = _datetime_nanmin(array)
2505+
# xarray always uses np.datetime64[ns] for np.datetime64 data
2506+
dtype = "timedelta64[ns]"
2507+
array = datetime_to_numeric(array, offset)
2508+
elif is_cftime:
2509+
offset = array.min()
2510+
array = datetime_to_numeric(array, offset, datetime_unit="us")
2511+
24892512
if nax == 1 and by_.ndim > 1 and expected_ is None:
24902513
# When we reduce along all axes, we are guaranteed to see all
24912514
# groups in the final combine stage, so everything works.
@@ -2671,6 +2694,14 @@ def groupby_reduce(
26712694

26722695
if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)):
26732696
result = result.astype(bool)
2697+
2698+
# Output of count has an int dtype.
2699+
if requires_numeric and func != "count":
2700+
if is_npdatetime:
2701+
return result.astype(dtype) + offset
2702+
elif is_cftime:
2703+
return _to_pytimedelta(result, unit="us") + offset
2704+
26742705
return (result, *groups)
26752706

26762707

flox/xarray.py

-25
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pandas as pd
88
import xarray as xr
99
from packaging.version import Version
10-
from xarray.core.duck_array_ops import _datetime_nanmin
1110

1211
from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
1312
from .core import (
@@ -18,7 +17,6 @@
1817
)
1918
from .core import rechunk_for_blockwise as rechunk_array_for_blockwise
2019
from .core import rechunk_for_cohorts as rechunk_array_for_cohorts
21-
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric
2220

2321
if TYPE_CHECKING:
2422
from xarray.core.types import T_DataArray, T_Dataset
@@ -366,22 +364,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
366364
if "nan" not in func and func not in ["all", "any", "count"]:
367365
func = f"nan{func}"
368366

369-
# Flox's count works with non-numeric and its faster than converting.
370-
requires_numeric = func not in ["count", "any", "all"] or (
371-
func == "count" and kwargs["engine"] != "flox"
372-
)
373-
if requires_numeric:
374-
is_npdatetime = array.dtype.kind in "Mm"
375-
is_cftime = _contains_cftime_datetimes(array)
376-
if is_npdatetime:
377-
offset = _datetime_nanmin(array)
378-
# xarray always uses np.datetime64[ns] for np.datetime64 data
379-
dtype = "timedelta64[ns]"
380-
array = datetime_to_numeric(array, offset)
381-
elif is_cftime:
382-
offset = array.min()
383-
array = datetime_to_numeric(array, offset, datetime_unit="us")
384-
385367
result, *groups = groupby_reduce(array, *by, func=func, **kwargs)
386368

387369
# Transpose the new quantile dimension to the end. This is ugly.
@@ -395,13 +377,6 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
395377
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
396378
result = np.moveaxis(result, 0, -1)
397379

398-
# Output of count has an int dtype.
399-
if requires_numeric and func != "count":
400-
if is_npdatetime:
401-
return result.astype(dtype) + offset
402-
elif is_cftime:
403-
return _to_pytimedelta(result, unit="us") + offset
404-
405380
return result
406381

407382
# These data variables do not have any of the core dimension,

flox/xrutils.py

+22
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,28 @@ def _contains_cftime_datetimes(array) -> bool:
345345
return False
346346

347347

348+
def _datetime_nanmin(array):
349+
"""nanmin() function for datetime64.
350+
351+
Caveats that this function deals with:
352+
353+
- In numpy < 1.18, min() on datetime64 incorrectly ignores NaT
354+
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
355+
- dask min() does not work on datetime64 (all versions at the moment of writing)
356+
"""
357+
from .xrdtypes import is_datetime_like
358+
359+
dtype = array.dtype
360+
assert is_datetime_like(dtype)
361+
# (NaT).astype(float) does not produce NaN...
362+
array = np.where(pd.isnull(array), np.nan, array.astype(float))
363+
array = min(array, skipna=True)
364+
if isinstance(array, float):
365+
array = np.array(array)
366+
# ...but (NaN).astype("M8") does produce NaT
367+
return array.astype(dtype)
368+
369+
348370
def _select_along_axis(values, idx, axis):
349371
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
350372
sl = other_ind[:axis] + (idx,) + other_ind[axis:]

tests/test_core.py

+16
Original file line numberDiff line numberDiff line change
@@ -2006,3 +2006,19 @@ def test_blockwise_avoid_rechunk():
20062006
actual, groups = groupby_reduce(array, by, func="first")
20072007
assert_equal(groups, ["", "0", "1"])
20082008
assert_equal(actual, np.array([0, 0, 0], dtype=np.int64))
2009+
2010+
2011+
@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
2012+
def test_datetime_timedelta_first_last(engine, func):
2013+
import flox
2014+
2015+
idx = 0 if "first" in func else -1
2016+
2017+
dt = pd.date_range("2001-01-01", freq="d", periods=5).values
2018+
by = np.ones(dt.shape, dtype=int)
2019+
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
2020+
assert_equal(actual, dt[[idx]])
2021+
2022+
dt = dt - dt[0]
2023+
actual, _ = flox.groupby_reduce(dt, by, func=func, engine=engine)
2024+
assert_equal(actual, dt[[idx]])

0 commit comments

Comments
 (0)