Skip to content

Commit

Permalink
Support arbitrary block index functions in blockwise (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jan 18, 2024
1 parent 667cf64 commit 49df637
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 81 deletions.
55 changes: 55 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cubed.core.array import CoreArray, check_array_specs, compute, gensym
from cubed.core.plan import Plan, new_temp_path
from cubed.primitive.blockwise import blockwise as primitive_blockwise
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
from cubed.primitive.rechunk import rechunk as primitive_rechunk
from cubed.utils import chunk_memory, get_item, to_chunksize
from cubed.vendor.dask.array.core import common_blockdim, normalize_chunks
Expand Down Expand Up @@ -295,6 +296,60 @@ def blockwise(
return Array(name, pipeline.target_array, spec, plan)


def general_blockwise(
func,
block_function,
*arrays,
shape,
dtype,
chunks,
target_store=None,
extra_func_kwargs=None,
**kwargs,
) -> "Array":
assert len(arrays) > 0

# replace arrays with zarr arrays
zargs = [a.zarray_maybe_lazy for a in arrays]
in_names = [a.name for a in arrays]

extra_source_arrays = kwargs.pop("extra_source_arrays", [])
source_arrays = list(arrays) + list(extra_source_arrays)

extra_projected_mem = kwargs.pop("extra_projected_mem", 0)

name = gensym()
spec = check_array_specs(arrays)
if target_store is None:
target_store = new_temp_path(name=name, spec=spec)
pipeline = primitive_general_blockwise(
func,
block_function,
*zargs,
allowed_mem=spec.allowed_mem,
reserved_mem=spec.reserved_mem,
extra_projected_mem=extra_projected_mem,
target_store=target_store,
shape=shape,
dtype=dtype,
chunks=chunks,
in_names=in_names,
extra_func_kwargs=extra_func_kwargs,
**kwargs,
)
plan = Plan._new(
name,
"blockwise",
pipeline.target_array,
pipeline,
False,
*source_arrays,
)
from cubed.array_api import Array

return Array(name, pipeline.target_array, spec, plan)


def elemwise(func, *args: "Array", dtype=None) -> "Array":
"""Apply a function elementwise to array arguments, respecting broadcasting."""
shapes = [arg.shape for arg in args]
Expand Down
172 changes: 92 additions & 80 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import toolz
import zarr
Expand All @@ -15,7 +15,7 @@
from cubed.runtime.types import CubedPipeline
from cubed.storage.zarr import T_ZarrArray, lazy_empty
from cubed.types import T_Chunks, T_DType, T_Shape, T_Store
from cubed.utils import chunk_memory, get_item, split_into, to_chunksize
from cubed.utils import chunk_memory, get_item, map_nested, split_into, to_chunksize
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product
from cubed.vendor.dask.core import flatten
Expand All @@ -40,7 +40,7 @@ class BlockwiseSpec:
Attributes
----------
block_function : Callable
A function that maps input chunk indexes to an output chunk index.
A function that maps an output chunk index to one or more input chunk indexes.
function : Callable
A function that maps input chunks to an output chunk.
reads_map : Dict[str, CubedArrayProxy]
Expand All @@ -62,15 +62,13 @@ def apply_blockwise(out_key: List[int], *, config: BlockwiseSpec) -> None:
out_chunk_key = key_to_slices(
out_key_tuple, config.write.array, config.write.chunks
)

# get array chunks for input keys, preserving any nested list structure
args = []
get_chunk_config = partial(get_chunk, config=config)
name_chunk_inds = config.block_function(("out",) + out_key_tuple)
for name_chunk_ind in name_chunk_inds:
name = name_chunk_ind[0]
chunk_ind = name_chunk_ind[1:]
arr = config.reads_map[name].open()
chunk_key = key_to_slices(chunk_ind, arr)
arg = arr[chunk_key]
arg = numpy_array_to_backend_array(arg)
arg = map_nested(get_chunk_config, name_chunk_ind)
args.append(arg)

result = config.function(*args)
Expand All @@ -91,6 +89,17 @@ def key_to_slices(
return get_item(chunks, key)


def get_chunk(name_chunk_ind, config):
"""Read a chunk from the named array"""
name = name_chunk_ind[0]
chunk_ind = name_chunk_ind[1:]
arr = config.reads_map[name].open()
chunk_key = key_to_slices(chunk_ind, arr)
arg = arr[chunk_key]
arg = numpy_array_to_backend_array(arg)
return arg


def blockwise(
func: Callable[..., Any],
out_ind: Sequence[Union[str, int]],
Expand All @@ -110,6 +119,9 @@ def blockwise(
):
"""Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules.
Unlike ```general_blockwise``, an index notation is used to specify the block mapping,
like in Dask Array.
Parameters
----------
func : callable
Expand Down Expand Up @@ -146,10 +158,8 @@ def blockwise(
CubedPipeline to run the operation
"""

# Use dask's make_blockwise_graph
arrays: Sequence[T_ZarrArray] = args[::2]
array_names = in_names or [f"in_{i}" for i in range(len(arrays))]
array_map = {name: array for name, array in zip(array_names, arrays)}

inds: Sequence[Union[str, int]] = args[1::2]

Expand All @@ -164,11 +174,6 @@ def blockwise(
for name, ind in zip(array_names, inds):
argindsstr.extend((name, ind))

# TODO: check output shape and chunks are consistent with inputs
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)

# block func

block_function = make_blockwise_function_flattened(
func,
out_name or "out",
Expand All @@ -178,28 +183,78 @@ def blockwise(
new_axes=new_axes,
)

output_blocks_generator_fn = partial(
get_output_blocks,
return general_blockwise(
func,
out_name or "out",
out_ind,
*argindsstr,
numblocks=numblocks,
new_axes=new_axes,
block_function,
*arrays,
allowed_mem=allowed_mem,
reserved_mem=reserved_mem,
target_store=target_store,
shape=shape,
dtype=dtype,
chunks=chunks,
in_names=in_names,
extra_projected_mem=extra_projected_mem,
extra_func_kwargs=extra_func_kwargs,
**kwargs,
)
output_blocks = IterableFromGenerator(output_blocks_generator_fn)

num_tasks = num_output_blocks(
func,
out_name or "out",
out_ind,
*argindsstr,
numblocks=numblocks,
new_axes=new_axes,
)

# end block func
def general_blockwise(
func: Callable[..., Any],
block_function: Callable[..., Any],
*arrays: Any,
allowed_mem: int,
reserved_mem: int,
target_store: T_Store,
shape: T_Shape,
dtype: T_DType,
chunks: T_Chunks,
in_names: Optional[List[str]] = None,
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""A more general form of ``blockwise`` that uses a function to specify the block
mapping, rather than an index notation.
Parameters
----------
func : callable
Function to apply to individual tuples of blocks
block_function : callable
A function that maps an output chunk index to one or more input chunk indexes.
*arrays : sequence of Array
The input arrays.
allowed_mem : int
The memory available to a worker for running a task, in bytes. Includes ``reserved_mem``.
reserved_mem : int
The memory reserved on a worker for non-data use when running a task, in bytes
target_store : string or zarr.Array
Path to output Zarr store, or Zarr array
shape : tuple
The shape of the output array.
dtype : np.dtype
The ``dtype`` of the output array.
chunks : tuple
The chunks of the output array.
extra_projected_mem : int
Extra memory projected to be needed (in bytes) in addition to the memory used reading
the input arrays and writing the output.
extra_func_kwargs : dict
Extra keyword arguments to pass to function that can't be passed as regular keyword arguments
since they clash with other blockwise arguments (such as dtype).
**kwargs : dict
Extra keyword arguments to pass to function
Returns
-------
CubedPipeline to run the operation
"""
array_names = in_names or [f"in_{i}" for i in range(len(arrays))]
array_map = {name: array for name, array in zip(array_names, arrays)}

chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
chunksize = to_chunksize(chunks)
if isinstance(target_store, zarr.Array):
target_array = target_store
Expand Down Expand Up @@ -236,6 +291,10 @@ def blockwise(
f"Projected blockwise memory ({projected_mem}) exceeds allowed_mem ({allowed_mem}), including reserved_mem ({reserved_mem})"
)

# this must be an iterator of lists, not of tuples, otherwise lithops breaks
output_blocks = map(list, itertools.product(*[range(len(c)) for c in chunks]))
num_tasks = math.prod(len(c) for c in chunks)

return CubedPipeline(
apply_blockwise,
gensym("apply_blockwise"),
Expand Down Expand Up @@ -474,50 +533,3 @@ def blockwise_fn_flattened(out_key):
return name_chunk_inds

return blockwise_fn_flattened


def get_output_blocks(
func: Callable[..., Any],
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
new_axes: Optional[Dict[int, int]] = None,
) -> Iterator[List[int]]:
if numblocks is None:
raise ValueError("Missing required numblocks argument.")
new_axes = new_axes or {}
argpairs = list(toolz.partition(2, arrind_pairs))

# Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions
dims = _make_dims(argpairs, numblocks, new_axes)

# return a list of lists, not of tuples, otherwise lithops breaks
for tup in itertools.product(*[range(dims[i]) for i in out_indices]):
yield list(tup)


class IterableFromGenerator:
def __init__(self, generator_fn: Callable[[], Iterator[List[int]]]):
self.generator_fn = generator_fn

def __iter__(self):
return self.generator_fn()


def num_output_blocks(
func: Callable[..., Any],
output: str,
out_indices: Sequence[Union[str, int]],
*arrind_pairs: Any,
numblocks: Optional[Dict[str, Tuple[int, ...]]] = None,
new_axes: Optional[Dict[int, int]] = None,
) -> int:
if numblocks is None:
raise ValueError("Missing required numblocks argument.")
new_axes = new_axes or {}
argpairs = list(toolz.partition(2, arrind_pairs))

# Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions
dims = _make_dims(argpairs, numblocks, new_axes)
return math.prod(dims[i] for i in out_indices)
Loading

0 comments on commit 49df637

Please sign in to comment.