Skip to content

Commit

Permalink
fix: correct LSP
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Oct 3, 2023
1 parent 30d2cf8 commit e0827e9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 32 deletions.
14 changes: 9 additions & 5 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def __getstate__(self) -> dict:
return d


T = TypeVar("T")


class ImplementsIOFunction(Protocol):
def __call__(self, *args, **kwargs) -> AwkwardArray:
...


T = TypeVar("T")


class ImplementsProjection(Protocol):
@property
def meta(self) -> AwkwardArray:
Expand Down Expand Up @@ -83,7 +83,7 @@ def __call__(self, *args, **kwargs) -> AwkwardArray:
return self._io_func(*args, **kwargs)

@property
def meta(self):
def meta(self) -> AwkwardArray:
return self._meta

def prepare_for_projection(self) -> tuple[AwkwardArray, None]:
Expand Down Expand Up @@ -144,7 +144,11 @@ def is_projectable(self) -> bool:
# isinstance(self.io_func, ImplementsProjection)
return io_func_implements_project(self.io_func)

def mock(self) -> tuple[AwkwardInputLayer, T]:
def mock(self) -> AwkwardInputLayer:
layer, _ = self.prepare_for_projection()
return layer

def prepare_for_projection(self) -> tuple[AwkwardInputLayer, T]:
assert self.is_projectable
new_meta_array, state = self.io_func.prepare_for_projection()

Expand Down
35 changes: 8 additions & 27 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,15 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:

layer_to_projection_state: dict[str, Any] = {}
projection_layers = dsk.layers.copy() # type:
projectable = _projectable_input_layer_names(dsk)
for name, lay in dsk.layers.items():
if name in projectable:
# Insert mocked array into layers, replacing generation func
# Keep track of mocked state
projection_layers[name], layer_to_projection_state[name] = lay.mock()
if isinstance(lay, AwkwardInputLayer):
if lay.is_projectable:
# Insert mocked array into layers, replacing generation func
# Keep track of mocked state
(
projection_layers[name],
layer_to_projection_state[name],
) = lay.prepare_for_projection()
elif hasattr(lay, "mock"):
projection_layers[name] = lay.mock()

Expand Down Expand Up @@ -185,28 +188,6 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
return HighLevelGraph(layers, dsk.dependencies)


def _projectable_input_layer_names(dsk: HighLevelGraph) -> list[str]:
"""Get list of column-projectable AwkwardInputLayer names.
Parameters
----------
dsk : HighLevelGraph
Task graph of interest
Returns
-------
list[str]
Names of the AwkwardInputLayers in the graph that are
column-projectable.
"""
return [
n
for n, v in dsk.layers.items()
if isinstance(v, AwkwardInputLayer) and v.is_projectable
]


def _layers_with_annotation(dsk: HighLevelGraph, key: str) -> list[str]:
return [n for n, v in dsk.layers.items() if (v.annotations or {}).get(key)]

Expand Down

0 comments on commit e0827e9

Please sign in to comment.