diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 671142a98a6b36..08ab621f5e3cef 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -6,6 +6,7 @@ import warnings import copy from textwrap import dedent +from contextlib import contextmanager from pandas.compat import ( zip, range, lzip, @@ -549,6 +550,16 @@ def f(self): return attr +@contextmanager +def _group_selection_context(groupby): + """ + set / reset the _group_selection_context + """ + groupby._set_group_selection() + yield groupby + groupby._reset_group_selection() + + class _GroupBy(PandasObject, SelectionMixin): _group_selection = None _apply_whitelist = frozenset([]) @@ -704,6 +715,8 @@ def _set_group_selection(self): """ Create group based selection. Used when selection is not passed directly but instead via a grouper. + + NOTE: this should be paired with a call to _reset_group_selection """ grp = self.grouper if not (self.as_index and @@ -785,10 +798,10 @@ def _make_wrapper(self, name): type(self).__name__)) raise AttributeError(msg) - # need to setup the selection - # as are not passed directly but in the grouper self._set_group_selection() + # need to setup the selection + # as are not passed directly but in the grouper f = getattr(self._selected_obj, name) if not isinstance(f, types.MethodType): return self.apply(lambda self: getattr(self, name)) @@ -913,9 +926,8 @@ def f(g): # fails on *some* columns, e.g. a numeric operation # on a string grouper column - self._set_group_selection() - result = self._python_apply_general(f) - self._reset_group_selection() + with _group_selection_context(self): + return self._python_apply_general(f) return result @@ -1295,9 +1307,9 @@ def mean(self, *args, **kwargs): except GroupByError: raise except Exception: # pragma: no cover - self._set_group_selection() - f = lambda x: x.mean(axis=self.axis, **kwargs) - return self._python_agg_general(f) + with _group_selection_context(self): + f = lambda x: x.mean(axis=self.axis, **kwargs) + return self._python_agg_general(f) @Substitution(name='groupby') @Appender(_doc_template) @@ -1313,13 +1325,12 @@ def median(self, **kwargs): raise except Exception: # pragma: no cover - self._set_group_selection() - def f(x): if isinstance(x, np.ndarray): x = Series(x) return x.median(axis=self.axis, **kwargs) - return self._python_agg_general(f) + with _group_selection_context(self): + return self._python_agg_general(f) @Substitution(name='groupby') @Appender(_doc_template) @@ -1356,9 +1367,9 @@ def var(self, ddof=1, *args, **kwargs): if ddof == 1: return self._cython_agg_general('var', **kwargs) else: - self._set_group_selection() f = lambda x: x.var(ddof=ddof, **kwargs) - return self._python_agg_general(f) + with _group_selection_context(self): + return self._python_agg_general(f) @Substitution(name='groupby') @Appender(_doc_template) @@ -1404,6 +1415,7 @@ def f(self, **kwargs): kwargs['numeric_only'] = numeric_only if 'min_count' not in kwargs: kwargs['min_count'] = min_count + self._set_group_selection() try: return self._cython_agg_general( @@ -1797,13 +1809,12 @@ def ngroup(self, ascending=True): .cumcount : Number the rows in each group. """ - self._set_group_selection() - - index = self._selected_obj.index - result = Series(self.grouper.group_info[0], index) - if not ascending: - result = self.ngroups - 1 - result - return result + with _group_selection_context(self): + index = self._selected_obj.index + result = Series(self.grouper.group_info[0], index) + if not ascending: + result = self.ngroups - 1 - result + return result @Substitution(name='groupby') def cumcount(self, ascending=True): @@ -1854,11 +1865,10 @@ def cumcount(self, ascending=True): .ngroup : Number the groups themselves. """ - self._set_group_selection() - - index = self._selected_obj.index - cumcounts = self._cumcount_array(ascending=ascending) - return Series(cumcounts, index) + with _group_selection_context(self): + index = self._selected_obj.index + cumcounts = self._cumcount_array(ascending=ascending) + return Series(cumcounts, index) @Substitution(name='groupby') @Appender(_doc_template) diff --git a/pandas/tests/groupby/test_apply.py b/pandas/tests/groupby/test_apply.py index ecb4e2af530424..07eef2d87feb30 100644 --- a/pandas/tests/groupby/test_apply.py +++ b/pandas/tests/groupby/test_apply.py @@ -519,11 +519,11 @@ def test_func(x): def test_apply_with_mixed_types(): # gh-20949 - df = pd.DataFrame({'A': 'a a b'.split(), 'B': [1,2,3], 'C': [4, 6, 5]}) + df = pd.DataFrame({'A': 'a a b'.split(), 'B': [1, 2, 3], 'C': [4, 6, 5]}) g = df.groupby('A') result = g.transform(lambda x: x / x.sum()) - expected = pd.DataFrame({'B': [1/3., 2/3., 1], 'C': [0.4, 0.6, 1.0]}) + expected = pd.DataFrame({'B': [1 / 3., 2 / 3., 1], 'C': [0.4, 0.6, 1.0]}) tm.assert_frame_equal(result, expected) result = g.apply(lambda x: x / x.sum())