From e564811dfb1f0fe6ef1f02d96e39c57753e9a566 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Thu, 14 Sep 2023 18:53:53 +0100 Subject: [PATCH 1/4] refactor: clarify logic --- src/dask_awkward/lib/_utils.py | 6 ++- src/dask_awkward/lib/optimize.py | 79 +++++++++++++++++++------------- 2 files changed, 53 insertions(+), 32 deletions(-) diff --git a/src/dask_awkward/lib/_utils.py b/src/dask_awkward/lib/_utils.py index 30d433f4..195a3fe0 100644 --- a/src/dask_awkward/lib/_utils.py +++ b/src/dask_awkward/lib/_utils.py @@ -1,7 +1,11 @@ from __future__ import annotations +from typing import Final + from awkward.forms.form import Form +LIST_KEY: Final = "__list__" + def set_form_keys(form: Form, *, key: str) -> Form: """Recursive function to apply key labels to `form`. @@ -35,7 +39,7 @@ def set_form_keys(form: Form, *, key: str) -> Form: # touched the offsets and not the data buffer for this kind of # identified form; keep recursing elif form.is_list: - form.form_key = f"{key}.__list__" + form.form_key = f"{key}.{LIST_KEY}" set_form_keys(form.content, key=key) # NumPy like array is easy diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 70435c0b..3158d4ba 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -13,6 +13,7 @@ from dask.local import get_sync from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer +from dask_awkward.lib._utils import LIST_KEY log = logging.getLogger(__name__) @@ -419,23 +420,31 @@ def _get_column_reports(dsk: HighLevelGraph) -> dict[str, Any]: def _necessary_columns(dsk: HighLevelGraph) -> dict[str, list[str]]: """Pair layer names with lists of necessary columns.""" - kv = {} + layer_to_columns = {} for name, report in _get_column_reports(dsk).items(): - cols = {_ for _ in report.data_touched if _ is not None} - select = [] - for col in sorted(cols): - if col == name: + touched_data_keys = {_ for _ in report.data_touched if _ is not None} + print(set(report.shape_touched), set(report.data_touched)) + necessary_columns = [] + for key in sorted(touched_data_keys): + if key == name: continue - n, c = col.split(".", 1) - if n == name: - if c.endswith("__list__"): - cnew = c[:-9].rstrip(".") - if cnew not in select: - select.append(f"{cnew}.*") - else: - select.append(c) - kv[name] = select - return kv + + layer, column = key.split(".", 1) + if layer != name: + continue + + # List offsets are tagged as {key}.{LIST_KEY}. This routine resolve + # _columns_, so we use a wildcard to indicate that we want to load + # *any* child column of this list. If the list contains no records, + # then we load + if column.endswith(LIST_KEY): + list_parent_path = column[: -(len(LIST_KEY) + 1)].rstrip(".") + if list_parent_path not in necessary_columns: + necessary_columns.append(f"{list_parent_path}.*") + else: + necessary_columns.append(column) + layer_to_columns[name] = necessary_columns + return layer_to_columns def _prune_wildcards(columns: list[str], meta: AwkwardArray) -> list[str]: @@ -478,33 +487,41 @@ def _prune_wildcards(columns: list[str], meta: AwkwardArray) -> list[str]: good_columns: list[str] = [] wildcard_columns: list[str] = [] - for col in columns: - if ".*" in col: - wildcard_columns.append(col) + for column in columns: + if column.endswith(".*"): + wildcard_columns.append(column) else: - good_columns.append(col) + good_columns.append(column) - for col in wildcard_columns: + for column in wildcard_columns: # each time we meet a wildcard column we need to start back # with the original meta array. imeta = meta - colsplit = col.split(".")[:-1] - parts = list(reversed(colsplit)) - while parts: - part = parts.pop() + reverse_column_parts = [*column.split(".")[:-1]] + reverse_column_parts.reverse() + + while reverse_column_parts: + part = reverse_column_parts.pop() # for unnamed roots part may be an empty string, so we # need this if statement. if part: imeta = imeta[part] - for field in imeta.fields: - wholecol = f"{col[:-2]}.{field}" - if wholecol in good_columns: - break + definite_column = column[:-2] + # The given wildcard column contains no sub-columns, so load + # the column itself + if not imeta.fields: + good_columns.append(definite_column) + + # Otherwise, prefer a column that we already need to load else: - if imeta.fields: - good_columns.append(f"{col[:-2]}.{imeta.fields[0]}") + for field in imeta.fields: + field_column = f"{definite_column}.{field}" + if field_column in good_columns: + break + # Or, pick an arbitrary (first) column if no other fields are yet + # required else: - good_columns.append(col[:-2]) + good_columns.append(f"{definite_column}.{imeta.fields[0]}") return good_columns From 5ac5903227ebfb18e04f5743a2852b29356ac53d Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Thu, 14 Sep 2023 18:54:14 +0100 Subject: [PATCH 2/4] fix: support pickle-5 for placeholders --- src/dask_awkward/pickle.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/dask_awkward/pickle.py b/src/dask_awkward/pickle.py index c57f9ce6..06fee32a 100644 --- a/src/dask_awkward/pickle.py +++ b/src/dask_awkward/pickle.py @@ -2,9 +2,17 @@ __all__ = ("plugin",) -import pickle +from pickle import PickleBuffer import awkward as ak +from awkward.typetracer import PlaceholderArray + + +def maybe_make_pickle_buffer(buffer) -> PlaceholderArray | PickleBuffer: + if isinstance(buffer, PlaceholderArray): + return buffer + else: + return PickleBuffer(buffer) def pickle_record(record: ak.Record, protocol: int) -> tuple: @@ -18,7 +26,7 @@ def pickle_record(record: ak.Record, protocol: int) -> tuple: # For pickle >= 5, we can avoid copying the buffers if protocol >= 5: - container = {k: pickle.PickleBuffer(v) for k, v in container.items()} + container = {k: maybe_make_pickle_buffer(v) for k, v in container.items()} if record.behavior is ak.behavior: behavior = None @@ -43,7 +51,7 @@ def pickle_array(array: ak.Array, protocol: int) -> tuple: # For pickle >= 5, we can avoid copying the buffers if protocol >= 5: - container = {k: pickle.PickleBuffer(v) for k, v in container.items()} + container = {k: maybe_make_pickle_buffer(v) for k, v in container.items()} if array.behavior is ak.behavior: behavior = None From 1e54a19768f6ef738a57820cd4bcf292a75db32f Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Thu, 14 Sep 2023 18:54:59 +0100 Subject: [PATCH 3/4] refactor: clarify logic --- src/dask_awkward/lib/optimize.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 3158d4ba..48166765 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -494,10 +494,11 @@ def _prune_wildcards(columns: list[str], meta: AwkwardArray) -> list[str]: good_columns.append(column) for column in wildcard_columns: + definite_column = column[:-2] # each time we meet a wildcard column we need to start back # with the original meta array. imeta = meta - reverse_column_parts = [*column.split(".")[:-1]] + reverse_column_parts = [*definite_column.split(".")] reverse_column_parts.reverse() while reverse_column_parts: @@ -506,8 +507,6 @@ def _prune_wildcards(columns: list[str], meta: AwkwardArray) -> list[str]: # need this if statement. if part: imeta = imeta[part] - - definite_column = column[:-2] # The given wildcard column contains no sub-columns, so load # the column itself if not imeta.fields: From bac88eda16df1a53950bc161023da33177af6aa4 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Thu, 14 Sep 2023 18:57:38 +0100 Subject: [PATCH 4/4] chore: drop debug print --- src/dask_awkward/lib/optimize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 48166765..1b8707fb 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -423,7 +423,7 @@ def _necessary_columns(dsk: HighLevelGraph) -> dict[str, list[str]]: layer_to_columns = {} for name, report in _get_column_reports(dsk).items(): touched_data_keys = {_ for _ in report.data_touched if _ is not None} - print(set(report.shape_touched), set(report.data_touched)) + necessary_columns = [] for key in sorted(touched_data_keys): if key == name: