diff --git a/src/dask_awkward/layers/__init__.py b/src/dask_awkward/layers/__init__.py index df0ba5270..d4ba4c5eb 100644 --- a/src/dask_awkward/layers/__init__.py +++ b/src/dask_awkward/layers/__init__.py @@ -5,8 +5,8 @@ AwkwardTreeReductionLayer, ImplementsIOFunction, ImplementsProjection, - IOFunctionWithMeta, - io_func_implements_project, + IOFunctionWithMocking, + io_func_implements_projection, ) __all__ = ( @@ -16,6 +16,6 @@ "AwkwardTreeReductionLayer", "ImplementsProjection", "ImplementsIOFunction", - "IOFunctionWithMeta", - "io_func_implements_project", + "IOFunctionWithMocking", + "io_func_implements_projection", ) diff --git a/src/dask_awkward/layers/layers.py b/src/dask_awkward/layers/layers.py index eb3a317bd..a1dcc2119 100644 --- a/src/dask_awkward/layers/layers.py +++ b/src/dask_awkward/layers/layers.py @@ -54,11 +54,12 @@ def __call__(self, *args, **kwargs) -> AwkwardArray: T = TypeVar("T") -class ImplementsProjection(Protocol): - @property - def meta(self) -> AwkwardArray: +class ImplementsMocking(Protocol): + def mock(self) -> AwkwardArray: ... + +class ImplementsProjection(ImplementsMocking, Protocol): def prepare_for_projection(self) -> tuple[AwkwardArray, T]: ... @@ -66,15 +67,21 @@ def project(self, state: T) -> ImplementsIOFunction: ... -# IO functions may not end up performing buffer projection, so they -# should also support directly returning the result +# IO functions can implement full-blown projection class ImplementsIOFunctionWithProjection( ImplementsProjection, ImplementsIOFunction, Protocol ): ... -class IOFunctionWithMeta(ImplementsIOFunctionWithProjection): +# Or they can implement simple mocking +class ImplementsIOFunctionWithMocking( + ImplementsMocking, ImplementsIOFunction, Protocol +): + ... + + +class IOFunctionWithMocking(ImplementsIOFunctionWithProjection): def __init__(self, meta: AwkwardArray, io_func: ImplementsIOFunction): self._meta = meta self._io_func = io_func @@ -86,15 +93,13 @@ def __call__(self, *args, **kwargs) -> AwkwardArray: def meta(self) -> AwkwardArray: return self._meta - def prepare_for_projection(self) -> tuple[AwkwardArray, None]: - return self._meta, None - def project(self, state: None): - return self._io_func +def io_func_implements_projection(func: ImplementsIOFunction) -> bool: + return hasattr(func, "prepare_for_projection") -def io_func_implements_project(func: ImplementsIOFunction) -> bool: - return hasattr(func, "project") +def io_func_implements_mocking(func: ImplementsIOFunction) -> bool: + return hasattr(func, "mock") class AwkwardInputLayer(AwkwardBlockwiseLayer): @@ -108,7 +113,9 @@ def __init__( *, name: str, inputs: Any, - io_func: ImplementsIOFunction | ImplementsIOFunctionWithProjection, + io_func: ImplementsIOFunction + | ImplementsIOFunctionWithMocking + | ImplementsIOFunctionWithProjection, label: str | None = None, produces_tasks: bool = False, creation_info: dict | None = None, @@ -142,11 +149,25 @@ def __repr__(self) -> str: @property def is_projectable(self) -> bool: # isinstance(self.io_func, ImplementsProjection) - return io_func_implements_project(self.io_func) + return io_func_implements_projection(self.io_func) + + @property + def is_mockable(self) -> bool: + # isinstance(self.io_func, ImplementsMocking) + return io_func_implements_mocking(self.io_func) def mock(self) -> AwkwardInputLayer: - layer, _ = self.prepare_for_projection() - return layer + assert self.is_mockable + + return AwkwardInputLayer( + name=self.name, + inputs=[None][: int(list(self.numblocks.values())[0][0])], + io_func=lambda *_, **__: self.io_func.mock(), + label=self.label, + produces_tasks=self.produces_tasks, + creation_info=self.creation_info, + annotations=self.annotations, + ) def prepare_for_projection(self) -> tuple[AwkwardInputLayer, T]: """Mock the input layer as starting with a data-less typetracer. diff --git a/src/dask_awkward/lib/io/io.py b/src/dask_awkward/lib/io/io.py index e733a0b82..be207a24f 100644 --- a/src/dask_awkward/lib/io/io.py +++ b/src/dask_awkward/lib/io/io.py @@ -4,7 +4,7 @@ import warnings from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Mapping, Protocol +from typing import TYPE_CHECKING, Any, Callable, Mapping, Protocol, cast import awkward as ak import numpy as np @@ -19,10 +19,13 @@ AwkwardInputLayer, ImplementsIOFunction, ImplementsProjection, - IOFunctionWithMeta, - io_func_implements_project, ) -from dask_awkward.layers.layers import AwkwardMaterializedLayer +from dask_awkward.layers.layers import ( + AwkwardMaterializedLayer, + ImplementsMocking, + IOFunctionWithMocking, + io_func_implements_mocking, +) from dask_awkward.lib.core import ( empty_typetracer, map_partitions, @@ -566,18 +569,19 @@ def from_map( packed=packed, ) - # Special `io_func` implementations can - if io_func_implements_project(func): + # Special `io_func` implementations can implement mocking and optionally + # support buffer projection. + if io_func_implements_mocking(func): io_func = func - array_meta = func.meta + array_meta = cast(ImplementsMocking, func).mock() + # If we know the meta, we can spoof mocking + elif meta is not None: + io_func = IOFunctionWithMocking(meta, func) + array_meta = meta # Without `meta`, the meta will be computed by executing the graph - elif meta is None: + else: io_func = func array_meta = None - # If we know the meta, we can spoof projection - else: - io_func = IOFunctionWithMeta(meta, func) - array_meta = meta dsk = AwkwardInputLayer(name=name, inputs=inputs, io_func=io_func) diff --git a/src/dask_awkward/lib/io/json.py b/src/dask_awkward/lib/io/json.py index 94e0e9365..6824e5eeb 100644 --- a/src/dask_awkward/lib/io/json.py +++ b/src/dask_awkward/lib/io/json.py @@ -74,8 +74,7 @@ def __init__( def __call__(self, source: Any) -> ak.Array: ... - @property - def meta(self) -> AwkwardArray: + def mock(self) -> AwkwardArray: return ak.typetracer.typetracer_from_form(self.form, behavior=self.behavior) def prepare_for_projection(self) -> tuple[AwkwardArray, dict]: diff --git a/src/dask_awkward/lib/io/parquet.py b/src/dask_awkward/lib/io/parquet.py index 48d8e62e1..febd52f29 100644 --- a/src/dask_awkward/lib/io/parquet.py +++ b/src/dask_awkward/lib/io/parquet.py @@ -68,6 +68,9 @@ def __init__( def __call__(self, source: Any) -> ak.Array: ... + def mock(self) -> AwkwardArray: + return ak.typetracer.typetracer_from_form(self.form, behavior=self.behavior) + def prepare_for_projection(self) -> tuple[AwkwardArray, dict]: form = form_with_unique_keys(self.form, "") @@ -84,10 +87,6 @@ def prepare_for_projection(self) -> tuple[AwkwardArray, dict]: "report": report, } - @property - def meta(self) -> AwkwardArray: - return ak.typetracer.typetracer_from_form(self.form, behavior=self.behavior) - @abc.abstractmethod def project(self, state: dict) -> _FromParquetFn: ... diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 3a11f8c7e..a388ad785 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -126,6 +126,8 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph: projection_layers[name], layer_to_projection_state[name], ) = lay.prepare_for_projection() + elif lay.is_mockable: + projection_layers[name] = lay.mock() elif hasattr(lay, "mock"): projection_layers[name] = lay.mock()