Skip to content

Commit

Permalink
allow fn passed to from_map raise exceptions provided by caller
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Oct 31, 2023
1 parent c9f10ab commit 2c31978
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import functools
import math
import warnings
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, cast
from typing import TYPE_CHECKING, Any, Type, cast

import awkward as ak
import numpy as np
Expand All @@ -13,6 +14,7 @@
from dask.highlevelgraph import HighLevelGraph
from dask.utils import funcname, is_integer, parse_bytes
from fsspec.utils import infer_compression
from typing_extensions import ParamSpec

from dask_awkward.layers import (
AwkwardBlockwiseLayer,
Expand Down Expand Up @@ -41,6 +43,8 @@

from dask_awkward.lib.core import Array

P = ParamSpec("P")


class _FromAwkwardFn:
def __init__(self, arr: ak.Array) -> None:
Expand Down Expand Up @@ -460,6 +464,21 @@ def __call__(self, packed_arg):
)


def return_empty_on_raise(
fn: Callable[P, Array],
allowed_exceptions: tuple[Type[BaseException], ...],
) -> Callable[P, Array]:
@functools.wraps(fn)
def wrapped(*args, **kwargs):
try:
result = fn(*args, **kwargs)
return result
except allowed_exceptions:
return fn.form.length_zero_array()

return wrapped


def from_map(
func: ImplementsIOFunction,
*iterables: Iterable,
Expand All @@ -468,6 +487,7 @@ def from_map(
token: str | None = None,
divisions: tuple[int, ...] | tuple[None, ...] | None = None,
meta: ak.Array | None = None,
empty_on_raise: tuple[Type[BaseException], ...] | None = None,
behavior: dict | None = None,
**kwargs: Any,
) -> Array:
Expand Down Expand Up @@ -574,6 +594,9 @@ def from_map(
io_func = func
array_meta = None

if empty_on_raise:
io_func = return_empty_on_raise(io_func, allowed_exceptions=empty_on_raise)

dsk = AwkwardInputLayer(name=name, inputs=inputs, io_func=io_func)

if behavior is not None:
Expand Down

0 comments on commit 2c31978

Please sign in to comment.