Skip to content

Commit

Permalink
feat-computable-report-for-read-failure
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Nov 14, 2023
1 parent fa07194 commit d8e3fae
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 17 deletions.
77 changes: 65 additions & 12 deletions src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, cast, overload

import awkward as ak
import numpy as np
Expand All @@ -16,14 +16,6 @@
from dask.utils import funcname, is_integer, parse_bytes
from fsspec.utils import infer_compression

try:
from distributed.queues import Queue
from distributed.worker import get_worker
except ImportError:
Queue = None
get_worker = None


from dask_awkward.layers.layers import (
AwkwardBlockwiseLayer,
AwkwardInputLayer,
Expand All @@ -40,6 +32,7 @@
new_array_object,
typetracer_array,
)
from dask_awkward.utils import first, second

if TYPE_CHECKING:
from dask.array.core import Array as DaskArray
Expand Down Expand Up @@ -473,6 +466,28 @@ def __call__(self, packed_arg):
)


@dataclass
class ReadFailure:
args: tuple[Any, ...] | None
kwargs: dict[str, Any] | None
exception: Any | None
error: Any | None

def as_array(self):
return ak.Array(
[
{
"args": str(self.args),
"kwargs": str(self.kwargs),
"exception": str(self.exception.__name__)
if self.exception is not None
else "None",
"error": str(self.error),
}
]
)


def return_empty_on_raise(
fn: Callable,
allowed_exceptions: tuple[type[BaseException], ...],
Expand All @@ -481,8 +496,9 @@ def return_empty_on_raise(
@functools.wraps(fn)
def wrapped(*args, **kwargs):
try:
return fn(*args, **kwargs)
return fn(*args, **kwargs), ReadFailure(None, None, None, None).as_array()
except allowed_exceptions as err:
rf = ReadFailure(args, kwargs, type(err), err)
logmsg = (
"%s call failed with args %s and kwargs %s; empty array returned. %s"
% (
Expand All @@ -493,11 +509,43 @@ def wrapped(*args, **kwargs):
)
)
logger.info(logmsg)
return fn.mock_empty(backend)
return fn.mock_empty(backend), rf.as_array()

return wrapped


@overload
def from_map(
func: Callable,
*iterables: Iterable,
args: tuple[Any, ...] | None = None,
label: str | None = None,
token: str | None = None,
divisions: tuple[int, ...] | tuple[None, ...] | None = None,
meta: ak.Array | None = None,
empty_on_raise: None = None,
empty_backend: None = None,
**kwargs: Any,
) -> Array:
...


@overload
def from_map(
func: Callable,
*iterables: Iterable,
empty_on_raise: tuple[type[BaseException], ...],
empty_backend: BackendT,
args: tuple[Any, ...] | None = None,
label: str | None = None,
token: str | None = None,
divisions: tuple[int, ...] | tuple[None, ...] | None = None,
meta: ak.Array | None = None,
**kwargs: Any,
) -> tuple[Array, Array]:
...


def from_map(
func: Callable,
*iterables: Iterable,
Expand All @@ -509,7 +557,7 @@ def from_map(
empty_on_raise: tuple[type[BaseException], ...] | None = None,
empty_backend: BackendT | None = None,
**kwargs: Any,
) -> Array:
) -> Array | tuple[Array, Array]:
"""Create an Array collection from a custom mapping.
Parameters
Expand Down Expand Up @@ -639,6 +687,11 @@ def from_map(
else:
result = new_array_object(hlg, name, meta=array_meta, npartitions=len(inputs))

if empty_on_raise and empty_backend:
res = result.map_partitions(first, meta=array_meta)
rep = result.map_partitions(second, meta=empty_typetracer())
return res, rep

return result


Expand Down
6 changes: 6 additions & 0 deletions src/dask_awkward/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,9 @@ def first(seq: Iterable[T]) -> T:
"""
return next(iter(seq))


def second(seq: Iterable[T]) -> T:
the_iter = iter(seq)
next(the_iter)
return next(the_iter)
15 changes: 10 additions & 5 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_random_fail_from_lists():
divs = (0, *np.cumsum(list(map(len, many))))
form = ak.Array(many[0]).layout.form

array = from_map(
array, report = from_map(
RandomFailFromListsFn(form),
many,
meta=typetracer_array(ak.Array(many[0])),
Expand All @@ -357,8 +357,13 @@ def test_random_fail_from_lists():
)
assert len(array.compute()) < (len(single) * len(many))

computed_report = report.compute()
assert len(computed_report[computed_report["args"] == "None"]) < len(
computed_report
)

with pytest.raises(OSError, match="BAD"):
array = from_map(
array, report = from_map(
RandomFailFromListsFn(form),
many,
meta=typetracer_array(ak.Array(many[0])),
Expand All @@ -380,7 +385,7 @@ def test_random_fail_from_lists():
array.compute()

with pytest.raises(ValueError, match="must be used together"):
array = from_map(
from_map(
RandomFailFromListsFn(form),
many,
meta=typetracer_array(ak.Array(many[0])),
Expand All @@ -390,7 +395,7 @@ def test_random_fail_from_lists():
)

with pytest.raises(ValueError, match="must be used together"):
array = from_map(
from_map(
RandomFailFromListsFn(form),
many,
meta=typetracer_array(ak.Array(many[0])),
Expand All @@ -410,7 +415,7 @@ def __call__(self, *args):
return self.x * args[0]

with pytest.raises(ValueError, match="must implement"):
array = from_map(
from_map(
NoMockEmpty(5),
many,
meta=typetracer_array(ak.Array(many[0])),
Expand Down

0 comments on commit d8e3fae

Please sign in to comment.