Skip to content

Commit

Permalink
Fixed Issue Preventing Agg on RollingGroupBy Objects (pandas-dev#21323)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored and victor committed Sep 30, 2018
1 parent 9143599 commit 6bbcfc1
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ Groupby/Resample/Rolling
- Bug in :meth:`Series.resample` when passing ``numpy.timedelta64`` to ``loffset`` kwarg (:issue:`7687`).
- Bug in :meth:`Resampler.asfreq` when frequency of ``TimedeltaIndex`` is a subperiod of a new frequency (:issue:`13022`).
- Bug in :meth:`SeriesGroupBy.mean` when values were integral but could not fit inside of int64, overflowing instead. (:issue:`22487`)
- :func:`RollingGroupby.agg` and :func:`ExpandingGroupby.agg` now support multiple aggregation functions as parameters (:issue:`15072`)

Sparse
^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ def _obj_with_exclusions(self):

def __getitem__(self, key):
if self._selection is not None:
raise Exception('Column(s) {selection} already selected'
.format(selection=self._selection))
raise IndexError('Column(s) {selection} already selected'
.format(selection=self._selection))

if isinstance(key, (list, tuple, ABCSeries, ABCIndexClass,
np.ndarray)):
Expand Down
9 changes: 8 additions & 1 deletion pandas/core/groupby/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ def _gotitem(self, key, ndim, subset=None):
# we need to make a shallow copy of ourselves
# with the same groupby
kwargs = {attr: getattr(self, attr) for attr in self._attributes}

# Try to select from a DataFrame, falling back to a Series
try:
groupby = self._groupby[key]
except IndexError:
groupby = self._groupby

self = self.__class__(subset,
groupby=self._groupby[key],
groupby=groupby,
parent=self,
**kwargs)
self._reset_cache()
Expand Down
10 changes: 8 additions & 2 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,14 @@ def test_as_index_series_return_frame(df):
assert isinstance(result2, DataFrame)
assert_frame_equal(result2, expected2)

# corner case
pytest.raises(Exception, grouped['C'].__getitem__, 'D')

def test_as_index_series_column_slice_raises(df):
# GH15072
grouped = df.groupby('A', as_index=False)
msg = r"Column\(s\) C already selected"

with tm.assert_raises_regex(IndexError, msg):
grouped['C'].__getitem__('D')


def test_groupby_as_index_cython(df):
Expand Down
48 changes: 48 additions & 0 deletions pandas/tests/test_window.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import OrderedDict
from itertools import product
import pytest
import warnings
Expand Down Expand Up @@ -314,6 +315,53 @@ def test_preserve_metadata(self):
assert s2.name == 'foo'
assert s3.name == 'foo'

@pytest.mark.parametrize("func,window_size,expected_vals", [
('rolling', 2, [[np.nan, np.nan, np.nan, np.nan],
[15., 20., 25., 20.],
[25., 30., 35., 30.],
[np.nan, np.nan, np.nan, np.nan],
[20., 30., 35., 30.],
[35., 40., 60., 40.],
[60., 80., 85., 80]]),
('expanding', None, [[10., 10., 20., 20.],
[15., 20., 25., 20.],
[20., 30., 30., 20.],
[10., 10., 30., 30.],
[20., 30., 35., 30.],
[26.666667, 40., 50., 30.],
[40., 80., 60., 30.]])])
def test_multiple_agg_funcs(self, func, window_size, expected_vals):
# GH 15072
df = pd.DataFrame([
['A', 10, 20],
['A', 20, 30],
['A', 30, 40],
['B', 10, 30],
['B', 30, 40],
['B', 40, 80],
['B', 80, 90]], columns=['stock', 'low', 'high'])

f = getattr(df.groupby('stock'), func)
if window_size:
window = f(window_size)
else:
window = f()

index = pd.MultiIndex.from_tuples([
('A', 0), ('A', 1), ('A', 2),
('B', 3), ('B', 4), ('B', 5), ('B', 6)], names=['stock', None])
columns = pd.MultiIndex.from_tuples([
('low', 'mean'), ('low', 'max'), ('high', 'mean'),
('high', 'min')])
expected = pd.DataFrame(expected_vals, index=index, columns=columns)

result = window.agg(OrderedDict((
('low', ['mean', 'max']),
('high', ['mean', 'min']),
)))

tm.assert_frame_equal(result, expected)


@pytest.mark.filterwarnings("ignore:can't resolve package:ImportWarning")
class TestWindow(Base):
Expand Down

0 comments on commit 6bbcfc1

Please sign in to comment.