From 9f11558f69c217dacf2004f116dcbdecc55ee161 Mon Sep 17 00:00:00 2001 From: Tom White Date: Thu, 25 Jan 2024 15:27:09 +0000 Subject: [PATCH] Get extra_projected_mem right for partial_reduce --- cubed/core/ops.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 7746616df..123be2b8a 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -1033,7 +1033,7 @@ def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None): def block_function(out_key): out_coords = out_key[1:] - # return a tuple with a single item that is the list of input keys to be merged + # return a tuple with a single item that is an iterator of input keys to be merged in_keys = [ list( range( @@ -1043,7 +1043,13 @@ def block_function(out_key): ) for i, bi in enumerate(out_coords) ] - return ([(x.name,) + tuple(p) for p in product(*in_keys)],) + return (iter([(x.name,) + tuple(p) for p in product(*in_keys)]),) + + # Since block_function returns an iterator input keys, the the array chunks passed to + # _partial_reduce are retrieved one at a time. However, we need an extra chunk of memory + # to stay within limits (maybe because the iterator doesn't free the previous object + # before getting the next). + extra_projected_mem = x.chunkmem return general_blockwise( _partial_reduce, @@ -1052,7 +1058,7 @@ def block_function(out_key): shape=shape, dtype=dtype, chunks=chunks, - extra_projected_mem=0, + extra_projected_mem=extra_projected_mem, reduce_func=func, initial_func=initial_func, axis=axis,