diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index 53022cb2..21a50d2c 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -25,7 +25,7 @@ def max(x, /, *, axis=None, keepdims=False): return reduction(x, nxp.max, axis=axis, dtype=x.dtype, keepdims=keepdims) -def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False): +def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False, split_every=None): if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in mean") # This implementation uses NumPy and Zarr's structured arrays to store a @@ -47,6 +47,7 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False): dtype=dtype, keepdims=keepdims, use_new_impl=use_new_impl, + split_every=split_every, extra_func_kwargs=extra_func_kwargs, ) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index bb71c1b2..41db51dd 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -27,10 +27,16 @@ from cubed.primitive.blockwise import blockwise as primitive_blockwise from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise from cubed.primitive.rechunk import rechunk as primitive_rechunk -from cubed.utils import chunk_memory, get_item, offset_to_block_id, to_chunksize +from cubed.utils import ( + _concatenate2, + chunk_memory, + get_item, + offset_to_block_id, + to_chunksize, +) from cubed.vendor.dask.array.core import common_blockdim, normalize_chunks from cubed.vendor.dask.array.utils import validate_axis -from cubed.vendor.dask.blockwise import broadcast_dimensions +from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product from cubed.vendor.dask.utils import has_keyword if TYPE_CHECKING: @@ -266,6 +272,7 @@ def blockwise( extra_projected_mem = kwargs.pop("extra_projected_mem", 0) fusable = kwargs.pop("fusable", True) + num_input_blocks = kwargs.pop("num_input_blocks", None) name = gensym() spec = check_array_specs(arrays) @@ -287,6 +294,7 @@ def blockwise( out_name=name, extra_func_kwargs=extra_func_kwargs, fusable=fusable, + num_input_blocks=num_input_blocks, **kwargs, ) plan = Plan._new( @@ -324,6 +332,8 @@ def general_blockwise( extra_projected_mem = kwargs.pop("extra_projected_mem", 0) + num_input_blocks = kwargs.pop("num_input_blocks", None) + name = gensym() spec = check_array_specs(arrays) if target_store is None: @@ -341,6 +351,7 @@ def general_blockwise( chunks=chunks, in_names=in_names, extra_func_kwargs=extra_func_kwargs, + num_input_blocks=num_input_blocks, **kwargs, ) plan = Plan._new( @@ -759,6 +770,7 @@ def rechunk(x, chunks, target_store=None): def merge_chunks(x, chunks): + """Merge multiple chunks into one.""" target_chunksize = chunks if len(target_chunksize) != x.ndim: raise ValueError( @@ -787,6 +799,56 @@ def _copy_chunk(e, x, target_chunks=None, block_id=None): return out +def merge_chunks_new(x, chunks): + # new implementation that uses general_blockwise rather than map_direct + target_chunksize = chunks + if len(target_chunksize) != x.ndim: + raise ValueError( + f"Chunks {target_chunksize} must have same number of dimensions as array ({x.ndim})" + ) + if not all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunksize)): + raise ValueError( + f"Chunks {target_chunksize} must be a multiple of array's chunks {x.chunksize}" + ) + + target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype) + axes = [ + i for (i, (c0, c1)) in enumerate(zip(x.chunksize, target_chunksize)) if c0 != c1 + ] + + def block_function(out_key): + out_coords = out_key[1:] + + in_keys = [] + for i, (c0, c1) in enumerate(zip(x.chunksize, target_chunksize)): + k = c1 // c0 # number of blocks to merge in axis i + if k == 1: + in_keys.append(out_coords[i]) + else: + start = out_coords[i] * k + stop = min(start + k, x.numblocks[i]) + in_keys.append(list(range(start, stop))) + + # return a tuple with a single item that is the list of input keys to be merged + return (lol_product((x.name,), in_keys),) + + num_input_blocks = int( + np.prod([c1 // c0 for (c0, c1) in zip(x.chunksize, target_chunksize)]) + ) + + return general_blockwise( + _concatenate2, + block_function, + x, + shape=x.shape, + dtype=x.dtype, + chunks=target_chunks, + extra_projected_mem=0, + num_input_blocks=(num_input_blocks,), + axes=axes, + ) + + def reduction( x: "Array", func, @@ -1059,6 +1121,7 @@ def block_function(out_key): dtype=dtype, chunks=chunks, extra_projected_mem=extra_projected_mem, + num_input_blocks=(sum(split_every.values()),), reduce_func=func, initial_func=initial_func, axis=axis, diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 9b8deb28..12b7ad56 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -330,6 +330,11 @@ def visualize( # remove pipeline attribute since it is a long string that causes graphviz to fail if "pipeline" in d: + pipeline = d["pipeline"] + if pipeline.config is not None: + tooltip += ( + f"\nnum input blocks: {pipeline.config.num_input_blocks}" + ) del d["pipeline"] if "stack_summaries" in d and d["stack_summaries"] is not None: diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 5ea1a623..14238278 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -1,5 +1,6 @@ import itertools import math +from collections.abc import Iterator from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -45,6 +46,8 @@ class BlockwiseSpec: A function that maps input chunks to an output chunk. function_nargs: int The number of array arguments that ``function`` takes. + num_input_blocks: Tuple[int, ...] + The number of input blocks read from each input array. reads_map : Dict[str, CubedArrayProxy] Read proxy dictionary keyed by array name. write : CubedArrayProxy @@ -54,6 +57,7 @@ class BlockwiseSpec: block_function: Callable[..., Any] function: Callable[..., Any] function_nargs: int + num_input_blocks: Tuple[int, ...] reads_map: Dict[str, CubedArrayProxy] write: CubedArrayProxy @@ -119,6 +123,7 @@ def blockwise( extra_projected_mem: int = 0, extra_func_kwargs: Optional[Dict[str, Any]] = None, fusable: bool = True, + num_input_blocks: Optional[Tuple[int, ...]] = None, **kwargs, ): """Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules. @@ -201,6 +206,7 @@ def blockwise( extra_projected_mem=extra_projected_mem, extra_func_kwargs=extra_func_kwargs, fusable=fusable, + num_input_blocks=num_input_blocks, **kwargs, ) @@ -219,6 +225,7 @@ def general_blockwise( extra_projected_mem: int = 0, extra_func_kwargs: Optional[Dict[str, Any]] = None, fusable: bool = True, + num_input_blocks: Optional[Tuple[int, ...]] = None, **kwargs, ): """A more general form of ``blockwise`` that uses a function to specify the block @@ -271,12 +278,18 @@ def general_blockwise( func_kwargs = extra_func_kwargs or {} func_with_kwargs = partial(func, **{**kwargs, **func_kwargs}) + num_input_blocks = num_input_blocks or (1,) * len(arrays) read_proxies = { name: CubedArrayProxy(array, array.chunks) for name, array in array_map.items() } write_proxy = CubedArrayProxy(target_array, chunksize) spec = BlockwiseSpec( - block_function, func_with_kwargs, len(arrays), read_proxies, write_proxy + block_function, + func_with_kwargs, + len(arrays), + num_input_blocks, + read_proxies, + write_proxy, ) # calculate projected memory @@ -344,10 +357,17 @@ def can_fuse_multiple_primitive_ops( if is_fuse_candidate(primitive_op) and all( is_fuse_candidate(p) for p in predecessor_primitive_ops ): - # if the peak projected memory for running all the predecessor ops in order is - # larger than allowed_mem then we can't fuse + # If the peak projected memory for running all the predecessor ops in + # order is larger than allowed_mem then we can't fuse. if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem: return False + # If the number of input blocks for each input is not uniform, then we + # can't fuse. (This should never happen since all operations are + # currently uniform, and fused operations are too if fuse is applied in + # topological order.) + num_input_blocks = primitive_op.pipeline.config.num_input_blocks + if not all(num_input_blocks[0] == n for n in num_input_blocks): + return False return all( primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops ) @@ -390,8 +410,17 @@ def fused_func(*args): function_nargs = pipeline1.config.function_nargs read_proxies = pipeline1.config.reads_map write_proxy = pipeline2.config.write + num_input_blocks = tuple( + n * pipeline2.config.num_input_blocks[0] + for n in pipeline1.config.num_input_blocks + ) spec = BlockwiseSpec( - fused_blockwise_func, fused_func, function_nargs, read_proxies, write_proxy + fused_blockwise_func, + fused_func, + function_nargs, + num_input_blocks, + read_proxies, + write_proxy, ) target_array = primitive_op2.target_array @@ -424,12 +453,6 @@ def fuse_multiple( Fuse a blockwise operation and its predecessors into a single operation, avoiding writing to (or reading from) the targets of the predecessor operations. """ - assert all( - primitive_op.num_tasks == p.num_tasks - for p in predecessor_primitive_ops - if p is not None - ) - pipeline = primitive_op.pipeline predecessor_pipelines = [ primitive_op.pipeline if primitive_op is not None else None @@ -444,10 +467,19 @@ def fuse_multiple( mappable = pipeline.mappable - def apply_pipeline_block_func(pipeline, arg): + def apply_pipeline_block_func(pipeline, n_input_blocks, arg): if pipeline is None: return (arg,) - return pipeline.config.block_function(arg) + if n_input_blocks == 1: + assert isinstance(arg, tuple) + return pipeline.config.block_function(arg) + else: + # more than one input block is being read from arg + assert isinstance(arg, (list, Iterator)) + return tuple( + list(item) + for item in zip(*(pipeline.config.block_function(a) for a in arg)) + ) def fused_blockwise_func(out_key): # this will change when multiple outputs are supported @@ -455,31 +487,54 @@ def fused_blockwise_func(out_key): # split all args to the fused function into groups, one for each predecessor function func_args = tuple( item - for p, a in zip(predecessor_pipelines, args) - for item in apply_pipeline_block_func(p, a) + for i, (p, a) in enumerate(zip(predecessor_pipelines, args)) + for item in apply_pipeline_block_func( + p, pipeline.config.num_input_blocks[i], a + ) ) return split_into(func_args, predecessor_funcs_nargs) - def apply_pipeline_func(pipeline, *args): + def apply_pipeline_func(pipeline, n_input_blocks, *args): if pipeline is None: return args[0] - return pipeline.config.function(*args) + if n_input_blocks == 1: + ret = pipeline.config.function(*args) + else: + # more than one input block is being read from this group of args to primitive op + ret = [pipeline.config.function(*item) for item in list(zip(*args))] + return ret def fused_func(*args): # args are grouped appropriately so they can be called by each predecessor function func_args = [ - apply_pipeline_func(p, *a) for p, a in zip(predecessor_pipelines, args) + apply_pipeline_func(p, pipeline.config.num_input_blocks[i], *a) + for i, (p, a) in enumerate(zip(predecessor_pipelines, args)) ] return pipeline.config.function(*func_args) - function_nargs = pipeline.config.function_nargs + fused_function_nargs = pipeline.config.function_nargs + # ok to get num_input_blocks[0] since it is uniform (see check in can_fuse_multiple_primitive_ops) + fused_num_input_blocks = tuple( + pipeline.config.num_input_blocks[0] * n + for n in itertools.chain( + *( + p.pipeline.config.num_input_blocks if p is not None else (1,) + for p in predecessor_primitive_ops + ) + ) + ) read_proxies = dict(pipeline.config.reads_map) for p in predecessor_pipelines: if p is not None: read_proxies.update(p.config.reads_map) write_proxy = pipeline.config.write spec = BlockwiseSpec( - fused_blockwise_func, fused_func, function_nargs, read_proxies, write_proxy + fused_blockwise_func, + fused_func, + fused_function_nargs, + fused_num_input_blocks, + read_proxies, + write_proxy, ) target_array = primitive_op.target_array diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index 94071ad5..ea62fdaf 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -12,7 +12,7 @@ import cubed.random from cubed.backend_array_api import namespace as nxp from cubed.core.ops import merge_chunks, partial_reduce, tree_reduce -from cubed.core.optimization import fuse_all_optimize_dag +from cubed.core.optimization import fuse_all_optimize_dag, multiple_inputs_optimize_dag from cubed.tests.utils import ( ALL_EXECUTORS, MAIN_EXECUTORS, @@ -531,10 +531,19 @@ def test_plan_quad_means(tmp_path, t_length): u = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec) v = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec) uv = u * v - m = xp.mean(uv, axis=0) + m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True) assert m.plan.num_tasks() > 0 - m.visualize(filename=tmp_path / "quad_means") + m.visualize( + filename=tmp_path / "quad_means_unoptimized", + optimize_graph=False, + show_hidden=True, + ) + m.visualize( + filename=tmp_path / "quad_means", + optimize_function=multiple_inputs_optimize_dag, + show_hidden=True, + ) def quad_means(tmp_path, t_length): diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index d13a8c35..4cf69b01 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -8,7 +8,7 @@ import cubed import cubed.array_api as xp from cubed.backend_array_api import namespace as nxp -from cubed.core.ops import elemwise +from cubed.core.ops import elemwise, merge_chunks_new from cubed.core.optimization import ( fuse_all_optimize_dag, fuse_only_optimize_dag, @@ -138,9 +138,18 @@ def custom_optimize_function(dag): ) -def fuse_one_level(arr): +def get_num_input_blocks(dag, arr_name): + op_name = next(dag.predecessors(arr_name)) + return dag.nodes(data=True)[op_name]["pipeline"].config.num_input_blocks + + +def fuse_one_level(arr, *, always_fuse=None): # use fuse_predecessors to test one level of fusion - return partial(fuse_predecessors, name=next(arr.plan.dag.predecessors(arr.name))) + return partial( + fuse_predecessors, + name=next(arr.plan.dag.predecessors(arr.name)), + always_fuse=always_fuse, + ) def fuse_multiple_levels(*, max_total_source_arrays=4): @@ -224,10 +233,10 @@ def test_fuse_unary_op(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (c,)) - assert structurally_equivalent( - c.plan.optimize(optimize_function=opt_fn).dag, - expected_fused_dag, - ) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(c.plan.dag, c.name) == (1,) + assert get_num_input_blocks(optimized_dag, c.name) == (1,) num_created_arrays = 2 # b, c assert c.plan.num_tasks(optimize_graph=False) == num_created_arrays + 2 @@ -265,9 +274,10 @@ def test_fuse_binary_op(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (a, b), (e,)) - assert structurally_equivalent( - e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, e.name) == (1, 1) num_created_arrays = 3 # c, d, e assert e.plan.num_tasks(optimize_graph=False) == num_created_arrays + 3 @@ -307,9 +317,10 @@ def test_fuse_unary_and_binary_op(spec): add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (), (c,)) add_placeholder_op(expected_fused_dag, (a, b, c), (f,)) - assert structurally_equivalent( - f.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = f.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(f.plan.dag, f.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, f.name) == (1, 1, 1) result = f.compute(optimize_function=opt_fn) assert_array_equal(result, np.ones((2, 2))) @@ -340,9 +351,10 @@ def test_fuse_mixed_levels(spec): add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (), (c,)) add_placeholder_op(expected_fused_dag, (a, b, c), (e,)) - assert structurally_equivalent( - e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, e.name) == (1, 1, 1) result = e.compute(optimize_function=opt_fn) assert_array_equal(result, 3 * np.ones((2, 2))) @@ -370,9 +382,10 @@ def test_fuse_diamond(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a, a), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, d.name) == (1, 1) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((2, 2))) @@ -404,9 +417,10 @@ def test_fuse_mixed_levels_and_diamond(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (b,)) add_placeholder_op(expected_fused_dag, (a, b), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, d.name) == (1, 1) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((2, 2))) @@ -434,9 +448,10 @@ def test_fuse_repeated_argument(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a, a), (c,)) - assert structurally_equivalent( - c.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, c.name) == (1, 1) result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -2 * np.ones((2, 2))) @@ -469,9 +484,10 @@ def test_fuse_other_dependents(spec): add_placeholder_op(expected_fused_dag, (a,), (c,)) add_placeholder_op(expected_fused_dag, (b,), (d,)) plan = arrays_to_plan(c, d) - assert structurally_equivalent( - plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(c.plan.dag, c.name) == (1,) + assert get_num_input_blocks(optimized_dag, c.name) == (1,) c_result, d_result = cubed.compute(c, d, optimize_function=opt_fn) assert_array_equal(c_result, np.ones((2, 2))) @@ -535,9 +551,10 @@ def stack_add(*a): ), (j,), ) - assert structurally_equivalent( - j.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = j.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(j.plan.dag, j.name) == (1,) + assert get_num_input_blocks(optimized_dag, j.name) == (1,) * 8 result = j.compute(optimize_function=opt_fn) assert_array_equal(result, -8 * np.ones((2, 2))) @@ -593,9 +610,10 @@ def test_fuse_large_fan_in_default(spec): add_placeholder_op(expected_fused_dag, (a, b, c, d), (n,)) add_placeholder_op(expected_fused_dag, (e, f, g, h), (o,)) add_placeholder_op(expected_fused_dag, (n, o), (p,)) - assert structurally_equivalent( - p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(p.plan.dag, p.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, p.name) == (1, 1) result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) @@ -662,9 +680,10 @@ def test_fuse_large_fan_in_override(spec): ), (p,), ) - assert structurally_equivalent( - p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(p.plan.dag, p.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, p.name) == (1,) * 8 result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) @@ -672,15 +691,145 @@ def test_fuse_large_fan_in_override(spec): # now force everything to be fused with fuse_all_optimize_dag # note that max_total_source_arrays is *not* set opt_fn = fuse_all_optimize_dag - - assert structurally_equivalent( - p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) +# merge chunks with same number of tasks (unary) +# +# a -> a +# | 3 | 3 +# b c +# | 1 +# c +# +def test_fuse_with_merge_chunks_unary(spec): + a = xp.ones((3, 2), chunks=(1, 2), spec=spec) + b = merge_chunks_new(a, chunks=(3, 2)) + c = xp.negative(b) + + opt_fn = fuse_one_level(c) + + c.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a,), (c,)) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(b.plan.dag, b.name) == (3,) + assert get_num_input_blocks(c.plan.dag, c.name) == (1,) + assert get_num_input_blocks(optimized_dag, c.name) == (3,) + + result = c.compute(optimize_function=opt_fn) + assert_array_equal(result, -np.ones((3, 2))) + + +# merge chunks with same number of tasks (binary) +# +# a b -> a b +# 3 | | 1 3 \ / 1 +# c d e +# 1 \ / 1 +# e +# +def test_fuse_with_merge_chunks_binary(spec): + a = xp.ones((3, 2), chunks=(1, 2), spec=spec) + b = xp.ones((3, 2), chunks=(3, 2), spec=spec) + c = merge_chunks_new(a, chunks=(3, 2)) + d = xp.negative(b) + e = xp.add(c, d) + + opt_fn = fuse_one_level(e) + + e.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (a, b), (e,)) + optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, e.name) == (3, 1) + + result = e.compute(optimize_function=opt_fn) + assert_array_equal(result, np.zeros((3, 2))) + + +# merge chunks with different number of tasks (b has more tasks than c) +# +# a -> a +# | 1 | 3 +# b c +# | 3 +# c +# +def test_fuse_merge_chunks_unary(spec): + a = xp.ones((3, 2), chunks=(1, 2), spec=spec) + b = xp.negative(a) + c = merge_chunks_new(b, chunks=(3, 2)) + + # force c to fuse + last_op = sorted(c.plan.dag.nodes())[-1] + opt_fn = fuse_one_level(c, always_fuse=[last_op]) + + c.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a,), (c,)) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(b.plan.dag, b.name) == (1,) + assert get_num_input_blocks(c.plan.dag, c.name) == (3,) + assert get_num_input_blocks(optimized_dag, c.name) == (3,) + + result = c.compute(optimize_function=opt_fn) + assert_array_equal(result, -np.ones((3, 2))) + + +# merge chunks with different number of tasks (c has more tasks than d) +# +# a b -> a b +# 1 \ / 1 3 \ / 3 +# c d +# | 3 +# d +# +def test_fuse_merge_chunks_binary(spec): + a = xp.ones((3, 2), chunks=(1, 2), spec=spec) + b = xp.ones((3, 2), chunks=(1, 2), spec=spec) + c = xp.add(a, b) + d = merge_chunks_new(c, chunks=(3, 2)) + + # force d to fuse + last_op = sorted(d.plan.dag.nodes())[-1] + opt_fn = fuse_one_level(d, always_fuse=[last_op]) + + d.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (a, b), (d,)) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1) + assert get_num_input_blocks(d.plan.dag, d.name) == (3,) + assert get_num_input_blocks(optimized_dag, d.name) == (3, 3) + + result = d.compute(optimize_function=opt_fn) + assert_array_equal(result, 2 * np.ones((3, 2))) + + def test_fuse_only_optimize_dag(spec): a = xp.ones((2, 2), chunks=(2, 2), spec=spec) b = xp.negative(a) @@ -699,10 +848,10 @@ def test_fuse_only_optimize_dag(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (b,)) add_placeholder_op(expected_fused_dag, (b,), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, - expected_fused_dag, - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(d.plan.dag, d.name) == (1,) + assert get_num_input_blocks(optimized_dag, d.name) == (1,) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((2, 2))) diff --git a/cubed/utils.py b/cubed/utils.py index ff36192c..84b1cfc3 100644 --- a/cubed/utils.py +++ b/cubed/utils.py @@ -310,3 +310,68 @@ def broadcast_trick(func): inner.__doc__ = func.__doc__ inner.__name__ = func.__name__ return inner + + +# From dask.array.core, but changed to use nxp namespace +def _concatenate2(arrays, axes=None): + """Recursively concatenate nested lists of arrays along axes + + Each entry in axes corresponds to each level of the nested list. The + length of axes should correspond to the level of nesting of arrays. + If axes is an empty list or tuple, return arrays, or arrays[0] if + arrays is a list. + + >>> x = np.array([[1, 2], [3, 4]]) + >>> _concatenate2([x, x], axes=[0]) + array([[1, 2], + [3, 4], + [1, 2], + [3, 4]]) + + >>> _concatenate2([x, x], axes=[1]) + array([[1, 2, 1, 2], + [3, 4, 3, 4]]) + + >>> _concatenate2([[x, x], [x, x]], axes=[0, 1]) + array([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) + + Supports Iterators + >>> _concatenate2(iter([x, x]), axes=[1]) + array([[1, 2, 1, 2], + [3, 4, 3, 4]]) + + Special Case + >>> _concatenate2([x, x], axes=()) + array([[1, 2], + [3, 4]]) + """ + if axes is None: + axes = [] + + if axes == (): + if isinstance(arrays, list): + return arrays[0] + else: + return arrays + + if isinstance(arrays, Iterator): + arrays = list(arrays) + if not isinstance(arrays, (list, tuple)): + return arrays + if len(axes) > 1: + arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays] + concatenate = nxp.concat + if isinstance(arrays[0], dict): + # Handle concatenation of `dict`s, used as a replacement for structured + # arrays when that's not supported by the array library (e.g., CuPy). + keys = list(arrays[0].keys()) + assert all(list(a.keys()) == keys for a in arrays) + ret = dict() + for k in keys: + ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0]) + return ret + else: + return concatenate(arrays, axis=axes[0])