diff --git a/cubed/core/ops.py b/cubed/core/ops.py index c2a4b76e..7746616d 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -926,24 +926,16 @@ def reduction_new( if intermediate_dtype is None: intermediate_dtype = dtype - inds = tuple(range(x.ndim)) - - result = x + split_every = _normalize_split_every(split_every, axis) - # reduce initial chunks - args = (result, inds) - adjust_chunks = { - i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks) - } - result = blockwise( - func, - inds, - *args, - axis=axis, - keepdims=True, + result = partial_reduce( + x, + partial(combine_func, **(extra_func_kwargs or {})), + initial_func=partial( + func, axis=axis, keepdims=True, **(extra_func_kwargs or {}) + ), + split_every=split_every, dtype=intermediate_dtype, - adjust_chunks=adjust_chunks, - extra_func_kwargs=extra_func_kwargs, ) # combine intermediates @@ -952,6 +944,7 @@ def reduction_new( partial(combine_func, **(extra_func_kwargs or {})), axis=axis, dtype=intermediate_dtype, + split_every=split_every, ) # aggregate final chunks @@ -1012,8 +1005,23 @@ def tree_reduce( return x -def partial_reduce(x, func, split_every, dtype=None): - """Apply a reduction function to multiple blocks across multiple axes.""" +def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None): + """Apply a reduction function to multiple blocks across multiple axes. + + Parameters + ---------- + x: Array + Array being reduced along one or more axes + func: callable + Reduction function to apply to each chunk of data, resulting in a chunk + with size one in each of the reduction axes. + initial_func: callable, optional + Function to apply to each chunk of data before reduction. + split_every: int >= 2 or dict(axis: int), optional + The depth of the recursive aggregation. + dtype: DType + Output data type. + """ # map over output chunks chunks = [ (1,) * math.ceil(len(c) / split_every[i]) if i in split_every else c @@ -1046,14 +1054,17 @@ def block_function(out_key): chunks=chunks, extra_projected_mem=0, reduce_func=func, + initial_func=initial_func, axis=axis, ) -def _partial_reduce(arrays, reduce_func=None, axis=None): +def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None): # reduce each array in turn, accumulating in result result = None for array in arrays: + if initial_func is not None: + array = initial_func(array) reduced_chunk = reduce_func(array, axis=axis, keepdims=True) if result is None: result = reduced_chunk