Skip to content

Commit

Permalink
Force reindex to be bool always (#176)
Browse files Browse the repository at this point in the history
* Force reindex to be bool always

Closes #155

Turns out we weren't using the more efficient simple_combine with
map_reduce in all cases because do_simple_combine was None when reindex
was None.

Now the default for map-reduce is reindex=True when (expected_groups is not None)
or (expected_groups is None and by_is_dask is False)
  • Loading branch information
dcherian authored Oct 19, 2022
1 parent 6897240 commit 47e0b38
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 20 deletions.
31 changes: 20 additions & 11 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,27 +1388,35 @@ def dask_groupby_agg(
return (result, groups)


def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_groups) -> bool | None:
def _validate_reindex(
reindex: bool | None, func, method: T_Method, expected_groups, by_is_dask: bool
) -> bool:
if reindex is True:
if _is_arg_reduction(func):
raise NotImplementedError
if method == "blockwise":
raise NotImplementedError

if method == "blockwise" or _is_arg_reduction(func):
reindex = False
if reindex is None:
if method == "blockwise" or _is_arg_reduction(func):
reindex = False

if reindex is None and expected_groups is not None:
reindex = True
elif expected_groups is not None:
reindex = True

elif method in ["split-reduce", "cohorts"]:
reindex = True

elif method == "map-reduce":
if expected_groups is None and by_is_dask:
reindex = False
else:
reindex = True

if method in ["split-reduce", "cohorts"] and reindex is False:
raise NotImplementedError

if method in ["split-reduce", "cohorts"] and reindex is None:
reindex = True

# TODO: Should reindex be a bool-only at this point? Would've been nice but
# None's are relied on after this function as well.
assert isinstance(reindex, bool)
return reindex


Expand Down Expand Up @@ -1597,7 +1605,6 @@ def groupby_reduce(
"argreductions not supported for engine='flox' yet."
"Try engine='numpy' or engine='numba' instead."
)
reindex = _validate_reindex(reindex, func, method, expected_groups)

bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(bys)
Expand All @@ -1606,6 +1613,8 @@ def groupby_reduce(
if method in ["split-reduce", "cohorts"] and by_is_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

reindex = _validate_reindex(reindex, func, method, expected_groups, by_is_dask)

if not is_duck_array(array):
array = np.asarray(array)
is_bool_array = np.issubdtype(array.dtype, bool)
Expand Down
61 changes: 52 additions & 9 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from functools import reduce
from functools import partial, reduce
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -13,6 +13,7 @@
_convert_expected_groups_to_index,
_get_optimal_chunks_for_groups,
_normalize_indexes,
_validate_reindex,
factorize_,
find_group_cohorts,
groupby_reduce,
Expand Down Expand Up @@ -221,14 +222,26 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
if not has_dask:
continue
for method in ["map-reduce", "cohorts", "split-reduce"]:
if "arg" in func and method != "map-reduce":
continue
actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs)
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
assert actual.dtype.kind == "i"
assert_equal(actual, expected, tolerance)
if method == "map-reduce":
reindexes = [True, False, None]
else:
reindexes = [None]
for reindex in reindexes:
call = partial(
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
)
if "arg" in func:
if method != "map-reduce" or reindex is True:
with pytest.raises(NotImplementedError):
call()
continue

actual, *groups = call()
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
assert actual.dtype.kind == "i"
assert_equal(actual, expected, tolerance)


@requires_dask
Expand Down Expand Up @@ -1125,3 +1138,33 @@ def test_subset_block_2d(flatblocks, expectidx):
subset = subset_to_blocks(array, flatblocks)
assert len(subset.dask.layers) == 2
assert_equal(subset, array.compute()[expectidx])


@pytest.mark.parametrize("method", ["map-reduce", "cohorts"])
@pytest.mark.parametrize(
"expected, reindex, func, expected_groups, by_is_dask",
[
# argmax only False
[False, None, "argmax", None, False],
# True when by is numpy but expected is None
[True, None, "sum", None, False],
# False when by is dask but expected is None
[False, None, "sum", None, True],
# if expected_groups then always True
[True, None, "sum", [1, 2, 3], False],
[True, None, "sum", ([1], [2]), False],
[True, None, "sum", ([1], [2]), True],
[True, None, "sum", ([1], None), False],
[True, None, "sum", ([1], None), True],
],
)
def test_validate_reindex(expected, reindex, func, method, expected_groups, by_is_dask):
if by_is_dask and method == "cohorts":
# This should error elsewhere
pytest.skip()
call = partial(_validate_reindex, reindex, func, method, expected_groups, by_is_dask)
if "arg" in func and method == "cohorts":
with pytest.raises(NotImplementedError):
call()
else:
assert call() == expected

0 comments on commit 47e0b38

Please sign in to comment.