Skip to content

BUG: idxmax/min (and argmax/min) for Series with underlying ExtensionArray #37924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jan 1, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c000668
fix idxmax/min for Series with underlying 'Int' datatype
tonyyyyip Nov 17, 2020
1632081
test added
tonyyyyip Nov 18, 2020
8018586
editted test
tonyyyyip Nov 18, 2020
b076d23
added test for argmax argmin
tonyyyyip Nov 19, 2020
664e4ec
added validations
tonyyyyip Nov 20, 2020
8403c38
The overhaul
tonyyyyip Nov 20, 2020
741c97a
2nd attempt
tonyyyyip Nov 24, 2020
131ae83
2nd attempt
tonyyyyip Nov 24, 2020
3269fb5
simplified idxmaxmin and added parameterised tests
tonyyyyip Nov 25, 2020
0cfb621
fixed newbie mistake
tonyyyyip Nov 26, 2020
d4b13ac
fixed newbie mistake
tonyyyyip Nov 26, 2020
5648eb9
used any_numeric_dtype in test
tonyyyyip Nov 26, 2020
9a1b332
does this solve the pre-commit check failure now?
tonyyyyip Nov 26, 2020
9f5e683
moved EA's skipna logic from Series.argmin/max to EA.argmin/max
tonyyyyip Nov 26, 2020
e73a3d1
moved EA's skipna logic back to Series
tonyyyyip Nov 27, 2020
d78e28c
moved EA's skipna logic back to Series
tonyyyyip Nov 27, 2020
66f3187
added whatsnew entry and extra tests
tonyyyyip Dec 5, 2020
6f01069
moved tests to tests/reductions/rest_reductions.py
tonyyyyip Dec 5, 2020
3540797
added 1.3 whatsnew entry
tonyyyyip Dec 28, 2020
6d0b68e
moved and restructured tests
tonyyyyip Dec 30, 2020
da5bf06
Merge branch 'master' into fix-argmax
tonyyyyip Dec 30, 2020
4f6d111
moved tests to pandas/tests/extensions/methods.py
tonyyyyip Dec 31, 2020
aa909e9
moved tests to pandas/tests/extensions/methods.py
tonyyyyip Dec 31, 2020
dd0ecde
Merge branch 'fix-argmax' of https://github.com/tonyyyyip/pandas into…
tonyyyyip Dec 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,9 @@ Sparse

ExtensionArray
^^^^^^^^^^^^^^

- Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with ExtensionArray dtype (:issue:`38729`)
-
- Bug in :meth:`Series.idxmax`, :meth:`Series.idxmin` and ``argmax/min`` fail when the underlying data is :class:`ExtensionArray` (:issue:`32749`, :issue:`33719`, :issue:`36566`)
-

Other
Expand Down
24 changes: 20 additions & 4 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,9 +715,17 @@ def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:
the minimum cereal calories is the first element,
since series is zero-indexed.
"""
delegate = self._values
nv.validate_minmax_axis(axis)
nv.validate_argmax_with_skipna(skipna, args, kwargs)
return nanops.nanargmax(self._values, skipna=skipna)
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)

if isinstance(delegate, ExtensionArray):
if not skipna and delegate.isna().any():
return -1
else:
return delegate.argmax()
else:
return nanops.nanargmax(delegate, skipna=skipna)

def min(self, axis=None, skipna: bool = True, *args, **kwargs):
"""
Expand Down Expand Up @@ -765,9 +773,17 @@ def min(self, axis=None, skipna: bool = True, *args, **kwargs):

@doc(argmax, op="min", oppose="max", value="smallest")
def argmin(self, axis=None, skipna=True, *args, **kwargs) -> int:
delegate = self._values
nv.validate_minmax_axis(axis)
nv.validate_argmax_with_skipna(skipna, args, kwargs)
return nanops.nanargmin(self._values, skipna=skipna)
skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs)

if isinstance(delegate, ExtensionArray):
if not skipna and delegate.isna().any():
return -1
else:
return delegate.argmin()
else:
return nanops.nanargmin(delegate, skipna=skipna)

def tolist(self):
"""
Expand Down
6 changes: 2 additions & 4 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,8 +2076,7 @@ def idxmin(self, axis=0, skipna=True, *args, **kwargs):
>>> s.idxmin(skipna=False)
nan
"""
skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer validating args/kwargs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry missed this comment

cc @tonyyyyip

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I think if we let idxmax call self.argmax(axis=axis, skipna=skipna, *args, **kwargs) then the args and kwargs can be validated by argmax's validator. I can make another PR if this is desirable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be great @tonyyyyip

i = nanops.nanargmin(self._values, skipna=skipna)
i = self.argmin(None, skipna=skipna)
if i == -1:
return np.nan
return self.index[i]
Expand Down Expand Up @@ -2147,8 +2146,7 @@ def idxmax(self, axis=0, skipna=True, *args, **kwargs):
>>> s.idxmax(skipna=False)
nan
"""
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)
i = nanops.nanargmax(self._values, skipna=skipna)
i = self.argmax(None, skipna=skipna)
if i == -1:
return np.nan
return self.index[i]
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/extension/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class TestMyDtype(BaseDtypeTests):
)
from .printing import BasePrintingTests # noqa
from .reduce import ( # noqa
BaseArgReduceTests,
BaseBooleanReduceTests,
BaseNoReduceTests,
BaseNumericReduceTests,
Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/extension/base/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,11 @@ def test_reduce_series(self, data, all_boolean_reductions, skipna):
op_name = all_boolean_reductions
s = pd.Series(data)
self.check_reduce(s, op_name, skipna)


class BaseArgReduceTests(BaseReduceTests):
@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.parametrize("op_name", ["argmin", "argmax", "idxmin", "idxmax"])
def test_reduce_series(self, data, op_name, skipna):
s = pd.Series(data)
self.check_reduce(s, op_name, skipna)
4 changes: 4 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):
pass


class TestArgReduce(base.BaseArgReduceTests):
pass


class TestMethods(BaseDecimal, base.BaseMethodsTests):
@pytest.mark.parametrize("dropna", [True, False])
@pytest.mark.xfail(reason="value_counts not implemented yet.")
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ class TestBooleanReduce(base.BaseBooleanReduceTests):
pass


class TestArgReduce(base.BaseArgReduceTests):
pass


class TestPrinting(base.BasePrintingTests):
pass

Expand Down
14 changes: 14 additions & 0 deletions pandas/tests/extension/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pandas.core.dtypes.dtypes import DatetimeTZDtype

import pandas as pd
import pandas._testing as tm
from pandas.core.arrays import DatetimeArray
from pandas.tests.extension import base

Expand Down Expand Up @@ -223,3 +224,16 @@ class TestGroupby(BaseDatetimeTests, base.BaseGroupbyTests):

class TestPrinting(BaseDatetimeTests, base.BasePrintingTests):
pass


class TestArgReduce(base.BaseArgReduceTests):
def check_reduce(self, s, op_name, skipna):
result = getattr(s, op_name)(skipna=skipna)
if not skipna and s.isna().any():
if op_name in ["argmin", "argmax"]:
expected = -1
else:
expected = np.nan
else:
expected = getattr(s.dropna().astype("int64"), op_name)(skipna=skipna)
tm.assert_almost_equal(result, expected)
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ class TestBooleanReduce(base.BaseBooleanReduceTests):
pass


class TestArgReduce(base.BaseArgReduceTests):
pass


class TestPrinting(base.BasePrintingTests):
pass

Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ class TestBooleanReduce(base.BaseBooleanReduceTests):
pass


class TestArgReduce(base.BaseArgReduceTests):
pass


class TestPrinting(base.BasePrintingTests):
pass

Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ class TestBooleanReduce(BaseNumPyTests, base.BaseBooleanReduceTests):
pass


@skip_nested
class TestArgReduce(base.BaseArgReduceTests):
pass


class TestMissing(BaseNumPyTests, base.BaseMissingTests):
@skip_nested
def test_fillna_scalar(self, data_missing):
Expand Down