Skip to content

Commit

Permalink
Backport PR #45363: BUG: correctly instantiate subclassed DataFrame/S…
Browse files Browse the repository at this point in the history
…eries in groupby apply (#45397)
  • Loading branch information
meeseeksmachine authored Jan 16, 2022
1 parent 502dbdf commit 219811f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
16 changes: 4 additions & 12 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ def apply(
zipped = zip(group_keys, splitter)

for key, group in zipped:
group = group.__finalize__(data, method="groupby")
object.__setattr__(group, "name", key)

# group might be modified
Expand Down Expand Up @@ -1000,6 +1001,7 @@ def _aggregate_series_pure_python(
splitter = get_splitter(obj, ids, ngroups, axis=0)

for i, group in enumerate(splitter):
group = group.__finalize__(obj, method="groupby")
res = func(group)
res = libreduction.extract_result(res)

Expand Down Expand Up @@ -1243,13 +1245,7 @@ def _chop(self, sdata: Series, slice_obj: slice) -> Series:
# fastpath equivalent to `sdata.iloc[slice_obj]`
mgr = sdata._mgr.get_slice(slice_obj)
# __finalize__ not called here, must be applied by caller if applicable

# fastpath equivalent to:
# `return sdata._constructor(mgr, name=sdata.name, fastpath=True)`
obj = type(sdata)._from_mgr(mgr)
object.__setattr__(obj, "_flags", sdata._flags)
object.__setattr__(obj, "_name", sdata._name)
return obj
return sdata._constructor(mgr, name=sdata.name, fastpath=True)


class FrameSplitter(DataSplitter):
Expand All @@ -1261,11 +1257,7 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
# return sdata.iloc[:, slice_obj]
mgr = sdata._mgr.get_slice(slice_obj, axis=1 - self.axis)
# __finalize__ not called here, must be applied by caller if applicable

# fastpath equivalent to `return sdata._constructor(mgr)`
obj = type(sdata)._from_mgr(mgr)
object.__setattr__(obj, "_flags", sdata._flags)
return obj
return sdata._constructor(mgr)


def get_splitter(
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/generic/test_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def test_categorical_accessor(method):
"method",
[
operator.methodcaller("sum"),
lambda x: x.apply(lambda y: y),
lambda x: x.agg("sum"),
lambda x: x.agg("mean"),
lambda x: x.agg("median"),
Expand All @@ -764,7 +765,6 @@ def test_groupby_finalize(obj, method):
"method",
[
lambda x: x.agg(["sum", "count"]),
lambda x: x.apply(lambda y: y),
lambda x: x.agg("std"),
lambda x: x.agg("var"),
lambda x: x.agg("sem"),
Expand Down
23 changes: 23 additions & 0 deletions pandas/tests/groupby/test_groupby_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pandas import (
DataFrame,
Index,
Series,
)
import pandas._testing as tm
Expand Down Expand Up @@ -64,6 +65,28 @@ def test_groupby_preserves_metadata():
for _, group_df in custom_df.groupby("c"):
assert group_df.testattr == "hello"

# GH-45314
def func(group):
assert isinstance(group, tm.SubclassedDataFrame)
assert hasattr(group, "testattr")
return group.testattr

result = custom_df.groupby("c").apply(func)
expected = tm.SubclassedSeries(["hello"] * 3, index=Index([7, 8, 9], name="c"))
tm.assert_series_equal(result, expected)

def func2(group):
assert isinstance(group, tm.SubclassedSeries)
assert hasattr(group, "testattr")
return group.testattr

custom_series = tm.SubclassedSeries([1, 2, 3])
custom_series.testattr = "hello"
result = custom_series.groupby(custom_df["c"]).apply(func2)
tm.assert_series_equal(result, expected)
result = custom_series.groupby(custom_df["c"]).agg(func2)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("obj", [DataFrame, tm.SubclassedDataFrame])
def test_groupby_resample_preserves_subclass(obj):
Expand Down

0 comments on commit 219811f

Please sign in to comment.