Skip to content

Commit

Permalink
use a context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
jreback committed May 5, 2018
1 parent cf9ee69 commit 4a0ca5e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 27 deletions.
60 changes: 35 additions & 25 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
import copy
from textwrap import dedent
from contextlib import contextmanager

from pandas.compat import (
zip, range, lzip,
Expand Down Expand Up @@ -549,6 +550,16 @@ def f(self):
return attr


@contextmanager
def _group_selection_context(groupby):
"""
set / reset the _group_selection_context
"""
groupby._set_group_selection()
yield groupby
groupby._reset_group_selection()


class _GroupBy(PandasObject, SelectionMixin):
_group_selection = None
_apply_whitelist = frozenset([])
Expand Down Expand Up @@ -704,6 +715,8 @@ def _set_group_selection(self):
"""
Create group based selection. Used when selection is not passed
directly but instead via a grouper.
NOTE: this should be paired with a call to _reset_group_selection
"""
grp = self.grouper
if not (self.as_index and
Expand Down Expand Up @@ -785,10 +798,10 @@ def _make_wrapper(self, name):
type(self).__name__))
raise AttributeError(msg)

# need to setup the selection
# as are not passed directly but in the grouper
self._set_group_selection()

# need to setup the selection
# as are not passed directly but in the grouper
f = getattr(self._selected_obj, name)
if not isinstance(f, types.MethodType):
return self.apply(lambda self: getattr(self, name))
Expand Down Expand Up @@ -913,9 +926,8 @@ def f(g):
# fails on *some* columns, e.g. a numeric operation
# on a string grouper column

self._set_group_selection()
result = self._python_apply_general(f)
self._reset_group_selection()
with _group_selection_context(self):
return self._python_apply_general(f)

return result

Expand Down Expand Up @@ -1295,9 +1307,9 @@ def mean(self, *args, **kwargs):
except GroupByError:
raise
except Exception: # pragma: no cover
self._set_group_selection()
f = lambda x: x.mean(axis=self.axis, **kwargs)
return self._python_agg_general(f)
with _group_selection_context(self):
f = lambda x: x.mean(axis=self.axis, **kwargs)
return self._python_agg_general(f)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand All @@ -1313,13 +1325,12 @@ def median(self, **kwargs):
raise
except Exception: # pragma: no cover

self._set_group_selection()

def f(x):
if isinstance(x, np.ndarray):
x = Series(x)
return x.median(axis=self.axis, **kwargs)
return self._python_agg_general(f)
with _group_selection_context(self):
return self._python_agg_general(f)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -1356,9 +1367,9 @@ def var(self, ddof=1, *args, **kwargs):
if ddof == 1:
return self._cython_agg_general('var', **kwargs)
else:
self._set_group_selection()
f = lambda x: x.var(ddof=ddof, **kwargs)
return self._python_agg_general(f)
with _group_selection_context(self):
return self._python_agg_general(f)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -1404,6 +1415,7 @@ def f(self, **kwargs):
kwargs['numeric_only'] = numeric_only
if 'min_count' not in kwargs:
kwargs['min_count'] = min_count

self._set_group_selection()
try:
return self._cython_agg_general(
Expand Down Expand Up @@ -1797,13 +1809,12 @@ def ngroup(self, ascending=True):
.cumcount : Number the rows in each group.
"""

self._set_group_selection()

index = self._selected_obj.index
result = Series(self.grouper.group_info[0], index)
if not ascending:
result = self.ngroups - 1 - result
return result
with _group_selection_context(self):
index = self._selected_obj.index
result = Series(self.grouper.group_info[0], index)
if not ascending:
result = self.ngroups - 1 - result
return result

@Substitution(name='groupby')
def cumcount(self, ascending=True):
Expand Down Expand Up @@ -1854,11 +1865,10 @@ def cumcount(self, ascending=True):
.ngroup : Number the groups themselves.
"""

self._set_group_selection()

index = self._selected_obj.index
cumcounts = self._cumcount_array(ascending=ascending)
return Series(cumcounts, index)
with _group_selection_context(self):
index = self._selected_obj.index
cumcounts = self._cumcount_array(ascending=ascending)
return Series(cumcounts, index)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/groupby/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,11 @@ def test_func(x):

def test_apply_with_mixed_types():
# gh-20949
df = pd.DataFrame({'A': 'a a b'.split(), 'B': [1,2,3], 'C': [4, 6, 5]})
df = pd.DataFrame({'A': 'a a b'.split(), 'B': [1, 2, 3], 'C': [4, 6, 5]})
g = df.groupby('A')

result = g.transform(lambda x: x / x.sum())
expected = pd.DataFrame({'B': [1/3., 2/3., 1], 'C': [0.4, 0.6, 1.0]})
expected = pd.DataFrame({'B': [1 / 3., 2 / 3., 1], 'C': [0.4, 0.6, 1.0]})
tm.assert_frame_equal(result, expected)

result = g.apply(lambda x: x / x.sum())
Expand Down

0 comments on commit 4a0ca5e

Please sign in to comment.