Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse operations with different numbers of tasks #368

Merged
merged 8 commits into from
Feb 5, 2024
3 changes: 2 additions & 1 deletion cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def max(x, /, *, axis=None, keepdims=False):
return reduction(x, nxp.max, axis=axis, dtype=x.dtype, keepdims=keepdims)


def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False):
def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False, split_every=None):
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in mean")
# This implementation uses NumPy and Zarr's structured arrays to store a
Expand All @@ -47,6 +47,7 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False):
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this split_every argument a temporary implementation detail whilst we figure out fusing heuristics? It would be nice to relate the meaning of this argument back to the discussion in #284 (as I currently don't quite understand what it means)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split_every argument is what we referred to as "fan-in" in #284 - the number of chunks read by one task doing the reduction step. It's the same in Dask and can be a dictionary indicating the number of chunks to read in each dimension.

I hope it is something that we can get better heuristics for (or at least good defaults) - possibly by measuring trade offs like they did in the Primula paper, see #331.

extra_func_kwargs=extra_func_kwargs,
)

Expand Down
67 changes: 65 additions & 2 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@
from cubed.primitive.blockwise import blockwise as primitive_blockwise
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
from cubed.primitive.rechunk import rechunk as primitive_rechunk
from cubed.utils import chunk_memory, get_item, offset_to_block_id, to_chunksize
from cubed.utils import (
_concatenate2,
chunk_memory,
get_item,
offset_to_block_id,
to_chunksize,
)
from cubed.vendor.dask.array.core import common_blockdim, normalize_chunks
from cubed.vendor.dask.array.utils import validate_axis
from cubed.vendor.dask.blockwise import broadcast_dimensions
from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product
from cubed.vendor.dask.utils import has_keyword

if TYPE_CHECKING:
Expand Down Expand Up @@ -266,6 +272,7 @@ def blockwise(
extra_projected_mem = kwargs.pop("extra_projected_mem", 0)

fusable = kwargs.pop("fusable", True)
num_input_blocks = kwargs.pop("num_input_blocks", (1,) * len(source_arrays))

name = gensym()
spec = check_array_specs(arrays)
Expand All @@ -287,6 +294,7 @@ def blockwise(
out_name=name,
extra_func_kwargs=extra_func_kwargs,
fusable=fusable,
num_input_blocks=num_input_blocks,
**kwargs,
)
plan = Plan._new(
Expand Down Expand Up @@ -324,6 +332,8 @@ def general_blockwise(

extra_projected_mem = kwargs.pop("extra_projected_mem", 0)

num_input_blocks = kwargs.pop("num_input_blocks", (1,) * len(source_arrays))

name = gensym()
spec = check_array_specs(arrays)
if target_store is None:
Expand All @@ -341,6 +351,7 @@ def general_blockwise(
chunks=chunks,
in_names=in_names,
extra_func_kwargs=extra_func_kwargs,
num_input_blocks=num_input_blocks,
**kwargs,
)
plan = Plan._new(
Expand Down Expand Up @@ -759,6 +770,7 @@ def rechunk(x, chunks, target_store=None):


def merge_chunks(x, chunks):
"""Merge multiple chunks into one."""
target_chunksize = chunks
if len(target_chunksize) != x.ndim:
raise ValueError(
Expand Down Expand Up @@ -787,6 +799,56 @@ def _copy_chunk(e, x, target_chunks=None, block_id=None):
return out


def merge_chunks_new(x, chunks):
# new implementation that uses general_blockwise rather than map_direct
target_chunksize = chunks
if len(target_chunksize) != x.ndim:
raise ValueError(
f"Chunks {target_chunksize} must have same number of dimensions as array ({x.ndim})"
)
if not all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunksize)):
raise ValueError(
f"Chunks {target_chunksize} must be a multiple of array's chunks {x.chunksize}"
)

target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
axes = [
i for (i, (c0, c1)) in enumerate(zip(x.chunksize, target_chunksize)) if c0 != c1
]

def block_function(out_key):
out_coords = out_key[1:]

in_keys = []
for i, (c0, c1) in enumerate(zip(x.chunksize, target_chunksize)):
k = c1 // c0 # number of blocks to merge in axis i
if k == 1:
in_keys.append(out_coords[i])
else:
start = out_coords[i] * k
stop = min(start + k, x.numblocks[i])
in_keys.append(list(range(start, stop)))

# return a tuple with a single item that is the list of input keys to be merged
return (lol_product((x.name,), in_keys),)

num_input_blocks = int(
np.prod([c1 // c0 for (c0, c1) in zip(x.chunksize, target_chunksize)])
)

return general_blockwise(
_concatenate2,
block_function,
x,
shape=x.shape,
dtype=x.dtype,
chunks=target_chunks,
extra_projected_mem=0,
num_input_blocks=(num_input_blocks,),
axes=axes,
)


def reduction(
x: "Array",
func,
Expand Down Expand Up @@ -1059,6 +1121,7 @@ def block_function(out_key):
dtype=dtype,
chunks=chunks,
extra_projected_mem=extra_projected_mem,
num_input_blocks=(sum(split_every.values()),),
reduce_func=func,
initial_func=initial_func,
axis=axis,
Expand Down
4 changes: 4 additions & 0 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ def visualize(
tooltip += f"\ntasks: {primitive_op.num_tasks}"
if primitive_op.write_chunks is not None:
tooltip += f"\nwrite chunks: {primitive_op.write_chunks}"
if primitive_op.num_input_blocks is not None:
tooltip += (
f"\nnum input blocks: {primitive_op.num_input_blocks}"
)
del d["primitive_op"]

# remove pipeline attribute since it is a long string that causes graphviz to fail
Expand Down
53 changes: 40 additions & 13 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import math
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -119,6 +120,7 @@ def blockwise(
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
fusable: bool = True,
num_input_blocks: Optional[Tuple[int, ...]] = None,
**kwargs,
):
"""Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules.
Expand Down Expand Up @@ -201,6 +203,7 @@ def blockwise(
extra_projected_mem=extra_projected_mem,
extra_func_kwargs=extra_func_kwargs,
fusable=fusable,
num_input_blocks=num_input_blocks,
**kwargs,
)

Expand All @@ -219,6 +222,7 @@ def general_blockwise(
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
fusable: bool = True,
num_input_blocks: Optional[Tuple[int, ...]] = None,
**kwargs,
):
"""A more general form of ``blockwise`` that uses a function to specify the block
Expand Down Expand Up @@ -317,6 +321,7 @@ def general_blockwise(
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=fusable,
num_input_blocks=num_input_blocks,
)


Expand Down Expand Up @@ -399,6 +404,9 @@ def fused_func(*args):
allowed_mem = primitive_op2.allowed_mem
reserved_mem = primitive_op2.reserved_mem
num_tasks = primitive_op2.num_tasks
num_input_blocks = tuple(
n * primitive_op2.num_input_blocks[0] for n in primitive_op1.num_input_blocks
)

pipeline = CubedPipeline(
apply_blockwise,
Expand All @@ -414,6 +422,7 @@ def fused_func(*args):
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=True,
num_input_blocks=num_input_blocks,
)


Expand All @@ -424,12 +433,6 @@ def fuse_multiple(
Fuse a blockwise operation and its predecessors into a single operation, avoiding writing to (or reading from) the targets of the predecessor operations.
"""

assert all(
primitive_op.num_tasks == p.num_tasks
for p in predecessor_primitive_ops
if p is not None
)

pipeline = primitive_op.pipeline
predecessor_pipelines = [
primitive_op.pipeline if primitive_op is not None else None
Expand All @@ -444,31 +447,48 @@ def fuse_multiple(

mappable = pipeline.mappable

def apply_pipeline_block_func(pipeline, arg):
def apply_pipeline_block_func(pipeline, n_input_blocks, arg):
if pipeline is None:
return (arg,)
return pipeline.config.block_function(arg)
if n_input_blocks == 1:
assert isinstance(arg, tuple)
return pipeline.config.block_function(arg)
else:
# more than one input block is being read from arg
assert isinstance(arg, (list, Iterator))
return tuple(
list(item)
for item in zip(*(pipeline.config.block_function(a) for a in arg))
)

def fused_blockwise_func(out_key):
# this will change when multiple outputs are supported
args = pipeline.config.block_function(out_key)
# split all args to the fused function into groups, one for each predecessor function
func_args = tuple(
item
for p, a in zip(predecessor_pipelines, args)
for item in apply_pipeline_block_func(p, a)
for i, (p, a) in enumerate(zip(predecessor_pipelines, args))
for item in apply_pipeline_block_func(
p, primitive_op.num_input_blocks[i], a
)
)
return split_into(func_args, predecessor_funcs_nargs)

def apply_pipeline_func(pipeline, *args):
def apply_pipeline_func(pipeline, n_input_blocks, *args):
if pipeline is None:
return args[0]
return pipeline.config.function(*args)
if n_input_blocks == 1:
ret = pipeline.config.function(*args)
else:
# more than one input block is being read from this group of args to primitive op
ret = [pipeline.config.function(*item) for item in list(zip(*args))]
return ret

def fused_func(*args):
# args are grouped appropriately so they can be called by each predecessor function
func_args = [
apply_pipeline_func(p, *a) for p, a in zip(predecessor_pipelines, args)
apply_pipeline_func(p, primitive_op.num_input_blocks[i], *a)
for i, (p, a) in enumerate(zip(predecessor_pipelines, args))
]
return pipeline.config.function(*func_args)

Expand All @@ -490,6 +510,12 @@ def fused_func(*args):
allowed_mem = primitive_op.allowed_mem
reserved_mem = primitive_op.reserved_mem
num_tasks = primitive_op.num_tasks
tmp = [
p.num_input_blocks if p is not None else (1,) for p in predecessor_primitive_ops
]
num_input_blocks = tuple(
primitive_op.num_input_blocks[0] * n for n in itertools.chain(*tmp)
)

fused_pipeline = CubedPipeline(
apply_blockwise,
Expand All @@ -505,6 +531,7 @@ def fused_func(*args):
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=True,
num_input_blocks=num_input_blocks,
)


Expand Down
5 changes: 4 additions & 1 deletion cubed/primitive/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, Optional, Tuple

import zarr

Expand Down Expand Up @@ -37,6 +37,9 @@ class PrimitiveOperation:
fusable: bool = True
"""Whether this operation should be considered for fusion."""

num_input_blocks: Optional[Tuple[int, ...]] = None
"""The number of input blocks read from each input array."""

write_chunks: Optional[T_RegularChunks] = None
"""The chunk size used by this operation."""

Expand Down
15 changes: 12 additions & 3 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import cubed.random
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import merge_chunks, partial_reduce, tree_reduce
from cubed.core.optimization import fuse_all_optimize_dag
from cubed.core.optimization import fuse_all_optimize_dag, multiple_inputs_optimize_dag
from cubed.tests.utils import (
ALL_EXECUTORS,
MAIN_EXECUTORS,
Expand Down Expand Up @@ -531,10 +531,19 @@ def test_plan_quad_means(tmp_path, t_length):
u = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec)
v = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec)
uv = u * v
m = xp.mean(uv, axis=0)
m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True)

assert m.plan.num_tasks() > 0
m.visualize(filename=tmp_path / "quad_means")
m.visualize(
filename=tmp_path / "quad_means_unoptimized",
optimize_graph=False,
show_hidden=True,
)
m.visualize(
filename=tmp_path / "quad_means",
optimize_function=multiple_inputs_optimize_dag,
show_hidden=True,
)


def quad_means(tmp_path, t_length):
Expand Down
Loading
Loading