Skip to content

Commit

Permalink
BUG/API: consistency in .agg with nested dicts pandas-dev#9052
Browse files Browse the repository at this point in the history
  • Loading branch information
jreback committed Dec 19, 2015
1 parent 3c23dc9 commit 36fb835
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 28 deletions.
46 changes: 46 additions & 0 deletions doc/source/whatsnew/v0.18.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,55 @@ New features
~~~~~~~~~~~~


.. _whatsnew_0180.enhancements.moments:

Computation moments are now methods
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Computational moments have been refactored to be method on ``Series/DataFrame`` objects, rather than top-level functions, which are now deprecated. This allows these window-type functions, to have a similar API to that of ``.groupby``. See the full documentation :ref:`here <stats.moments>` (:issue:`11603`)

.. ipython:: python

np.random.seed(1234)
df = DataFrame({'A' : range(10), 'B' : np.random.randn(10)})
df

Previous Behavior:

.. code-block:: python

In [8]: pd.rolling_mean(df,window=3)
Out[8]:
A B
0 NaN NaN
1 NaN NaN
2 1 0.237722
3 2 -0.023640
4 3 0.133155
5 4 -0.048693
6 5 0.342054
7 6 0.370076
8 7 0.079587
9 8 -0.954504

New Behavior:

.. ipython:: python

r = df.rolling(window=3)

# descriptive repr
r

# operate on this Rolling object itself
r.mean()

# getitem access
r['A'].mean()

# aggregates
r.agg({'A' : {'ra' : ['mean','std']},
'B' : {'rb' : ['mean','std']}})

.. _whatsnew_0180.enhancements.other:

Expand Down Expand Up @@ -195,6 +240,7 @@ Bug Fixes
- Bug in ``Period.end_time`` when a multiple of time period is requested (:issue:`11738`)
- Regression in ``.clip`` with tz-aware datetimes (:issue:`11838`)
- Bug in ``date_range`` when the boundaries fell on the frequency (:issue:`11804`)
- Bug in consistency of passing nested dicts to ``.groupby(...).agg(...)`` (:issue:`9052`)



Expand Down
33 changes: 26 additions & 7 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class SelectionMixin(object):
sub-classes need to define: obj, exclusions
"""
_selection = None
_internal_names = ['_cache']
_internal_names = ['_cache','__setstate__']
_internal_names_set = set(_internal_names)
_builtin_table = {
builtins.sum: np.sum,
Expand Down Expand Up @@ -368,6 +368,13 @@ def _aggregate(self, arg, *args, **kwargs):
"""
provide an implementation for the aggregators
Parameters
----------
arg : string, dict, function
*args : args to pass on to the function
**kwargs : kwargs to pass on to the function
Returns
-------
tuple of result, how
Expand All @@ -378,6 +385,7 @@ def _aggregate(self, arg, *args, **kwargs):
None if not required
"""

_level = kwargs.pop('_level',None)
if isinstance(arg, compat.string_types):
return getattr(self, arg)(*args, **kwargs), None

Expand All @@ -403,24 +411,24 @@ def _aggregate(self, arg, *args, **kwargs):

for fname, agg_how in compat.iteritems(arg):
colg = self._gotitem(self._selection, ndim=1, subset=subset)
result[fname] = colg.aggregate(agg_how)
result[fname] = colg.aggregate(agg_how, _level=None)
keys.append(fname)
else:
for col, agg_how in compat.iteritems(arg):
colg = self._gotitem(col, ndim=1)
result[col] = colg.aggregate(agg_how)
result[col] = colg.aggregate(agg_how, _level=(_level or 0) + 1)
keys.append(col)

if isinstance(list(result.values())[0], com.ABCDataFrame):
from pandas.tools.merge import concat
result = concat([result[k] for k in keys], keys=keys, axis=1)
result = concat([ result[k] for k in keys ], keys=keys, axis=1)
else:
from pandas import DataFrame
result = DataFrame(result)

return result, True
elif hasattr(arg, '__iter__'):
return self._aggregate_multiple_funcs(arg), None
return self._aggregate_multiple_funcs(arg, _level=_level), None
else:
result = None

Expand All @@ -431,7 +439,7 @@ def _aggregate(self, arg, *args, **kwargs):
# caller can react
return result, True

def _aggregate_multiple_funcs(self, arg):
def _aggregate_multiple_funcs(self, arg, _level):
from pandas.tools.merge import concat

if self.axis != 0:
Expand All @@ -447,7 +455,15 @@ def _aggregate_multiple_funcs(self, arg):
try:
colg = self._gotitem(obj.name, ndim=1, subset=obj)
results.append(colg.aggregate(a))
keys.append(getattr(a,'name',a))

# find a good name, this could be a function that we don't recognize
name = self._is_cython_func(a) or a
if not isinstance(name, compat.string_types):
name = getattr(a,name,a)
if not isinstance(name, compat.string_types):
name = getattr(a,func_name,a)

keys.append(name)
except (TypeError, DataError):
pass
except SpecificationError:
Expand All @@ -464,6 +480,9 @@ def _aggregate_multiple_funcs(self, arg):
pass
except SpecificationError:
raise

if _level:
keys = None
result = concat(results, keys=keys, axis=1)

return result
Expand Down
15 changes: 12 additions & 3 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2362,6 +2362,7 @@ def aggregate(self, func_or_funcs, *args, **kwargs):
-------
Series or DataFrame
"""
_level = kwargs.pop('_level',None)
if isinstance(func_or_funcs, compat.string_types):
return getattr(self, func_or_funcs)(*args, **kwargs)

Expand Down Expand Up @@ -2411,11 +2412,18 @@ def _aggregate_multiple_funcs(self, arg):

results = {}
for name, func in arg:
obj = self
if name in results:
raise SpecificationError('Function names must be unique, '
'found multiple named %s' % name)

results[name] = self.aggregate(func)
# reset the cache so that we
# only include the named selection
if name in self._selected_obj:
obj = copy.copy(obj)
obj._reset_cache()
obj._selection = name
results[name] = obj.aggregate(func)

return DataFrame(results, columns=columns)

Expand Down Expand Up @@ -2856,7 +2864,8 @@ def _post_process_cython_aggregate(self, obj):
@Appender(SelectionMixin._agg_doc)
def aggregate(self, arg, *args, **kwargs):

result, how = self._aggregate(arg, *args, **kwargs)
_level = kwargs.pop('_level',None)
result, how = self._aggregate(arg, _level=_level, *args, **kwargs)
if how is None:
return result

Expand All @@ -2870,7 +2879,7 @@ def aggregate(self, arg, *args, **kwargs):
# try to treat as if we are passing a list
try:
assert not args and not kwargs
result = self._aggregate_multiple_funcs([arg])
result = self._aggregate_multiple_funcs([arg], _level=_level)
result.columns = Index(result.columns.levels[0],
name=self._selected_obj.columns.name)
except:
Expand Down
14 changes: 7 additions & 7 deletions pandas/core/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import defaultdict

import pandas as pd
from pandas.lib import isscalar
from pandas.core.base import PandasObject, SelectionMixin, AbstractMethodError
import pandas.core.common as com
import pandas.algos as algos
Expand Down Expand Up @@ -64,11 +65,12 @@ def _gotitem(self, key, ndim, subset=None):
# create a new object to prevent aliasing
if subset is None:
subset = self.obj
new_self = self._shallow_copy(subset)
if ndim==2 and key in subset:
new_self._selection = key
new_self._reset_cache()
return new_self
self = self._shallow_copy(subset)
self._reset_cache()
if subset.ndim==2:
if isscalar(key) and key in subset or com.is_list_like(key):
self._selection = key
return self

def __getattr__(self, attr):
if attr in self._internal_names_set:
Expand Down Expand Up @@ -191,8 +193,6 @@ def _convert_freq(self):
@Appender(SelectionMixin._agg_doc)
def aggregate(self, arg, *args, **kwargs):
result, how = self._aggregate(arg, *args, **kwargs)
if result is None:
import pdb; pdb.set_trace()
return result

class Window(_Window):
Expand Down
42 changes: 42 additions & 0 deletions pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,48 @@ def test_frame_set_name_single(self):
result = grouped['C'].agg({'foo': np.mean, 'bar': np.std})
self.assertEqual(result.index.name, 'A')

def test_aggregate_api_consistency(self):
# GH 9052
# make sure that the aggregates via dict
# are consistent


def compare(result, expected):
# if we ar passin dicts then ordering is not guaranteed for output columns
assert_frame_equal(result.reindex_like(expected), expected)


df = DataFrame({'A' : ['foo', 'bar', 'foo', 'bar',
'foo', 'bar', 'foo', 'foo'],
'B' : ['one', 'one', 'two', 'three',
'two', 'two', 'one', 'three'],
'C' : np.random.randn(8),
'D' : np.random.randn(8)})

grouped = df.groupby(['A', 'B'])
result = grouped[['D','C']].agg({'r':np.sum, 'r2':np.mean})
expected = pd.concat([grouped[['D','C']].sum(),
grouped[['D','C']].mean()],
keys=['r','r2'],
axis=1).stack(level=1)
compare(result, expected)

result = grouped[['D','C']].agg({'r': { 'C' : np.sum }, 'r2' : { 'D' : np.mean }})
expected = pd.concat([grouped[['C']].sum(),
grouped[['D']].mean()],
axis=1)
expected.columns = MultiIndex.from_tuples([('r','C'),('r2','D')])
compare(result, expected)

result = grouped[['D','C']].agg([np.sum, np.mean])
expected = pd.concat([grouped['D'].sum(),
grouped['D'].mean(),
grouped['C'].sum(),
grouped['C'].mean()],
axis=1)
expected.columns = MultiIndex.from_product([['D','C'],['sum','mean']])
compare(result, expected)

def test_multi_iter(self):
s = Series(np.arange(6))
k1 = np.array(['a', 'a', 'a', 'b', 'b', 'b'])
Expand Down
66 changes: 55 additions & 11 deletions pandas/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,62 @@ def test_agg(self):
'B' : range(0,10,2)})

r = df.rolling(window=3)
a_mean = r['A'].mean()
a_std = r['A'].std()
a_sum = r['A'].sum()
b_mean = r['B'].mean()
b_std = r['B'].std()
b_sum = r['B'].sum()

def compare(result, expected):
# if we are using dicts, the orderings is not guaranteed
assert_frame_equal(result.reindex_like(expected), expected)

result = r.aggregate([np.mean, np.std])
expected = pd.concat([a_mean,a_std,b_mean,b_std],axis=1)
expected.columns = pd.MultiIndex.from_product([['A','B'],['mean','std']])
assert_frame_equal(result, expected)

result = r.aggregate({'A': np.mean,
'B': np.std})
expected = pd.concat([a_mean,b_std],axis=1)
compare(result, expected)

result = r.aggregate({'A': ['mean','std']})
expected = pd.concat([a_mean,a_std],axis=1)
expected.columns = pd.MultiIndex.from_product([['A'],['mean','std']])
assert_frame_equal(result, expected)

result = r['A'].aggregate(['mean','sum'])
expected = pd.concat([a_mean,a_sum],axis=1)
expected.columns = pd.MultiIndex.from_product([['A'],['mean','sum']])
assert_frame_equal(result, expected)

import pdb; pdb.set_trace()
agged = r.aggregate([np.mean, np.std])
agged = r.aggregate({'A': np.mean,
'B': np.std})
agged = r.aggregate({'A': ['mean','sum']})
agged = r['A'].aggregate(['mean','sum'])
agged = r.aggregate({'A': { 'mean' : 'mean', 'sum' : 'sum' } })
agged = r.aggregate({'A': { 'mean' : 'mean', 'sum' : 'sum' },
'B': { 'mean2' : 'mean', 'sum2' : 'sum' }})
agged = r.aggregate({'r1': { 'A' : ['mean','sum'] },
'r2' : { 'B' : ['mean','sum'] }})
result = r.aggregate({'A': { 'mean' : 'mean', 'sum' : 'sum' } })
expected = pd.concat([a_mean,a_sum],axis=1)
expected.columns = pd.MultiIndex.from_product([['A'],['mean','sum']])
compare(result, expected)

result = r.aggregate({'A': { 'mean' : 'mean', 'sum' : 'sum' },
'B': { 'mean2' : 'mean', 'sum2' : 'sum' }})
expected = pd.concat([a_mean,a_sum,b_mean,b_sum],axis=1)
expected.columns = pd.MultiIndex.from_tuples([('A','mean'),('A','sum'),
('B','mean2'),('B','sum2')])
compare(result, expected)

result = r.aggregate({'r1' : { 'A' : ['mean','sum'] },
'r2' : { 'B' : ['mean','sum'] }})
expected = pd.concat([a_mean,a_sum,b_mean,b_sum],axis=1)
expected.columns = pd.MultiIndex.from_tuples([('r1','A','mean'),('r1','A','sum'),
('r2','B','mean'),('r2','B','sum')])
compare(result, expected)

result = r.agg({'A' : {'ra' : ['mean','std']},
'B' : {'rb' : ['mean','std']}})
expected = pd.concat([a_mean,a_std,b_mean,b_std],axis=1)
expected.columns = pd.MultiIndex.from_tuples([('A','ra','mean'),('A','ra','std'),
('B','rb','mean'),('B','rb','std')])
compare(result, expected)

class TestMoments(Base):

Expand Down

0 comments on commit 36fb835

Please sign in to comment.