From 2c3197801f25268ae4d8c29eeb030253efa198da Mon Sep 17 00:00:00 2001 From: Doug Davis Date: Mon, 30 Oct 2023 19:38:34 -0500 Subject: [PATCH] allow fn passed to from_map raise exceptions provided by caller --- src/dask_awkward/lib/io/io.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/dask_awkward/lib/io/io.py b/src/dask_awkward/lib/io/io.py index 6153021b6..68c01cdbe 100644 --- a/src/dask_awkward/lib/io/io.py +++ b/src/dask_awkward/lib/io/io.py @@ -1,10 +1,11 @@ from __future__ import annotations +import functools import math import warnings -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, cast +from typing import TYPE_CHECKING, Any, Type, cast import awkward as ak import numpy as np @@ -13,6 +14,7 @@ from dask.highlevelgraph import HighLevelGraph from dask.utils import funcname, is_integer, parse_bytes from fsspec.utils import infer_compression +from typing_extensions import ParamSpec from dask_awkward.layers import ( AwkwardBlockwiseLayer, @@ -41,6 +43,8 @@ from dask_awkward.lib.core import Array +P = ParamSpec("P") + class _FromAwkwardFn: def __init__(self, arr: ak.Array) -> None: @@ -460,6 +464,21 @@ def __call__(self, packed_arg): ) +def return_empty_on_raise( + fn: Callable[P, Array], + allowed_exceptions: tuple[Type[BaseException], ...], +) -> Callable[P, Array]: + @functools.wraps(fn) + def wrapped(*args, **kwargs): + try: + result = fn(*args, **kwargs) + return result + except allowed_exceptions: + return fn.form.length_zero_array() + + return wrapped + + def from_map( func: ImplementsIOFunction, *iterables: Iterable, @@ -468,6 +487,7 @@ def from_map( token: str | None = None, divisions: tuple[int, ...] | tuple[None, ...] | None = None, meta: ak.Array | None = None, + empty_on_raise: tuple[Type[BaseException], ...] | None = None, behavior: dict | None = None, **kwargs: Any, ) -> Array: @@ -574,6 +594,9 @@ def from_map( io_func = func array_meta = None + if empty_on_raise: + io_func = return_empty_on_raise(io_func, allowed_exceptions=empty_on_raise) + dsk = AwkwardInputLayer(name=name, inputs=inputs, io_func=io_func) if behavior is not None: