Skip to content

Commit

Permalink
fix: harden typing in the code base; add pre-commit check; improve do…
Browse files Browse the repository at this point in the history
…cs (#383)
  • Loading branch information
douglasdavis authored Oct 12, 2023
1 parent 8c9ff67 commit a1b24fc
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 62 deletions.
13 changes: 12 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace

- repo: https://github.com/psf/black
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1
hooks:
- id: black
Expand Down Expand Up @@ -54,3 +54,14 @@ repos:
- id: blacken-docs
additional_dependencies:
- black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.0
hooks:
- id: mypy
args: [--ignore-missing-imports]
additional_dependencies:
- dask
- types-PyYAML
- pytest
- numpy
8 changes: 8 additions & 0 deletions docs/dev/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ Commit your changes, push your branch to your fork, and open a Pull
Request. We suggest that you install `pre-commit <precommit_>`_ to run
some checks locally when creating new commits.

Typing
------

We include a pre-commit hook that runs ``mypy`` for static type
checking. Code added to dask-awkward is *not required* to be typed,
but the pre-commit check does enforce correctness when type hints are
present.

Adding documentation
--------------------

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ src_paths = ["src", "tests"]
[tool.mypy]
python_version = "3.9"
files = ["src", "tests"]
strict = false
warn_unused_configs = true
show_error_codes = true
allow_incomplete_defs = false
Expand Down
6 changes: 3 additions & 3 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __repr__(self) -> str:


class ImplementsIOFunction(Protocol):
def __call__(self, *args, **kwargs) -> AwkwardArray:
def __call__(self, *args: Any, **kwargs: Any) -> AwkwardArray:
...


Expand Down Expand Up @@ -79,7 +79,7 @@ def __getstate__(self) -> dict:
state["_meta"] = None
return state

def __call__(self, *args, **kwargs) -> AwkwardArray:
def __call__(self, *args, **kwargs):
return self._io_func(*args, **kwargs)

def mock(self) -> AwkwardArray:
Expand Down Expand Up @@ -220,7 +220,7 @@ def project(
self,
report: TypeTracerReport,
state: T,
):
) -> AwkwardInputLayer:
assert self.is_projectable
return AwkwardInputLayer(
name=self.name,
Expand Down
5 changes: 2 additions & 3 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer
from dask_awkward.lib.optimize import all_optimizations
from dask_awkward.typing import AwkwardDaskCollection
from dask_awkward.utils import (
DaskAwkwardNotImplemented,
IncompatiblePartitions,
Expand Down Expand Up @@ -370,7 +369,7 @@ def _check_meta(self, m: Any | None) -> Any | None:
raise TypeError(f"meta must be a Record typetracer object, not a {type(m)}")
return m

def __getitem__(self, where: str) -> AwkwardDaskCollection:
def __getitem__(self, where):
token = tokenize(self, where)
new_name = f"{where}-{token}"
new_meta = self._meta[where]
Expand All @@ -397,7 +396,7 @@ def __getitem__(self, where: str) -> AwkwardDaskCollection:
else:
return new_scalar_object(hlg, new_name, meta=new_meta)

def __getattr__(self, attr: str) -> Any:
def __getattr__(self, attr):
if attr not in (self.fields or []):
raise AttributeError(f"{attr} not in fields.")
try:
Expand Down
24 changes: 18 additions & 6 deletions src/dask_awkward/lib/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def report_necessary_buffers(
*args: Any, traverse: bool = True
) -> dict[str, NecessaryBuffers | None]:
r"""Determine the buffer keys necessary to compute a collection.
Parameters
----------
*args : Dask collections or HighLevelGraphs
Expand All @@ -31,24 +32,30 @@ def report_necessary_buffers(
traverse : bool, optional
If True (default), builtin Python collections are traversed
looking for any Dask collections they might contain.
Returns
-------
dict[str, NecessaryBuffers | None]
Mapping that pairs the input layers in the graph to objects
describing the data and shape buffers that have been tagged
as required by column optimisation of the given layer.
Examples
--------
If we have a hypothetical parquet dataset (``ds``) with the fields
- "foo"
- "bar"
- "baz"
And the "baz" field has fields
- "x"
- "y"
The calculation of ``ds.bar + ds.baz.x`` will only require the
``bar`` and ``baz.x`` columns from the parquet file.
>>> import dask_awkward as dak
>>> ds = dak.from_parquet("some-dataset")
>>> ds.fields
Expand Down Expand Up @@ -96,6 +103,7 @@ def report_necessary_buffers(
shape_only = frozenset(report.shape_touched) - data_and_shape

# Update set of touched keys
assert existing_buffers is not None
name_to_necessary_buffers[name] = NecessaryBuffers(
data_and_shape=existing_buffers.data_and_shape | data_and_shape,
shape_only=existing_buffers.shape_only | shape_only,
Expand Down Expand Up @@ -161,7 +169,7 @@ def report_necessary_columns(

seen_names = set()

name_to_necessary_columns: dict[str, frozenset] = {}
name_to_necessary_columns: dict[str, frozenset | None] = {}
for obj in collections:
dsk = obj if isinstance(obj, HighLevelGraph) else obj.dask
projection_data = o._prepare_buffer_projection(dsk)
Expand All @@ -183,6 +191,7 @@ def report_necessary_columns(

existing_columns = name_to_necessary_columns.setdefault(name, frozenset())

assert existing_columns is not None
# Update set of touched keys
name_to_necessary_columns[
name
Expand Down Expand Up @@ -237,11 +246,14 @@ def sample(
rows will remain.
"""
if not (factor is None) ^ (probability is None):
if (factor is None and probability is None) or (
factor is not None and probability is not None
):
raise ValueError("Give exactly one of factor or probability")
if factor:
return arr.map_partitions(lambda x: x[::factor], meta=arr._meta)
else:
return arr.map_partitions(
lambda x: x[_random_boolean_like(x, probability)], meta=arr._meta
)
assert probability is not None
proba = float(probability)
return arr.map_partitions(
lambda x: x[_random_boolean_like(x, proba)], meta=arr._meta
)
13 changes: 8 additions & 5 deletions src/dask_awkward/lib/io/columnar.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Protocol, TypeVar, cast
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast

import awkward as ak
from awkward import Array as AwkwardArray
from awkward.forms import Form

from dask_awkward.layers.layers import ImplementsNecessaryColumns
from dask_awkward.layers.layers import ImplementsIOFunction, ImplementsNecessaryColumns
from dask_awkward.lib.utils import (
METADATA_ATTRIBUTES,
FormStructure,
Expand Down Expand Up @@ -39,6 +39,9 @@ def behavior(self) -> dict | None:
def project_columns(self: T, columns: frozenset[str]) -> T:
...

def __call__(self, *args: Any, **kwargs: Any) -> AwkwardArray:
...


S = TypeVar("S", bound=ImplementsColumnProjectionMixin)

Expand Down Expand Up @@ -84,7 +87,7 @@ def necessary_columns(
form_key_to_parent_form_key = state["form_key_to_parent_form_key"]
form_key_to_child_form_keys: dict[str, list[str]] = {}
for child_key, parent_key in form_key_to_parent_form_key.items():
form_key_to_child_form_keys.setdefault(parent_key, []).append(child_key)
form_key_to_child_form_keys.setdefault(parent_key, []).append(child_key) # type: ignore
form_key_to_form = state["form_key_to_form"]
# Buffer hierarchy information
form_key_to_buffer_keys = state["form_key_to_buffer_keys"]
Expand Down Expand Up @@ -151,8 +154,8 @@ def project(
self: S,
report: TypeTracerReport,
state: FormStructure,
) -> S:
if not self.use_optimization:
) -> ImplementsIOFunction:
if not self.use_optimization: # type: ignore[attr-defined]
return self

return self.project_columns(self.necessary_columns(report, state))
2 changes: 1 addition & 1 deletion src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Mapping, Protocol, cast
from typing import TYPE_CHECKING, Any, Callable, cast

import awkward as ak
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion src/dask_awkward/lib/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def use_optimization(self) -> bool:
and self.schema is None
)

def project_columns(self, columns: set[str]):
def project_columns(self, columns):
form = self.form.select_columns(columns)
assert form is not None
schema = layout_to_jsonschema(form.length_zero_array(highlevel=False))
Expand Down
6 changes: 3 additions & 3 deletions src/dask_awkward/lib/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(self, source: Any) -> ak.Array:
...

@abc.abstractmethod
def project_columns(self, columns: set[str]):
def project_columns(self, columns):
...

@property
Expand Down Expand Up @@ -122,7 +122,7 @@ def __call__(self, source: Any) -> Any:
)
return ak.Array(unproject_layout(self.original_form, array.layout))

def project_columns(self, columns: set[str]):
def project_columns(self, columns):
return _FromParquetFileWiseFn(
fs=self.fs,
form=self.form.select_columns(columns),
Expand Down Expand Up @@ -172,7 +172,7 @@ def __call__(self, pair: Any) -> ak.Array:
)
return ak.Array(unproject_layout(self.original_form, array.layout))

def project_columns(self, columns: set[str]):
def project_columns(self, columns):
return _FromParquetFragmentWiseFn(
fs=self.fs,
form=self.form.select_columns(columns),
Expand Down
2 changes: 1 addition & 1 deletion src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,5 +431,5 @@ def _recursive_replace(args, layer, parent, indices):

def _buffer_keys_for_layer(
buffer_keys: Iterable[str], known_buffer_keys: frozenset[str]
):
) -> set[str]:
return {k for k in buffer_keys if k in known_buffer_keys}
47 changes: 27 additions & 20 deletions src/dask_awkward/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__all__ = ("trace_form_structure", "buffer_keys_required_to_compute_shapes")

from collections.abc import Callable, Iterable, Iterator
from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping
from typing import TYPE_CHECKING, TypedDict, TypeVar

import awkward as ak
Expand All @@ -17,19 +17,23 @@


class FormStructure(TypedDict):
form_key_to_form: dict[str, Form]
form_key_to_parent_form_key: dict[str, str]
form_key_to_path: dict[str, tuple[str, ...]]
form_key_to_buffer_keys: dict[str, tuple[str, ...]]
form_key_to_form: MutableMapping[str, Form]
form_key_to_parent_form_key: MutableMapping[str, str | None]
form_key_to_path: MutableMapping[str, tuple[str, ...]]
form_key_to_buffer_keys: MutableMapping[str, tuple[str, ...]]


def trace_form_structure(form: Form, buffer_key: Callable) -> FormStructure:
form_key_to_form: dict[str, Form] = {}
form_key_to_parent_form_key: dict[str, str | None] = {}
form_key_to_path: dict[str, tuple[str, ...]] = {}
form_key_to_buffer_keys: dict[str, tuple[str, ...]] = {}

def impl_with_parent(form: Form, parent_form: Form | None, column_path):
form_key_to_form: MutableMapping[str, Form] = {}
form_key_to_parent_form_key: MutableMapping[str, str | None] = {}
form_key_to_path: MutableMapping[str, tuple[str, ...]] = {}
form_key_to_buffer_keys: MutableMapping[str, tuple[str, ...]] = {}

def impl_with_parent(
form: Form,
parent_form: Form | None,
column_path: tuple[str, ...],
) -> None:
# Associate child form key with parent form key
form_key_to_parent_form_key[form.form_key] = (
None if parent_form is None else parent_form.form_key
Expand Down Expand Up @@ -74,13 +78,13 @@ def impl_with_parent(form: Form, parent_form: Form | None, column_path):
T = TypeVar("T")


def walk_bijective_graph(node: T, graph: dict[T, T | None]) -> Iterator[T]:
while (node := graph.get(node)) is not None:
def walk_bijective_graph(node: T, graph: Mapping[T, T | None]) -> Iterator[T]:
while (node := graph.get(node)) is not None: # type: ignore[assignment]
yield node


def walk_graph_breadth_first(
node: T, graph: dict[T, Iterable[T] | None]
node: T, graph: Mapping[T, Iterable[T] | None]
) -> Iterator[T]:
children = graph.get(node)
if children is None:
Expand All @@ -90,7 +94,9 @@ def walk_graph_breadth_first(
yield from walk_graph_breadth_first(node, graph)


def walk_graph_depth_first(node: T, graph: dict[T, Iterable[T] | None]) -> Iterator[T]:
def walk_graph_depth_first(
node: T, graph: Mapping[T, Iterable[T] | None]
) -> Iterator[T]:
children = graph.get(node)
if children is None:
return
Expand All @@ -102,9 +108,9 @@ def walk_graph_depth_first(node: T, graph: dict[T, Iterable[T] | None]) -> Itera
def buffer_keys_required_to_compute_shapes(
parse_buffer_key: Callable[[str], tuple[str, str]],
shape_buffers: Iterable[str],
form_key_to_parent_key: dict[str, str],
form_key_to_buffer_keys: dict[str, Iterable[str]],
):
form_key_to_parent_key: Mapping[str, str | None],
form_key_to_buffer_keys: Mapping[str, Iterable[str]],
) -> Iterable[str]:
# Buffers needing known shapes must traverse all the way up the tree.
for buffer_key in shape_buffers:
form_key, attribute = parse_buffer_key(buffer_key)
Expand All @@ -126,11 +132,12 @@ def render_buffer_key(form: Form, form_key: str, attribute: str) -> str:


def parse_buffer_key(buffer_key: str) -> tuple[str, str]:
return buffer_key.rsplit("-", maxsplit=1)
head, tail = buffer_key.rsplit("-", maxsplit=1)
return head, tail


def form_with_unique_keys(form: Form, key: str) -> Form:
def impl(form: Form, key: str):
def impl(form: Form, key: str) -> None:
# Set form key
form.form_key = key

Expand Down
5 changes: 1 addition & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from __future__ import annotations

try:
import ujson as json
except ImportError:
import json # type: ignore[no-redef]
import json

import awkward as ak
import fsspec
Expand Down
Loading

0 comments on commit a1b24fc

Please sign in to comment.