Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arbitrary block index functions in blockwise #350

Merged
merged 1 commit into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cubed.core.array import CoreArray, check_array_specs, compute, gensym
from cubed.core.plan import Plan, new_temp_path
from cubed.primitive.blockwise import blockwise as primitive_blockwise
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
from cubed.primitive.rechunk import rechunk as primitive_rechunk
from cubed.utils import chunk_memory, get_item, to_chunksize
from cubed.vendor.dask.array.core import common_blockdim, normalize_chunks
Expand Down Expand Up @@ -295,6 +296,60 @@ def blockwise(
return Array(name, pipeline.target_array, spec, plan)


def general_blockwise(
func,
block_function,
*arrays,
shape,
dtype,
chunks,
target_store=None,
extra_func_kwargs=None,
**kwargs,
) -> "Array":
assert len(arrays) > 0

# replace arrays with zarr arrays
zargs = [a.zarray_maybe_lazy for a in arrays]
in_names = [a.name for a in arrays]

extra_source_arrays = kwargs.pop("extra_source_arrays", [])
source_arrays = list(arrays) + list(extra_source_arrays)

extra_projected_mem = kwargs.pop("extra_projected_mem", 0)

name = gensym()
spec = check_array_specs(arrays)
if target_store is None:
target_store = new_temp_path(name=name, spec=spec)
pipeline = primitive_general_blockwise(
func,
block_function,
*zargs,
allowed_mem=spec.allowed_mem,
reserved_mem=spec.reserved_mem,
extra_projected_mem=extra_projected_mem,
target_store=target_store,
shape=shape,
dtype=dtype,
chunks=chunks,
in_names=in_names,
extra_func_kwargs=extra_func_kwargs,
**kwargs,
)
plan = Plan._new(
name,
"blockwise",
pipeline.target_array,
pipeline,
False,
*source_arrays,
)
from cubed.array_api import Array

return Array(name, pipeline.target_array, spec, plan)


def elemwise(func, *args: "Array", dtype=None) -> "Array":
shapes = [arg.shape for arg in args]
out_ndim = len(np.broadcast_shapes(*shapes))
Expand Down
172 changes: 92 additions & 80 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import toolz
import zarr
Expand All @@ -15,7 +15,7 @@
from cubed.runtime.types import CubedPipeline
from cubed.storage.zarr import T_ZarrArray, lazy_empty
from cubed.types import T_Chunks, T_DType, T_Shape, T_Store
from cubed.utils import chunk_memory, get_item, split_into, to_chunksize
from cubed.utils import chunk_memory, get_item, map_nested, split_into, to_chunksize
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product
from cubed.vendor.dask.core import flatten
Expand All @@ -40,7 +40,7 @@ class BlockwiseSpec:
Attributes
----------
block_function : Callable
A function that maps input chunk indexes to an output chunk index.
A function that maps an output chunk index to one or more input chunk indexes.
function : Callable
A function that maps input chunks to an output chunk.
reads_map : Dict[str, CubedArrayProxy]
Expand All @@ -62,15 +62,13 @@ def apply_blockwise(out_key: List[int], *, config: BlockwiseSpec) -> None:
out_chunk_key = key_to_slices(
out_key_tuple, config.write.array, config.write.chunks
)

# get array chunks for input keys, preserving any nested list structure
args = []
get_chunk_config = partial(get_chunk, config=config)
name_chunk_inds = config.block_function(("out",) + out_key_tuple)
for name_chunk_ind in name_chunk_inds:
name = name_chunk_ind[0]
chunk_ind = name_chunk_ind[1:]
arr = config.reads_map[name].open()
chunk_key = key_to_slices(chunk_ind, arr)
arg = arr[chunk_key]
arg = numpy_array_to_backend_array(arg)
arg = map_nested(get_chunk_config, name_chunk_ind)
args.append(arg)

result = config.function(*args)
Expand All @@ -91,6 +89,17 @@ def key_to_slices(
return get_item(chunks, key)


def get_chunk(name_chunk_ind, config):
"""Read a chunk from the named array"""
name = name_chunk_ind[0]
chunk_ind = name_chunk_ind[1:]
arr = config.reads_map[name].open()
chunk_key = key_to_slices(chunk_ind, arr)
arg = arr[chunk_key]
arg = numpy_array_to_backend_array(arg)
return arg


def blockwise(
func: Callable[..., Any],
out_ind: Sequence[Union[str, int]],
Expand All @@ -110,6 +119,9 @@ def blockwise(
):
"""Apply a function across blocks from multiple source Zarr arrays.

Unlike ```general_blockwise``, an index notation is used to specify the block mapping,
like in Dask Array.

Parameters
----------
func : callable
Expand Down Expand Up @@ -146,10 +158,8 @@ def blockwise(
CubedPipeline to run the operation
"""

# Use dask's make_blockwise_graph
arrays: Sequence[T_ZarrArray] = args[::2]
array_names = in_names or [f"in_{i}" for i in range(len(arrays))]
array_map = {name: array for name, array in zip(array_names, arrays)}

inds: Sequence[Union[str, int]] = args[1::2]

Expand All @@ -164,11 +174,6 @@ def blockwise(
for name, ind in zip(array_names, inds):
argindsstr.extend((name, ind))

# TODO: check output shape and chunks are consistent with inputs
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)

# block func

block_function = make_blockwise_function_flattened(
func,
out_name or "out",
Expand All @@ -178,28 +183,78 @@ def blockwise(
new_axes=new_axes,
)

output_blocks_generator_fn = partial(
get_output_blocks,
return general_blockwise(
func,
out_name or "out",
out_ind,
*argindsstr,
numblocks=numblocks,
new_axes=new_axes,
block_function,
*arrays,
allowed_mem=allowed_mem,
reserved_mem=reserved_mem,
target_store=target_store,
shape=shape,
dtype=dtype,
chunks=chunks,
in_names=in_names,
extra_projected_mem=extra_projected_mem,
extra_func_kwargs=extra_func_kwargs,
**kwargs,
)
output_blocks = IterableFromGenerator(output_blocks_generator_fn)

num_tasks = num_output_blocks(
func,
out_name or "out",
out_ind,
*argindsstr,
numblocks=numblocks,
new_axes=new_axes,
)

# end block func
def general_blockwise(
func: Callable[..., Any],
block_function: Callable[..., Any],
*arrays: Any,
allowed_mem: int,
reserved_mem: int,
target_store: T_Store,
shape: T_Shape,
dtype: T_DType,
chunks: T_Chunks,
in_names: Optional[List[str]] = None,
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""A more general form of ``blockwise`` that uses a function to specify the block
mapping, rather than an index notation.

Parameters
----------
func : callable
Function to apply to individual tuples of blocks
block_function : callable
A function that maps an output chunk index to one or more input chunk indexes.
*arrays : sequence of Array
The input arrays.
allowed_mem : int
The memory available to a worker for running a task, in bytes. Includes ``reserved_mem``.
reserved_mem : int
The memory reserved on a worker for non-data use when running a task, in bytes
target_store : string or zarr.Array
Path to output Zarr store, or Zarr array
shape : tuple
The shape of the output array.
dtype : np.dtype
The ``dtype`` of the output array.
chunks : tuple
The chunks of the output array.
extra_projected_mem : int
Extra memory projected to be needed (in bytes) in addition to the memory used reading
the input arrays and writing the output.
extra_func_kwargs : dict
Extra keyword arguments to pass to function that can't be passed as regular keyword arguments
since they clash with other blockwise arguments (such as dtype).
**kwargs : dict
Extra keyword arguments to pass to function

Returns
-------
CubedPipeline to run the operation
"""
array_names = in_names or [f"in_{i}" for i in range(len(arrays))]
array_map = {name: array for name, array in zip(array_names, arrays)}

chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
chunksize = to_chunksize(chunks)
if isinstance(target_store, zarr.Array):
target_array = target_store
Expand Down Expand Up @@ -236,6 +291,10 @@ def blockwise(
f"Projected blockwise memory ({projected_mem}) exceeds allowed_mem ({allowed_mem}), including reserved_mem ({reserved_mem})"
)

# this must be an iterator of lists, not of tuples, otherwise lithops breaks
output_blocks = map(list, itertools.product(*[range(len(c)) for c in chunks]))
num_tasks = math.prod(len(c) for c in chunks)

return CubedPipeline(
apply_blockwise,
gensym("apply_blockwise"),
Expand Down Expand Up @@ -474,50 +533,3 @@ def blockwise_fn_flattened(out_key):
return name_chunk_inds

return blockwise_fn_flattened


def get_output_blocks(
func: Callable[..., Any],
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
new_axes: Optional[Dict[int, int]] = None,
) -> Iterator[List[int]]:
if numblocks is None:
raise ValueError("Missing required numblocks argument.")
new_axes = new_axes or {}
argpairs = list(toolz.partition(2, arrind_pairs))

# Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions
dims = _make_dims(argpairs, numblocks, new_axes)

# return a list of lists, not of tuples, otherwise lithops breaks
for tup in itertools.product(*[range(dims[i]) for i in out_indices]):
yield list(tup)


class IterableFromGenerator:
def __init__(self, generator_fn: Callable[[], Iterator[List[int]]]):
self.generator_fn = generator_fn

def __iter__(self):
return self.generator_fn()


def num_output_blocks(
func: Callable[..., Any],
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
new_axes: Optional[Dict[int, int]] = None,
) -> int:
if numblocks is None:
raise ValueError("Missing required numblocks argument.")
new_axes = new_axes or {}
argpairs = list(toolz.partition(2, arrind_pairs))

# Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions
dims = _make_dims(argpairs, numblocks, new_axes)
return math.prod(dims[i] for i in out_indices)
Loading
Loading