Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

datasets: Improved type-hints with TypedDicts and Literals #203

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scipy-stubs/datasets/_download_all.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from os import PathLike

def download_all(path: str | PathLike[str] | None = None) -> None: ...
def main() -> None: ...
def main() -> None: ... # undocumented
25 changes: 15 additions & 10 deletions scipy-stubs/datasets/_fetchers.pyi
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from typing import Final, Literal, TypeAlias, overload
from typing import Final, Literal as L, overload
from typing_extensions import LiteralString

import numpy as np
from ._typing import AscentDataset, CanFetch, ECGDataset, Face2Dataset, Face3Dataset

# TODO: stub `pooch` (this should be a `pooch.code.Pooch`)
_DataFetcher: TypeAlias = object
data_fetcher: Final[_DataFetcher]
###

def fetch_data(dataset_name: LiteralString, data_fetcher: _DataFetcher = ...) -> LiteralString: ...
def ascent() -> np.ndarray[tuple[Literal[512], Literal[512]], np.dtype[np.uint8]]: ...
def electrocardiogram() -> np.ndarray[tuple[Literal[108_000]], np.dtype[np.float64]]: ...
data_fetcher: Final[CanFetch | None] = ... # undocumented

def fetch_data(
dataset_name: L["ascent.dat", "ecg.dat", "face.dat"],
data_fetcher: CanFetch | None = None,
) -> LiteralString: ... # undocumented

#
def ascent() -> AscentDataset: ...
def electrocardiogram() -> ECGDataset: ...
@overload
def face(gray: Literal[False] = False) -> np.ndarray[tuple[Literal[768], Literal[1_024], Literal[3]], np.dtype[np.uint8]]: ...
def face(gray: L[True, 1]) -> Face2Dataset: ...
@overload
def face(gray: Literal[True]) -> np.ndarray[tuple[Literal[768], Literal[1_024]], np.dtype[np.uint8]]: ...
def face(gray: L[False, 0] = False) -> Face3Dataset: ...
16 changes: 12 additions & 4 deletions scipy-stubs/datasets/_registry.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from typing import Final
from typing import Final, Literal as L, TypedDict, type_check_only
from typing_extensions import LiteralString

registry: Final[dict[LiteralString, LiteralString]]
registry_urls: Final[dict[LiteralString, LiteralString]]
method_files_map: Final[dict[LiteralString, list[LiteralString]]]
@type_check_only
class _MethodRegistry(TypedDict):
ascent: list[L["ascent.dat"]]
electrocardiogram: list[L["ecg.dat"]]
face: list[L["face.dat"]]

_DataRegistry = TypedDict("_DataRegistry", {"ascent.dat": LiteralString, "ecg.dat": LiteralString, "face.dat": LiteralString})

registry: Final[_DataRegistry] = ...
registry_urls: Final[_DataRegistry] = ...
method_files_map: Final[_MethodRegistry] = ...
31 changes: 31 additions & 0 deletions scipy-stubs/datasets/_typing.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from collections.abc import Callable
from typing import Literal as L, Protocol, TypeAlias, overload, type_check_only
from typing_extensions import LiteralString

import numpy as np
import optype.numpy as onp

__all__ = "AscentDataset", "CanFetch", "Dataset", "ECGDataset", "Face2Dataset", "Face3Dataset", "Fetcher"

@type_check_only
class CanFetch(Protocol):
def fetch(self, dataset_name: LiteralString, /) -> LiteralString: ...

AscentDataset: TypeAlias = onp.Array[tuple[L[512], L[512]], np.uint8]
ECGDataset: TypeAlias = onp.Array[tuple[L[108_000]], np.float64]
Face2Dataset: TypeAlias = onp.Array[tuple[L[768], L[1_024]], np.uint8]
Face3Dataset: TypeAlias = onp.Array[tuple[L[768], L[1_024], L[3]], np.uint8]
_FaceDataset: TypeAlias = Face2Dataset | Face3Dataset
Dataset: TypeAlias = AscentDataset | ECGDataset | _FaceDataset

_AscentFetcher: TypeAlias = Callable[[], AscentDataset]
_ECGFetcher: TypeAlias = Callable[[], ECGDataset]

@type_check_only
class _FaceFetcher(Protocol):
@overload
def __call__(self, /, gray: L[True, 1]) -> Face2Dataset: ...
@overload
def __call__(self, /, gray: L[False, 0] = False) -> Face3Dataset: ...

Fetcher: TypeAlias = _AscentFetcher | _ECGFetcher | _FaceFetcher
14 changes: 3 additions & 11 deletions scipy-stubs/datasets/_utils.pyi
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
from collections.abc import Callable
from typing import TypeAlias
from typing_extensions import TypeVar
from ._typing import Fetcher

import numpy as np

_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...], default=tuple[int] | tuple[int, int] | tuple[int, int, int])
_DT = TypeVar("_DT", bound=np.dtype[np.generic], default=np.dtype[np.float64] | np.dtype[np.uint8])

_AnyDataset: TypeAlias = Callable[[], np.ndarray[_ShapeT, _DT]]

def clear_cache(datasets: list[_AnyDataset] | tuple[_AnyDataset, ...] | None = None) -> None: ...
# NOTE: the implementation explcitily checks for `list` and `tuple` types
def clear_cache(datasets: Fetcher | list[Fetcher] | tuple[Fetcher, ...] | None = None) -> None: ...