Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataFrame subclass lost in groupby.agg with split_out set. #1024

Open
TomAugspurger opened this issue Apr 13, 2024 · 1 comment
Open

DataFrame subclass lost in groupby.agg with split_out set. #1024

TomAugspurger opened this issue Apr 13, 2024 · 1 comment

Comments

@TomAugspurger
Copy link
Member

TomAugspurger commented Apr 13, 2024

Describe the issue:

As part of geopandas/dask-geopandas#285, we found that dask-expr will lose the type of a pandas DataFrame subclass in groupby.agg if (and only if?) the split_out parameter is used.

Minimal Complete Verifiable Example:

Given this file:

# file: test.py
import dask.dataframe.backends
import pandas as pd

import dask_expr as dx
import dask.dataframe as dd
from dask.dataframe.dispatch import make_meta_dispatch, meta_nonempty
from dask.dataframe.core import get_parallel_type
import dask.dataframe.backends


dask.config.set(scheduler="single-threaded")

class MySeries(pd.Series):
    @property
    def _constructor(self):
        return MySeries

    @property
    def _constructor_expanddim(self):
        return MyDataFrame


class MyDataFrame(pd.DataFrame):
    @property
    def _constructor(self):
        return MyDataFrame

    @property
    def _constructor_sliced(self):
        return MySeries


class MyIndex(pd.Index): ...


class MyDaskSeries(dx.Series):
    _partition_type = MySeries


class MyDaskDataFrame(dx.DataFrame):
    _partition_type = MyDataFrame


class MyDaskIndex(dx.Index):
    _partition_type = MyIndex


# Unclear if any of get_parallel_type and make_meta_dispatch are needed.
# Reproduces with or without them.
@get_parallel_type.register(MyDataFrame)
def get_parallel_type_dataframe(df):
    return MyDataFrame


@get_parallel_type.register(MySeries)
def get_parallel_type_series(s):
    return MyDaskSeries


@get_parallel_type.register(MyIndex)
def get_parallel_type_index(ind):
    return MyDaskIndex


@make_meta_dispatch.register(MyDataFrame)
def make_meta_dataframe(df, index=None):
    return df.head(0)


@make_meta_dispatch.register(MySeries)
def make_meta_series(s, index=None):
    return s.head(0)


@make_meta_dispatch.register(MyIndex)
def make_meta_index(ind, index=None):
    return ind[:0]


@meta_nonempty.register(MyDataFrame)
def make_meta_nonempty_dataframe(x):
    return MyDataFrame(dask.dataframe.backends.meta_nonempty_dataframe(x))


df = dx.from_dict(
    {"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]}, npartitions=4, constructor=MyDataFrame
)
a = df.groupby("a").agg("first")
b = df.groupby("a").agg("first", split_out=2)

print("split-out=None", type(a.compute()))
print("split-out=2   ", type(b.compute()))

running that produces

$ python test.py
split-out=None <class '__main__.MyDataFrame'>
split-out=2    <class 'pandas.core.frame.DataFrame'>

I would expect the type there to be __main__.MyDataFrame regardless of split_out.

Anything else we need to know?:

Environment:

dask               2024.4.1
dask-expr          1.0.11

Edit: I made one addition to the script: adding a @meta_nonempty.register(MyDataFrame). I noticed that in DecomposableGroupbyAggregation.combine and DecomposableGroupbyAggregation.aggregate the types were regular pandas DataFrames, instead of the subclass.

Registering that meta_nonempty does keep it as MyDataFrame initially. I put some print statements in those methods to print the type of inputs[0] and type(_concat(inputs)) and get

combine <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class 'pandas.core.frame.DataFrame'> <class 'pandas.core.frame.DataFrame'>
aggregate <class 'pandas.core.frame.DataFrame'> <class 'pandas.core.frame.DataFrame'>
split-out=2    <class 'pandas.core.frame.DataFrame'>

So initially we're OK, but by the time we do the final aggregate we've lost the subclass.

@TomAugspurger TomAugspurger changed the title DataFrame subclass in groupby.agg with split_out set. DataFrame subclass lost in groupby.agg with split_out set. Apr 14, 2024
@phofl
Copy link
Collaborator

phofl commented Apr 15, 2024

This is a shuffle issue (and also present on the current implementation if I am not mistaken?)

df.shuffle("a") will lose your type, that's what we do under the hood if split_out != 1. shuffle_method="tasks" keeps it, disk and p2p lose it.

I can patch that so that your resulting DataFrame will have the correct type, but I don't know if we can guarantee that we keep whatever you might add to the subclass through shuffles without you overriding the shuffle specific methods

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants