Skip to content

Commit

Permalink
Merge pull request #366 from dask-contrib/agoose77/fix-pickle-placeho…
Browse files Browse the repository at this point in the history
…lder

fix: support placeholder arrays in pickling
  • Loading branch information
agoose77 authored Sep 18, 2023
2 parents b60a25e + c7165e7 commit 3125bf4
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 35 deletions.
6 changes: 5 additions & 1 deletion src/dask_awkward/lib/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

from typing import Final

from awkward.forms.form import Form

LIST_KEY: Final = "__list__"


def set_form_keys(form: Form, *, key: str) -> Form:
"""Recursive function to apply key labels to `form`.
Expand Down Expand Up @@ -35,7 +39,7 @@ def set_form_keys(form: Form, *, key: str) -> Form:
# touched the offsets and not the data buffer for this kind of
# identified form; keep recursing
elif form.is_list:
form.form_key = f"{key}.__list__"
form.form_key = f"{key}.{LIST_KEY}"
set_form_keys(form.content, key=key)

# NumPy like array is easy
Expand Down
78 changes: 47 additions & 31 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dask.local import get_sync

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer
from dask_awkward.lib._utils import LIST_KEY

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -419,23 +420,31 @@ def _get_column_reports(dsk: HighLevelGraph) -> dict[str, Any]:

def _necessary_columns(dsk: HighLevelGraph) -> dict[str, list[str]]:
"""Pair layer names with lists of necessary columns."""
kv = {}
layer_to_columns = {}
for name, report in _get_column_reports(dsk).items():
cols = {_ for _ in report.data_touched if _ is not None}
select = []
for col in sorted(cols):
if col == name:
touched_data_keys = {_ for _ in report.data_touched if _ is not None}

necessary_columns = []
for key in sorted(touched_data_keys):
if key == name:
continue

layer, column = key.split(".", 1)
if layer != name:
continue
n, c = col.split(".", 1)
if n == name:
if c.endswith("__list__"):
cnew = c[:-9].rstrip(".")
if cnew not in select:
select.append(f"{cnew}.*")
else:
select.append(c)
kv[name] = select
return kv

# List offsets are tagged as {key}.{LIST_KEY}. This routine resolve
# _columns_, so we use a wildcard to indicate that we want to load
# *any* child column of this list. If the list contains no records,
# then we load
if column.endswith(LIST_KEY):
list_parent_path = column[: -(len(LIST_KEY) + 1)].rstrip(".")
if list_parent_path not in necessary_columns:
necessary_columns.append(f"{list_parent_path}.*")
else:
necessary_columns.append(column)
layer_to_columns[name] = necessary_columns
return layer_to_columns


def _prune_wildcards(columns: list[str], meta: AwkwardArray) -> list[str]:
Expand Down Expand Up @@ -478,33 +487,40 @@ def _prune_wildcards(columns: list[str], meta: AwkwardArray) -> list[str]:

good_columns: list[str] = []
wildcard_columns: list[str] = []
for col in columns:
if ".*" in col:
wildcard_columns.append(col)
for column in columns:
if column.endswith(".*"):
wildcard_columns.append(column)
else:
good_columns.append(col)
good_columns.append(column)

for col in wildcard_columns:
for column in wildcard_columns:
definite_column = column[:-2]
# each time we meet a wildcard column we need to start back
# with the original meta array.
imeta = meta
colsplit = col.split(".")[:-1]
parts = list(reversed(colsplit))
while parts:
part = parts.pop()
reverse_column_parts = [*definite_column.split(".")]
reverse_column_parts.reverse()

while reverse_column_parts:
part = reverse_column_parts.pop()
# for unnamed roots part may be an empty string, so we
# need this if statement.
if part:
imeta = imeta[part]
# The given wildcard column contains no sub-columns, so load
# the column itself
if not imeta.fields:
good_columns.append(definite_column)

for field in imeta.fields:
wholecol = f"{col[:-2]}.{field}"
if wholecol in good_columns:
break
# Otherwise, prefer a column that we already need to load
else:
if imeta.fields:
good_columns.append(f"{col[:-2]}.{imeta.fields[0]}")
for field in imeta.fields:
field_column = f"{definite_column}.{field}"
if field_column in good_columns:
break
# Or, pick an arbitrary (first) column if no other fields are yet
# required
else:
good_columns.append(col[:-2])
good_columns.append(f"{definite_column}.{imeta.fields[0]}")

return good_columns
14 changes: 11 additions & 3 deletions src/dask_awkward/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@

__all__ = ("plugin",)

import pickle
from pickle import PickleBuffer

import awkward as ak
from awkward.typetracer import PlaceholderArray


def maybe_make_pickle_buffer(buffer) -> PlaceholderArray | PickleBuffer:
if isinstance(buffer, PlaceholderArray):
return buffer
else:
return PickleBuffer(buffer)


def pickle_record(record: ak.Record, protocol: int) -> tuple:
Expand All @@ -18,7 +26,7 @@ def pickle_record(record: ak.Record, protocol: int) -> tuple:

# For pickle >= 5, we can avoid copying the buffers
if protocol >= 5:
container = {k: pickle.PickleBuffer(v) for k, v in container.items()}
container = {k: maybe_make_pickle_buffer(v) for k, v in container.items()}

if record.behavior is ak.behavior:
behavior = None
Expand All @@ -43,7 +51,7 @@ def pickle_array(array: ak.Array, protocol: int) -> tuple:

# For pickle >= 5, we can avoid copying the buffers
if protocol >= 5:
container = {k: pickle.PickleBuffer(v) for k, v in container.items()}
container = {k: maybe_make_pickle_buffer(v) for k, v in container.items()}

if array.behavior is ak.behavior:
behavior = None
Expand Down

0 comments on commit 3125bf4

Please sign in to comment.