Skip to content

Commit d409035

Browse files
DGradyDaniel Grady
authored and
Daniel Grady
committed
BUG: Fix behavior of argmax and argmin with inf (#16449)
Closes #13595 The implementations of `nanargmin` and `nanargmax` in `nanops` were forcing the `_get_values` utility function to always mask out infinite values. For example, in `nanargmax`, >>> nanops._get_values(np.array([1, np.nan, np.inf]), True, isfinite=True, fill_value_typ='-inf') (array([ 1., -inf, -inf]), array([False, True, True], dtype=bool), dtype('float64'), numpy.float64) The first element of the result tuple (the masked version of the values array) is used for actually finding the max or min argument. As a result, infinite values could never be correctly recognized as the maximum or minimum values in an array. This also affects the behavior of `DataFrame.groupby.idxmax`: since `Series.idxmax` previously raised a `ValueError` with string data, the group by would silently drop columns that contained strings.
1 parent 18f7b1c commit d409035

File tree

4 files changed

+72
-7
lines changed

4 files changed

+72
-7
lines changed

Diff for: doc/source/whatsnew/v0.21.0.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ Sparse
127127
Reshaping
128128
^^^^^^^^^
129129

130-
130+
- `argmin`, `argmax`, `idxmin`, and `idxmax` on Series, DataFrame, and GroupBy objects work correctly with floating point data that contains infinite values (:issue:`13595`). These functions now also work with string data, as long as there are no missing values.
131131

132132
Numeric
133133
^^^^^^^

Diff for: pandas/core/nanops.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,7 @@ def nanargmax(values, axis=None, skipna=True):
474474
"""
475475
Returns -1 in the NA case
476476
"""
477-
values, mask, dtype, _ = _get_values(values, skipna, fill_value_typ='-inf',
478-
isfinite=True)
477+
values, mask, dtype, _ = _get_values(values, skipna, fill_value_typ='-inf')
479478
result = values.argmax(axis)
480479
result = _maybe_arg_null_out(result, axis, mask, skipna)
481480
return result
@@ -485,8 +484,7 @@ def nanargmin(values, axis=None, skipna=True):
485484
"""
486485
Returns -1 in the NA case
487486
"""
488-
values, mask, dtype, _ = _get_values(values, skipna, fill_value_typ='+inf',
489-
isfinite=True)
487+
values, mask, dtype, _ = _get_values(values, skipna, fill_value_typ='+inf')
490488
result = values.argmin(axis)
491489
result = _maybe_arg_null_out(result, axis, mask, skipna)
492490
return result

Diff for: pandas/tests/groupby/test_groupby.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2339,7 +2339,8 @@ def test_non_cython_api(self):
23392339
assert_frame_equal(result, expected)
23402340

23412341
# idxmax
2342-
expected = DataFrame([[0], [nan]], columns=['B'], index=[1, 3])
2342+
expected = DataFrame([[0.0, 0.0], [nan, 2.0]], columns=['B', 'C'],
2343+
index=[1, 3])
23432344
expected.index.name = 'A'
23442345
result = g.idxmax()
23452346
assert_frame_equal(result, expected)

Diff for: pandas/tests/series/test_operators.py

+67-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pandas.core.indexes.timedeltas import Timedelta
2121
import pandas.core.nanops as nanops
2222

23-
from pandas.compat import range, zip
23+
from pandas.compat import range, zip, PY3
2424
from pandas import compat
2525
from pandas.util.testing import (assert_series_equal, assert_almost_equal,
2626
assert_frame_equal, assert_index_equal)
@@ -1857,3 +1857,69 @@ def test_op_duplicate_index(self):
18571857
result = s1 + s2
18581858
expected = pd.Series([11, 12, np.nan], index=[1, 1, 2])
18591859
assert_series_equal(result, expected)
1860+
1861+
def test_argminmax(self):
1862+
# Series.argmin, Series.argmax are aliased to Series.idxmin,
1863+
# Series.idxmax
1864+
1865+
# Expected behavior for empty Series
1866+
s = pd.Series([])
1867+
1868+
with pytest.raises(ValueError):
1869+
s.argmin()
1870+
with pytest.raises(ValueError):
1871+
s.argmin(skipna=False)
1872+
with pytest.raises(ValueError):
1873+
s.argmax()
1874+
with pytest.raises(ValueError):
1875+
s.argmax(skipna=False)
1876+
1877+
# For numeric data with NA and Inf (GH #13595)
1878+
s = pd.Series([0, -np.inf, np.inf, np.nan])
1879+
1880+
assert s.argmin() == 1
1881+
assert np.isnan(s.argmin(skipna=False))
1882+
1883+
assert s.argmax() == 2
1884+
assert np.isnan(s.argmax(skipna=False))
1885+
1886+
# Using old-style behavior that treats floating point nan, -inf, and
1887+
# +inf as missing
1888+
s = pd.Series([0, -np.inf, np.inf, np.nan])
1889+
1890+
with pd.option_context('mode.use_inf_as_null', True):
1891+
assert s.argmin() == 0
1892+
assert np.isnan(s.argmin(skipna=False))
1893+
assert s.argmax() == 0
1894+
np.isnan(s.argmax(skipna=False))
1895+
1896+
# For non-NA strings
1897+
s = pd.Series(['foo', 'foo', 'bar', 'bar', 'baz'])
1898+
1899+
assert s.argmin() == 2
1900+
assert s.argmin(skipna=False) == 2
1901+
1902+
assert s.argmax() == 0
1903+
assert s.argmax(skipna=False) == 0
1904+
1905+
# For mixed string and NA
1906+
# This works differently under Python 2 and 3: under Python 2,
1907+
# comparing strings and None, for example, is valid, and we can
1908+
# compute an argmax. Under Python 3, such comparisons are not valid
1909+
# and raise a TypeError.
1910+
s = pd.Series(['foo', 'foo', 'bar', 'bar', None, np.nan, 'baz'])
1911+
1912+
if PY3:
1913+
with pytest.raises(TypeError):
1914+
s.argmin()
1915+
with pytest.raises(TypeError):
1916+
s.argmin(skipna=False)
1917+
with pytest.raises(TypeError):
1918+
s.argmax()
1919+
with pytest.raises(TypeError):
1920+
s.argmax(skipna=False)
1921+
else:
1922+
assert s.argmin() == 4
1923+
assert np.isnan(s.argmin(skipna=False))
1924+
assert s.argmax() == 0
1925+
assert np.isnan(s.argmax(skipna=False))

0 commit comments

Comments
 (0)