Skip to content

Commit

Permalink
Force reindex to be bool always
Browse files Browse the repository at this point in the history
Closes #155

Turns out we weren't using the more efficient simple_combine with
map_reduce in all cases because do_simple_combine was None when reindex
was None.

Now the default for map-reduce is
1. reindex=True when (expected_groups is not None)
   or (expected_groups is None and by_is_dask is False)
  • Loading branch information
dcherian committed Oct 19, 2022
1 parent 6897240 commit 5348d47
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
31 changes: 20 additions & 11 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,27 +1388,35 @@ def dask_groupby_agg(
return (result, groups)


def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_groups) -> bool | None:
def _validate_reindex(
reindex: bool | None, func, method: T_Method, expected_groups, by_is_dask: bool
) -> bool:
if reindex is True:
if _is_arg_reduction(func):
raise NotImplementedError
if method == "blockwise":
raise NotImplementedError

if method == "blockwise" or _is_arg_reduction(func):
reindex = False
if reindex is None:
if method == "blockwise" or _is_arg_reduction(func):
reindex = False

if reindex is None and expected_groups is not None:
reindex = True
elif expected_groups is not None:
reindex = True

elif method in ["split-reduce", "cohorts"]:
reindex = True

elif method == "map-reduce":
if expected_groups is None and by_is_dask:
reindex = False
else:
reindex = True

if method in ["split-reduce", "cohorts"] and reindex is False:
raise NotImplementedError

if method in ["split-reduce", "cohorts"] and reindex is None:
reindex = True

# TODO: Should reindex be a bool-only at this point? Would've been nice but
# None's are relied on after this function as well.
assert isinstance(reindex, bool)
return reindex


Expand Down Expand Up @@ -1597,7 +1605,6 @@ def groupby_reduce(
"argreductions not supported for engine='flox' yet."
"Try engine='numpy' or engine='numba' instead."
)
reindex = _validate_reindex(reindex, func, method, expected_groups)

bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(bys)
Expand All @@ -1606,6 +1613,8 @@ def groupby_reduce(
if method in ["split-reduce", "cohorts"] and by_is_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

reindex = _validate_reindex(reindex, func, method, expected_groups, by_is_dask)

if not is_duck_array(array):
array = np.asarray(array)
is_bool_array = np.issubdtype(array.dtype, bool)
Expand Down
19 changes: 13 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,19 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
for method in ["map-reduce", "cohorts", "split-reduce"]:
if "arg" in func and method != "map-reduce":
continue
actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs)
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
assert actual.dtype.kind == "i"
assert_equal(actual, expected, tolerance)
if method == "map-reduce":
reindexes = [True, False, None]
else:
reindexes = [None]
for reindex in reindexes:
actual, *groups = groupby_reduce(
array, *by, method=method, reindex=reindex, **flox_kwargs
)
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
assert actual.dtype.kind == "i"
assert_equal(actual, expected, tolerance)


@requires_dask
Expand Down

0 comments on commit 5348d47

Please sign in to comment.