Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scan. #531

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 119 additions & 1 deletion cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from itertools import product
from numbers import Integral, Number
from operator import add
from typing import TYPE_CHECKING, Any, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Sequence, Union
from warnings import warn

import ndindex
Expand All @@ -22,6 +22,7 @@
from cubed.core.plan import Plan, new_temp_path
from cubed.primitive.blockwise import blockwise as primitive_blockwise
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
from cubed.primitive.blockwise import key_to_slices
from cubed.primitive.rechunk import rechunk as primitive_rechunk
from cubed.spec import spec_from_config
from cubed.storage.backend import open_backend_array
Expand Down Expand Up @@ -1442,3 +1443,120 @@ def smallest_blockdim(blockdims):
m = ntd[0]
out = ntd
return out


def wrapper_binop(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call something like _scan_binop to link it to the scan implementation? I've been using a naming convention like that elsewhere in the file.

out: np.ndarray,
left: Array,
right: Array,
*,
binop: Callable,
block_id: tuple[int, ...],
axis: int,
identity: Any,
) -> Array:
# print(type(out), out.shape)
# print(block_id)
# print("left", left)
# print("right", right)
left_slicer = key_to_slices(block_id, left)
right_slicer = list(left_slicer)

# For the first block, we add the identity element
# For all other blocks `k`, we add the `k-1` element along `axis`
right_slicer[axis] = slice(block_id[axis] - 1, block_id[axis])
right_slicer = tuple(right_slicer)
right_ = right[right_slicer] if block_id[axis] > 0 else identity
# print("left", left[left_slicer].shape)
# print("right", right_.shape)
return binop(left[left_slicer], right_)


def scan(
array: "Array",
func: Callable,
*,
preop: Callable,
binop: Callable,
identity: Any,
axis: int,
dtype=None,
) -> Array:
"""
Generic parallel scan.

Parameters
----------
x: Cubed Array
func: callable
Scan or cumulative function like np.cumsum or np.cumprod
preop: callable
Function applied blockwise that reduces each block to a single value
along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``.
binop: callable
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
identity: Any
Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``.
axis: int
dtype: dtype

Notes
-----
This method uses a variant of the Blelloch (1989) alogrithm.

Returns
-------
Array

See also
--------
cumsum
cumprod
"""
# Blelloch (1990) out-of-core algorithm.
# 1. First, scan blockwise
scanned = blockwise(func, "ij", array, "ij", axis=axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using map_blocks would be simpler and avoid the 2D assumption

# If there is only a single chunk, we can be done
if array.numblocks[-1] == 1:
return scanned

# 2. Calculate the blockwise reduction using `preop`
# TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned`
reduced = blockwise(
preop, "ij", array, "ij", axis=axis, adjust_chunks={"j": 1}, keepdims=True
)

# 3. Now scan `reduced` to generate the increments for each block of `scanned`.
# Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan.
# Instead we generalize recursively apply the scan to `reduced`.
# 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1
new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need input here on choosing a new intermediate chunksize to rechunk to based on memory info.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple of things to consider here: the number of chunks to combine at each stage, and the memory limits.

The first is like split_every in reduction, where the default is 4, although 6 or 8 may be better for larger workloads.

For the second, we should make sure the new chunksize is no larger than (x.spec.allowed_mem - x.spec.reserved_mem) // 4, where the factor of 4 is comes about because of the {compressed,uncompressed} * {input,output} copies.

There is an error case where this memory constraint means the new chunksize is no larger than the existing one, so the computation can't proceed. The user can fix this either by reducing the chunksize or by increasing the memory. This is similar to this case:

cubed/cubed/core/ops.py

Lines 985 to 991 in 88c5dc4

# single axis: see how many result chunks fit in max_mem
# factor of 4 is memory for {compressed, uncompressed} x {input, output}
target_chunk_size = (max_mem - chunk_mem) // (chunk_mem * 4)
if target_chunk_size <= 1:
raise ValueError(
f"Not enough memory for reduction. Increase allowed_mem ({allowed_mem}) or decrease chunk size"
)

new_chunks = reduced.chunksize[:-1] + (new_chunksize,)
merged = merge_chunks(reduced, new_chunks)

# 3b. Recursively scan this merged array to generate the increment for each block of `scanned`
increment = scan(
merged, func, preop=preop, binop=binop, identity=identity, axis=axis
)

# 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`.
# Use map_direct since the chunks of increment and scanned aren't aligned anymore.
assert increment.shape[axis] == scanned.numblocks[axis]
# 5. Bada-bing, bada-boom.
return map_direct(
partial(wrapper_binop, binop=binop, axis=axis, identity=identity),
scanned,
increment,
shape=scanned.shape,
dtype=scanned.dtype,
chunks=scanned.chunks,
extra_projected_mem=scanned.chunkmem * 2, # arbitrary
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need input here too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the memory allocated to read from the side inputs (scanned and increment here). We double the chunk memory to account for reading the compressed Zarr chunk, so the result would be

extra_projected_mem=scanned.chunkmem * 2 + increment.chunkmem * 2

(There's an open issue #288 to make this a bit more transparent.)

)


# result = scan(
# array, preop=np.sum, func=np.cumsum, binop=np.add, identity=0, axis=-1
# )
# print(result)
# print(result.compute())
# np.testing.assert_equal(result, np.cumsum(array.compute(), axis=-1))