Skip to content

Commit

Permalink
Fuse multiple inputs (#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jan 16, 2024
1 parent 41a1225 commit 5c8e0c7
Show file tree
Hide file tree
Showing 5 changed files with 710 additions and 2 deletions.
133 changes: 132 additions & 1 deletion cubed/core/optimization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from cubed.primitive.blockwise import can_fuse_pipelines, fuse
import networkx as nx

from cubed.primitive.blockwise import (
can_fuse_multiple_pipelines,
can_fuse_pipelines,
fuse,
fuse_multiple,
)


def simple_optimize_dag(dag):
Expand Down Expand Up @@ -56,3 +63,127 @@ def can_fuse(n):
dag.remove_node(op1)

return dag


sym_counter = 0


def gensym(name="op"):
global sym_counter
sym_counter += 1
return f"{name}-{sym_counter:03}"


def predecessors(dag, name):
"""Return a node's predecessors, with repeats for multiple edges."""
for pre, _ in dag.in_edges(name):
yield pre


def predecessor_ops(dag, name):
"""Return an op node's op predecessors"""
for input in predecessors(dag, name):
for pre in predecessors(dag, input):
yield pre


def is_fusable(node_dict):
"Return True if a node can be fused."
return "pipeline" in node_dict


def can_fuse_predecessors(dag, name, *, max_total_nargs=4):
nodes = dict(dag.nodes(data=True))

# if node itself can't be fused then there is nothing to fuse
if not is_fusable(nodes[name]):
return False

# if no predecessor ops can be fused then there is nothing to fuse
if all(not is_fusable(nodes[pre]) for pre in predecessor_ops(dag, name)):
return False

# if there is more than a single predecessor op, and the total number of args to
# the fused function would be more than an allowed maximum, then don't fuse
if len(list(predecessor_ops(dag, name))) > 1:
total_nargs = sum(
len(list(predecessors(dag, pre))) if is_fusable(nodes[pre]) else 1
for pre in predecessor_ops(dag, name)
)
if total_nargs > max_total_nargs:
return False

predecessor_pipelines = [
nodes[pre]["pipeline"]
for pre in predecessor_ops(dag, name)
if is_fusable(nodes[pre])
]
return can_fuse_multiple_pipelines(nodes[name]["pipeline"], *predecessor_pipelines)


def fuse_predecessors(dag, name):
"""Fuse a node with its immediate predecessors."""

# if can't fuse then return dag unchanged
if not can_fuse_predecessors(dag, name):
return dag

nodes = dict(dag.nodes(data=True))

pipeline = nodes[name]["pipeline"]

# if a predecessor op has no pipeline then just use None
predecessor_pipelines = [
nodes[pre]["pipeline"] if is_fusable(nodes[pre]) else None
for pre in predecessor_ops(dag, name)
]

# if a predecessor op has no func then use 1 for nargs
predecessor_funcs_nargs = [
len(list(predecessors(dag, pre))) if is_fusable(nodes[pre]) else 1
for pre in predecessor_ops(dag, name)
]

fused_pipeline = fuse_multiple(
pipeline,
*predecessor_pipelines,
predecessor_funcs_nargs=predecessor_funcs_nargs,
)

fused_dag = dag.copy()
fused_nodes = dict(fused_dag.nodes(data=True))

fused_nodes[name]["pipeline"] = fused_pipeline

# re-wire dag to remove predecessor nodes that have been fused

# 1. update edges to change inputs
for input in predecessors(dag, name):
pre = next(predecessors(dag, input))
if not is_fusable(fused_nodes[pre]):
# if a predecessor is marked as not fusable then don't change the edge
continue
fused_dag.remove_edge(input, name)
for pre in predecessor_ops(dag, name):
if not is_fusable(fused_nodes[pre]):
# if a predecessor is marked as not fusable then don't change the edge
continue
for input in predecessors(dag, pre):
fused_dag.add_edge(input, name)

# 2. remove predecessor nodes with no successors
# (ones with successors are needed by other nodes)
for input in predecessors(dag, name):
if fused_dag.out_degree(input) == 0:
for pre in list(predecessors(fused_dag, input)):
fused_dag.remove_node(pre)
fused_dag.remove_node(input)

return fused_dag


def multiple_inputs_optimize_dag(dag):
"""Fuse multiple inputs."""
for name in list(nx.topological_sort(dag)):
dag = fuse_predecessors(dag, name)
return dag
90 changes: 89 additions & 1 deletion cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from cubed.runtime.types import CubedPipeline
from cubed.storage.zarr import T_ZarrArray, lazy_empty
from cubed.types import T_Chunks, T_DType, T_Shape, T_Store
from cubed.utils import chunk_memory, get_item, to_chunksize
from cubed.utils import chunk_memory, get_item, split_into, to_chunksize
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product
from cubed.vendor.dask.core import flatten
Expand Down Expand Up @@ -265,6 +265,16 @@ def can_fuse_pipelines(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> bo
return False


def can_fuse_multiple_pipelines(
pipeline: CubedPipeline, *predecessor_pipelines: CubedPipeline
) -> bool:
if is_fuse_candidate(pipeline) and all(
is_fuse_candidate(p) for p in predecessor_pipelines
):
return all(pipeline.num_tasks == p.num_tasks for p in predecessor_pipelines)
return False


def fuse(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> CubedPipeline:
"""
Fuse two blockwise pipelines into a single pipeline, avoiding writing to (or reading from) the target of the first pipeline.
Expand Down Expand Up @@ -304,6 +314,84 @@ def fused_func(*args):
)


def fuse_multiple(
pipeline: CubedPipeline,
*predecessor_pipelines: CubedPipeline,
predecessor_funcs_nargs=None,
) -> CubedPipeline:
"""
Fuse a blockwise pipeline and its predecessors into a single pipeline, avoiding writing to (or reading from) the targets of the predecessor pipelines.
"""

assert all(
pipeline.num_tasks == p.num_tasks
for p in predecessor_pipelines
if p is not None
)

mappable = pipeline.mappable

def apply_pipeline_block_func(pipeline, arg):
if pipeline is None:
return (arg,)
return pipeline.config.block_function(arg)

def fused_blockwise_func(out_key):
# this will change when multiple outputs are supported
args = pipeline.config.block_function(out_key)
# flatten one level of args as the fused_func adds back grouping structure
func_args = tuple(
item
for p, a in zip(predecessor_pipelines, args)
for item in apply_pipeline_block_func(p, a)
)
return func_args

def apply_pipeline_func(pipeline, *args):
if pipeline is None:
return args[0]
return pipeline.config.function(*args)

def fused_func(*args):
# split all args to the fused function into groups, one for each predecessor function
split_args = split_into(args, predecessor_funcs_nargs)
func_args = [
apply_pipeline_func(p, *a)
for p, a in zip(predecessor_pipelines, split_args)
]
return pipeline.config.function(*func_args)

read_proxies = dict(pipeline.config.reads_map)
for p in predecessor_pipelines:
if p is not None:
read_proxies.update(p.config.reads_map)
write_proxy = pipeline.config.write
spec = BlockwiseSpec(fused_blockwise_func, fused_func, read_proxies, write_proxy)

target_array = pipeline.target_array
projected_mem = max(
pipeline.projected_mem,
*(p.projected_mem for p in predecessor_pipelines if p is not None),
)
reserved_mem = max(
pipeline.reserved_mem,
*(p.reserved_mem for p in predecessor_pipelines if p is not None),
)
num_tasks = pipeline.num_tasks

return CubedPipeline(
apply_blockwise,
gensym("fused_apply_blockwise"),
mappable,
spec,
target_array,
projected_mem,
reserved_mem,
num_tasks,
None,
)


# blockwise functions


Expand Down
Loading

0 comments on commit 5c8e0c7

Please sign in to comment.