Skip to content

Commit

Permalink
cleaner typing in scipy.stats.qmc
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Nov 8, 2024
1 parent 1eacdc0 commit e27f9a0
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 84 deletions.
109 changes: 63 additions & 46 deletions scipy-stubs/stats/_qmc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import abc
import numbers
from collections.abc import Callable, Mapping, Sequence
from typing import Any, ClassVar, Final, Literal, Protocol, TypeAlias, overload, type_check_only
from typing_extensions import Self, TypeVar
from typing_extensions import Self, TypeVar, override

import numpy as np
import numpy.typing as npt
Expand All @@ -28,30 +28,31 @@ __all__ = [
]

_RNGT = TypeVar("_RNGT", bound=np.random.Generator | np.random.RandomState)
_SCT = TypeVar("_SCT", bound=np.generic)
_SCT0 = TypeVar("_SCT0", bound=np.generic, default=np.float64)
_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic)
_SCT_fc = TypeVar("_SCT_fc", bound=np.inexact[Any])
_ArrayT_f = TypeVar("_ArrayT_f", bound=npt.NDArray[np.floating[Any]])
_N = TypeVar("_N", bound=int)

# the `__len__` ensures that scalar types like `np.generic` are excluded
@type_check_only
class _CanLenArray(Protocol[_SCT_co]):
def __len__(self, /) -> int: ...
def __array__(self, /) -> npt.NDArray[_SCT_co]: ...

_Scalar_f_co: TypeAlias = np.floating[Any] | np.integer[Any] | np.bool_
_ScalarLike_f: TypeAlias = float | np.floating[Any]

_Array1D: TypeAlias = onpt.Array[tuple[int], _SCT]
_Array1D_f8: TypeAlias = _Array1D[np.float64]
_Array2D: TypeAlias = onpt.Array[tuple[int, int], _SCT]
_Array2D_f8: TypeAlias = _Array2D[np.float64]
_Array1D: TypeAlias = onpt.Array[tuple[int], _SCT0]
_Array2D: TypeAlias = onpt.Array[tuple[int, int], _SCT0]
_Array1D_f_co: TypeAlias = _Array1D[_Scalar_f_co]

_Any1D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[float | np.floating[Any]]
_Any1D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[_ScalarLike_f]
_Any1D_f_co: TypeAlias = _CanLenArray[_Scalar_f_co] | Sequence[AnyReal]
_Any2D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[Sequence[float | np.floating[Any]]] | Sequence[_Any1D_f]
_Any2D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[Sequence[_ScalarLike_f]] | Sequence[_Any1D_f]
_Any2D_f_co: TypeAlias = _CanLenArray[_Scalar_f_co] | Sequence[Sequence[AnyReal]] | Sequence[_Any1D_f_co]

_MethodOptimize: TypeAlias = Literal["random-cd", "lloyd"]
_MethodQMC: TypeAlias = Literal["random-cd", "lloyd"]
_MethodDisc: TypeAlias = Literal["CD", "WD", "MD", "L2-star"]
_MethodDist: TypeAlias = Literal["mindist", "mst"]
_MetricDist: TypeAlias = _MetricKind | _MetricCallback
Expand All @@ -68,8 +69,8 @@ class QMCEngine(abc.ABC):
num_generated: int

@abc.abstractmethod
def __init__(self, /, d: AnyInt, *, optimization: _MethodOptimize | None = None, seed: Seed | None = None) -> None: ...
def random(self, /, n: opt.AnyInt = 1, *, workers: AnyInt = 1) -> _Array2D_f8: ...
def __init__(self, /, d: AnyInt, *, optimization: _MethodQMC | None = None, seed: Seed | None = None) -> None: ...
def random(self, /, n: opt.AnyInt = 1, *, workers: AnyInt = 1) -> _Array2D: ...
def integers(
self,
/,
Expand All @@ -93,13 +94,13 @@ class Halton(QMCEngine):
d: AnyInt,
*,
scramble: bool = True,
optimization: _MethodOptimize | None = None,
optimization: _MethodQMC | None = None,
seed: Seed | None = None,
) -> None: ...

class LatinHypercube(QMCEngine):
scramble: bool
lhs_method: Callable[[int | np.integer[Any]], _Array2D_f8]
lhs_method: Callable[[int | np.integer[Any]], _Array2D]

def __init__(
self,
Expand All @@ -108,7 +109,7 @@ class LatinHypercube(QMCEngine):
*,
scramble: bool = True,
strength: int = 1,
optimization: _MethodOptimize | None = None,
optimization: _MethodQMC | None = None,
seed: Seed | None = None,
) -> None: ...

Expand All @@ -126,10 +127,10 @@ class Sobol(QMCEngine):
*,
scramble: op.CanBool = True,
bits: AnyInt | None = None,
optimization: _MethodQMC | None = None,
seed: Seed | None = None,
optimization: _MethodOptimize | None = None,
) -> None: ...
def random_base2(self, /, m: AnyInt) -> _Array2D_f8: ...
def random_base2(self, /, m: AnyInt) -> _Array2D: ...

@type_check_only
class _HypersphereMethod(Protocol):
Expand All @@ -139,7 +140,7 @@ class _HypersphereMethod(Protocol):
center: npt.NDArray[_Scalar_f_co],
radius: AnyReal,
candidates: AnyInt = 1,
) -> _Array2D_f8: ...
) -> _Array2D: ...

class PoissonDisk(QMCEngine):
hypersphere_method: Final[_HypersphereMethod]
Expand All @@ -150,7 +151,7 @@ class PoissonDisk(QMCEngine):
cell_size: Final[np.float64]
grid_size: Final[_Array1D[np.int_]]

sample_pool: list[_Array1D_f8]
sample_pool: list[_Array1D]
sample_grid: npt.NDArray[np.float32]

def __init__(
Expand All @@ -161,13 +162,19 @@ class PoissonDisk(QMCEngine):
radius: AnyReal = 0.05,
hypersphere: Literal["volume", "surface"] = "volume",
ncandidates: AnyInt = 30,
optimization: _MethodOptimize | None = None,
optimization: _MethodQMC | None = None,
seed: Seed | None = None,
) -> None: ...
def fill_space(self, /) -> _Array2D_f8: ...
def fill_space(self, /) -> _Array2D: ...

class MultivariateNormalQMC:
engine: Final[QMCEngine]
@type_check_only
class _QMCDistribution:
engine: Final[QMCEngine] # defaults to `Sobol`
def __init__(self, /, *, engine: QMCEngine | None = None, seed: Seed | None = None) -> None: ...
def random(self, /, n: AnyInt = 1) -> _Array2D: ...

class MultivariateNormalQMC(_QMCDistribution):
@override
def __init__(
self,
/,
Expand All @@ -179,13 +186,12 @@ class MultivariateNormalQMC:
engine: QMCEngine | None = None,
seed: Seed | None = None,
) -> None: ...
def random(self, /, n: AnyInt = 1) -> _Array2D_f8: ...

class MultinomialQMC:
class MultinomialQMC(_QMCDistribution):
pvals: Final[_Array1D[np.floating[Any]]]
n_trials: Final[AnyInt]
engine: Final[QMCEngine]

@override
def __init__(
self,
/,
Expand All @@ -195,55 +201,64 @@ class MultinomialQMC:
engine: QMCEngine | None = None,
seed: Seed | None = None,
) -> None: ...
def random(self, /, n: AnyInt = 1) -> _Array2D_f8: ...

#
@overload
def check_random_state(seed: int | np.integer[Any] | numbers.Integral | None = None) -> np.random.Generator: ...
@overload
def check_random_state(seed: _RNGT) -> _RNGT: ...

#
def scale(
sample: _Any2D_f,
l_bounds: _Any1D_f_co | AnyReal,
u_bounds: _Any1D_f_co | AnyReal,
*,
reverse: op.CanBool = False,
) -> _Array2D_f8: ...
) -> _Array2D: ...

#
def discrepancy(
sample: _Any2D_f,
*,
iterative: op.CanBool = False,
method: _MethodDisc = "CD",
workers: op.CanInt = 1,
) -> float: ...
def geometric_discrepancy(sample: _Any2D_f, method: _MethodDist = "mindist", metric: _MetricDist = "euclidean") -> np.float64: ...
def update_discrepancy(x_new: _Any1D_f, sample: _Any2D_f, initial_disc: op.CanFloat) -> float: ...
workers: opt.AnyInt = 1,
) -> float | np.float64: ...

#
def geometric_discrepancy(
sample: _Any2D_f,
method: _MethodDist = "mindist",
metric: _MetricDist = "euclidean",
) -> float | np.float64: ...
def update_discrepancy(x_new: _Any1D_f, sample: _Any2D_f, initial_disc: opt.AnyFloat) -> float: ...
def primes_from_2_to(n: AnyInt) -> _Array1D[np.int_]: ...
def n_primes(n: AnyInt) -> list[int] | _Array1D[np.int_]: ...

#
def _select_optimizer(optimization: _MethodOptimize | None, config: Mapping[str, object]) -> _FuncOptimize | None: ...
def _select_optimizer(optimization: _MethodQMC | None, config: Mapping[str, object]) -> _FuncOptimize | None: ...
def _random_cd(best_sample: _ArrayT_f, n_iters: AnyInt, n_nochange: AnyInt, rng: RNG) -> _ArrayT_f: ...
def _l1_norm(sample: _Any2D_f) -> np.float64: ...
def _l1_norm(sample: _Any2D_f) -> float | np.float64: ...
def _lloyd_iteration(sample: _ArrayT_f, decay: AnyReal, qhull_options: str | None) -> _ArrayT_f: ...
def _lloyd_centroidal_voronoi_tessellation(
sample: _Any2D_f,
*,
tol: AnyReal = 1e-5,
maxiter: AnyInt = 10,
qhull_options: str | None = None,
) -> _Array2D_f8: ...
) -> _Array2D: ...
def _ensure_in_unit_hypercube(sample: _Any2D_f) -> _Array2D: ...

#
def _ensure_in_unit_hypercube(sample: _Any2D_f) -> _Array2D_f8: ...
@overload
def _perturb_discrepancy(
sample: _Array2D[np.integer[Any] | np.bool_],
i1: op.CanIndex,
i2: op.CanIndex,
k: op.CanIndex,
disc: AnyReal,
) -> np.float64: ...
) -> float | np.float64: ...
@overload
def _perturb_discrepancy(
sample: _Array2D[_SCT_fc],
Expand All @@ -252,10 +267,14 @@ def _perturb_discrepancy(
k: op.CanIndex,
disc: AnyReal,
) -> _SCT_fc: ...

#
@overload
def _van_der_corput_permutation(base: op.CanIndex, *, random_state: Seed | None = None) -> _Array2D[np.int_]: ...
@overload
def _van_der_corput_permutation(base: op.CanFloat, *, random_state: Seed | None = None) -> _Array2D_f8: ...
def _van_der_corput_permutation(base: op.CanFloat, *, random_state: Seed | None = None) -> _Array2D: ...

#
def van_der_corput(
n: op.CanInt,
base: AnyInt = 2,
Expand All @@ -264,18 +283,16 @@ def van_der_corput(
scramble: op.CanBool = False,
permutations: _ArrayLikeInt | None = None,
seed: Seed | None = None,
workers: op.CanInt = 1,
) -> _Array1D_f8: ...
workers: opt.AnyInt = 1,
) -> _Array1D: ...

#
@overload
def _validate_workers(workers: op.CanInt[Literal[1]] | op.CanIndex[Literal[1]] | Literal[1] = 1) -> Literal[1]: ...
def _validate_workers(workers: Literal[1] = 1) -> Literal[1]: ...
@overload
def _validate_workers(workers: _N) -> _N: ...
@overload
def _validate_workers(workers: op.CanInt[_N] | op.CanIndex[_N]) -> _N: ...
def _validate_bounds(
l_bounds: _Any1D_f_co,
u_bounds: _Any1D_f_co,
d: AnyInt,
) -> tuple[_Array1D[_Scalar_f_co], _Array1D[_Scalar_f_co]]: ...
def _validate_workers(workers: opt.AnyInt[_N]) -> _N: ...

#
def _validate_bounds(l_bounds: _Any1D_f_co, u_bounds: _Any1D_f_co, d: AnyInt) -> tuple[_Array1D_f_co, _Array1D_f_co]: ...
58 changes: 20 additions & 38 deletions scipy-stubs/stats/_qmc_cy.pyi
Original file line number Diff line number Diff line change
@@ -1,42 +1,24 @@
from typing import TypeAlias

import numpy as np
import optype as op
import optype.numpy as onpt
from optype import CanBool, CanFloat, CanInt
import optype.typing as opt

_Vector_i8: TypeAlias = onpt.Array[tuple[int, int], np.int64]
_Vector_f8: TypeAlias = onpt.Array[tuple[int], np.float64]
_Matrix_f8: TypeAlias = onpt.Array[tuple[int, int], np.float64]

def _cy_wrapper_centered_discrepancy(
sample: onpt.Array[tuple[int, int], np.float64],
iterative: CanBool,
workers: CanInt,
) -> float: ...
def _cy_wrapper_wrap_around_discrepancy(
sample: onpt.Array[tuple[int, int], np.float64],
iterative: CanBool,
workers: CanInt,
) -> float: ...
def _cy_wrapper_mixture_discrepancy(
sample: onpt.Array[tuple[int, int], np.float64],
iterative: CanBool,
workers: CanInt,
) -> float: ...
def _cy_wrapper_l2_star_discrepancy(
sample: onpt.Array[tuple[int, int], np.float64],
iterative: CanBool,
workers: CanInt,
) -> float: ...
def _cy_wrapper_update_discrepancy(
x_new_view: onpt.Array[tuple[int], np.float64],
sample_view: onpt.Array[tuple[int, int], np.float64],
initial_disc: CanFloat,
) -> float: ...
def _cy_van_der_corput(
n: CanInt,
base: CanInt,
start_index: CanInt,
workers: CanInt,
) -> onpt.Array[tuple[int], np.float64]: ...
def _cy_wrapper_centered_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
def _cy_wrapper_wrap_around_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
def _cy_wrapper_mixture_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
def _cy_wrapper_l2_star_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
def _cy_wrapper_update_discrepancy(x_new_view: _Vector_f8, sample_view: _Matrix_f8, initial_disc: opt.AnyFloat) -> float: ...
def _cy_van_der_corput(n: opt.AnyInt, base: opt.AnyInt, start_index: opt.AnyInt, workers: opt.AnyInt) -> _Vector_f8: ...
def _cy_van_der_corput_scrambled(
n: CanInt,
base: CanInt,
start_index: CanInt,
permutations: onpt.Array[tuple[int, int], np.int64],
workers: CanInt,
) -> onpt.Array[tuple[int], np.float64]: ...
n: opt.AnyInt,
base: opt.AnyInt,
start_index: opt.AnyInt,
permutations: _Vector_i8,
workers: opt.AnyInt,
) -> _Vector_f8: ...

0 comments on commit e27f9a0

Please sign in to comment.