Skip to content

BUG: make Index.where behavior mirror Index.putmask behavior #39412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 3, 2021
10 changes: 10 additions & 0 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,16 @@ def to_tuples(self, na_tuple=True):

# ---------------------------------------------------------------------

def putmask(self, mask: np.ndarray, value) -> None:
value_left, value_right = self._validate_setitem_value(value)

if isinstance(self._left, np.ndarray):
np.putmask(self._left, mask, value_left)
np.putmask(self._right, mask, value_right)
else:
self._left.putmask(mask, value_left)
self._right.putmask(mask, value_right)

def delete(self: IntervalArrayT, loc) -> IntervalArrayT:
if isinstance(self._left, np.ndarray):
new_left = np.delete(self._left, loc)
Expand Down
53 changes: 32 additions & 21 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from pandas.core.dtypes.cast import (
find_common_type,
infer_dtype_from,
maybe_cast_to_integer_array,
maybe_promote,
validate_numeric_casting,
Expand Down Expand Up @@ -87,7 +88,7 @@
ABCTimedeltaIndex,
)
from pandas.core.dtypes.inference import is_dict_like
from pandas.core.dtypes.missing import array_equivalent, isna
from pandas.core.dtypes.missing import array_equivalent, is_valid_nat_for_dtype, isna

from pandas.core import missing, ops
from pandas.core.accessor import CachedAccessor
Expand All @@ -114,7 +115,7 @@
)

if TYPE_CHECKING:
from pandas import MultiIndex, RangeIndex, Series
from pandas import IntervalIndex, MultiIndex, RangeIndex, Series
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin


Expand Down Expand Up @@ -4317,19 +4318,8 @@ def where(self, cond, other=None):
>>> idx.where(idx.isin(['car', 'train']), 'other')
Index(['car', 'other', 'train', 'other'], dtype='object')
"""
if other is None:
other = self._na_value

values = self.values

try:
self._validate_fill_value(other)
except (ValueError, TypeError):
return self.astype(object).where(cond, other)

values = np.where(cond, values, other)

return Index(values, name=self.name)
cond = np.asarray(cond, dtype=bool)
return self.putmask(~cond, other)

# construction helpers
@final
Expand Down Expand Up @@ -4542,16 +4532,24 @@ def putmask(self, mask, value):
numpy.ndarray.putmask : Changes elements of an array
based on conditional and input values.
"""
values = self._values.copy()
mask = np.asarray(mask, dtype=bool)
if mask.shape != self.shape:
raise ValueError("putmask: mask and data must be the same size")
if not mask.any():
return self.copy()

if value is None:
value = self._na_value
try:
converted = self._validate_fill_value(value)
except (ValueError, TypeError) as err:
if is_object_dtype(self):
raise err

# coerces to object
return self.astype(object).putmask(mask, value)
dtype = self._find_common_type_compat(value)
return self.astype(dtype).putmask(mask, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy=False?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll always be making a copy here since dtype != self.dtype


values = self._values.copy()
np.putmask(values, mask, converted)
return self._shallow_copy(values)

Expand Down Expand Up @@ -5189,18 +5187,31 @@ def _maybe_promote(self, other: Index):

return self, other

def _find_common_type_compat(self, target: Index) -> DtypeObj:
@final
def _find_common_type_compat(self, target) -> DtypeObj:
"""
Implementation of find_common_type that adjusts for Index-specific
special cases.
"""
dtype = find_common_type([self.dtype, target.dtype])
if is_interval_dtype(self.dtype) and is_valid_nat_for_dtype(target, self.dtype):
# e.g. setting NA value into IntervalArray[int64]
self = cast("IntervalIndex", self)
return IntervalDtype(np.float64, closed=self.closed)

target_dtype, _ = infer_dtype_from(target, pandas_dtype=True)
dtype = find_common_type([self.dtype, target_dtype])
if dtype.kind in ["i", "u"]:
# TODO: what about reversed with self being categorical?
if is_categorical_dtype(target.dtype) and target.hasnans:
if (
isinstance(target, Index)
and is_categorical_dtype(target.dtype)
and target.hasnans
):
# FIXME: find_common_type incorrect with Categorical GH#38240
# FIXME: some cases where float64 cast can be lossy?
dtype = np.dtype(np.float64)
if dtype.kind == "c":
dtype = np.dtype(object)
return dtype

@final
Expand Down
32 changes: 14 additions & 18 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,29 +799,22 @@ def length(self):
return Index(self._data.length, copy=False)

def putmask(self, mask, value):
arr = self._data.copy()
mask = np.asarray(mask, dtype=bool)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you share / use the array/interval putmask code here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do; L814-815 is directly using IntervalArray.putmask. everything before that point is for ways that the Index method behaves differently from the array method

if mask.shape != self.shape:
raise ValueError("putmask: mask and data must be the same size")
if not mask.any():
return self.copy()

try:
value_left, value_right = arr._validate_setitem_value(value)
self._validate_fill_value(value)
except (ValueError, TypeError):
return self.astype(object).putmask(mask, value)
dtype = self._find_common_type_compat(value)
return self.astype(dtype).putmask(mask, value)

if isinstance(self._data._left, np.ndarray):
np.putmask(arr._left, mask, value_left)
np.putmask(arr._right, mask, value_right)
else:
# TODO: special case not needed with __array_function__
arr._left.putmask(mask, value_left)
arr._right.putmask(mask, value_right)
arr = self._data.copy()
arr.putmask(mask, value)
return type(self)._simple_new(arr, name=self.name)

@Appender(Index.where.__doc__)
def where(self, cond, other=None):
if other is None:
other = self._na_value
values = np.where(cond, self._values, other)
result = IntervalArray(values)
return type(self)._simple_new(result, name=self.name)

def insert(self, loc, item):
"""
Return a new IntervalIndex inserting new item at location. Follows
Expand Down Expand Up @@ -998,6 +991,9 @@ def func(self, other, sort=sort):

# --------------------------------------------------------------------

def _validate_fill_value(self, value):
return self._data._validate_setitem_value(value)

@property
def _is_all_dates(self) -> bool:
"""
Expand Down
9 changes: 6 additions & 3 deletions pandas/core/indexes/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,12 @@ def _validate_fill_value(self, value):
raise TypeError
value = int(value)

elif hasattr(value, "dtype") and value.dtype.kind in ["m", "M"]:
# TODO: if we're checking arraylike here, do so systematically
raise TypeError
elif hasattr(value, "dtype"):
if value.dtype.kind in ["m", "M"]:
raise TypeError
if value.dtype.kind == "f" and self.dtype.kind in ["i", "u"]:
# TODO: maybe OK if value is castable?
raise TypeError

return value

Expand Down
11 changes: 7 additions & 4 deletions pandas/tests/series/indexing/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,20 +331,23 @@ def test_index_where(self, obj, key, expected, request):
mask = np.zeros(obj.shape, dtype=bool)
mask[key] = True

if obj.dtype == bool and not mask.all():
# When mask is all True, casting behavior does not apply
if obj.dtype == bool:
msg = "Index/Series casting behavior inconsistent GH#38692"
mark = pytest.mark.xfail(reason=msg)
request.node.add_marker(mark)

res = Index(obj).where(~mask, np.nan)
tm.assert_index_equal(res, Index(expected))

@pytest.mark.xfail(reason="Index/Series casting behavior inconsistent GH#38692")
def test_index_putmask(self, obj, key, expected):
def test_index_putmask(self, obj, key, expected, request):
mask = np.zeros(obj.shape, dtype=bool)
mask[key] = True

if obj.dtype == bool:
msg = "Index/Series casting behavior inconsistent GH#38692"
mark = pytest.mark.xfail(reason=msg)
request.node.add_marker(mark)

res = Index(obj).putmask(mask, np.nan)
tm.assert_index_equal(res, Index(expected))

Expand Down