Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG in .groupby.apply when applying a function that has mixed data types and the user supplied function can fail on the grouping column #20959

Merged
merged 3 commits into from
May 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.23.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,7 @@ Groupby/Resample/Rolling
- Bug in :func:`DataFrame.resample` that dropped timezone information (:issue:`13238`)
- Bug in :func:`DataFrame.groupby` where transformations using ``np.all`` and ``np.any`` were raising a ``ValueError`` (:issue:`20653`)
- Bug in :func:`DataFrame.resample` where ``ffill``, ``bfill``, ``pad``, ``backfill``, ``fillna``, ``interpolate``, and ``asfreq`` were ignoring ``loffset``. (:issue:`20744`)
- Bug in :func:`DataFrame.groupby` when applying a function that has mixed data types and the user supplied function can fail on the grouping column (:issue:`20949`)

Sparse
^^^^^^
Expand Down
108 changes: 69 additions & 39 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
import copy
from textwrap import dedent
from contextlib import contextmanager

from pandas.compat import (
zip, range, lzip,
Expand Down Expand Up @@ -549,6 +550,16 @@ def f(self):
return attr


@contextmanager
def _group_selection_context(groupby):
"""
set / reset the _group_selection_context
"""
groupby._set_group_selection()
yield groupby
groupby._reset_group_selection()


class _GroupBy(PandasObject, SelectionMixin):
_group_selection = None
_apply_whitelist = frozenset([])
Expand Down Expand Up @@ -696,26 +707,32 @@ def _reset_group_selection(self):
each group regardless of whether a group selection was previously set.
"""
if self._group_selection is not None:
self._group_selection = None
# GH12839 clear cached selection too when changing group selection
self._group_selection = None
self._reset_cache('_selected_obj')

def _set_group_selection(self):
"""
Create group based selection. Used when selection is not passed
directly but instead via a grouper.

NOTE: this should be paired with a call to _reset_group_selection
"""
grp = self.grouper
if self.as_index and getattr(grp, 'groupings', None) is not None and \
self.obj.ndim > 1:
ax = self.obj._info_axis
groupers = [g.name for g in grp.groupings
if g.level is None and g.in_axis]
if not (self.as_index and
getattr(grp, 'groupings', None) is not None and
self.obj.ndim > 1 and
self._group_selection is None):
return

ax = self.obj._info_axis
groupers = [g.name for g in grp.groupings
if g.level is None and g.in_axis]

if len(groupers):
self._group_selection = ax.difference(Index(groupers)).tolist()
# GH12839 clear selected obj cache when group selection changes
self._reset_cache('_selected_obj')
if len(groupers):
# GH12839 clear selected obj cache when group selection changes
self._group_selection = ax.difference(Index(groupers)).tolist()
self._reset_cache('_selected_obj')

def _set_result_index_ordered(self, result):
# set the result index on the passed values object and
Expand Down Expand Up @@ -781,10 +798,10 @@ def _make_wrapper(self, name):
type(self).__name__))
raise AttributeError(msg)

# need to setup the selection
# as are not passed directly but in the grouper
self._set_group_selection()

# need to setup the selection
# as are not passed directly but in the grouper
f = getattr(self._selected_obj, name)
if not isinstance(f, types.MethodType):
return self.apply(lambda self: getattr(self, name))
Expand Down Expand Up @@ -897,7 +914,22 @@ def f(g):

# ignore SettingWithCopy here in case the user mutates
with option_context('mode.chained_assignment', None):
return self._python_apply_general(f)
try:
result = self._python_apply_general(f)
except Exception:

# gh-20949
# try again, with .apply acting as a filtering
# operation, by excluding the grouping column
# This would normally not be triggered
# except if the udf is trying an operation that
# fails on *some* columns, e.g. a numeric operation
# on a string grouper column

with _group_selection_context(self):
return self._python_apply_general(f)

return result

def _python_apply_general(self, f):
keys, values, mutated = self.grouper.apply(f, self._selected_obj,
Expand Down Expand Up @@ -1275,9 +1307,9 @@ def mean(self, *args, **kwargs):
except GroupByError:
raise
except Exception: # pragma: no cover
self._set_group_selection()
f = lambda x: x.mean(axis=self.axis, **kwargs)
return self._python_agg_general(f)
with _group_selection_context(self):
f = lambda x: x.mean(axis=self.axis, **kwargs)
return self._python_agg_general(f)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand All @@ -1293,13 +1325,12 @@ def median(self, **kwargs):
raise
except Exception: # pragma: no cover

self._set_group_selection()

def f(x):
if isinstance(x, np.ndarray):
x = Series(x)
return x.median(axis=self.axis, **kwargs)
return self._python_agg_general(f)
with _group_selection_context(self):
return self._python_agg_general(f)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -1336,9 +1367,9 @@ def var(self, ddof=1, *args, **kwargs):
if ddof == 1:
return self._cython_agg_general('var', **kwargs)
else:
self._set_group_selection()
f = lambda x: x.var(ddof=ddof, **kwargs)
return self._python_agg_general(f)
with _group_selection_context(self):
return self._python_agg_general(f)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -1384,6 +1415,7 @@ def f(self, **kwargs):
kwargs['numeric_only'] = numeric_only
if 'min_count' not in kwargs:
kwargs['min_count'] = min_count

self._set_group_selection()
try:
return self._cython_agg_general(
Expand Down Expand Up @@ -1453,11 +1485,11 @@ def ohlc(self):

@Appender(DataFrame.describe.__doc__)
def describe(self, **kwargs):
self._set_group_selection()
result = self.apply(lambda x: x.describe(**kwargs))
if self.axis == 1:
return result.T
return result.unstack()
with _group_selection_context(self):
result = self.apply(lambda x: x.describe(**kwargs))
if self.axis == 1:
return result.T
return result.unstack()

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -1778,13 +1810,12 @@ def ngroup(self, ascending=True):
.cumcount : Number the rows in each group.
"""

self._set_group_selection()

index = self._selected_obj.index
result = Series(self.grouper.group_info[0], index)
if not ascending:
result = self.ngroups - 1 - result
return result
with _group_selection_context(self):
index = self._selected_obj.index
result = Series(self.grouper.group_info[0], index)
if not ascending:
result = self.ngroups - 1 - result
return result

@Substitution(name='groupby')
def cumcount(self, ascending=True):
Expand Down Expand Up @@ -1835,11 +1866,10 @@ def cumcount(self, ascending=True):
.ngroup : Number the groups themselves.
"""

self._set_group_selection()

index = self._selected_obj.index
cumcounts = self._cumcount_array(ascending=ascending)
return Series(cumcounts, index)
with _group_selection_context(self):
index = self._selected_obj.index
cumcounts = self._cumcount_array(ascending=ascending)
return Series(cumcounts, index)

@Substitution(name='groupby')
@Appender(_doc_template)
Expand Down Expand Up @@ -3768,7 +3798,6 @@ def nunique(self, dropna=True):

@Appender(Series.describe.__doc__)
def describe(self, **kwargs):
self._set_group_selection()
result = self.apply(lambda x: x.describe(**kwargs))
if self.axis == 1:
return result.T
Expand Down Expand Up @@ -4411,6 +4440,7 @@ def transform(self, func, *args, **kwargs):
return self._transform_general(func, *args, **kwargs)

obj = self._obj_with_exclusions

# nuiscance columns
if not result.columns.equals(obj.columns):
return self._transform_general(func, *args, **kwargs)
Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/groupby/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,16 @@ def test_func(x):
index=index2)
tm.assert_frame_equal(result1, expected1)
tm.assert_frame_equal(result2, expected2)


def test_apply_with_mixed_types():
# gh-20949
df = pd.DataFrame({'A': 'a a b'.split(), 'B': [1, 2, 3], 'C': [4, 6, 5]})
g = df.groupby('A')

result = g.transform(lambda x: x / x.sum())
expected = pd.DataFrame({'B': [1 / 3., 2 / 3., 1], 'C': [0.4, 0.6, 1.0]})
tm.assert_frame_equal(result, expected)

result = g.apply(lambda x: x / x.sum())
tm.assert_frame_equal(result, expected)