diff --git a/flox/core.py b/flox/core.py index 344e2b822..b67541ddb 100644 --- a/flox/core.py +++ b/flox/core.py @@ -736,13 +736,6 @@ def _squeeze_results(results: IntermediateDict, axis: T_Axes) -> IntermediateDic return newresults -def _split_groups(array, j, slicer): - """Slices out chunks when split_out > 1""" - results = {"groups": array["groups"][..., slicer]} - results["intermediates"] = [v[..., slicer] for v in array["intermediates"]] - return results - - def _finalize_results( results: IntermediateDict, agg: Aggregation, @@ -997,38 +990,6 @@ def _grouped_combine( return results -def split_blocks(applied, split_out, expected_groups, split_name): - import dask.array - from dask.array.core import normalize_chunks - from dask.highlevelgraph import HighLevelGraph - - chunk_tuples = tuple(itertools.product(*tuple(range(n) for n in applied.numblocks))) - ngroups = len(expected_groups) - group_chunks = normalize_chunks(np.ceil(ngroups / split_out), (ngroups,)) - idx = tuple(np.cumsum((0,) + group_chunks[0])) - - # split each block into `split_out` chunks - dsk = {} - for i in chunk_tuples: - for j in range(split_out): - dsk[(split_name, *i, j)] = ( - _split_groups, - (applied.name, *i), - j, - slice(idx[j], idx[j + 1]), - ) - - # now construct an array that can be passed to _tree_reduce - intergraph = HighLevelGraph.from_collections(split_name, dsk, dependencies=(applied,)) - intermediate = dask.array.Array( - intergraph, - name=split_name, - chunks=applied.chunks + ((1,) * split_out,), - meta=applied._meta, - ) - return intermediate, group_chunks - - def _reduce_blockwise( array, by, @@ -1169,7 +1130,6 @@ def dask_groupby_agg( agg: Aggregation, expected_groups: pd.Index | None, axis: T_Axes = (), - split_out: int = 1, fill_value: Any = None, method: T_Method = "map-reduce", reindex: bool = False, @@ -1186,19 +1146,14 @@ def dask_groupby_agg( assert isinstance(axis, Sequence) assert all(ax >= 0 for ax in axis) - if method == "blockwise" and (split_out > 1 or not isinstance(by, np.ndarray)): - raise NotImplementedError - - if split_out > 1 and expected_groups is None: - # This could be implemented using the "hash_split" strategy - # from dask.dataframe + if method == "blockwise" and not isinstance(by, np.ndarray): raise NotImplementedError inds = tuple(range(array.ndim)) name = f"groupby_{agg.name}" - token = dask.base.tokenize(array, by, agg, expected_groups, axis, split_out) + token = dask.base.tokenize(array, by, agg, expected_groups, axis) - if expected_groups is None and (reindex or split_out > 1): + if expected_groups is None and reindex: expected_groups = _get_expected_groups(by, sort=sort) by_input = by @@ -1229,9 +1184,7 @@ def dask_groupby_agg( # This allows us to discover groups at compute time, support argreductions, lower intermediate # memory usage (but method="cohorts" would also work to reduce memory in some cases) - do_simple_combine = ( - method != "blockwise" and reindex and not _is_arg_reduction(agg) and split_out == 1 - ) + do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg) if method == "blockwise": # use the "non dask" code path, but applied blockwise blockwise_method = partial( @@ -1244,14 +1197,14 @@ def dask_groupby_agg( func=agg.chunk, fill_value=agg.fill_value["intermediate"], dtype=agg.dtype["intermediate"], - reindex=reindex or (split_out > 1), + reindex=reindex, ) if do_simple_combine: # Add a dummy dimension that then gets reduced over blockwise_method = tlz.compose(_expand_dims, blockwise_method) # apply reduction on chunk - applied = dask.array.blockwise( + intermediate = dask.array.blockwise( partial( blockwise_method, axis=axis, @@ -1271,18 +1224,14 @@ def dask_groupby_agg( token=f"{name}-chunk-{token}", ) - if split_out > 1: - intermediate, group_chunks = split_blocks( - applied, split_out, expected_groups, split_name=f"{name}-split-{token}" - ) - else: - intermediate = applied - if expected_groups is None: - if is_duck_dask_array(by_input): - expected_groups = None - else: - expected_groups = _get_expected_groups(by_input, sort=sort) - group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),) + if expected_groups is None: + if is_duck_dask_array(by_input): + expected_groups = None + else: + expected_groups = _get_expected_groups(by_input, sort=sort) + group_chunks: tuple[tuple[Union[int, float], ...]] = ( + (len(expected_groups),) if expected_groups is not None else (np.nan,), + ) if method in ["map-reduce", "cohorts", "split-reduce"]: combine: Callable[..., IntermediateDict] @@ -1311,9 +1260,7 @@ def dask_groupby_agg( if method == "map-reduce": reduced = tree_reduce( intermediate, - aggregate=partial( - aggregate, expected_groups=None if split_out > 1 else expected_groups - ), + aggregate=partial(aggregate, expected_groups=expected_groups), ) if is_duck_dask_array(by_input) and expected_groups is None: groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype) @@ -1380,7 +1327,7 @@ def dask_groupby_agg( raise ValueError(f"Unknown method={method}.") # extract results from the dict - output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks + output_chunks = reduced.chunks[: -len(axis)] + group_chunks ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks) layer2: dict[tuple, tuple] = {} agg_name = f"{name}-{token}" @@ -1392,10 +1339,7 @@ def dask_groupby_agg( nblocks = tuple(len(array.chunks[ax]) for ax in axis) inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks) else: - inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) - if split_out > 1: - inchunk = inchunk + (0,) - inchunk = inchunk + (ochunk[-1],) + inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],) layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name) @@ -1516,7 +1460,6 @@ def groupby_reduce( fill_value=None, dtype: np.typing.DTypeLike = None, min_count: int | None = None, - split_out: int = 1, method: T_Method = "map-reduce", engine: T_Engine = "numpy", reindex: bool | None = None, @@ -1555,8 +1498,6 @@ def groupby_reduce( fewer than min_count non-NA values are present the result will be NA. Only used if skipna is set to True or defaults to True for the array's dtype. - split_out : int, optional - Number of chunks along group axis in output (last axis) method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional Strategy for reduction of dask arrays only: * ``"map-reduce"``: @@ -1750,7 +1691,7 @@ def groupby_reduce( if kwargs["fill_value"] is None: kwargs["fill_value"] = agg.fill_value[agg.name] - partial_agg = partial(dask_groupby_agg, split_out=split_out, **kwargs) + partial_agg = partial(dask_groupby_agg, **kwargs) if method == "blockwise" and by_.ndim == 1: array = rechunk_for_blockwise(array, axis=-1, labels=by_) diff --git a/flox/xarray.py b/flox/xarray.py index 100d5fb4a..55eefd812 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -62,7 +62,6 @@ def xarray_reduce( isbin: bool | Sequence[bool] = False, sort: bool = True, dim: Dims | ellipsis = None, - split_out: int = 1, fill_value=None, dtype: np.typing.DTypeLike = None, method: str = "map-reduce", @@ -95,8 +94,6 @@ def xarray_reduce( dim : hashable dimension name along which to reduce. If None, reduces across all dimensions of `by` - split_out : int, optional - Number of output chunks along grouped dimension in output. fill_value Value used for missing groups in the output i.e. when one of the labels in ``expected_groups`` is not actually present in ``by``. @@ -397,7 +394,6 @@ def wrapper(array, *by, func, skipna, **kwargs): "func": func, "axis": axis, "sort": sort, - "split_out": split_out, "fill_value": fill_value, "method": method, "min_count": min_count, diff --git a/tests/test_core.py b/tests/test_core.py index 53a71f808..f9d412182 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -79,7 +79,7 @@ def test_alignment_error(): @pytest.mark.parametrize("dtype", (float, int)) -@pytest.mark.parametrize("chunk, split_out", [(False, 1), (True, 1), (True, 2), (True, 3)]) +@pytest.mark.parametrize("chunk", [False, True]) @pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2])]) @pytest.mark.parametrize( "func, array, by, expected", @@ -114,7 +114,6 @@ def test_groupby_reduce( expected: list[float], expected_groups: T_ExpectedGroupsOpt, chunk: bool, - split_out: int, dtype: np.typing.DTypeLike, ) -> None: array = array.astype(dtype) @@ -137,7 +136,6 @@ def test_groupby_reduce( func=func, expected_groups=expected_groups, fill_value=123, - split_out=split_out, engine=engine, ) g_dtype = by.dtype if expected_groups is None else np.asarray(expected_groups).dtype