Skip to content

Commit

Permalink
misc: str module typing (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis authored Sep 18, 2023
1 parent 3125bf4 commit dff9a37
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 49 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ classifiers = [
dependencies = [
"awkward >=2.4.0",
"dask >=2023.04.0",
"typing_extensions>=4.8.0; python_version < \"3.11\"",
]
dynamic = ["version"]

Expand Down
107 changes: 59 additions & 48 deletions src/dask_awkward/lib/str.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from __future__ import annotations

import functools
import sys
from collections.abc import Callable
from typing import Any, TypeVar

import awkward.operations.str as akstr

if sys.version_info < (3, 11, 0):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec

from dask_awkward.lib.core import Array, map_partitions

T = TypeVar("T")
P = ParamSpec("P")


def always_highlevel(fn):
def always_highlevel(fn: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(fn)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if not kwargs.get("highlevel", True):
raise ValueError("dask-awkward supports only highlevel awkward arrays.")
return fn(*args, **kwargs)
Expand All @@ -35,8 +46,8 @@ def capitalize(
@always_highlevel
def center(
array: Array,
width,
padding=" ",
width: int,
padding: str | bytes = " ",
*,
highlevel: bool = True,
behavior: dict | None = None,
Expand Down Expand Up @@ -128,9 +139,9 @@ def extract_regex(
@always_highlevel
def find_substring(
array: Array,
pattern,
pattern: str | bytes,
*,
ignore_case=False,
ignore_case: bool = False,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand All @@ -147,9 +158,9 @@ def find_substring(
@always_highlevel
def find_substring_regex(
array: Array,
pattern,
pattern: str | bytes,
*,
ignore_case=False,
ignore_case: bool = False,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand All @@ -166,9 +177,9 @@ def find_substring_regex(
@always_highlevel
def index_in(
array: Array,
value_set,
value_set: Any,
*,
skip_nones=False,
skip_nones: bool = False,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand Down Expand Up @@ -260,9 +271,9 @@ def is_digit(
@always_highlevel
def is_in(
array: Array,
value_set,
value_set: Any,
*,
skip_nones=False,
skip_nones: bool = False,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand Down Expand Up @@ -369,7 +380,7 @@ def is_upper(
@always_highlevel
def join(
array: Array,
separator,
separator: Any,
*,
highlevel: bool = True,
behavior: dict | None = None,
Expand All @@ -385,7 +396,7 @@ def join(

@always_highlevel
def join_element_wise(
*arrays,
*arrays: Array,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand Down Expand Up @@ -430,8 +441,8 @@ def lower(
@always_highlevel
def lpad(
array: Array,
width,
padding=" ",
width: int,
padding: str | bytes = " ",
*,
highlevel: bool = True,
behavior: dict | None = None,
Expand All @@ -449,7 +460,7 @@ def lpad(
@always_highlevel
def ltrim(
array: Array,
characters,
characters: str | bytes,
*,
highlevel: bool = True,
behavior: dict | None = None,
Expand Down Expand Up @@ -481,9 +492,9 @@ def ltrim_whitespace(
@always_highlevel
def match_like(
array: Array,
pattern,
pattern: str | bytes,
*,
ignore_case=False,
ignore_case: bool = False,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand All @@ -500,9 +511,9 @@ def match_like(
@always_highlevel
def match_substring(
array: Array,
pattern,
pattern: str | bytes,
*,
ignore_case=False,
ignore_case: bool = False,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand All @@ -519,9 +530,9 @@ def match_substring(
@always_highlevel
def match_substring_regex(
array: Array,
pattern,
pattern: str | bytes,
*,
ignore_case=False,
ignore_case: bool = False,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand All @@ -538,7 +549,7 @@ def match_substring_regex(
@always_highlevel
def repeat(
array: Array,
num_repeats,
num_repeats: Any,
*,
highlevel: bool = True,
behavior: dict | None = None,
Expand All @@ -555,9 +566,9 @@ def repeat(
@always_highlevel
def replace_slice(
array: Array,
start,
stop,
replacement,
start: int,
stop: int,
replacement: str | bytes,
*,
highlevel: bool = True,
behavior: dict | None = None,
Expand All @@ -576,10 +587,10 @@ def replace_slice(
@always_highlevel
def replace_substring(
array: Array,
pattern,
replacement,
pattern: str,
replacement: str | bytes,
*,
max_replacements=None,
max_replacements: int | None = None,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand All @@ -597,10 +608,10 @@ def replace_substring(
@always_highlevel
def replace_substring_regex(
array: Array,
pattern,
replacement,
pattern: str,
replacement: str | bytes,
*,
max_replacements=None,
max_replacements: int | None = None,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand Down Expand Up @@ -633,8 +644,8 @@ def reverse(
@always_highlevel
def rpad(
array: Array,
width,
padding=" ",
width: int,
padding: str | bytes = " ",
*,
highlevel: bool = True,
behavior: dict | None = None,
Expand All @@ -652,7 +663,7 @@ def rpad(
@always_highlevel
def rtrim(
array: Array,
characters,
characters: str | bytes,
*,
highlevel: bool = True,
behavior: dict | None = None,
Expand Down Expand Up @@ -684,9 +695,9 @@ def rtrim_whitespace(
@always_highlevel
def slice(
array: Array,
start,
stop=None,
step=1,
start: int,
stop: int | None = None,
step: int = 1,
*,
highlevel: bool = True,
behavior: dict | None = None,
Expand All @@ -705,10 +716,10 @@ def slice(
@always_highlevel
def split_pattern(
array: Array,
pattern,
pattern: str | bytes,
*,
max_splits=None,
reverse=False,
max_splits: int | None = None,
reverse: bool = False,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand All @@ -726,10 +737,10 @@ def split_pattern(
@always_highlevel
def split_pattern_regex(
array: Array,
pattern,
pattern: str | bytes,
*,
max_splits=None,
reverse=False,
max_splits: int | None = None,
reverse: bool = False,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand Down Expand Up @@ -765,9 +776,9 @@ def split_whitespace(
@always_highlevel
def starts_with(
array: Array,
pattern,
pattern: str | bytes,
*,
ignore_case=False,
ignore_case: bool = False,
highlevel: bool = True,
behavior: dict | None = None,
) -> Array:
Expand Down Expand Up @@ -829,7 +840,7 @@ def to_categorical(
@always_highlevel
def trim(
array: Array,
characters,
characters: str | bytes,
*,
highlevel: bool = True,
behavior: dict | None = None,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_rtrim_whitespace() -> None:
(0, None, 3),
],
)
def test_slice(args) -> None:
def test_slice(args: tuple) -> None:
start, stop, step = args
assert_eq(akstr.slice(daa, start, stop, step), akstr.slice(caa, start, stop, step))

Expand Down

0 comments on commit dff9a37

Please sign in to comment.