Skip to content

Commit

Permalink
Add handling for nested dicts in dask-cudf groupby (#9054)
Browse files Browse the repository at this point in the history
Closes #9017 

Adds handling for nested dict (renamed) aggregations supplied to dask-cudf's groupby, by storing the new aggregation names when standardizing the `aggs` input and applying them in `_finalize_gb_agg()`.

Authors:
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

Approvers:
  - Marlene  (https://github.com/marlenezw)
  - Benjamin Zaitlen (https://github.com/quasiben)

URL: #9054
  • Loading branch information
charlesbluca authored Aug 26, 2021
1 parent 263190a commit 4e0584b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 33 deletions.
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,7 +2585,7 @@ def columns(self, columns):

if not len(columns) == len(self._data.names):
raise ValueError(
f"Length mismatch: expected {len(self._data.names)} elements ,"
f"Length mismatch: expected {len(self._data.names)} elements, "
f"got {len(columns)} elements"
)

Expand Down
80 changes: 48 additions & 32 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,34 +198,8 @@ def groupby_agg(
in `dask.dataframe`, because it allows the cudf backend to
perform multiple aggregations at once.
"""

# Deal with default split_out and split_every params
if split_every is False:
split_every = ddf.npartitions
split_every = split_every or 8
split_out = split_out or 1

# Standardize `gb_cols` and `columns` lists
aggs = _redirect_aggs(aggs_in.copy())
if isinstance(gb_cols, str):
gb_cols = [gb_cols]
columns = [c for c in ddf.columns if c not in gb_cols]
str_cols_out = False
if isinstance(aggs, dict):
# Use `str_cols_out` to specify if the output columns
# will have str (rather than MultiIndex/tuple) names.
# This happens when all values in the `aggs` dict are
# strings (no lists)
str_cols_out = True
for col in aggs:
if isinstance(aggs[col], str) or callable(aggs[col]):
aggs[col] = [aggs[col]]
else:
str_cols_out = False
if col in gb_cols:
columns.append(col)

# Assert that aggregations are supported
aggs = _redirect_aggs(aggs_in)
_supported = {
"count",
"mean",
Expand All @@ -244,10 +218,39 @@ def groupby_agg(
f"Aggregations must be specified with dict or list syntax."
)

# Always convert aggs to dict for consistency
# Deal with default split_out and split_every params
if split_every is False:
split_every = ddf.npartitions
split_every = split_every or 8
split_out = split_out or 1

# Standardize `gb_cols`, `columns`, and `aggs`
if isinstance(gb_cols, str):
gb_cols = [gb_cols]
columns = [c for c in ddf.columns if c not in gb_cols]
if isinstance(aggs, list):
aggs = {col: aggs for col in columns}

# Assert if our output will have a MultiIndex; this will be the case if
# any value in the `aggs` dict is not a string (i.e. multiple/named
# aggregations per column)
str_cols_out = True
aggs_renames = {}
for col in aggs:
if isinstance(aggs[col], str) or callable(aggs[col]):
aggs[col] = [aggs[col]]
elif isinstance(aggs[col], dict):
str_cols_out = False
col_aggs = []
for k, v in aggs[col].items():
aggs_renames[col, v] = k
col_aggs.append(v)
aggs[col] = col_aggs
else:
str_cols_out = False
if col in gb_cols:
columns.append(col)

# Begin graph construction
dsk = {}
token = tokenize(ddf, gb_cols, aggs)
Expand Down Expand Up @@ -314,6 +317,13 @@ def groupby_agg(
for col in aggs:
_aggs[col] = _aggs[col][0]
_meta = ddf._meta.groupby(gb_cols, as_index=as_index).agg(_aggs)
if aggs_renames:
col_array = []
agg_array = []
for col, agg in _meta.columns:
col_array.append(col)
agg_array.append(aggs_renames.get((col, agg), agg))
_meta.columns = pd.MultiIndex.from_arrays([col_array, agg_array])
for s in range(split_out):
dsk[(gb_agg_name, s)] = (
_finalize_gb_agg,
Expand All @@ -326,6 +336,7 @@ def groupby_agg(
sort,
sep,
str_cols_out,
aggs_renames,
)

divisions = [None] * (split_out + 1)
Expand All @@ -350,6 +361,10 @@ def _redirect_aggs(arg):
for col in arg:
if isinstance(arg[col], list):
new_arg[col] = [redirects.get(agg, agg) for agg in arg[col]]
elif isinstance(arg[col], dict):
new_arg[col] = {
k: redirects.get(v, v) for k, v in arg[col].items()
}
else:
new_arg[col] = redirects.get(arg[col], arg[col])
return new_arg
Expand All @@ -367,6 +382,8 @@ def _is_supported(arg, supported: set):
for col in arg:
if isinstance(arg[col], list):
_global_set = _global_set.union(set(arg[col]))
elif isinstance(arg[col], dict):
_global_set = _global_set.union(set(arg[col].values()))
else:
_global_set.add(arg[col])
else:
Expand Down Expand Up @@ -460,10 +477,8 @@ def _tree_node_agg(dfs, gb_cols, split_out, dropna, sort, sep):
agg = col.split(sep)[-1]
if agg in ("count", "sum"):
agg_dict[col] = ["sum"]
elif agg in ("min", "max"):
elif agg in ("min", "max", "collect"):
agg_dict[col] = [agg]
elif agg == "collect":
agg_dict[col] = ["collect"]
else:
raise ValueError(f"Unexpected aggregation: {agg}")

Expand Down Expand Up @@ -508,6 +523,7 @@ def _finalize_gb_agg(
sort,
sep,
str_cols_out,
aggs_renames,
):
""" Final aggregation task.
Expand Down Expand Up @@ -564,7 +580,7 @@ def _finalize_gb_agg(
else:
name, agg = col.split(sep)
col_array.append(name)
agg_array.append(agg)
agg_array.append(aggs_renames.get((name, agg), agg))
if str_cols_out:
gb.columns = col_array
else:
Expand Down
29 changes: 29 additions & 0 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,3 +645,32 @@ def test_groupby_with_list_of_series():
dd.assert_eq(
gdf.groupby([ggs]).agg(["sum"]), ddf.groupby([pgs]).agg(["sum"])
)


@pytest.mark.parametrize(
"func",
[
lambda df: df.groupby("x").agg({"y": {"foo": "sum"}}),
lambda df: df.groupby("x").agg({"y": {"foo": "sum", "bar": "count"}}),
],
)
def test_groupby_nested_dict(func):
pdf = pd.DataFrame(
{
"x": np.random.randint(0, 5, size=10000),
"y": np.random.normal(size=10000),
}
)

ddf = dd.from_pandas(pdf, npartitions=5)
c_ddf = ddf.map_partitions(cudf.from_pandas)

a = func(ddf).compute()
b = func(c_ddf).compute().to_pandas()

a.index.name = None
a.name = None
b.index.name = None
b.name = None

dd.assert_eq(a, b)

0 comments on commit 4e0584b

Please sign in to comment.