Skip to content

Commit

Permalink
Add optional initial_func to partial_reduce to reduce initial chunks …
Browse files Browse the repository at this point in the history
…in same task
  • Loading branch information
tomwhite committed Jan 29, 2024
1 parent abc0947 commit 39487d5
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,24 +926,16 @@ def reduction_new(
if intermediate_dtype is None:
intermediate_dtype = dtype

inds = tuple(range(x.ndim))

result = x
split_every = _normalize_split_every(split_every, axis)

# reduce initial chunks
args = (result, inds)
adjust_chunks = {
i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks)
}
result = blockwise(
func,
inds,
*args,
axis=axis,
keepdims=True,
result = partial_reduce(
x,
partial(combine_func, **(extra_func_kwargs or {})),
initial_func=partial(
func, axis=axis, keepdims=True, **(extra_func_kwargs or {})
),
split_every=split_every,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
extra_func_kwargs=extra_func_kwargs,
)

# combine intermediates
Expand All @@ -952,6 +944,7 @@ def reduction_new(
partial(combine_func, **(extra_func_kwargs or {})),
axis=axis,
dtype=intermediate_dtype,
split_every=split_every,
)

# aggregate final chunks
Expand Down Expand Up @@ -1012,8 +1005,23 @@ def tree_reduce(
return x


def partial_reduce(x, func, split_every, dtype=None):
"""Apply a reduction function to multiple blocks across multiple axes."""
def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None):
"""Apply a reduction function to multiple blocks across multiple axes.
Parameters
----------
x: Array
Array being reduced along one or more axes
func: callable
Reduction function to apply to each chunk of data, resulting in a chunk
with size one in each of the reduction axes.
initial_func: callable, optional
Function to apply to each chunk of data before reduction.
split_every: int >= 2 or dict(axis: int), optional
The depth of the recursive aggregation.
dtype: DType
Output data type.
"""
# map over output chunks
chunks = [
(1,) * math.ceil(len(c) / split_every[i]) if i in split_every else c
Expand Down Expand Up @@ -1046,14 +1054,17 @@ def block_function(out_key):
chunks=chunks,
extra_projected_mem=0,
reduce_func=func,
initial_func=initial_func,
axis=axis,
)


def _partial_reduce(arrays, reduce_func=None, axis=None):
def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None):
# reduce each array in turn, accumulating in result
result = None
for array in arrays:
if initial_func is not None:
array = initial_func(array)
reduced_chunk = reduce_func(array, axis=axis, keepdims=True)
if result is None:
result = reduced_chunk
Expand Down

0 comments on commit 39487d5

Please sign in to comment.