From b5f341332f05d88c0ec64926f39b18a356e60461 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 9 Nov 2020 12:58:53 -0800 Subject: [PATCH] REF: collect boilerplate in _datetimelike_compat --- pandas/core/nanops.py | 53 ++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index cfb02e5b1e987..d9e51810f1445 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -367,6 +367,32 @@ def _wrap_results(result, dtype: np.dtype, fill_value=None): return result +def _datetimelike_compat(func): + """ + If we have datetime64 or timedelta64 values, ensure we have a correct + mask before calling the wrapped function, then cast back afterwards. + """ + + @functools.wraps(func) + def new_func(values, *, axis=None, skipna=True, mask=None, **kwargs): + orig_values = values + + datetimelike = values.dtype.kind in ["m", "M"] + if datetimelike and mask is None: + mask = isna(values) + + result = func(values, axis=axis, skipna=skipna, mask=mask, **kwargs) + + if datetimelike: + result = _wrap_results(result, orig_values.dtype, fill_value=iNaT) + if not skipna: + result = _mask_datetimelike_result(result, axis, mask, orig_values) + + return result + + return new_func + + def _na_for_min_count( values: np.ndarray, axis: Optional[int] ) -> Union[Scalar, np.ndarray]: @@ -480,6 +506,7 @@ def nanall( @disallow("M8") +@_datetimelike_compat def nansum( values: np.ndarray, *, @@ -511,25 +538,18 @@ def nansum( >>> nanops.nansum(s) 3.0 """ - orig_values = values - values, mask, dtype, dtype_max, _ = _get_values( values, skipna, fill_value=0, mask=mask ) dtype_sum = dtype_max - datetimelike = False if is_float_dtype(dtype): dtype_sum = dtype elif is_timedelta64_dtype(dtype): - datetimelike = True dtype_sum = np.float64 the_sum = values.sum(axis, dtype=dtype_sum) the_sum = _maybe_null_out(the_sum, axis, mask, values.shape, min_count=min_count) - the_sum = _wrap_results(the_sum, dtype) - if datetimelike and not skipna: - the_sum = _mask_datetimelike_result(the_sum, axis, mask, orig_values) return the_sum @@ -552,6 +572,7 @@ def _mask_datetimelike_result( @disallow(PeriodDtype) @bottleneck_switch() +@_datetimelike_compat def nanmean( values: np.ndarray, *, @@ -583,8 +604,6 @@ def nanmean( >>> nanops.nanmean(s) 1.5 """ - orig_values = values - values, mask, dtype, dtype_max, _ = _get_values( values, skipna, fill_value=0, mask=mask ) @@ -592,9 +611,7 @@ def nanmean( dtype_count = np.float64 # not using needs_i8_conversion because that includes period - datetimelike = False if dtype.kind in ["m", "M"]: - datetimelike = True dtype_sum = np.float64 elif is_integer_dtype(dtype): dtype_sum = np.float64 @@ -616,9 +633,6 @@ def nanmean( else: the_mean = the_sum / count if count > 0 else np.nan - the_mean = _wrap_results(the_mean, dtype) - if datetimelike and not skipna: - the_mean = _mask_datetimelike_result(the_mean, axis, mask, orig_values) return the_mean @@ -875,7 +889,7 @@ def nanvar(values, *, axis=None, skipna=True, ddof=1, mask=None): # precision as the original values array. if is_float_dtype(dtype): result = result.astype(dtype) - return _wrap_results(result, values.dtype) + return result @disallow("M8", "m8") @@ -930,6 +944,7 @@ def nansem( def _nanminmax(meth, fill_value_typ): @bottleneck_switch(name="nan" + meth) + @_datetimelike_compat def reduction( values: np.ndarray, *, @@ -938,13 +953,10 @@ def reduction( mask: Optional[np.ndarray] = None, ) -> Dtype: - orig_values = values values, mask, dtype, dtype_max, fill_value = _get_values( values, skipna, fill_value_typ=fill_value_typ, mask=mask ) - datetimelike = orig_values.dtype.kind in ["m", "M"] - if (axis is not None and values.shape[axis] == 0) or values.size == 0: try: result = getattr(values, meth)(axis, dtype=dtype_max) @@ -954,12 +966,7 @@ def reduction( else: result = getattr(values, meth)(axis) - result = _wrap_results(result, dtype, fill_value) result = _maybe_null_out(result, axis, mask, values.shape) - - if datetimelike and not skipna: - result = _mask_datetimelike_result(result, axis, mask, orig_values) - return result return reduction