From ffa20fb750cf8f707168838baa4bc98faa58797c Mon Sep 17 00:00:00 2001 From: Doug Davis Date: Mon, 25 Sep 2023 10:39:07 -0500 Subject: [PATCH] misc: typing (#373) --- pyproject.toml | 10 +++- src/dask_awkward/layers/layers.py | 5 +- src/dask_awkward/lib/core.py | 75 +++++++++--------------- src/dask_awkward/lib/inspect.py | 9 ++- src/dask_awkward/lib/io/io.py | 45 +++++++------- src/dask_awkward/lib/io/json.py | 66 ++++----------------- src/dask_awkward/lib/io/parquet.py | 71 +--------------------- src/dask_awkward/lib/io/text.py | 12 ++-- src/dask_awkward/lib/operations.py | 5 +- src/dask_awkward/lib/optimize.py | 42 ++++++------- src/dask_awkward/lib/str.py | 7 +-- src/dask_awkward/lib/structure.py | 4 +- src/dask_awkward/lib/unproject_layout.py | 5 +- src/dask_awkward/pickle.py | 6 +- src/dask_awkward/utils.py | 7 ++- tests/conftest.py | 6 +- tests/test_core.py | 16 +++-- tests/test_io.py | 8 +-- tests/test_io_json.py | 3 +- tests/test_optimize.py | 4 +- tests/test_parquet.py | 2 +- 21 files changed, 148 insertions(+), 260 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 193c4db3..71022383 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ classifiers = [ dependencies = [ "awkward >=2.4.0", "dask >=2023.04.0", - "typing_extensions>=4.8.0; python_version < \"3.11\"", + "typing_extensions >=4.8.0", ] dynamic = ["version"] @@ -144,6 +144,14 @@ warn_unreachable = true module = ["tlz.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] + module = ["uproot.*"] + ignore_missing_imports = true + +[[tool.mypy.overrides]] + module = ["cloudpickle.*"] + ignore_missing_imports = true + [tool.pyright] include = ["src"] pythonVersion = "3.9" diff --git a/src/dask_awkward/layers/layers.py b/src/dask_awkward/layers/layers.py index 002d078b..5af5c971 100644 --- a/src/dask_awkward/layers/layers.py +++ b/src/dask_awkward/layers/layers.py @@ -214,7 +214,7 @@ def __init__( super().__init__(mapping, **kwargs) def mock(self) -> tuple[MaterializedLayer, Any | None]: - mapping = self.mapping.copy() + mapping = copy.copy(self.mapping) if not mapping: # no partitions at all return self, None @@ -256,9 +256,6 @@ def mock(self) -> tuple[MaterializedLayer, Any | None]: task = (self.fn, *name0s) return MaterializedLayer({(name, 0): task}), None - # failed to cull during column opt - return self, None - class AwkwardTreeReductionLayer(DataFrameTreeReduction): def mock(self) -> tuple[AwkwardTreeReductionLayer, Any | None]: diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index da984384..90c97f60 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -16,6 +16,7 @@ import awkward as ak import dask.config import numpy as np +from awkward._do import remove_structure as ak_do_remove_structure from awkward._nplikes.typetracer import ( MaybeNone, OneOf, @@ -833,26 +834,22 @@ def _getitem_trivial_map_partitions( label=label, ) - def _getitem_outer_bool_or_int_lazy_array( - self, where: Array | tuple[Any, ...] - ) -> Any: + def _getitem_outer_bool_or_int_lazy_array(self, where): ba = where if isinstance(where, Array) else where[0] if partition_compatibility(self, ba) == PartitionCompatibility.NO: raise IncompatiblePartitions("getitem", self, ba) - new_meta: Any | None = None - if self._meta is not None: - if isinstance(where, tuple): - raise DaskAwkwardNotImplemented( - "tuple style input boolean/int selection is not supported." - ) - elif isinstance(where, Array): - new_meta = self._meta[where._meta] - return self.map_partitions( - operator.getitem, - where, - meta=new_meta, - ) + if isinstance(where, tuple): + raise DaskAwkwardNotImplemented( + "tuple style input boolean/int selection is not supported." + ) + + new_meta = self._meta[where._meta] + return self.map_partitions( + operator.getitem, + where, + meta=new_meta, + ) def _getitem_outer_str_or_list( self, @@ -942,9 +939,9 @@ def _getitem_outer_int(self, where: int | tuple[Any, ...]) -> Any: else: return new_scalar_object(hlg, name, meta=new_meta) - def _getitem_slice_on_zero(self, where: tuple[slice, ...]): + def _getitem_slice_on_zero(self, where): # normalise - sl: slice = where[0] + sl = where[0] rest = tuple(where[1:]) step = sl.step or 1 start = sl.start or 0 @@ -1014,7 +1011,7 @@ def _getitem_slice_on_zero(self, where: tuple[slice, ...]): divisions=tuple(divisions), ) - def _getitem_tuple(self, where: tuple[Any, ...]) -> Array: + def _getitem_tuple(self, where): if isinstance(where[0], int): return self._getitem_outer_int(where) @@ -1052,7 +1049,7 @@ def _getitem_tuple(self, where: tuple[Any, ...]) -> Array: f"Array.__getitem__ doesn't support multi object: {where}" ) - def _getitem_single(self, where: Any) -> Array: + def _getitem_single(self, where): # a single string if isinstance(where, str): return self._getitem_outer_str_or_list(where, label=where) @@ -1089,17 +1086,7 @@ def _getitem_single(self, where: Any) -> Array: raise DaskAwkwardNotImplemented(f"__getitem__ doesn't support where={where}.") - @overload - def __getitem__(self, where: Array | str | Sequence[str] | slice) -> Array: - ... - - @overload - def __getitem__(self, where: int) -> Scalar: - ... - - def __getitem__( - self, where: Array | str | Sequence[str] | int | slice - ) -> Array | Scalar: + def __getitem__(self, where): """Select items from the collection. Heavily under construction. @@ -1369,9 +1356,7 @@ def head(self, nrow=10, compute=True): By default this is then processed eagerly and returned. """ - out: Array = self.partitions[0].map_partitions( - lambda x: x[:nrow], meta=self._meta - ) + out = self.partitions[0].map_partitions(lambda x: x[:nrow], meta=self._meta) if compute: return out.compute() if self.known_divisions: @@ -1727,16 +1712,13 @@ def map_partitions( ) -PartialReductionType = ak.Array - - def _chunk_reducer_non_positional( - chunk: ak.Array | PartialReductionType, + chunk: ak.Array, is_axis_none: bool, *, reducer: Callable, mask_identity: bool, -) -> PartialReductionType: +) -> ak.Array: return reducer( chunk, keepdims=True, @@ -1746,14 +1728,14 @@ def _chunk_reducer_non_positional( def _concat_reducer_non_positional( - partials: list[PartialReductionType], is_axis_none: bool + partials: list[ak.Array], is_axis_none: bool ) -> ak.Array: concat_axis = -1 if is_axis_none else 0 return ak.concatenate(partials, axis=concat_axis) def _finalise_reducer_non_positional( - partial: PartialReductionType, + partial: ak.Array, is_axis_none: bool, *, reducer: Callable, @@ -1771,7 +1753,7 @@ def _finalise_reducer_non_positional( def _prepare_axis_none_chunk(chunk: ak.Array) -> ak.Array: # TODO: this is private Awkward code. We should figure out how to export it # if needed - (layout,) = ak._do.remove_structure( + (layout,) = ak_do_remove_structure( ak.to_layout(chunk), flatten_records=False, drop_nones=False, @@ -1785,7 +1767,7 @@ def non_trivial_reduction( *, label: str, array: Array, - axis: Literal[0] | None, + axis: int | None, is_positional: bool, keepdims: bool, mask_identity: bool, @@ -1794,7 +1776,7 @@ def non_trivial_reduction( token: str | None = None, dtype: Any | None = None, split_every: int | bool | None = None, -): +) -> Array | Scalar: if is_positional: raise NotImplementedError("positional reducers at axis=0 or axis=None") @@ -1807,8 +1789,9 @@ def non_trivial_reduction( if combiner is None: combiner = reducer - if is_positional: - assert combiner is reducer + # is_positional == True is not implemented + # if is_positional: + # assert combiner is reducer # For `axis=None`, we prepare each array to have the following structure: # [[[ ... [x1 x2 x3 ... xN] ... ]]] (length-1 outer lists) diff --git a/src/dask_awkward/lib/inspect.py b/src/dask_awkward/lib/inspect.py index cf760c43..280d9162 100644 --- a/src/dask_awkward/lib/inspect.py +++ b/src/dask_awkward/lib/inspect.py @@ -87,7 +87,11 @@ def necessary_columns(*args: Any, traverse: bool = True) -> dict[str, list[str]] return out -def sample(arr, factor: int | None = None, probability: float | None = None) -> Array: +def sample( + arr: Array, + factor: int | None = None, + probability: float | None = None, +) -> Array: """Decimate the data to a smaller number of rows. Must give either `factor` or `probability`. @@ -111,5 +115,6 @@ def sample(arr, factor: int | None = None, probability: float | None = None) -> return arr.map_partitions(lambda x: x[::factor], meta=arr._meta) else: return arr.map_partitions( - lambda x: x[np.random.random(len(x)) < probability], meta=arr._meta + lambda x: x[np.random.random(len(x)) < probability], # type: ignore + meta=arr._meta, ) diff --git a/src/dask_awkward/lib/io/io.py b/src/dask_awkward/lib/io/io.py index 7307767c..baa7f655 100644 --- a/src/dask_awkward/lib/io/io.py +++ b/src/dask_awkward/lib/io/io.py @@ -93,10 +93,10 @@ def from_awkward( """ nrows = len(source) if nrows == 0: - locs = [None, None] + locs: tuple[None, ...] | tuple[int, ...] = (None, None) else: chunksize = int(math.ceil(nrows / npartitions)) - locs = list(range(0, nrows, chunksize)) + [nrows] + locs = tuple(list(range(0, nrows, chunksize)) + [nrows]) starts = locs[:-1] stops = locs[1:] meta = typetracer_array(source) @@ -106,7 +106,7 @@ def from_awkward( stops, label=label or "from-awkward", token=tokenize(source, npartitions), - divisions=tuple(locs), + divisions=locs, meta=meta, behavior=behavior, ) @@ -116,11 +116,11 @@ class _FromListsFn: def __init__(self, behavior: dict | None = None): self.behavior = behavior - def __call__(self, x, **kwargs): + def __call__(self, x: list) -> ak.Array: return ak.Array(x, behavior=self.behavior) -def from_lists(source: list[list[Any]], behavior: dict | None = None) -> Array: +def from_lists(source: list, behavior: dict | None = None) -> Array: """Create an Array collection from a list of lists. Parameters @@ -149,7 +149,7 @@ def from_lists(source: list[list[Any]], behavior: dict | None = None) -> Array: lists = list(source) divs = (0, *np.cumsum(list(map(len, lists)))) return from_map( - _FromListsFn(), + _FromListsFn(behavior=behavior), lists, meta=typetracer_array(ak.Array(lists[0])), divisions=divs, @@ -383,7 +383,7 @@ def from_dask_array(array: DaskArray, behavior: dict | None = None) -> Array: def to_dataframe( - array, + array: Array, optimize_graph: bool = True, **kwargs: Any, ) -> DaskDataFrame: @@ -463,7 +463,7 @@ def from_map( args: tuple[Any, ...] | None = None, label: str | None = None, token: str | None = None, - divisions: tuple[int, ...] | None = None, + divisions: tuple[int, ...] | tuple[None, ...] | None = None, meta: ak.Array | None = None, behavior: dict | None = None, **kwargs: Any, @@ -603,7 +603,7 @@ def _bytes_with_sample( compression: str | None, delimiter: bytes, not_zero: bool, - blocksize: str | int, + blocksize: str | int | None, sample: str | int | bool, ) -> tuple[list[list[_BytesReadingInstructions]], bytes]: """Generate instructions for reading bytes from paths in a filesystem. @@ -653,7 +653,7 @@ def _bytes_with_sample( if blocksize is None: offsets = [[0]] * len(paths) - lengths = [[None]] * len(paths) + lengths: list = [[None]] * len(paths) else: offsets = [] lengths = [] @@ -717,21 +717,16 @@ def _bytes_with_sample( sample_size = parse_bytes(sample) if isinstance(sample, str) else sample with fs.open(paths[0], compression=compression) as f: # read block without seek (because we start at zero) - if delimiter is None: - sample_bytes = f.read(sample_size) - else: - sample_buff = f.read(sample_size) - while True: - new = f.read(sample_size) - if not new: - break - if delimiter in new: - sample_buff = ( - sample_buff + new.split(delimiter, 1)[0] + delimiter - ) - break - sample_buff = sample_buff + new - sample_bytes = sample_buff + sample_buff = f.read(sample_size) + while True: + new = f.read(sample_size) + if not new: + break + if delimiter in new: + sample_buff = sample_buff + new.split(delimiter, 1)[0] + delimiter + break + sample_buff = sample_buff + new + sample_bytes = sample_buff rfind = sample_bytes.rfind(delimiter) if rfind > 0: diff --git a/src/dask_awkward/lib/io/json.py b/src/dask_awkward/lib/io/json.py index e94ad3b6..65f8ade4 100644 --- a/src/dask_awkward/lib/io/json.py +++ b/src/dask_awkward/lib/io/json.py @@ -418,13 +418,13 @@ def _from_json_bytes( ) bytes_ingredients, the_sample_bytes = _bytes_with_sample( - fs, - paths, - compression, - delimiter, - not_zero, - blocksize, - sample_bytes, + fs=fs, + paths=paths, + compression=compression, + delimiter=delimiter, + not_zero=not_zero, + blocksize=blocksize, + sample=sample_bytes, ) sample_array = ak.from_json(the_sample_bytes, line_delimited=True, **kwargs) @@ -478,7 +478,7 @@ def from_json( resize: float = 8, highlevel: bool = True, behavior: dict | None = None, - blocksize: str | None = None, + blocksize: int | str | None = None, delimiter: bytes | None = None, compression: str | None = "infer", storage_options: dict[str, Any] | None = None, @@ -526,10 +526,10 @@ def from_json( dask-awkward. behavior : dict, optional See :func:`ak.from_json` - blocksize : str, optional + blocksize : int, str, optional If ``None`` (default), the collection will be partitioned on a per-file bases. If defined, this sets the size (in bytes) of - each partition. + each partition. Can be a string of the form ``"10 MiB"``. delimiter : bytes, optional Delimiter to use for separating blocks; if ``blocksize`` is defined but this argument is not defined, the default is the @@ -701,50 +701,10 @@ def __call__(self, array: ak.Array, block_index: tuple[int]) -> None: return None -@overload def to_json( array: Array, path: str, - line_delimited: bool | str = True, - num_indent_spaces: int | None = None, - num_readability_spaces: int = 0, - nan_string: str | None = None, - posinf_string: str | None = None, - neginf_string: str | None = None, - complex_record_fields: tuple[str, str] | None = None, - convert_bytes: Callable | None = None, - convert_other: Callable | None = None, - storage_options: dict[str, Any] | None = None, - compression: str | None = None, - compute: Literal[False] = False, -) -> Scalar: - ... - - -@overload -def to_json( - array: Array, - path: str, - line_delimited: bool | str, - num_indent_spaces: int | None, - num_readability_spaces: int, - nan_string: str | None, - posinf_string: str | None, - neginf_string: str | None, - complex_record_fields: tuple[str, str] | None, - convert_bytes: Callable | None, - convert_other: Callable | None, - storage_options: dict[str, Any] | None, - compression: str | None, - compute: Literal[True], -) -> None: - ... - - -def to_json( - array: Array, - path: str, - line_delimited: bool | str = True, + line_delimited: bool = True, num_indent_spaces: int | None = None, num_readability_spaces: int = 0, nan_string: str | None = None, @@ -755,7 +715,7 @@ def to_json( convert_other: Callable | None = None, storage_options: dict[str, Any] | None = None, compression: str | None = None, - compute: bool = False, + compute: bool = True, ) -> Scalar | None: """Store Array collection in JSON text. @@ -767,7 +727,7 @@ def to_json( Root directory to save data; interpreted by filesystem-spec (can be a remote filesystem path, for example an s3 bucket: ``"s3://bucket/data"``). - line_delimited : bool | str + line_delimited : bool See docstring for :py:func:`ak.to_json`. num_indent_spaces : int, optional See docstring for :py:func:`ak.to_json`. diff --git a/src/dask_awkward/lib/io/parquet.py b/src/dask_awkward/lib/io/parquet.py index 567d639e..8b840c54 100644 --- a/src/dask_awkward/lib/io/parquet.py +++ b/src/dask_awkward/lib/io/parquet.py @@ -6,7 +6,7 @@ import math import operator from collections.abc import Sequence -from typing import Any, Literal, overload +from typing import Any, Literal import awkward as ak import awkward.operations.ak_from_parquet as ak_from_parquet @@ -463,78 +463,9 @@ def __call__(self, data, block_index): ) -@overload def to_parquet( array: Array, destination: str, - *, - list_to32: bool, - string_to32: bool, - bytestring_to32: bool, - emptyarray_to: Any | None, - categorical_as_dictionary: bool, - extensionarray: bool, - count_nulls: bool, - compression: str | dict | None, - compression_level: int | dict | None, - row_group_size: int | None, - data_page_size: int | None, - parquet_flavor: Literal["spark"] | None, - parquet_version: Literal["1.0"] | Literal["2.4"] | Literal["2.6"], - parquet_page_version: Literal["1.0"] | Literal["2.0"], - parquet_metadata_statistics: bool | dict, - parquet_dictionary_encoding: bool | dict, - parquet_byte_stream_split: bool | dict, - parquet_coerce_timestamps: Literal["ms"] | Literal["us"] | None, - parquet_old_int96_timestamps: bool | None, - parquet_compliant_nested: bool, - parquet_extra_options: dict | None, - storage_options: dict[str, Any] | None, - write_metadata: bool, - compute: Literal[True], - prefix: str | None, -) -> None: - ... - - -@overload -def to_parquet( - array: Array, - destination: str, - *, - list_to32: bool, - string_to32: bool, - bytestring_to32: bool, - emptyarray_to: Any | None, - categorical_as_dictionary: bool, - extensionarray: bool, - count_nulls: bool, - compression: str | dict | None, - compression_level: int | dict | None, - row_group_size: int | None, - data_page_size: int | None, - parquet_flavor: Literal["spark"] | None, - parquet_version: Literal["1.0"] | Literal["2.4"] | Literal["2.6"], - parquet_page_version: Literal["1.0"] | Literal["2.0"], - parquet_metadata_statistics: bool | dict, - parquet_dictionary_encoding: bool | dict, - parquet_byte_stream_split: bool | dict, - parquet_coerce_timestamps: Literal["ms"] | Literal["us"] | None, - parquet_old_int96_timestamps: bool | None, - parquet_compliant_nested: bool, - parquet_extra_options: dict | None, - storage_options: dict[str, Any] | None, - write_metadata: bool, - compute: Literal[False], - prefix: str | None, -) -> Scalar: - ... - - -def to_parquet( - array: Array, - destination: str, - *, list_to32: bool = False, string_to32: bool = True, bytestring_to32: bool = True, diff --git a/src/dask_awkward/lib/io/text.py b/src/dask_awkward/lib/io/text.py index eb82f5eb..a45996ba 100644 --- a/src/dask_awkward/lib/io/text.py +++ b/src/dask_awkward/lib/io/text.py @@ -100,12 +100,12 @@ def from_text( bytes_ingredients, _ = _bytes_with_sample( fs, - paths, - compression, - delimiter, - False, - blocksize, - False, + paths=paths, + compression=compression, + delimiter=delimiter, + not_zero=False, + blocksize=blocksize, + sample=False, ) # meta is _always_ an unknown length array of strings. diff --git a/src/dask_awkward/lib/operations.py b/src/dask_awkward/lib/operations.py index 9abbef3c..32e94479 100644 --- a/src/dask_awkward/lib/operations.py +++ b/src/dask_awkward/lib/operations.py @@ -50,14 +50,15 @@ def concatenate( i += 1 meta = ak.concatenate(metas) + assert isinstance(meta, ak.Array) prev_names = [iarr.name for iarr in arrays] - g = AwkwardMaterializedLayer( + aml = AwkwardMaterializedLayer( g, previous_layer_names=prev_names, fn=_concatenate_axis0_multiarg, ) - hlg = HighLevelGraph.from_collections(name, g, dependencies=arrays) + hlg = HighLevelGraph.from_collections(name, aml, dependencies=arrays) return new_array_object(hlg, name, meta=meta, npartitions=npartitions) if axis > 0: diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 50ce634e..58586438 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -66,7 +66,7 @@ def all_optimizations( def optimize( - dsk: Mapping, + dsk: HighLevelGraph, keys: Hashable | list[Hashable] | set[Hashable], **_: Any, ) -> Mapping: @@ -79,9 +79,9 @@ def optimize( if dask.config.get("awkward.optimization.enabled"): which = dask.config.get("awkward.optimization.which") if "columns" in which: - dsk = optimize_columns(dsk) # type: ignore + dsk = optimize_columns(dsk) if "layer-chains" in which: - dsk = rewrite_layer_chains(dsk) + dsk = rewrite_layer_chains(dsk, keys) return dsk @@ -224,12 +224,12 @@ def _touch_and_call(layer): return new_layer -def rewrite_layer_chains(dsk: HighLevelGraph) -> HighLevelGraph: +def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph: # dask.optimization.fuse_liner for blockwise layers import copy chains = [] - deps = dsk.dependencies.copy() + deps = copy.copy(dsk.dependencies) layers = {} # find chains; each chain list is at least two keys long @@ -285,32 +285,32 @@ def rewrite_layer_chains(dsk: HighLevelGraph) -> HighLevelGraph: outkey = chain[-1] layer0 = dsk.layers[chain[0]] outlayer = layers[outkey] - numblocks = [nb[0] for nb in layer0.numblocks.values() if nb[0] is not None][0] - deps[outkey] = deps[chain[0]] - [deps.pop(ch) for ch in chain[:-1]] + numblocks = [nb[0] for nb in layer0.numblocks.values() if nb[0] is not None][0] # type: ignore + deps[outkey] = deps[chain[0]] # type: ignore + [deps.pop(ch) for ch in chain[:-1]] # type: ignore - subgraph = layer0.dsk.copy() - indices = list(layer0.indices) + subgraph = layer0.dsk.copy() # type: ignore + indices = list(layer0.indices) # type: ignore parent = chain[0] - outlayer.io_deps = layer0.io_deps + outlayer.io_deps = layer0.io_deps # type: ignore for chain_member in chain[1:]: layer = dsk.layers[chain_member] - for k in layer.io_deps: - outlayer.io_deps[k] = layer.io_deps[k] - func, *args = layer.dsk[chain_member] + for k in layer.io_deps: # type: ignore + outlayer.io_deps[k] = layer.io_deps[k] # type: ignore + func, *args = layer.dsk[chain_member] # type: ignore args2 = _recursive_replace(args, layer, parent, indices) subgraph[chain_member] = (func,) + tuple(args2) parent = chain_member - outlayer.numblocks = {i[0]: (numblocks,) for i in indices if i[1] is not None} - outlayer.dsk = subgraph + outlayer.numblocks = {i[0]: (numblocks,) for i in indices if i[1] is not None} # type: ignore + outlayer.dsk = subgraph # type: ignore if hasattr(outlayer, "_dims"): del outlayer._dims - outlayer.indices = tuple( + outlayer.indices = tuple( # type: ignore (i[0], (".0",) if i[1] is not None else None) for i in indices ) - outlayer.output_indices = (".0",) - outlayer.inputs = getattr(layer0, "inputs", set()) + outlayer.output_indices = (".0",) # type: ignore + outlayer.inputs = getattr(layer0, "inputs", set()) # type: ignore if hasattr(outlayer, "_cached_dict"): del outlayer._cached_dict # reset, since original can be mutated return HighLevelGraph(layers, deps) @@ -356,8 +356,8 @@ def _get_column_reports(dsk: HighLevelGraph) -> dict[str, Any]: # make labelled report projectable = _projectable_input_layer_names(dsk) - for name, lay in dsk.layers.copy().items(): - if name in projectable: + for name, lay in dsk.layers.items(): + if name in projectable and hasattr(lay, "mock"): layers[name], report = lay.mock() reports[name] = report elif hasattr(lay, "mock"): diff --git a/src/dask_awkward/lib/str.py b/src/dask_awkward/lib/str.py index e71ce125..85324ae9 100644 --- a/src/dask_awkward/lib/str.py +++ b/src/dask_awkward/lib/str.py @@ -1,16 +1,11 @@ from __future__ import annotations import functools -import sys from collections.abc import Callable from typing import Any, TypeVar import awkward.operations.str as akstr - -if sys.version_info < (3, 11, 0): - from typing_extensions import ParamSpec -else: - from typing import ParamSpec +from typing_extensions import ParamSpec from dask_awkward.lib.core import Array, map_partitions diff --git a/src/dask_awkward/lib/structure.py b/src/dask_awkward/lib/structure.py index 5e601686..7dd8c1a9 100644 --- a/src/dask_awkward/lib/structure.py +++ b/src/dask_awkward/lib/structure.py @@ -820,7 +820,7 @@ def unzip( if len(fields) == 0: return (array,) else: - return tuple(array[field] for field in fields) # type: ignore + return tuple(array[field] for field in fields) @borrow_docstring(ak.values_astype) @@ -1141,7 +1141,7 @@ def _repartition_func(*stuff): return ak.concatenate(data) -def repartition_layer(arr: Array, key: str, divisions: tuple[int, ...]): +def repartition_layer(arr: Array, key: str, divisions: tuple[int, ...]) -> dict: layer = {} indivs = arr.defined_divisions diff --git a/src/dask_awkward/lib/unproject_layout.py b/src/dask_awkward/lib/unproject_layout.py index 1235a418..21c06aa5 100644 --- a/src/dask_awkward/lib/unproject_layout.py +++ b/src/dask_awkward/lib/unproject_layout.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +from typing import Any import awkward as ak import numpy as np @@ -53,7 +54,7 @@ } -def dummy_index_of(typecode: str, length: int, nplike) -> ak.index.Index: +def dummy_index_of(typecode: str, length: int, nplike: Any) -> ak.index.Index: index_cls = index_of[typecode] dtype = dtype_of[typecode] return index_cls(PlaceholderArray(nplike, (length,), dtype), nplike=nplike) @@ -118,6 +119,8 @@ def compatible(form: Form, layout: Content) -> bool: else: return False + return False + def _unproject_layout(form, layout, length, backend): if layout is None: diff --git a/src/dask_awkward/pickle.py b/src/dask_awkward/pickle.py index 06fee32a..a53236c6 100644 --- a/src/dask_awkward/pickle.py +++ b/src/dask_awkward/pickle.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + __all__ = ("plugin",) from pickle import PickleBuffer @@ -8,7 +10,7 @@ from awkward.typetracer import PlaceholderArray -def maybe_make_pickle_buffer(buffer) -> PlaceholderArray | PickleBuffer: +def maybe_make_pickle_buffer(buffer: Any) -> PlaceholderArray | PickleBuffer: if isinstance(buffer, PlaceholderArray): return buffer else: @@ -65,7 +67,7 @@ def pickle_array(array: ak.Array, protocol: int) -> tuple: ) -def plugin(obj, protocol: int) -> tuple | NotImplemented: +def plugin(obj: Any, protocol: int) -> tuple: if isinstance(obj, ak.Record): return pickle_record(obj, protocol) elif isinstance(obj, ak.Array): diff --git a/src/dask_awkward/utils.py b/src/dask_awkward/utils.py index a90b12cb..a1b02bf6 100644 --- a/src/dask_awkward/utils.py +++ b/src/dask_awkward/utils.py @@ -3,11 +3,14 @@ from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, TypeVar +from typing_extensions import ParamSpec + if TYPE_CHECKING: from dask_awkward.lib.core import Array T = TypeVar("T") +P = ParamSpec("P") class DaskAwkwardNotImplemented(NotImplementedError): @@ -68,8 +71,8 @@ def keys(self): return ((i,) for i in range(len(self.inputs))) -def borrow_docstring(original: Callable[..., T]) -> Callable[..., T]: - def wrapper(method): +def borrow_docstring(original: Callable) -> Callable: + def wrapper(method: Callable[P, T]) -> Callable[P, T]: method.__doc__ = ( f"Partitioned version of ak.{original.__name__}\n" f"{original.__doc__}" ) diff --git a/tests/conftest.py b/tests/conftest.py index d5e8a3c8..09d629cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,9 +68,9 @@ def daa_old(ndjson_points1: str) -> dak.Array: @pytest.fixture(scope="session") -def pq_points_dir(daa_old, tmp_path_factory) -> str: +def pq_points_dir(daa_old: dak.Array, tmp_path_factory: pytest.TempPathFactory) -> str: pqdir = tmp_path_factory.mktemp("pqfiles") - dak.to_parquet(daa_old, str(pqdir), compute=True) + dak.to_parquet(daa_old, str(pqdir)) return str(pqdir) @@ -168,7 +168,7 @@ def L4() -> list[list[dict[str, float]] | None]: @pytest.fixture(scope="session") def caa_parquet(caa: ak.Array, tmp_path_factory: pytest.TempPathFactory) -> str: - fname = tmp_path_factory.mktemp("parquet_data") / "caa.parquet" # type: ignore + fname = tmp_path_factory.mktemp("parquet_data") / "caa.parquet" ak.to_parquet(caa, str(fname), extensionarray=False) return str(fname) diff --git a/tests/test_core.py b/tests/test_core.py index 7d1d5ed7..a0a0789e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -145,7 +145,7 @@ def test_partitions_divisions(ndjson_points_file: str) -> None: assert not t1.known_divisions t2 = daa.partitions[1] assert t2.known_divisions - assert t2.divisions == (0, divs[2] - divs[1]) + assert t2.divisions == (0, divs[2] - divs[1]) # type: ignore def test_array_rebuild(ndjson_points_file: str) -> None: @@ -178,7 +178,7 @@ def test_typestr(daa: Array) -> None: assert len(daa._typestr(max=20)) == 20 + extras -def test_head(daa: Array): +def test_head(daa: Array) -> None: out = daa.head(1) assert out.tolist() == daa.compute()[:1].tolist() @@ -233,7 +233,7 @@ def test_scalar_getitem_getattr() -> None: slice(None, None, 3), ], ) -def test_getitem_zero_slice_single(daa: Array, where): +def test_getitem_zero_slice_single(daa: Array, where: slice) -> None: out = daa[where] assert out.compute().tolist() == daa.compute()[where].tolist() assert len(out) == len(daa.compute()[where]) @@ -257,7 +257,11 @@ def test_getitem_zero_slice_single(daa: Array, where): ], ) @pytest.mark.parametrize("rest", [slice(None, None, None), slice(0, 1)]) -def test_getitem_zero_slice_tuple(daa: Array, where, rest): +def test_getitem_zero_slice_tuple( + daa: Array, + where: slice, + rest: slice, +) -> None: out = daa[where, rest] assert out.compute().tolist() == daa.compute()[where, rest].tolist() assert len(out) == len(daa.compute()[where, rest]) @@ -476,7 +480,7 @@ def test_compatible_partitions_after_slice() -> None: assert_eq(lazy, ccrt) # sanity - assert dak.compatible_partitions(lazy, lazy + 2) + assert dak.compatible_partitions(lazy, lazy + 2) # type: ignore assert dak.compatible_partitions(lazy, dak.num(lazy, axis=1) > 2) assert not dak.compatible_partitions(lazy[:-2], lazy) @@ -666,7 +670,7 @@ def test_optimize_chain_single(daa): arr = ((daa.points.x + 1) + 6).map_partitions(lambda x: x + 1) # first a simple test by calling the one optimisation directly - dsk2 = rewrite_layer_chains(arr.dask) + dsk2 = rewrite_layer_chains(arr.dask, arr.keys) (out,) = dask.compute(arr, optimize_graph=False) arr._dask = dsk2 (out2,) = dask.compute(arr, optimize_graph=False) diff --git a/tests/test_io.py b/tests/test_io.py index 173a3f6e..63a50f76 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -164,10 +164,10 @@ def f(a, b): dak.from_map(f, [1, 2], [3, 4, 5]) with pytest.raises(ValueError, match="must be `callable`"): - dak.from_map(5, [1], [2]) + dak.from_map(5, [1], [2]) # type: ignore with pytest.raises(ValueError, match="must be Iterable"): - dak.from_map(f, 1, [1, 2]) + dak.from_map(f, 1, [1, 2]) # type: ignore with pytest.raises(ValueError, match="non-zero length"): dak.from_map(f, [], [], []) @@ -252,7 +252,7 @@ def test_to_dataframe(daa: dak.Array, caa: ak.Array, optimize_graph: bool) -> No from dask.dataframe.utils import assert_eq daa = daa["points", ["x", "y"]] - caa = caa["points", ["x", "y"]] + caa = caa["points", ["x", "y"]] # pyright: ignore dd = dak.to_dataframe(daa, optimize_graph=optimize_graph) df = ak.to_dataframe(caa) @@ -277,7 +277,7 @@ def test_to_dataframe_str( assert_eq(dd, df, check_index=False) -def test_from_awkward_empty_array(daa) -> None: +def test_from_awkward_empty_array(daa: dak.Array) -> None: # no form c1 = ak.Array([]) assert len(c1) == 0 diff --git a/tests/test_io_json.py b/tests/test_io_json.py index 244b44aa..578358ef 100644 --- a/tests/test_io_json.py +++ b/tests/test_io_json.py @@ -190,7 +190,7 @@ def test_to_and_from_json( p1 = os.path.join(tdir, "z", "z") - dak.to_json(array=daa, path=p1, compute=True) + dak.to_json(daa, p1) paths = list((Path(tdir) / "z" / "z").glob("part*.json")) assert len(paths) == daa.npartitions arrays = ak.concatenate([ak.from_json(p, line_delimited=True) for p in paths]) @@ -205,6 +205,7 @@ def test_to_and_from_json( compression=compression, compute=False, ) + assert isinstance(s, dak.Scalar) s.compute() suffix = "gz" if compression == "gzip" else compression r = dak.from_json(os.path.join(tdir, f"*.json.{suffix}")) diff --git a/tests/test_optimize.py b/tests/test_optimize.py index 66108048..afb2da88 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -6,7 +6,7 @@ import dask_awkward as dak -def test_multiple_computes(pq_points_dir) -> None: +def test_multiple_computes(pq_points_dir: str) -> None: ds1 = dak.from_parquet(pq_points_dir) # add a columns= argument to force a new tokenize result in # from_parquet so we get two unique collections. @@ -26,4 +26,4 @@ def test_multiple_computes(pq_points_dir) -> None: assert len(things3[1]) < len(things3[0]) things = dask.compute(ds1.points, ds2.points.x, ds2.points.y, ds1.points.y, ds3) - assert things[-1].tolist() == ak.Array(lists[0] + lists[1]).tolist() + assert things[-1].tolist() == ak.Array(lists[0] + lists[1]).tolist() # type: ignore diff --git a/tests/test_parquet.py b/tests/test_parquet.py index b1da940e..544ab0f1 100644 --- a/tests/test_parquet.py +++ b/tests/test_parquet.py @@ -191,7 +191,7 @@ def test_to_parquet_with_prefix( tmp_path: pathlib.Path, prefix: str | None, ) -> None: - dak.to_parquet(daa, str(tmp_path), prefix=prefix, compute=True) + dak.to_parquet(daa, str(tmp_path), prefix=prefix) files = list(tmp_path.glob("*")) for ifile in files: fname = ifile.parts[-1]