Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mapfilter decorator #551

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

pfackeldey
Copy link
Collaborator

@pfackeldey pfackeldey commented Nov 19, 2024

This PR adds a new decorator, called mapfilter, that behaves similarly to dask_awkward.map_partitions, but extends it with some useful features for e.g. HEP analyses:

  1. It can return multiple values and wraps them into dask collections - all with the same partitioning
  2. It makes sure that all input dask collections have the same partitioning
  3. The new needs argument can be used to touch additional columns
  4. The meta argument allows to mock the output return values. This essentially allows to skip the tracing step if combined with needs.

dak.mapfilter

A decorated function will be a single node in the compute graph for the single value return case. For multiple return values it will be 2 nodes in the compute graph (one for the decorated function, and one to select the return value). For multiple nested return values the number of nodes corresponds to the nesting depth + 1.

An example of it's usefulness is shown in the following:

import dask_awkward as dak
import awkward as ak
import numpy as np


ak_array = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]})
dak_array = dak.from_awkward(ak_array, 2)

class some: ...

@dak.mapfilter
def fun(x):
  y = x.foo + 1
  return y, (np.sum(y),), some(), ak.Array(np.ones(4))
  

# this is not possible with `dask_awkward.map_partitions`  
y, y_sum, something, static = fun(dak_array)

# print the graph (HLG)
# We're seeing 3 nodes:
# 0. IO-layer
# 1. the decorated function `fun`
# 2. a "pick" layer that selection the correct value from the output tuple, here `y` is the 0-th element of all return values of `fun`
print(y.dask)
# >> HighLevelGraph with 3 layers.
# >> <dask.highlevelgraph.HighLevelGraph object at 0x106c881c0>
# >> 0. from-awkward-2669279954392e1535b365de1bfdef38
# >> 1. <dask-awkward.lib.core.ArgsKwargsPackedFunction ob-5f5b871945e30263d4972530f5679e79
# >> 2. <dask-awkward.lib.core.ArgsKwargsPackedFunction ob-5f5b871945e30263d4972530f5679e79-pick-0th

print(y_sum.compute())
# >> [array(5), array(9)]

# we can also track metadata per partition, e.g.:
print(something.compute())
# >> (<__main__.some at 0x10a7fa680>, <__main__.some at 0x10a7fa650>)

print(static.compute())
# >> <Array [1, 1, 1, 1, 1, 1, 1, 1] type='8 * float64'>

Untraceable functions

In a complex HEP analysis it may happen that some computation is not traceable (i.e. a user leaves the "awkward-array world"). For this, needs and meta exist:

import dask_awkward as dak
import awkward as ak
import numpy as np


ak_array = ak.zip({"pt": [10, 20, 30, 40], "eta": [1, 1, 1, 1]})
dak_array = dak.from_awkward(ak_array, 2)

def untraceable_fun(muons):
  # a non-traceable computation for ak.typetracer, because we're switching to NumPy (non-awkward)
  # which needs "pt" column from muons and returns a 1-element array (per partition)
  pt = ak.to_numpy(muons.pt)
  return ak.Array([np.sum(pt)])
  
dak.map_partitions(untraceable_fun, dak_array)
# >> TypeError: Converting from an nplike without known data to an nplike with known data is not supported


# This can be circumvented by mocking the output and specifying explicitly the columns that need to be read:
from functools import partial

@partial(
  dak.mapfilter,
  needs={"muons": ["pt"]},
  meta=ak.Array([0, 0]),
)
def untraceable_fun(muons):
  # a non-traceable computation for ak.typetracer, because we're switching to NumPy (non-awkward)
  # which needs "pt" column from muons and returns a 1-element array (per partition)
  pt = ak.to_numpy(muons.pt)
  return ak.Array([np.sum(pt)])
  
out = untraceable_fun(dak_array)
print(out.compute())
# >> <Array [30, 70] type='2 * int64'>

# check what needs to be read:
cols = next(iter(dak.report_necessary_columns(out).values()))
print(cols)
# >> frozenset({'pt'})

There are 3 cases that need to be considered:

  1. typetracing is fine and no if conditions are present: dak.mapfilter works in the same way as dak.map_partitions.
  2. typetracing is fine, but there is branched code (if conditions): dak.mapfilter can be used with needs to touch additional columns needed in the if conditions.
  3. typracing fails: there's not much one can do about it except for skipping the tracing step. This can be done by providing needs with all needed columns, and meta with the expected outputs of the function.

dak.prerun

For complex untraceable functions especially needs may be cumbersome to provide, for this dak.prerun exists:

ak_array = ak.zip({"pt": [10, 20, 30, 40], "eta": [1, 1, 1, 1]})
dak_array = dak.from_awkward(ak_array, 2)

def untraceable_fun(muons):
  # a non-traceable computation for ak.typetracer, because we're switching to NumPy (non-awkward)
  # which needs "pt" column from muons and returns a 1-element array (per partition)
  pt = ak.to_numpy(muons.pt)
  return ak.Array([np.sum(pt)])

meta, needs = dak.prerun(untraceable_fun, muons=dak_array)
# >> UntraceableFunctionError: '<function untraceable_fun at 0x1056117e0>' is not traceable, an error occurred at line 7. 'dak.mapfilter' can circumvent this by providing 'needs' and 'meta' arguments to it.
#
# - 'needs': mapping where the keys point to input argument dask_awkward arrays and the values to columns that should be touched explicitly. The typetracing step could determine the following necessary columns until the exception occurred:
#
# needs={'muons': [('pt',)]}
#
# - 'meta': value(s) of what the wrapped function would return. For arrays, only the shape and type matter.

dak.prerun does a typetracing step of a given function and tries to infer needs and meta from it. If the function is untraceable (like in this example) it will report all recorded needs up to the point where the tracing failed. This can be useful for providing needs by hand to an untraceable function.

In addition, providing meta skips running the type tracer through the computation of untraceable_fun entirely - similar to how map_partitions works -, which can be beneficial if untraceable_fun is a computational expensive operation (e.g. evaluation of a neural network).
A useful trick here is to run meta, needs = dak.prerun(fun, *args, **kwargs) once, store meta and needs and provide it to dak.mapfilter in consecutive runs in order to avoid multiple unnecessary and costly tracings.

Other notes

Currently, there are only 2 types of dask collections that can be returned: a dask_awkward.Array or a dask.bag.Bag. It would be nice if array would be correctly wrapped into dask.Arrays and dataframe-likes into dask.DataFrames - this is currently not supported.
Instead it is recommended to wrap them into python collections (will be wrapped into dask Bags) or with awkward-arrays (will be wrapped into dak.Array).

@codecov-commenter
Copy link

codecov-commenter commented Nov 19, 2024

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

Attention: Patch coverage is 70.67308% with 61 lines in your changes missing coverage. Please review.

Project coverage is 91.81%. Comparing base (8cb8994) to head (afb628c).
Report is 157 commits behind head on main.

Files with missing lines Patch % Lines
src/dask_awkward/lib/mapfilter.py 62.91% 56 Missing ⚠️
src/dask_awkward/lib/io/parquet.py 75.00% 3 Missing ⚠️
src/dask_awkward/lib/core.py 95.34% 2 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #551      +/-   ##
==========================================
- Coverage   93.06%   91.81%   -1.26%     
==========================================
  Files          23       23              
  Lines        3290     3557     +267     
==========================================
+ Hits         3062     3266     +204     
- Misses        228      291      +63     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@pfackeldey pfackeldey marked this pull request as draft November 19, 2024 18:36
@pfackeldey
Copy link
Collaborator Author

Hi @martindurant,
This PR is ready to be reviewed now.
I'd only need to increase the awkward version once it is released 👍 (that's why I'm leaving it as draft for now).

@@ -483,6 +483,7 @@ def __init__(
npartitions: int,
prefix: str | None = None,
storage_options: dict | None = None,
write_metadata: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change (and lines below) leaked from another PR

@martindurant
Copy link
Collaborator

# we can also track metadata per partition, e.g.:
print(something.compute())
# >> (<__main__.some at 0x10a7fa680>, <__main__.some at 0x10a7fa650>)

print(static.compute())
# >> <Array [1, 1, 1, 1, 1, 1, 1, 1] type='8 * float64'>

Quick question on usage: why are there two returns for the first one above, but only one for the second?

@pfackeldey
Copy link
Collaborator Author

# we can also track metadata per partition, e.g.:
print(something.compute())
# >> (<__main__.some at 0x10a7fa680>, <__main__.some at 0x10a7fa650>)

print(static.compute())
# >> <Array [1, 1, 1, 1, 1, 1, 1, 1] type='8 * float64'>

Quick question on usage: why are there two returns for the first one above, but only one for the second?

The first one is a dask.Bag, while the second one is a dak.Array. In the dak.Array case the partitions are stacked by concatenating the first axis, while a dask.Bag returns a list/tuple with 1 element per partition.

@pfackeldey
Copy link
Collaborator Author

I'd like to follow up on this PR after some work I'm doing related to #559. I found that there's some synergy with mapfilter, that would be nice to implement.
I'll update this PR afterwards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants