Skip to content

Commit

Permalink
BUG in .groupby.apply when applying a function that has mixed data ty…
Browse files Browse the repository at this point in the history
…pes and the user supplied function can fail on the grouping column (#20959)
  • Loading branch information
jreback authored May 8, 2018
1 parent e051303 commit 620784f
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 39 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.23.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,7 @@ Groupby/Resample/Rolling
- Bug in :func:`DataFrame.resample` that dropped timezone information (:issue:`13238`)
- Bug in :func:`DataFrame.groupby` where transformations using ``np.all`` and ``np.any`` were raising a ``ValueError`` (:issue:`20653`)
- Bug in :func:`DataFrame.resample` where ``ffill``, ``bfill``, ``pad``, ``backfill``, ``fillna``, ``interpolate``, and ``asfreq`` were ignoring ``loffset``. (:issue:`20744`)
- Bug in :func:`DataFrame.groupby` when applying a function that has mixed data types and the user supplied function can fail on the grouping column (:issue:`20949`)

Sparse
^^^^^^
Expand Down
108 changes: 69 additions & 39 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
import copy
from textwrap import dedent
from contextlib import contextmanager

from pandas.compat import (
zip, range, lzip,
Expand Down Expand Up @@ -549,6 +550,16 @@ def f(self):
return attr


@contextmanager
def _group_selection_context(groupby):
"""
set / reset the _group_selection_context
"""
groupby._set_group_selection()
yield groupby
groupby._reset_group_selection()


class _GroupBy(PandasObject, SelectionMixin):
_group_selection = None
_apply_whitelist = frozenset([])
Expand Down Expand Up @@ -696,26 +707,32 @@ def _reset_group_selection(self):
each group regardless of whether a group selection was previously set.
"""
if self._group_selection is not None:
self._group_selection = None
# GH12839 clear cached selection too when changing group selection
self._group_selection = None
self._reset_cache('_selected_obj')

def _set_group_selection(self):
"""
Create group based selection. Used when selection is not passed
directly but instead via a grouper.
NOTE: this should be paired with a call to _reset_group_selection
"""
grp = self.grouper
if self.as_index and getattr(grp, 'groupings', None) is not None and \
self.obj.ndim > 1:
ax = self.obj._info_axis
groupers = [g.name for g in grp.groupings
if g.level is None and g.in_axis]
if not (self.as_index and
getattr(grp, 'groupings', None) is not None and
self.obj.ndim > 1 and
self._group_selection is None):
return

ax = self.obj._info_axis
groupers = [g.name for g in grp.groupings
if g.level is None and g.in_axis]

if len(groupers):
self._group_selection = ax.difference(Index(groupers)).tolist()
# GH12839 clear selected obj cache when group selection changes
self._reset_cache('_selected_obj')
if len(groupers):
# GH12839 clear selected obj cache when group selection changes
self._group_selection = ax.difference(Index(groupers)).tolist()
self._reset_cache('_selected_obj')

def _set_result_index_ordered(self, result):
# set the result index on the passed values object and
Expand Down Expand Up @@ -781,10 +798,10 @@ def _make_wrapper(self, name):
type(self).__name__))
raise AttributeError(msg)

# need to setup the selection
# as are not passed directly but in the grouper
self._set_group_selection()

# need to setup the selection
# as are not passed directly but in the grouper
f = getattr(self._selected_obj, name)
if not isinstance(f, types.MethodType):
return self.apply(lambda self: getattr(self, name))
Expand Down Expand Up @@ -897,7 +914,22 @@ def f(g):

# ignore SettingWithCopy here in case the user mutates
with option_context('mode.chained_assignment', None):
return self._python_apply_general(f)
try:
result = self._python_apply_general(f)
except Exception:

# gh-20949
# try again, with .apply acting as a filtering
# operation, by excluding the grouping column
# This would normally not be triggered
# except if the udf is trying an operation that
# fails on *some* columns, e.g. a numeric operation
# on a string grouper column

with _group_selection_context(self):
return self._python_apply_general(f)

return result

def _python_apply_general(self, f):
keys, values, mutated = self.grouper.apply(f, self._selected_obj,
Expand Down Expand Up @@ -1275,9 +1307,9 @@ def mean(self, *args, **kwargs):
except GroupByError:
raise
except Exception: # pragma: no cover
self._set_group_selection()
f = lambda x: x.mean(axis=self.axis, **kwargs)
return self._python_agg_general(f)
with _group_selection_context(self):
f = lambda x: x.mean(axis=self.axis, **kwargs)
return self._python_agg_general(f)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand All @@ -1293,13 +1325,12 @@ def median(self, **kwargs):
raise
except Exception: # pragma: no cover

self._set_group_selection()

def f(x):
if isinstance(x, np.ndarray):
x = Series(x)
return x.median(axis=self.axis, **kwargs)
return self._python_agg_general(f)
with _group_selection_context(self):
return self._python_agg_general(f)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -1336,9 +1367,9 @@ def var(self, ddof=1, *args, **kwargs):
if ddof == 1:
return self._cython_agg_general('var', **kwargs)
else:
self._set_group_selection()
f = lambda x: x.var(ddof=ddof, **kwargs)
return self._python_agg_general(f)
with _group_selection_context(self):
return self._python_agg_general(f)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -1384,6 +1415,7 @@ def f(self, **kwargs):
kwargs['numeric_only'] = numeric_only
if 'min_count' not in kwargs:
kwargs['min_count'] = min_count

self._set_group_selection()
try:
return self._cython_agg_general(
Expand Down Expand Up @@ -1453,11 +1485,11 @@ def ohlc(self):

@Appender(DataFrame.describe.__doc__)
def describe(self, **kwargs):
self._set_group_selection()
result = self.apply(lambda x: x.describe(**kwargs))
if self.axis == 1:
return result.T
return result.unstack()
with _group_selection_context(self):
result = self.apply(lambda x: x.describe(**kwargs))
if self.axis == 1:
return result.T
return result.unstack()

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -1778,13 +1810,12 @@ def ngroup(self, ascending=True):
.cumcount : Number the rows in each group.
"""

self._set_group_selection()

index = self._selected_obj.index
result = Series(self.grouper.group_info[0], index)
if not ascending:
result = self.ngroups - 1 - result
return result
with _group_selection_context(self):
index = self._selected_obj.index
result = Series(self.grouper.group_info[0], index)
if not ascending:
result = self.ngroups - 1 - result
return result

@Substitution(name='groupby')
def cumcount(self, ascending=True):
Expand Down Expand Up @@ -1835,11 +1866,10 @@ def cumcount(self, ascending=True):
.ngroup : Number the groups themselves.
"""

self._set_group_selection()

index = self._selected_obj.index
cumcounts = self._cumcount_array(ascending=ascending)
return Series(cumcounts, index)
with _group_selection_context(self):
index = self._selected_obj.index
cumcounts = self._cumcount_array(ascending=ascending)
return Series(cumcounts, index)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -3768,7 +3798,6 @@ def nunique(self, dropna=True):

@Appender(Series.describe.__doc__)
def describe(self, **kwargs):
self._set_group_selection()
result = self.apply(lambda x: x.describe(**kwargs))
if self.axis == 1:
return result.T
Expand Down Expand Up @@ -4411,6 +4440,7 @@ def transform(self, func, *args, **kwargs):
return self._transform_general(func, *args, **kwargs)

obj = self._obj_with_exclusions

# nuiscance columns
if not result.columns.equals(obj.columns):
return self._transform_general(func, *args, **kwargs)
Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/groupby/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,16 @@ def test_func(x):
index=index2)
tm.assert_frame_equal(result1, expected1)
tm.assert_frame_equal(result2, expected2)


def test_apply_with_mixed_types():
# gh-20949
df = pd.DataFrame({'A': 'a a b'.split(), 'B': [1, 2, 3], 'C': [4, 6, 5]})
g = df.groupby('A')

result = g.transform(lambda x: x / x.sum())
expected = pd.DataFrame({'B': [1 / 3., 2 / 3., 1], 'C': [0.4, 0.6, 1.0]})
tm.assert_frame_equal(result, expected)

result = g.apply(lambda x: x / x.sum())
tm.assert_frame_equal(result, expected)

0 comments on commit 620784f

Please sign in to comment.