Skip to content

Commit 6288622

Browse files
authored
Revert "REF: remove putmask_preserve, putmask_without_repeat (#44328)"
This reverts commit 0861da6.
1 parent 5304048 commit 6288622

File tree

4 files changed

+57
-24
lines changed

4 files changed

+57
-24
lines changed

Diff for: pandas/core/array_algos/putmask.py

+46-4
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd
126126

127127
if values.dtype.kind == new.dtype.kind:
128128
# preserves dtype if possible
129-
np.putmask(values, mask, new)
130-
return values
129+
return _putmask_preserve(values, new, mask)
131130

132131
dtype = find_common_type([values.dtype, new.dtype])
133132
# error: Argument 1 to "astype" of "_ArrayOrScalarCommon" has incompatible type
@@ -136,8 +135,51 @@ def putmask_smart(values: np.ndarray, mask: npt.NDArray[np.bool_], new) -> np.nd
136135
# List[Any], _DTypeDict, Tuple[Any, Any]]]"
137136
values = values.astype(dtype) # type: ignore[arg-type]
138137

139-
np.putmask(values, mask, new)
140-
return values
138+
return _putmask_preserve(values, new, mask)
139+
140+
141+
def _putmask_preserve(new_values: np.ndarray, new, mask: npt.NDArray[np.bool_]):
142+
try:
143+
new_values[mask] = new[mask]
144+
except (IndexError, ValueError):
145+
new_values[mask] = new
146+
return new_values
147+
148+
149+
def putmask_without_repeat(
150+
values: np.ndarray, mask: npt.NDArray[np.bool_], new: Any
151+
) -> None:
152+
"""
153+
np.putmask will truncate or repeat if `new` is a listlike with
154+
len(new) != len(values). We require an exact match.
155+
156+
Parameters
157+
----------
158+
values : np.ndarray
159+
mask : np.ndarray[bool]
160+
new : Any
161+
"""
162+
if getattr(new, "ndim", 0) >= 1:
163+
new = new.astype(values.dtype, copy=False)
164+
165+
# TODO: this prob needs some better checking for 2D cases
166+
nlocs = mask.sum()
167+
if nlocs > 0 and is_list_like(new) and getattr(new, "ndim", 1) == 1:
168+
if nlocs == len(new):
169+
# GH#30567
170+
# If length of ``new`` is less than the length of ``values``,
171+
# `np.putmask` would first repeat the ``new`` array and then
172+
# assign the masked values hence produces incorrect result.
173+
# `np.place` on the other hand uses the ``new`` values at it is
174+
# to place in the masked locations of ``values``
175+
np.place(values, mask, new)
176+
# i.e. values[mask] = new
177+
elif mask.shape[-1] == len(new) or len(new) == 1:
178+
np.putmask(values, mask, new)
179+
else:
180+
raise ValueError("cannot assign mismatch length to masked array")
181+
else:
182+
np.putmask(values, mask, new)
141183

142184

143185
def validate_putmask(

Diff for: pandas/core/internals/blocks.py

+10-15
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
is_extension_array_dtype,
5252
is_interval_dtype,
5353
is_list_like,
54-
is_object_dtype,
5554
is_string_dtype,
5655
)
5756
from pandas.core.dtypes.dtypes import (
@@ -77,6 +76,7 @@
7776
extract_bool_array,
7877
putmask_inplace,
7978
putmask_smart,
79+
putmask_without_repeat,
8080
setitem_datetimelike_compat,
8181
validate_putmask,
8282
)
@@ -960,7 +960,10 @@ def putmask(self, mask, new) -> list[Block]:
960960
new = self.fill_value
961961

962962
if self._can_hold_element(new):
963-
np.putmask(self.values.T, mask, new)
963+
964+
# error: Argument 1 to "putmask_without_repeat" has incompatible type
965+
# "Union[ndarray, ExtensionArray]"; expected "ndarray"
966+
putmask_without_repeat(self.values.T, mask, new) # type: ignore[arg-type]
964967
return [self]
965968

966969
elif noop:
@@ -1412,16 +1415,15 @@ def putmask(self, mask, new) -> list[Block]:
14121415

14131416
new_values = self.values
14141417

1418+
if isinstance(new, (np.ndarray, ExtensionArray)) and len(new) == len(mask):
1419+
new = new[mask]
1420+
14151421
if mask.ndim == new_values.ndim + 1:
14161422
# TODO(EA2D): unnecessary with 2D EAs
14171423
mask = mask.reshape(new_values.shape)
14181424

14191425
try:
1420-
if isinstance(new, (np.ndarray, ExtensionArray)):
1421-
# Caller is responsible for ensuring matching lengths
1422-
new_values[mask] = new[mask]
1423-
else:
1424-
new_values[mask] = new
1426+
new_values[mask] = new
14251427
except TypeError:
14261428
if not is_interval_dtype(self.dtype):
14271429
# Discussion about what we want to support in the general
@@ -1479,14 +1481,7 @@ def setitem(self, indexer, value):
14791481
# we are always 1-D
14801482
indexer = indexer[0]
14811483

1482-
try:
1483-
check_setitem_lengths(indexer, value, self.values)
1484-
except ValueError:
1485-
# If we are object dtype (e.g. PandasDtype[object]) then
1486-
# we can hold nested data, so can ignore this mismatch.
1487-
if not is_object_dtype(self.dtype):
1488-
raise
1489-
1484+
check_setitem_lengths(indexer, value, self.values)
14901485
self.values[indexer] = value
14911486
return self
14921487

Diff for: pandas/core/series.py

+1
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,7 @@ def __setitem__(self, key, value) -> None:
11011101
is_list_like(value)
11021102
and len(value) != len(self)
11031103
and not isinstance(value, Series)
1104+
and not is_object_dtype(self.dtype)
11041105
):
11051106
# Series will be reindexed to have matching length inside
11061107
# _where call below

Diff for: pandas/tests/extension/test_numpy.py

-5
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,6 @@ def test_concat(self, data, in_frame):
363363

364364

365365
class TestSetitem(BaseNumPyTests, base.BaseSetitemTests):
366-
@skip_nested
367-
def test_setitem_sequence_mismatched_length_raises(self, data, as_array):
368-
# doesn't raise bc object dtype holds nested data
369-
super().test_setitem_sequence_mismatched_length_raises(data, as_array)
370-
371366
@skip_nested
372367
def test_setitem_invalid(self, data, invalid_scalar):
373368
# object dtype can hold anything, so doesn't raise

0 commit comments

Comments
 (0)