Skip to content

Commit

Permalink
Merge pull request #203 from jorenham/datasets/improvements
Browse files Browse the repository at this point in the history
`datasets`: Improved type-hints with `TypedDict`s and `Literal`s
  • Loading branch information
jorenham authored Nov 26, 2024
2 parents 97d44e2 + 3dd0b4c commit 1cc893b
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 26 deletions.
2 changes: 1 addition & 1 deletion scipy-stubs/datasets/_download_all.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from os import PathLike

def download_all(path: str | PathLike[str] | None = None) -> None: ...
def main() -> None: ...
def main() -> None: ... # undocumented
25 changes: 15 additions & 10 deletions scipy-stubs/datasets/_fetchers.pyi
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from typing import Final, Literal, TypeAlias, overload
from typing import Final, Literal as L, overload
from typing_extensions import LiteralString

import numpy as np
from ._typing import AscentDataset, CanFetch, ECGDataset, Face2Dataset, Face3Dataset

# TODO: stub `pooch` (this should be a `pooch.code.Pooch`)
_DataFetcher: TypeAlias = object
data_fetcher: Final[_DataFetcher]
###

def fetch_data(dataset_name: LiteralString, data_fetcher: _DataFetcher = ...) -> LiteralString: ...
def ascent() -> np.ndarray[tuple[Literal[512], Literal[512]], np.dtype[np.uint8]]: ...
def electrocardiogram() -> np.ndarray[tuple[Literal[108_000]], np.dtype[np.float64]]: ...
data_fetcher: Final[CanFetch | None] = ... # undocumented

def fetch_data(
dataset_name: L["ascent.dat", "ecg.dat", "face.dat"],
data_fetcher: CanFetch | None = None,
) -> LiteralString: ... # undocumented

#
def ascent() -> AscentDataset: ...
def electrocardiogram() -> ECGDataset: ...
@overload
def face(gray: Literal[False] = False) -> np.ndarray[tuple[Literal[768], Literal[1_024], Literal[3]], np.dtype[np.uint8]]: ...
def face(gray: L[True, 1]) -> Face2Dataset: ...
@overload
def face(gray: Literal[True]) -> np.ndarray[tuple[Literal[768], Literal[1_024]], np.dtype[np.uint8]]: ...
def face(gray: L[False, 0] = False) -> Face3Dataset: ...
16 changes: 12 additions & 4 deletions scipy-stubs/datasets/_registry.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from typing import Final
from typing import Final, Literal as L, TypedDict, type_check_only
from typing_extensions import LiteralString

registry: Final[dict[LiteralString, LiteralString]]
registry_urls: Final[dict[LiteralString, LiteralString]]
method_files_map: Final[dict[LiteralString, list[LiteralString]]]
@type_check_only
class _MethodRegistry(TypedDict):
ascent: list[L["ascent.dat"]]
electrocardiogram: list[L["ecg.dat"]]
face: list[L["face.dat"]]

_DataRegistry = TypedDict("_DataRegistry", {"ascent.dat": LiteralString, "ecg.dat": LiteralString, "face.dat": LiteralString})

registry: Final[_DataRegistry] = ...
registry_urls: Final[_DataRegistry] = ...
method_files_map: Final[_MethodRegistry] = ...
31 changes: 31 additions & 0 deletions scipy-stubs/datasets/_typing.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from collections.abc import Callable
from typing import Literal as L, Protocol, TypeAlias, overload, type_check_only
from typing_extensions import LiteralString

import numpy as np
import optype.numpy as onp

__all__ = "AscentDataset", "CanFetch", "Dataset", "ECGDataset", "Face2Dataset", "Face3Dataset", "Fetcher"

@type_check_only
class CanFetch(Protocol):
def fetch(self, dataset_name: LiteralString, /) -> LiteralString: ...

AscentDataset: TypeAlias = onp.Array[tuple[L[512], L[512]], np.uint8]
ECGDataset: TypeAlias = onp.Array[tuple[L[108_000]], np.float64]
Face2Dataset: TypeAlias = onp.Array[tuple[L[768], L[1_024]], np.uint8]
Face3Dataset: TypeAlias = onp.Array[tuple[L[768], L[1_024], L[3]], np.uint8]
_FaceDataset: TypeAlias = Face2Dataset | Face3Dataset
Dataset: TypeAlias = AscentDataset | ECGDataset | _FaceDataset

_AscentFetcher: TypeAlias = Callable[[], AscentDataset]
_ECGFetcher: TypeAlias = Callable[[], ECGDataset]

@type_check_only
class _FaceFetcher(Protocol):
@overload
def __call__(self, /, gray: L[True, 1]) -> Face2Dataset: ...
@overload
def __call__(self, /, gray: L[False, 0] = False) -> Face3Dataset: ...

Fetcher: TypeAlias = _AscentFetcher | _ECGFetcher | _FaceFetcher
14 changes: 3 additions & 11 deletions scipy-stubs/datasets/_utils.pyi
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
from collections.abc import Callable
from typing import TypeAlias
from typing_extensions import TypeVar
from ._typing import Fetcher

import numpy as np

_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...], default=tuple[int] | tuple[int, int] | tuple[int, int, int])
_DT = TypeVar("_DT", bound=np.dtype[np.generic], default=np.dtype[np.float64] | np.dtype[np.uint8])

_AnyDataset: TypeAlias = Callable[[], np.ndarray[_ShapeT, _DT]]

def clear_cache(datasets: list[_AnyDataset] | tuple[_AnyDataset, ...] | None = None) -> None: ...
# NOTE: the implementation explcitily checks for `list` and `tuple` types
def clear_cache(datasets: Fetcher | list[Fetcher] | tuple[Fetcher, ...] | None = None) -> None: ...

0 comments on commit 1cc893b

Please sign in to comment.