Skip to content

Commit

Permalink
Adding argmax and argmin with proper behavior (pandas-dev#16830)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucas Kushner committed Jul 15, 2017
1 parent 2784c3f commit 49a512d
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 11 deletions.
60 changes: 59 additions & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5329,7 +5329,7 @@ def idxmin(self, axis=0, skipna=True):

def idxmax(self, axis=0, skipna=True):
"""
Return index of first occurrence of maximum over requested axis.
Return label of first occurrence of maximum over requested axis.
NA/null values are excluded.
Parameters
Expand Down Expand Up @@ -5358,6 +5358,64 @@ def idxmax(self, axis=0, skipna=True):
result = [index[i] if i >= 0 else NA for i in indices]
return Series(result, index=self._get_agg_axis(axis))

def argmin(self, axis=0, skipna=True):
"""
Return index of first occurrence of minimum over requested axis.
NA/null values are excluded.
Parameters
----------
axis : {0 or 'index', 1 or 'columns'}, default 0
0 or 'index' for row-wise, 1 or 'columns' for column-wise
skipna : boolean, default True
Exclude NA/null values. If an entire row/column is NA, the result
will be NA
Returns
-------
argmin : Series
Notes
-----
This method is the DataFrame version of ``ndarray.argmin``.
See Also
--------
Series.idxmin
"""
axis = self._get_axis_number(axis)
indices = nanops.nanargmin(self.values, axis=axis, skipna=skipna)
return Series(indices, index=self._get_agg_axis(axis))

def argmax(self, axis=0, skipna=True):
"""
Return index of first occurrence of maximum over requested axis.
NA/null values are excluded.
Parameters
----------
axis : {0 or 'index', 1 or 'columns'}, default 0
0 or 'index' for row-wise, 1 or 'columns' for column-wise
skipna : boolean, default True
Exclude NA/null values. If an entire row/column is NA, the result
will be first index.
Returns
-------
argmax : Series
Notes
-----
This method is the DataFrame version of ``ndarray.argmax``.
See Also
--------
Series.argmax
"""
axis = self._get_axis_number(axis)
indices = nanops.nanargmax(self.values, axis=axis, skipna=skipna)
return Series(indices, index=self._get_agg_axis(axis))

def _get_agg_axis(self, axis_num):
""" let's be explict about this """
if axis_num == 0:
Expand Down
66 changes: 56 additions & 10 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
import pandas.core.nanops as nanops
import pandas.io.formats.format as fmt
from pandas.util._decorators import (
Appender, deprecate, deprecate_kwarg, Substitution)
Appender, deprecate_kwarg, Substitution)
from pandas.util._validators import validate_bool_kwarg

from pandas._libs import index as libindex, tslib as libts, lib, iNaT
Expand Down Expand Up @@ -1239,7 +1239,7 @@ def duplicated(self, keep='first'):

def idxmin(self, axis=None, skipna=True, *args, **kwargs):
"""
Index of first occurrence of minimum of values.
Label of first occurrence of minimum of values.
Parameters
----------
Expand All @@ -1259,15 +1259,14 @@ def idxmin(self, axis=None, skipna=True, *args, **kwargs):
DataFrame.idxmin
numpy.ndarray.argmin
"""
skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs)
i = nanops.nanargmin(_values_from_object(self), skipna=skipna)
i = self.argmin(axis, skipna, *args, **kwargs)
if i == -1:
return np.nan
return self.index[i]

def idxmax(self, axis=None, skipna=True, *args, **kwargs):
"""
Index of first occurrence of maximum of values.
Label of first occurrence of maximum of values.
Parameters
----------
Expand All @@ -1287,15 +1286,62 @@ def idxmax(self, axis=None, skipna=True, *args, **kwargs):
DataFrame.idxmax
numpy.ndarray.argmax
"""
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)
i = nanops.nanargmax(_values_from_object(self), skipna=skipna)
i = self.argmax(axis, skipna, *args, **kwargs)
if i == -1:
return np.nan
return self.index[i]

# ndarray compat
argmin = deprecate('argmin', idxmin)
argmax = deprecate('argmax', idxmax)
def argmin(self, axis=None, skipna=True, *args, **kwargs):
"""
Index of first occurrence of minimum of values.
Parameters
----------
skipna : boolean, default True
Exclude NA/null values
Returns
-------
idxmin : Index of minimum of values
Notes
-----
This method is the Series version of ``ndarray.argmin``.
See Also
--------
DataFrame.argmin
numpy.ndarray.argmin
"""
skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs)
i = nanops.nanargmin(_values_from_object(self), skipna=skipna)
return i

def argmax(self, axis=None, skipna=True, *args, **kwargs):
"""
Index of first occurrence of maximum of values.
Parameters
----------
skipna : boolean, default True
Exclude NA/null values
Returns
-------
idxmax : Index of maximum of values
Notes
-----
This method is the Series version of ``ndarray.argmax``.
See Also
--------
DataFrame.argmax
numpy.ndarray.argmax
"""
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)
i = nanops.nanargmax(_values_from_object(self), skipna=skipna)
return i

def round(self, decimals=0, *args, **kwargs):
"""
Expand Down
28 changes: 28 additions & 0 deletions pandas/tests/frame/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,34 @@ def test_idxmax(self):

pytest.raises(ValueError, frame.idxmax, axis=2)

def test_argmin(self):
frame = self.frame
frame.loc[5:10] = np.nan
frame.loc[15:20, -2:] = np.nan
for skipna in [True, False]:
for axis in [0, 1]:
for df in [frame, self.intframe]:
result = df.argmin(axis=axis, skipna=skipna)
expected = df.apply(Series.argmin, axis=axis,
skipna=skipna)
tm.assert_series_equal(result, expected)

pytest.raises(ValueError, frame.argmin, axis=2)

def test_argmax(self):
frame = self.frame
frame.loc[5:10] = np.nan
frame.loc[15:20, -2:] = np.nan
for skipna in [True, False]:
for axis in [0, 1]:
for df in [frame, self.intframe]:
result = df.argmax(axis=axis, skipna=skipna)
expected = df.apply(Series.argmax, axis=axis,
skipna=skipna)
tm.assert_series_equal(result, expected)

pytest.raises(ValueError, frame.argmax, axis=2)

# ----------------------------------------------------------------------
# Logical reductions

Expand Down
2 changes: 2 additions & 0 deletions pandas/tests/series/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,7 @@ def test_numpy_argmin(self):
data = np.random.randint(0, 11, size=10)
result = np.argmin(Series(data))
assert result == np.argmin(data)
assert result == Series(data).argmin()

if not _np_version_under1p10:
msg = "the 'out' parameter is not supported"
Expand Down Expand Up @@ -1271,6 +1272,7 @@ def test_numpy_argmax(self):
data = np.random.randint(0, 11, size=10)
result = np.argmax(Series(data))
assert result == np.argmax(data)
assert result == Series(data).argmax()

if not _np_version_under1p10:
msg = "the 'out' parameter is not supported"
Expand Down

0 comments on commit 49a512d

Please sign in to comment.