Skip to content

Commit

Permalink
refactor: use new public API
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Nov 13, 2023
1 parent 07c3464 commit bcc8844
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 29 deletions.
18 changes: 8 additions & 10 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import dask.config
import numpy as np
from awkward._do import remove_structure as ak_do_remove_structure
from awkward._nplikes.typetracer import (
from awkward.highlevel import NDArrayOperatorsMixin, _dir_pattern
from awkward.typetracer import (
MaybeNone,
OneOf,
TypeTracerArray,
create_unknown_scalar,
is_unknown_scalar,
)
from awkward.highlevel import NDArrayOperatorsMixin, _dir_pattern
from dask.base import (
DaskMethodsMixin,
dont_optimize,
Expand Down Expand Up @@ -140,7 +141,9 @@ def key(self) -> Key:
return (self._name, 0)

def _check_meta(self, m: Any) -> Any | None:
if isinstance(m, (MaybeNone, OneOf)) or is_unknown_scalar(m):
if m is None:
return m
elif isinstance(m, (MaybeNone, OneOf)) or is_unknown_scalar(m):
return m
elif isinstance(m, ak.Array) and len(m) == 1:
return m
Expand Down Expand Up @@ -348,12 +351,9 @@ def new_scalar_object(dsk: HighLevelGraph, name: str, *, meta: Any) -> Scalar:
Resulting collection.
"""
if meta is None:
meta = ak.Array(TypeTracerArray._new(dtype=np.dtype(None), shape=()))

if isinstance(meta, MaybeNone):
meta = ak.Array(meta.content)
else:
elif meta is not None:
try:
if ak.backend(meta) != "typetracer":
raise TypeError(
Expand Down Expand Up @@ -411,9 +411,7 @@ def new_known_scalar(
dtype = np.dtype(dtype)
llg = AwkwardMaterializedLayer({(name, 0): s}, previous_layer_names=[])
hlg = HighLevelGraph.from_collections(name, llg, dependencies=())
return Scalar(
hlg, name, meta=TypeTracerArray._new(dtype=dtype, shape=()), known_value=s
)
return Scalar(hlg, name, meta=create_unknown_scalar(dtype), known_value=s)


class Record(Scalar):
Expand Down
31 changes: 17 additions & 14 deletions src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import awkward as ak
import numpy as np
from awkward._nplikes.typetracer import TypeTracerArray
from awkward.typetracer import create_unknown_scalar
from dask.base import is_dask_collection, tokenize
from dask.highlevelgraph import HighLevelGraph

Expand Down Expand Up @@ -94,7 +94,7 @@ def argcartesian(
with_name: str | None = None,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand Down Expand Up @@ -135,7 +135,7 @@ def argcombinations(
with_name: str | None = None,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand Down Expand Up @@ -184,7 +184,7 @@ def argsort(
stable: bool = True,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand Down Expand Up @@ -362,7 +362,7 @@ def drop_none(
axis: int | None = None,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand Down Expand Up @@ -491,6 +491,7 @@ def isclose(
equal_nan: bool = False,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand Down Expand Up @@ -534,6 +535,7 @@ def local_index(
axis: int = -1,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand All @@ -555,6 +557,7 @@ def mask(
valid_when: bool = True,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if partition_compatibility(array, mask) == PartitionCompatibility.NO:
raise IncompatiblePartitions("mask", array, mask)
Expand Down Expand Up @@ -603,7 +606,7 @@ def num(
axis: int = 1,
highlevel: bool = True,
behavior: dict | None = None,
attrs: Mapping[str, Any] = None,
attrs: Mapping[str, Any] | None = None,
) -> Any:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand All @@ -623,11 +626,7 @@ def num(
{(name, 0): (_numaxis0, *keys)}, previous_layer_names=[per_axis.name]
)
hlg = HighLevelGraph.from_collections(name, matlayer, dependencies=(per_axis,))
return new_scalar_object(
hlg,
name,
meta=TypeTracerArray._new(dtype=np.dtype(np.int64), shape=()),
)
return new_scalar_object(hlg, name, meta=create_unknown_scalar(np.int64))
else:
return map_partitions(
ak.num,
Expand All @@ -644,7 +643,7 @@ def ones_like(
array: Array,
highlevel: bool = True,
behavior: dict | None = None,
attrs: Mapping[str, Any] = None,
attrs: Mapping[str, Any] | None = None,
dtype: DTypeLike | None = None,
) -> Array:
if not highlevel:
Expand Down Expand Up @@ -697,6 +696,7 @@ def pad_none(
clip: bool = False,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand All @@ -721,7 +721,7 @@ def ravel(
array: Array,
highlevel: bool = True,
behavior: dict | None = None,
attrs: Mapping[str, Any] = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand All @@ -743,7 +743,7 @@ def run_lengths(
array: Array,
highlevel: bool = True,
behavior: dict | None = None,
attrs: Mapping[str, Any] = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand Down Expand Up @@ -863,6 +863,7 @@ def unflatten(
axis: int = 0,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand Down Expand Up @@ -911,6 +912,7 @@ def values_astype(
to: np.dtype | str,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand Down Expand Up @@ -1005,6 +1007,7 @@ def with_field(
where: str | Sequence[str] | None = None,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")
Expand Down
6 changes: 1 addition & 5 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,7 @@ def test_sort(daa, caa, ascending):


def test_copy(daa):
with pytest.raises(
DaskAwkwardNotImplemented,
match="This function is not necessary in the context of dask-awkward.",
):
dak.copy(daa)
assert dak.copy(daa) is daa


@pytest.mark.parametrize(
Expand Down

0 comments on commit bcc8844

Please sign in to comment.