diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 80e49152..f52067f6 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any import dask.config +from awkward.typetracer import touch_data from dask.blockwise import fuse_roots, optimize_blockwise from dask.core import flatten from dask.highlevelgraph import HighLevelGraph @@ -107,12 +108,12 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph: Parameters ---------- dsk : HighLevelGraph - Original high level dask graph + Task graph to optimize. Returns ------- HighLevelGraph - New dask graph with a modified ``AwkwardInputLayer``. + New, optimized task graph with column-projected ``AwkwardInputLayer``. """ layers = dsk.layers.copy() # type: ignore @@ -128,111 +129,40 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph: return HighLevelGraph(layers, deps) -def _projectable_input_layer_names(dsk: HighLevelGraph) -> list[str]: - """Get list of column-projectable AwkwardInputLayer names. - - Parameters - ---------- - dsk : HighLevelGraph - Task graph of interest - - Returns - ------- - list[str] - Names of the AwkwardInputLayers in the graph that are - column-projectable. - - """ - return [ - n - for n, v in dsk.layers.items() - if isinstance(v, AwkwardInputLayer) and hasattr(v.io_func, "project_columns") - # following condition means dep/pickled layers cannot be optimised - and hasattr(v, "_meta") - ] - - -def _layers_with_annotation(dsk: HighLevelGraph, key: str) -> list[str]: - return [n for n, v in dsk.layers.items() if (v.annotations or {}).get(key)] - - -def _ak_output_layer_names(dsk: HighLevelGraph) -> list[str]: - """Get a list output layer names. - - Output layer names are annotated with 'ak_output'. +def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph: + """Smush chains of blockwise layers into a single layer. + + The logic here identifies chains by popping layers (in arbitrary + order) from a set of all layers in the task graph and walking + through the dependencies (parent layers) and dependents (child + layers). If a multi layer chain is discovered we compress it into + a single layer with the second loop below (for chain in chains; + that step rewrites the graph). In the chain building logic, if a + layer exists in the `keys` argument (the keys necessary for the + compute that we are optimizing for), we shortcircuit the logic to + ensure we do not chain layers that contain a necessary key inside + (these layers are called `required_layers` below). Parameters ---------- dsk : HighLevelGraph - Graph of interest. + Task graph to optimize. + keys : Any + Keys that are requested by the compute that is being + optimized. Returns ------- - list[str] - Names of the output layers. + HighLevelGraph + New, optimized task graph. """ - return _layers_with_annotation(dsk, "ak_output") - - -def _opt_touch_all_layer_names(dsk: HighLevelGraph) -> list[str]: - return [n for n, v in dsk.layers.items() if hasattr(v, "_opt_touch_all")] - # return _layers_with_annotation(dsk, "ak_touch_all") - - -def _has_projectable_awkward_io_layer(dsk: HighLevelGraph) -> bool: - """Check if a graph at least one AwkwardInputLayer that is project-able.""" - for _, v in dsk.layers.items(): - if isinstance(v, AwkwardInputLayer) and hasattr(v.io_func, "project_columns"): - return True - return False - - -def _touch_all_data(*args, **kwargs): - """Mock writing an ak.Array to disk by touching data buffers.""" - import awkward as ak - - for arg in args + tuple(kwargs.values()): - ak.typetracer.touch_data(arg) - - -def _mock_output(layer): - """Update a layer to run the _touch_all_data.""" - assert len(layer.dsk) == 1 - - new_layer = copy.deepcopy(layer) - mp = new_layer.dsk.copy() - for k in iter(mp.keys()): - mp[k] = (_touch_all_data,) + mp[k][1:] - new_layer.dsk = mp - return new_layer - - -def _touch_and_call_fn(fn, *args, **kwargs): - _touch_all_data(*args, **kwargs) - return fn(*args, **kwargs) - - -def _touch_and_call(layer): - assert len(layer.dsk) == 1 - - new_layer = copy.deepcopy(layer) - mp = new_layer.dsk.copy() - for k in iter(mp.keys()): - mp[k] = (_touch_and_call_fn,) + mp[k] - new_layer.dsk = mp - return new_layer - - -def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph: # dask.optimization.fuse_liner for blockwise layers import copy chains = [] deps = copy.copy(dsk.dependencies) - # TODO: add some comments to the chaining algorithm w.r.t. when we - # use it and when we don't. required_layers = {k[0] for k in keys} layers = {} # find chains; each chain list is at least two keys long @@ -321,6 +251,100 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph: return HighLevelGraph(layers, deps) +def _projectable_input_layer_names(dsk: HighLevelGraph) -> list[str]: + """Get list of column-projectable AwkwardInputLayer names. + + Parameters + ---------- + dsk : HighLevelGraph + Task graph of interest + + Returns + ------- + list[str] + Names of the AwkwardInputLayers in the graph that are + column-projectable. + + """ + return [ + n + for n, v in dsk.layers.items() + if isinstance(v, AwkwardInputLayer) and hasattr(v.io_func, "project_columns") + # following condition means dep/pickled layers cannot be optimised + and hasattr(v, "_meta") + ] + + +def _layers_with_annotation(dsk: HighLevelGraph, key: str) -> list[str]: + return [n for n, v in dsk.layers.items() if (v.annotations or {}).get(key)] + + +def _ak_output_layer_names(dsk: HighLevelGraph) -> list[str]: + """Get a list output layer names. + + Output layer names are annotated with 'ak_output'. + + Parameters + ---------- + dsk : HighLevelGraph + Graph of interest. + + Returns + ------- + list[str] + Names of the output layers. + + """ + return _layers_with_annotation(dsk, "ak_output") + + +def _opt_touch_all_layer_names(dsk: HighLevelGraph) -> list[str]: + return [n for n, v in dsk.layers.items() if hasattr(v, "_opt_touch_all")] + # return _layers_with_annotation(dsk, "ak_touch_all") + + +def _has_projectable_awkward_io_layer(dsk: HighLevelGraph) -> bool: + """Check if a graph at least one AwkwardInputLayer that is project-able.""" + for _, v in dsk.layers.items(): + if isinstance(v, AwkwardInputLayer) and hasattr(v.io_func, "project_columns"): + return True + return False + + +def _touch_all_data(*args, **kwargs): + """Mock writing an ak.Array to disk by touching data buffers.""" + for arg in args + tuple(kwargs.values()): + touch_data(arg) + + +def _mock_output(layer): + """Update a layer to run the _touch_all_data.""" + assert len(layer.dsk) == 1 + + new_layer = copy.deepcopy(layer) + mp = new_layer.dsk.copy() + for k in iter(mp.keys()): + mp[k] = (_touch_all_data,) + mp[k][1:] + new_layer.dsk = mp + return new_layer + + +def _touch_and_call_fn(fn, *args, **kwargs): + _touch_all_data(*args, **kwargs) + return fn(*args, **kwargs) + + +def _touch_and_call(layer): + assert len(layer.dsk) == 1 + + new_layer = copy.deepcopy(layer) + mp = new_layer.dsk.copy() + for k in iter(mp.keys()): + mp[k] = (_touch_and_call_fn,) + mp[k] + new_layer.dsk = mp + return new_layer + + def _recursive_replace(args, layer, parent, indices): args2 = [] for arg in args: @@ -398,7 +422,7 @@ def _get_column_reports(dsk: HighLevelGraph) -> dict[str, Any]: results = get_sync(hlg, leaf_layers_keys) for out in results: if isinstance(out, (ak.Array, ak.Record)): - ak.typetracer.touch_data(out) + touch_data(out) except Exception as err: on_fail = dask.config.get("awkward.optimization.on-fail") # this is the default, throw a warning but skip the optimization.