Skip to content

REF: collect boilerplate in _datetimelike_compat #37723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 10, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 30 additions & 23 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,32 @@ def _wrap_results(result, dtype: np.dtype, fill_value=None):
return result


def _datetimelike_compat(func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI i don't think this needs to be private, but nbd.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also as a follow up need to ensure decorator preserves types #33455

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do follow-up to annotate new_func. can we start typing mask as np.ndarray[bool] or something like that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The important thing is that the decoratored function signature is preserved. It may not be so important, in the grand scheme of things, to type check this function. though can do as well.

"""
If we have datetime64 or timedelta64 values, ensure we have a correct
mask before calling the wrapped function, then cast back afterwards.
"""

@functools.wraps(func)
def new_func(values, *, axis=None, skipna=True, mask=None, **kwargs):
orig_values = values

datetimelike = values.dtype.kind in ["m", "M"]
if datetimelike and mask is None:
mask = isna(values)

result = func(values, axis=axis, skipna=skipna, mask=mask, **kwargs)

if datetimelike:
result = _wrap_results(result, orig_values.dtype, fill_value=iNaT)
if not skipna:
result = _mask_datetimelike_result(result, axis, mask, orig_values)

return result

return new_func


def _na_for_min_count(
values: np.ndarray, axis: Optional[int]
) -> Union[Scalar, np.ndarray]:
Expand Down Expand Up @@ -480,6 +506,7 @@ def nanall(


@disallow("M8")
@_datetimelike_compat
def nansum(
values: np.ndarray,
*,
Expand Down Expand Up @@ -511,25 +538,18 @@ def nansum(
>>> nanops.nansum(s)
3.0
"""
orig_values = values

values, mask, dtype, dtype_max, _ = _get_values(
values, skipna, fill_value=0, mask=mask
)
dtype_sum = dtype_max
datetimelike = False
if is_float_dtype(dtype):
dtype_sum = dtype
elif is_timedelta64_dtype(dtype):
datetimelike = True
dtype_sum = np.float64

the_sum = values.sum(axis, dtype=dtype_sum)
the_sum = _maybe_null_out(the_sum, axis, mask, values.shape, min_count=min_count)

the_sum = _wrap_results(the_sum, dtype)
if datetimelike and not skipna:
the_sum = _mask_datetimelike_result(the_sum, axis, mask, orig_values)
return the_sum


Expand All @@ -552,6 +572,7 @@ def _mask_datetimelike_result(

@disallow(PeriodDtype)
@bottleneck_switch()
@_datetimelike_compat
def nanmean(
values: np.ndarray,
*,
Expand Down Expand Up @@ -583,18 +604,14 @@ def nanmean(
>>> nanops.nanmean(s)
1.5
"""
orig_values = values

values, mask, dtype, dtype_max, _ = _get_values(
values, skipna, fill_value=0, mask=mask
)
dtype_sum = dtype_max
dtype_count = np.float64

# not using needs_i8_conversion because that includes period
datetimelike = False
if dtype.kind in ["m", "M"]:
datetimelike = True
dtype_sum = np.float64
elif is_integer_dtype(dtype):
dtype_sum = np.float64
Expand All @@ -616,9 +633,6 @@ def nanmean(
else:
the_mean = the_sum / count if count > 0 else np.nan

the_mean = _wrap_results(the_mean, dtype)
if datetimelike and not skipna:
the_mean = _mask_datetimelike_result(the_mean, axis, mask, orig_values)
return the_mean


Expand Down Expand Up @@ -875,7 +889,7 @@ def nanvar(values, *, axis=None, skipna=True, ddof=1, mask=None):
# precision as the original values array.
if is_float_dtype(dtype):
result = result.astype(dtype)
return _wrap_results(result, values.dtype)
return result


@disallow("M8", "m8")
Expand Down Expand Up @@ -930,6 +944,7 @@ def nansem(

def _nanminmax(meth, fill_value_typ):
@bottleneck_switch(name="nan" + meth)
@_datetimelike_compat
def reduction(
values: np.ndarray,
*,
Expand All @@ -938,13 +953,10 @@ def reduction(
mask: Optional[np.ndarray] = None,
) -> Dtype:

orig_values = values
values, mask, dtype, dtype_max, fill_value = _get_values(
values, skipna, fill_value_typ=fill_value_typ, mask=mask
)

datetimelike = orig_values.dtype.kind in ["m", "M"]

if (axis is not None and values.shape[axis] == 0) or values.size == 0:
try:
result = getattr(values, meth)(axis, dtype=dtype_max)
Expand All @@ -954,12 +966,7 @@ def reduction(
else:
result = getattr(values, meth)(axis)

result = _wrap_results(result, dtype, fill_value)
result = _maybe_null_out(result, axis, mask, values.shape)

if datetimelike and not skipna:
result = _mask_datetimelike_result(result, axis, mask, orig_values)

return result

return reduction
Expand Down