Skip to content

Commit

Permalink
Get extra_projected_mem right for partial_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jan 29, 2024
1 parent 39487d5 commit 9f11558
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None):
def block_function(out_key):
out_coords = out_key[1:]

# return a tuple with a single item that is the list of input keys to be merged
# return a tuple with a single item that is an iterator of input keys to be merged
in_keys = [
list(
range(
Expand All @@ -1043,7 +1043,13 @@ def block_function(out_key):
)
for i, bi in enumerate(out_coords)
]
return ([(x.name,) + tuple(p) for p in product(*in_keys)],)
return (iter([(x.name,) + tuple(p) for p in product(*in_keys)]),)

# Since block_function returns an iterator input keys, the the array chunks passed to
# _partial_reduce are retrieved one at a time. However, we need an extra chunk of memory
# to stay within limits (maybe because the iterator doesn't free the previous object
# before getting the next).
extra_projected_mem = x.chunkmem

return general_blockwise(
_partial_reduce,
Expand All @@ -1052,7 +1058,7 @@ def block_function(out_key):
shape=shape,
dtype=dtype,
chunks=chunks,
extra_projected_mem=0,
extra_projected_mem=extra_projected_mem,
reduce_func=func,
initial_func=initial_func,
axis=axis,
Expand Down

0 comments on commit 9f11558

Please sign in to comment.