From 3dd0b4c4b78cd829bb82e330697ce1b0a2caa90f Mon Sep 17 00:00:00 2001 From: jorenham Date: Tue, 26 Nov 2024 19:25:21 +0100 Subject: [PATCH] `datasets`: Improved type-hints with `TypedDict`s and `Literal`s --- scipy-stubs/datasets/_download_all.pyi | 2 +- scipy-stubs/datasets/_fetchers.pyi | 25 ++++++++++++--------- scipy-stubs/datasets/_registry.pyi | 16 +++++++++---- scipy-stubs/datasets/_typing.pyi | 31 ++++++++++++++++++++++++++ scipy-stubs/datasets/_utils.pyi | 14 +++--------- 5 files changed, 62 insertions(+), 26 deletions(-) create mode 100644 scipy-stubs/datasets/_typing.pyi diff --git a/scipy-stubs/datasets/_download_all.pyi b/scipy-stubs/datasets/_download_all.pyi index bde6e799..c82de2f2 100644 --- a/scipy-stubs/datasets/_download_all.pyi +++ b/scipy-stubs/datasets/_download_all.pyi @@ -1,4 +1,4 @@ from os import PathLike def download_all(path: str | PathLike[str] | None = None) -> None: ... -def main() -> None: ... +def main() -> None: ... # undocumented diff --git a/scipy-stubs/datasets/_fetchers.pyi b/scipy-stubs/datasets/_fetchers.pyi index 2b85158d..ec8cdf36 100644 --- a/scipy-stubs/datasets/_fetchers.pyi +++ b/scipy-stubs/datasets/_fetchers.pyi @@ -1,16 +1,21 @@ -from typing import Final, Literal, TypeAlias, overload +from typing import Final, Literal as L, overload from typing_extensions import LiteralString -import numpy as np +from ._typing import AscentDataset, CanFetch, ECGDataset, Face2Dataset, Face3Dataset -# TODO: stub `pooch` (this should be a `pooch.code.Pooch`) -_DataFetcher: TypeAlias = object -data_fetcher: Final[_DataFetcher] +### -def fetch_data(dataset_name: LiteralString, data_fetcher: _DataFetcher = ...) -> LiteralString: ... -def ascent() -> np.ndarray[tuple[Literal[512], Literal[512]], np.dtype[np.uint8]]: ... -def electrocardiogram() -> np.ndarray[tuple[Literal[108_000]], np.dtype[np.float64]]: ... +data_fetcher: Final[CanFetch | None] = ... # undocumented + +def fetch_data( + dataset_name: L["ascent.dat", "ecg.dat", "face.dat"], + data_fetcher: CanFetch | None = None, +) -> LiteralString: ... # undocumented + +# +def ascent() -> AscentDataset: ... +def electrocardiogram() -> ECGDataset: ... @overload -def face(gray: Literal[False] = False) -> np.ndarray[tuple[Literal[768], Literal[1_024], Literal[3]], np.dtype[np.uint8]]: ... +def face(gray: L[True, 1]) -> Face2Dataset: ... @overload -def face(gray: Literal[True]) -> np.ndarray[tuple[Literal[768], Literal[1_024]], np.dtype[np.uint8]]: ... +def face(gray: L[False, 0] = False) -> Face3Dataset: ... diff --git a/scipy-stubs/datasets/_registry.pyi b/scipy-stubs/datasets/_registry.pyi index 177f68ad..041535e6 100644 --- a/scipy-stubs/datasets/_registry.pyi +++ b/scipy-stubs/datasets/_registry.pyi @@ -1,6 +1,14 @@ -from typing import Final +from typing import Final, Literal as L, TypedDict, type_check_only from typing_extensions import LiteralString -registry: Final[dict[LiteralString, LiteralString]] -registry_urls: Final[dict[LiteralString, LiteralString]] -method_files_map: Final[dict[LiteralString, list[LiteralString]]] +@type_check_only +class _MethodRegistry(TypedDict): + ascent: list[L["ascent.dat"]] + electrocardiogram: list[L["ecg.dat"]] + face: list[L["face.dat"]] + +_DataRegistry = TypedDict("_DataRegistry", {"ascent.dat": LiteralString, "ecg.dat": LiteralString, "face.dat": LiteralString}) + +registry: Final[_DataRegistry] = ... +registry_urls: Final[_DataRegistry] = ... +method_files_map: Final[_MethodRegistry] = ... diff --git a/scipy-stubs/datasets/_typing.pyi b/scipy-stubs/datasets/_typing.pyi new file mode 100644 index 00000000..11b99931 --- /dev/null +++ b/scipy-stubs/datasets/_typing.pyi @@ -0,0 +1,31 @@ +from collections.abc import Callable +from typing import Literal as L, Protocol, TypeAlias, overload, type_check_only +from typing_extensions import LiteralString + +import numpy as np +import optype.numpy as onp + +__all__ = "AscentDataset", "CanFetch", "Dataset", "ECGDataset", "Face2Dataset", "Face3Dataset", "Fetcher" + +@type_check_only +class CanFetch(Protocol): + def fetch(self, dataset_name: LiteralString, /) -> LiteralString: ... + +AscentDataset: TypeAlias = onp.Array[tuple[L[512], L[512]], np.uint8] +ECGDataset: TypeAlias = onp.Array[tuple[L[108_000]], np.float64] +Face2Dataset: TypeAlias = onp.Array[tuple[L[768], L[1_024]], np.uint8] +Face3Dataset: TypeAlias = onp.Array[tuple[L[768], L[1_024], L[3]], np.uint8] +_FaceDataset: TypeAlias = Face2Dataset | Face3Dataset +Dataset: TypeAlias = AscentDataset | ECGDataset | _FaceDataset + +_AscentFetcher: TypeAlias = Callable[[], AscentDataset] +_ECGFetcher: TypeAlias = Callable[[], ECGDataset] + +@type_check_only +class _FaceFetcher(Protocol): + @overload + def __call__(self, /, gray: L[True, 1]) -> Face2Dataset: ... + @overload + def __call__(self, /, gray: L[False, 0] = False) -> Face3Dataset: ... + +Fetcher: TypeAlias = _AscentFetcher | _ECGFetcher | _FaceFetcher diff --git a/scipy-stubs/datasets/_utils.pyi b/scipy-stubs/datasets/_utils.pyi index eb15ee20..3b7d9cbe 100644 --- a/scipy-stubs/datasets/_utils.pyi +++ b/scipy-stubs/datasets/_utils.pyi @@ -1,12 +1,4 @@ -from collections.abc import Callable -from typing import TypeAlias -from typing_extensions import TypeVar +from ._typing import Fetcher -import numpy as np - -_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...], default=tuple[int] | tuple[int, int] | tuple[int, int, int]) -_DT = TypeVar("_DT", bound=np.dtype[np.generic], default=np.dtype[np.float64] | np.dtype[np.uint8]) - -_AnyDataset: TypeAlias = Callable[[], np.ndarray[_ShapeT, _DT]] - -def clear_cache(datasets: list[_AnyDataset] | tuple[_AnyDataset, ...] | None = None) -> None: ... +# NOTE: the implementation explcitily checks for `list` and `tuple` types +def clear_cache(datasets: Fetcher | list[Fetcher] | tuple[Fetcher, ...] | None = None) -> None: ...