diff --git a/src/dask_awkward/lib/io/columnar.py b/src/dask_awkward/lib/io/columnar.py index 345d1c7e..88b8bb02 100644 --- a/src/dask_awkward/lib/io/columnar.py +++ b/src/dask_awkward/lib/io/columnar.py @@ -1,10 +1,11 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, TypeVar, cast +from typing import TYPE_CHECKING, Protocol, TypeVar, cast import awkward as ak from awkward import Array as AwkwardArray +from awkward.forms import Form from dask_awkward.layers.layers import ImplementsNecessaryColumns from dask_awkward.lib.utils import ( @@ -26,6 +27,22 @@ T = TypeVar("T") +class ImplementsColumnProjectionMixin(ImplementsNecessaryColumns, Protocol): + @property + def form(self) -> Form: + ... + + @property + def behavior(self) -> dict | None: + ... + + def project_columns(self: T, columns: frozenset[str]) -> T: + ... + + +S = TypeVar("S", bound=ImplementsColumnProjectionMixin) + + class ColumnProjectionMixin(ImplementsNecessaryColumns[FormStructure]): """A mixin to add column-centric buffer projection to an IO function. @@ -35,11 +52,11 @@ class ColumnProjectionMixin(ImplementsNecessaryColumns[FormStructure]): when only metadata buffers are required. """ - def mock(self) -> AwkwardArray: + def mock(self: S) -> AwkwardArray: return ak.typetracer.typetracer_from_form(self.form, behavior=self.behavior) def prepare_for_projection( - self, + self: S, ) -> tuple[AwkwardArray, TypeTracerReport, FormStructure]: form = form_with_unique_keys(self.form, "@") @@ -58,7 +75,7 @@ def prepare_for_projection( ) def necessary_columns( - self, + self: S, report: TypeTracerReport, state: FormStructure, ) -> frozenset[str]: @@ -131,10 +148,10 @@ def necessary_columns( return frozenset({".".join(p) for p in paths if p}) def project( - self: T, + self: S, report: TypeTracerReport, state: FormStructure, - ) -> T: + ) -> S: if not self.use_optimization: return self