diff --git a/cubed/core/ops.py b/cubed/core/ops.py index e6d1ebc9..694325b9 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -5,7 +5,7 @@ from itertools import product from numbers import Integral, Number from operator import add -from typing import TYPE_CHECKING, Any, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Sequence, Union from warnings import warn import ndindex @@ -22,6 +22,7 @@ from cubed.core.plan import Plan, new_temp_path from cubed.primitive.blockwise import blockwise as primitive_blockwise from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise +from cubed.primitive.blockwise import key_to_slices from cubed.primitive.rechunk import rechunk as primitive_rechunk from cubed.spec import spec_from_config from cubed.storage.backend import open_backend_array @@ -1442,3 +1443,120 @@ def smallest_blockdim(blockdims): m = ntd[0] out = ntd return out + + +def wrapper_binop( + out: np.ndarray, + left: Array, + right: Array, + *, + binop: Callable, + block_id: tuple[int, ...], + axis: int, + identity: Any, +) -> Array: + # print(type(out), out.shape) + # print(block_id) + # print("left", left) + # print("right", right) + left_slicer = key_to_slices(block_id, left) + right_slicer = list(left_slicer) + + # For the first block, we add the identity element + # For all other blocks `k`, we add the `k-1` element along `axis` + right_slicer[axis] = slice(block_id[axis] - 1, block_id[axis]) + right_slicer = tuple(right_slicer) + right_ = right[right_slicer] if block_id[axis] > 0 else identity + # print("left", left[left_slicer].shape) + # print("right", right_.shape) + return binop(left[left_slicer], right_) + + +def scan( + array: "Array", + func: Callable, + *, + preop: Callable, + binop: Callable, + identity: Any, + axis: int, + dtype=None, +) -> Array: + """ + Generic parallel scan. + + Parameters + ---------- + x: Cubed Array + func: callable + Scan or cumulative function like np.cumsum or np.cumprod + preop: callable + Function applied blockwise that reduces each block to a single value + along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``. + binop: callable + Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul`` + identity: Any + Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``. + axis: int + dtype: dtype + + Notes + ----- + This method uses a variant of the Blelloch (1989) alogrithm. + + Returns + ------- + Array + + See also + -------- + cumsum + cumprod + """ + # Blelloch (1990) out-of-core algorithm. + # 1. First, scan blockwise + scanned = blockwise(func, "ij", array, "ij", axis=axis) + # If there is only a single chunk, we can be done + if array.numblocks[-1] == 1: + return scanned + + # 2. Calculate the blockwise reduction using `preop` + # TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned` + reduced = blockwise( + preop, "ij", array, "ij", axis=axis, adjust_chunks={"j": 1}, keepdims=True + ) + + # 3. Now scan `reduced` to generate the increments for each block of `scanned`. + # Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan. + # Instead we generalize recursively apply the scan to `reduced`. + # 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1 + new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5) + new_chunks = reduced.chunksize[:-1] + (new_chunksize,) + merged = merge_chunks(reduced, new_chunks) + + # 3b. Recursively scan this merged array to generate the increment for each block of `scanned` + increment = scan( + merged, func, preop=preop, binop=binop, identity=identity, axis=axis + ) + + # 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`. + # Use map_direct since the chunks of increment and scanned aren't aligned anymore. + assert increment.shape[axis] == scanned.numblocks[axis] + # 5. Bada-bing, bada-boom. + return map_direct( + partial(wrapper_binop, binop=binop, axis=axis, identity=identity), + scanned, + increment, + shape=scanned.shape, + dtype=scanned.dtype, + chunks=scanned.chunks, + extra_projected_mem=scanned.chunkmem * 2, # arbitrary + ) + + +# result = scan( +# array, preop=np.sum, func=np.cumsum, binop=np.add, identity=0, axis=-1 +# ) +# print(result) +# print(result.compute()) +# np.testing.assert_equal(result, np.cumsum(array.compute(), axis=-1))