Skip to content

Commit

Permalink
BUG: df.agg, df.transform and df.apply use different methods when axi…
Browse files Browse the repository at this point in the history
…s=1 than when axis=0 (pandas-dev#21224)
  • Loading branch information
topper-123 authored and victor committed Sep 30, 2018
1 parent 10a95d5 commit 638b0ad
Show file tree
Hide file tree
Showing 8 changed files with 338 additions and 98 deletions.
4 changes: 3 additions & 1 deletion doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,9 @@ Numeric
- Bug in :class:`Series` ``__rmatmul__`` doesn't support matrix vector multiplication (:issue:`21530`)
- Bug in :func:`factorize` fails with read-only array (:issue:`12813`)
- Fixed bug in :func:`unique` handled signed zeros inconsistently: for some inputs 0.0 and -0.0 were treated as equal and for some inputs as different. Now they are treated as equal for all inputs (:issue:`21866`)
-
- Bug in :meth:`DataFrame.agg`, :meth:`DataFrame.transform` and :meth:`DataFrame.apply` where,
when supplied with a list of functions and ``axis=1`` (e.g. ``df.apply(['sum', 'mean'], axis=1)``),
a ``TypeError`` was wrongly raised. For all three methods such calculation are now done correctly. (:issue:`16679`).
-

Strings
Expand Down
55 changes: 55 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ def spmatrix(request):
return getattr(sparse, request.param + '_matrix')


@pytest.fixture(params=[0, 1, 'index', 'columns'],
ids=lambda x: "axis {!r}".format(x))
def axis(request):
"""
Fixture for returning the axis numbers of a DataFrame.
"""
return request.param


axis_frame = axis


@pytest.fixture(params=[0, 'index'], ids=lambda x: "axis {!r}".format(x))
def axis_series(request):
"""
Fixture for returning the axis numbers of a Series.
"""
return request.param


@pytest.fixture
def ip():
"""
Expand Down Expand Up @@ -103,6 +123,41 @@ def all_arithmetic_operators(request):
return request.param


# use sorted as dicts in py<3.6 have random order, which xdist doesn't like
_cython_table = sorted(((key, value) for key, value in
pd.core.base.SelectionMixin._cython_table.items()),
key=lambda x: x[0].__class__.__name__)


@pytest.fixture(params=_cython_table)
def cython_table_items(request):
return request.param


def _get_cython_table_params(ndframe, func_names_and_expected):
"""combine frame, functions from SelectionMixin._cython_table
keys and expected result.
Parameters
----------
ndframe : DataFrame or Series
func_names_and_expected : Sequence of two items
The first item is a name of a NDFrame method ('sum', 'prod') etc.
The second item is the expected return value
Returns
-------
results : list
List of three items (DataFrame, function, expected result)
"""
results = []
for func_name, expected in func_names_and_expected:
results.append((ndframe, func_name, expected))
results += [(ndframe, func, expected) for func, name in _cython_table
if name == func_name]
return results


@pytest.fixture(params=['__eq__', '__ne__', '__le__',
'__lt__', '__ge__', '__gt__'])
def all_compare_operators(request):
Expand Down
16 changes: 7 additions & 9 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.common import (
is_extension_type,
is_dict_like,
is_list_like,
is_sequence)
from pandas.util._decorators import cache_readonly

Expand Down Expand Up @@ -105,6 +107,11 @@ def agg_axis(self):
def get_result(self):
""" compute the results """

# dispatch to agg
if is_list_like(self.f) or is_dict_like(self.f):
return self.obj.aggregate(self.f, axis=self.axis,
*self.args, **self.kwds)

# all empty
if len(self.columns) == 0 and len(self.index) == 0:
return self.apply_empty_result()
Expand Down Expand Up @@ -308,15 +315,6 @@ def wrap_results(self):
class FrameRowApply(FrameApply):
axis = 0

def get_result(self):

# dispatch to agg
if isinstance(self.f, (list, dict)):
return self.obj.aggregate(self.f, axis=self.axis,
*self.args, **self.kwds)

return super(FrameRowApply, self).get_result()

def apply_broadcast(self):
return super(FrameRowApply, self).apply_broadcast(self.obj)

Expand Down
27 changes: 21 additions & 6 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6070,19 +6070,34 @@ def _gotitem(self,
def aggregate(self, func, axis=0, *args, **kwargs):
axis = self._get_axis_number(axis)

# TODO: flipped axis
result = None
if axis == 0:
try:
result, how = self._aggregate(func, axis=0, *args, **kwargs)
except TypeError:
pass
try:
result, how = self._aggregate(func, axis=axis, *args, **kwargs)
except TypeError:
pass
if result is None:
return self.apply(func, axis=axis, args=args, **kwargs)
return result

def _aggregate(self, arg, axis=0, *args, **kwargs):
if axis == 1:
# NDFrame.aggregate returns a tuple, and we need to transpose
# only result
result, how = (super(DataFrame, self.T)
._aggregate(arg, *args, **kwargs))
result = result.T if result is not None else result
return result, how
return super(DataFrame, self)._aggregate(arg, *args, **kwargs)

agg = aggregate

@Appender(_shared_docs['transform'] % _shared_doc_kwargs)
def transform(self, func, axis=0, *args, **kwargs):
axis = self._get_axis_number(axis)
if axis == 1:
return super(DataFrame, self.T).transform(func, *args, **kwargs).T
return super(DataFrame, self).transform(func, *args, **kwargs)

def apply(self, func, axis=0, broadcast=None, raw=False, reduce=None,
result_type=None, args=(), **kwds):
"""
Expand Down
16 changes: 7 additions & 9 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9193,16 +9193,14 @@ def ewm(self, com=None, span=None, halflife=None, alpha=None,

cls.ewm = ewm

@Appender(_shared_docs['transform'] % _shared_doc_kwargs)
def transform(self, func, *args, **kwargs):
result = self.agg(func, *args, **kwargs)
if is_scalar(result) or len(result) != len(self):
raise ValueError("transforms cannot produce "
"aggregated results")
@Appender(_shared_docs['transform'] % _shared_doc_kwargs)
def transform(self, func, *args, **kwargs):
result = self.agg(func, *args, **kwargs)
if is_scalar(result) or len(result) != len(self):
raise ValueError("transforms cannot produce "
"aggregated results")

return result

cls.transform = transform
return result

# ----------------------------------------------------------------------
# Misc methods
Expand Down
Loading

0 comments on commit 638b0ad

Please sign in to comment.