Skip to content

Commit

Permalink
Support reindexing in simple_combine
Browse files Browse the repository at this point in the history
For 1D combine, great improvement for cohorts-type reductions
More memory but similar time for map-reduce.

Note that the map-reduce intermediates are a worst case where there are
no shared groups between the chunks being combined.
This case is actually optimized in _group_combine where reindexing is
skipped for reducing along a single axis.

[ 68.75%] ··· =========== ========= =========
              --                combine
              ----------- -------------------
                  kind     grouped   combine
              =========== ========= =========
                cohorts      760M      631M
               mapreduce     981M     1.81G
              =========== ========= =========

[ 75.00%] ··· =========== ========== ===========
              --                 combine
              ----------- ----------------------
                  kind     grouped     combine
              =========== ========== ===========
                cohorts    393±10ms    137±10ms
               mapreduce   652±10ms   611±400ms
              =========== ========== ===========

Fix bug in unique
  • Loading branch information
dcherian committed Oct 20, 2022
1 parent c370b5d commit 394925e
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 85 deletions.
29 changes: 19 additions & 10 deletions asv_bench/benchmarks/combine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import numpy as np

import flox
Expand All @@ -7,26 +9,31 @@
N = 1000


def _get_combine(combine):
if combine == "grouped":
return partial(flox.core._grouped_combine, engine="numpy")
else:
return partial(flox.core._simple_combine, reindex=False)


class Combine:
def setup(self, *args, **kwargs):
raise NotImplementedError

@parameterized("kind", ("cohorts", "mapreduce"))
def time_combine(self, kind):
flox.core._grouped_combine(
@parameterized(("kind", "combine"), (("reindexed", "not_reindexed"), ("grouped", "simple")))
def time_combine(self, kind, combine):
_get_combine(combine)(
getattr(self, f"x_chunk_{kind}"),
**self.kwargs,
keepdims=True,
engine="numpy",
)

@parameterized("kind", ("cohorts", "mapreduce"))
def peakmem_combine(self, kind):
flox.core._grouped_combine(
@parameterized(("kind", "combine"), (("reindexed", "not_reindexed"), ("grouped", "simple")))
def peakmem_combine(self, kind, combine):
_get_combine(combine)(
getattr(self, f"x_chunk_{kind}"),
**self.kwargs,
keepdims=True,
engine="numpy",
)


Expand All @@ -47,7 +54,7 @@ def construct_member(groups):
}

# motivated by
self.x_chunk_mapreduce = [
self.x_chunk_not_reindexed = [
construct_member(groups)
for groups in [
np.array((1, 2, 3, 4)),
Expand All @@ -57,5 +64,7 @@ def construct_member(groups):
* 2
]

self.x_chunk_cohorts = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4]
self.x_chunk_reindexed = [
construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4
]
self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,)}
103 changes: 59 additions & 44 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,8 +816,25 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
return results


def _find_unique_groups(x_chunk):
from dask.base import flatten
from dask.utils import deepmap

unique_groups = _unique(np.asarray(tuple(flatten(deepmap(listify_groups, x_chunk)))))
unique_groups = unique_groups[~isnull(unique_groups)]

if len(unique_groups) == 0:
unique_groups = [np.nan]
return unique_groups


def _simple_combine(
x_chunk, agg: Aggregation, axis: T_Axes, keepdims: bool, is_aggregate: bool = False
x_chunk,
agg: Aggregation,
axis: T_Axes,
keepdims: bool,
reindex: bool,
is_aggregate: bool = False,
) -> IntermediateDict:
"""
'Simple' combination of blockwise results.
Expand All @@ -830,8 +847,19 @@ def _simple_combine(
4. At the final agggregate step, we squeeze out DUMMY_AXIS
"""
from dask.array.core import deepfirst
from dask.utils import deepmap

if not reindex:
# We didn't reindex at the blockwise step
# So now reindex before combining by reducing along DUMMY_AXIS
unique_groups = _find_unique_groups(x_chunk)
x_chunk = deepmap(
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
)
else:
unique_groups = deepfirst(x_chunk)["groups"]

results: IntermediateDict = {"groups": deepfirst(x_chunk)["groups"]}
results: IntermediateDict = {"groups": unique_groups}
results["intermediates"] = []
axis_ = axis[:-1] + (DUMMY_AXIS,)
for idx, combine in enumerate(agg.combine):
Expand Down Expand Up @@ -886,7 +914,6 @@ def _grouped_combine(
sort: bool = True,
) -> IntermediateDict:
"""Combine intermediates step of tree reduction."""
from dask.base import flatten
from dask.utils import deepmap

if isinstance(x_chunk, dict):
Expand All @@ -897,11 +924,7 @@ def _grouped_combine(
# when there's only a single axis of reduction, we can just concatenate later,
# reindexing is unnecessary
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
unique_groups = _unique(np.array(tuple(flatten(deepmap(listify_groups, x_chunk)))))
unique_groups = unique_groups[~isnull(unique_groups)]
if len(unique_groups) == 0:
unique_groups = [np.nan]

unique_groups = _find_unique_groups(x_chunk)
x_chunk = deepmap(
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
)
Expand Down Expand Up @@ -1216,7 +1239,8 @@ def dask_groupby_agg(
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)

do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg)
do_simple_combine = not _is_arg_reduction(agg)

if method == "blockwise":
# use the "non dask" code path, but applied blockwise
blockwise_method = partial(
Expand Down Expand Up @@ -1268,31 +1292,32 @@ def dask_groupby_agg(
if method in ["map-reduce", "cohorts"]:
combine: Callable[..., IntermediateDict]
if do_simple_combine:
combine = _simple_combine
combine = partial(_simple_combine, reindex=reindex)
combine_name = "simple-combine"
else:
combine = partial(_grouped_combine, engine=engine, sort=sort)
combine_name = "grouped-combine"

# Each chunk of `reduced`` is really a dict mapping
# 1. reduction name to array
# 2. "groups" to an array of group labels
# Note: it does not make sense to interpret axis relative to
# shape of intermediate results after the blockwise call
tree_reduce = partial(
dask.array.reductions._tree_reduce,
combine=partial(combine, agg=agg),
name=f"{name}-reduce-{method}",
name=f"{name}-reduce-{method}-{combine_name}",
dtype=array.dtype,
axis=axis,
keepdims=True,
concatenate=False,
)
aggregate = partial(
_aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex
)
aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value)

# Each chunk of `reduced`` is really a dict mapping
# 1. reduction name to array
# 2. "groups" to an array of group labels
# Note: it does not make sense to interpret axis relative to
# shape of intermediate results after the blockwise call
if method == "map-reduce":
reduced = tree_reduce(
intermediate,
aggregate=partial(aggregate, expected_groups=expected_groups),
combine=partial(combine, agg=agg),
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
)
if is_duck_dask_array(by_input) and expected_groups is None:
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
Expand All @@ -1310,23 +1335,17 @@ def dask_groupby_agg(
reduced_ = []
groups_ = []
for blks, cohort in chunks_cohorts.items():
index = pd.Index(cohort)
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
if do_simple_combine:
# reindex so that reindex can be set to True later
reindexed = dask.array.map_blocks(
reindex_intermediates,
subset,
agg=agg,
unique_groups=cohort,
meta=subset._meta,
)
else:
reindexed = subset

reindexed = dask.array.map_blocks(
reindex_intermediates, subset, agg=agg, unique_groups=index, meta=subset._meta
)
# now that we have reindexed, we can set reindex=True explicitlly
reduced_.append(
tree_reduce(
reindexed,
aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex),
combine=partial(combine, agg=agg, reindex=True),
aggregate=partial(aggregate, expected_groups=index, reindex=True),
)
)
groups_.append(cohort)
Expand Down Expand Up @@ -1382,28 +1401,24 @@ def _validate_reindex(
if reindex is True:
if _is_arg_reduction(func):
raise NotImplementedError
if method == "blockwise":
raise NotImplementedError
if method in ["blockwise", "cohorts"]:
raise ValueError(
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
)

if reindex is None:
if method == "blockwise" or _is_arg_reduction(func):
reindex = False

elif expected_groups is not None:
reindex = True

elif method in ["split-reduce", "cohorts"]:
reindex = True
elif method == "cohorts":
reindex = False

elif method == "map-reduce":
if expected_groups is None and by_is_dask:
reindex = False
else:
reindex = True

if method in ["split-reduce", "cohorts"] and reindex is False:
raise NotImplementedError

assert isinstance(reindex, bool)
return reindex

Expand Down
71 changes: 40 additions & 31 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from functools import partial, reduce
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -219,29 +220,31 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
assert actual.dtype.kind == "i"
assert_equal(actual, expected, tolerance)

if not has_dask:
if not has_dask or chunks is None:
continue
for method in ["map-reduce", "cohorts", "split-reduce"]:
if method == "map-reduce":
reindexes = [True, False, None]

params = list(itertools.product(["map-reduce"], [True, False, None]))
params.extend(itertools.product(["cohorts"], [False, None]))
for method, reindex in params:
call = partial(
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
)
if "arg" in func and reindex is True:
# simple_combine with argreductions not supported right now
with pytest.raises(NotImplementedError):
call()
continue
actual, *groups = call()
if "arg" not in func:
# make sure we use simple combine
assert any("simple-combine" in key for key in actual.dask.layers.keys())
else:
reindexes = [None]
for reindex in reindexes:
call = partial(
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
)
if "arg" in func:
if method != "map-reduce" or reindex is True:
with pytest.raises(NotImplementedError):
call()
continue

actual, *groups = call()
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
assert actual.dtype.kind == "i"
assert_equal(actual, expected, tolerance)
assert any("grouped-combine" in key for key in actual.dask.layers.keys())
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
assert actual.dtype.kind == "i"
assert_equal(actual, expected, tolerance)


@requires_dask
Expand Down Expand Up @@ -1140,7 +1143,6 @@ def test_subset_block_2d(flatblocks, expectidx):
assert_equal(subset, array.compute()[expectidx])


@pytest.mark.parametrize("method", ["map-reduce", "cohorts"])
@pytest.mark.parametrize(
"expected, reindex, func, expected_groups, by_is_dask",
[
Expand All @@ -1158,13 +1160,20 @@ def test_subset_block_2d(flatblocks, expectidx):
[True, None, "sum", ([1], None), True],
],
)
def test_validate_reindex(expected, reindex, func, method, expected_groups, by_is_dask):
if by_is_dask and method == "cohorts":
# This should error elsewhere
pytest.skip()
call = partial(_validate_reindex, reindex, func, method, expected_groups, by_is_dask)
if "arg" in func and method == "cohorts":
def test_validate_reindex_map_reduce(expected, reindex, func, expected_groups, by_is_dask):
actual = _validate_reindex(reindex, func, "map-reduce", expected_groups, by_is_dask)
assert actual == expected


def test_validate_reindex():
for method in ["map-reduce", "cohorts"]:
with pytest.raises(NotImplementedError):
call()
else:
assert call() == expected
_validate_reindex(True, "argmax", method, expected_groups=None, by_is_dask=False)

for method in ["blockwise", "cohorts"]:
with pytest.raises(ValueError):
_validate_reindex(True, "sum", method, expected_groups=None, by_is_dask=False)

for func in ["sum", "argmax"]:
actual = _validate_reindex(None, func, method, expected_groups=None, by_is_dask=False)
assert actual is False

0 comments on commit 394925e

Please sign in to comment.