Skip to content

Commit

Permalink
sparse.linalg: complete LaplacianNd & accept in dtypes in `Line…
Browse files Browse the repository at this point in the history
…arOperator` (#301)
  • Loading branch information
jorenham authored Dec 11, 2024
1 parent 89cf589 commit e8e3700
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 47 deletions.
27 changes: 20 additions & 7 deletions scipy-stubs/sparse/linalg/_interface.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ _NumberT = TypeVar("_NumberT", bound=np.number[Any])
_Matrix: TypeAlias = np.matrix[Any, np.dtype[_NumberT]]

_ToShape: TypeAlias = Iterable[op.CanIndex]
_ToDType: TypeAlias = type[_SCT] | onp.HasDType[np.dtype[_SCT]] | np.dtype[_SCT]
_ToDType: TypeAlias = type[_SCT] | np.dtype[_SCT] | onp.HasDType[np.dtype[_SCT]]

_JustFloat: TypeAlias = opt.Just[float]
_JustComplex: TypeAlias = opt.Just[complex]

_FunMatVec: TypeAlias = Callable[[onp.Array1D[np.number[Any]] | onp.Array2D[np.number[Any]]], onp.ToComplex1D | onp.ToComplex2D]
_FunMatMat: TypeAlias = Callable[[onp.Array2D[np.number[Any]]], onp.ToComplex2D]

_SCT = TypeVar("_SCT", bound=np.inexact[Any])
_SCT_co = TypeVar("_SCT_co", bound=np.inexact[Any], default=np.inexact[Any], covariant=True)
_SCT1_co = TypeVar("_SCT1_co", bound=np.inexact[Any], default=np.inexact[Any], covariant=True)
_SCT2_co = TypeVar("_SCT2_co", bound=np.inexact[Any], default=_SCT1_co, covariant=True)
_SCT = TypeVar("_SCT", bound=np.number[Any])
_SCT_co = TypeVar("_SCT_co", bound=np.number[Any], default=np.inexact[Any], covariant=True)
_SCT1_co = TypeVar("_SCT1_co", bound=np.number[Any], default=np.inexact[Any], covariant=True)
_SCT2_co = TypeVar("_SCT2_co", bound=np.number[Any], default=_SCT1_co, covariant=True)
_FunMatVecT_co = TypeVar("_FunMatVecT_co", bound=_FunMatVec, default=_FunMatVec, covariant=True)

###
Expand All @@ -57,7 +58,19 @@ class LinearOperator(Generic[_SCT_co]):
@overload
def __init__(self, /, dtype: _ToDType[_SCT_co], shape: _ToShape) -> None: ...
@overload
def __init__(self: LinearOperator[np.float64], /, dtype: onp.AnyFloat64DType | type[float], shape: _ToShape) -> None: ...
def __init__(
self: LinearOperator[np.intp],
/,
dtype: onp.AnyIntPDType | type[opt.JustInt],
shape: _ToShape,
) -> None: ...
@overload
def __init__(
self: LinearOperator[np.float64],
/,
dtype: onp.AnyFloat64DType | type[_JustFloat],
shape: _ToShape,
) -> None: ...
@overload
def __init__(
self: LinearOperator[np.complex128],
Expand Down Expand Up @@ -97,7 +110,7 @@ class LinearOperator(Generic[_SCT_co]):
@overload
def dot(self, /, x: onp.ToFloat) -> _ScaledLinearOperator[_SCT_co]: ...
@overload
def dot(self, /, x: onp.ToComplex) -> _ScaledLinearOperator: ...
def dot(self, /, x: onp.ToComplex) -> _ScaledLinearOperator[_SCT_co | np.complex128]: ...
@overload
def dot(self, /, x: onp.ToFloatStrict1D) -> onp.Array1D[_SCT_co]: ...
@overload
Expand Down
85 changes: 45 additions & 40 deletions scipy-stubs/sparse/linalg/_special_sparse_arrays.pyi
Original file line number Diff line number Diff line change
@@ -1,44 +1,49 @@
from scipy._typing import Untyped
from typing import Any, Final, Generic, Literal, TypeAlias, overload
from typing_extensions import TypeVar

import numpy as np
import optype.numpy as onp
from scipy.sparse import bsr_array, coo_array, csc_array, csr_array, dia_array, dok_array, lil_array
from scipy.sparse.linalg import LinearOperator

__all__ = ["LaplacianNd"]

class LaplacianNd(LinearOperator):
grid_shape: Untyped
boundary_conditions: Untyped

def __init__(self, /, grid_shape: Untyped, *, boundary_conditions: str = "neumann", dtype: Untyped = ...) -> None: ...
def eigenvalues(self, /, m: Untyped | None = None) -> Untyped: ...
def eigenvectors(self, /, m: Untyped | None = None) -> Untyped: ...
def toarray(self, /) -> Untyped: ...
def tosparse(self, /) -> Untyped: ...

class Sakurai(LinearOperator):
n: Untyped
def __init__(self, /, n: Untyped, dtype: Untyped = ...) -> None: ...
def eigenvalues(self, /, m: Untyped | None = None) -> Untyped: ...
def tobanded(self, /) -> Untyped: ...
def tosparse(self, /) -> Untyped: ...
def toarray(self, /) -> Untyped: ...

class MikotaM(LinearOperator):
def __init__(self, /, shape: Untyped, dtype: Untyped = ...) -> None: ...
def tobanded(self, /) -> Untyped: ...
def tosparse(self, /) -> Untyped: ...
def toarray(self, /) -> Untyped: ...

class MikotaK(LinearOperator):
def __init__(self, /, shape: Untyped, dtype: Untyped = ...) -> None: ...
def tobanded(self, /) -> Untyped: ...
def tosparse(self, /) -> Untyped: ...
def toarray(self, /) -> Untyped: ...

class MikotaPair:
n: Untyped
dtype: Untyped
shape: Untyped
m: Untyped
k: Untyped

def __init__(self, /, n: Untyped, dtype: Untyped = ...) -> None: ...
def eigenvalues(self, /, m: Untyped | None = None) -> Untyped: ...
_SCT = TypeVar("_SCT", bound=np.number[Any])
_SCT_co = TypeVar("_SCT_co", bound=np.number[Any], default=np.int8, covariant=True)

_BoundaryConditions: TypeAlias = Literal["dirichlet", "neumann", "periodic"]
_ToDType: TypeAlias = type[_SCT] | np.dtype[_SCT] | onp.HasDType[np.dtype[_SCT]]

# because `scipy.sparse.sparray` does not implement anything :(
_SpArray: TypeAlias = bsr_array | coo_array | csc_array | csr_array | dia_array | dok_array | lil_array

###

class LaplacianNd(LinearOperator[_SCT_co], Generic[_SCT_co]):
grid_shape: Final[onp.AtLeast1D]
boundary_conditions: Final[_BoundaryConditions]

@overload # default dtype (int8)
def __init__(
self: LaplacianNd[np.int8],
/,
grid_shape: onp.AtLeast1D,
*,
boundary_conditions: _BoundaryConditions = "neumann",
dtype: _ToDType[np.int8] = ..., # default: np.int8
) -> None: ...
@overload # know dtype
def __init__(
self,
/,
grid_shape: onp.AtLeast1D,
*,
boundary_conditions: _BoundaryConditions = "neumann",
dtype: _ToDType[_SCT_co] = ..., # default: np.int8
) -> None: ...

#
def eigenvalues(self, /, m: onp.ToJustInt | None = None) -> onp.Array1D[np.float64]: ...
def eigenvectors(self, /, m: onp.ToJustInt | None = None) -> onp.Array2D[np.float64]: ...
def toarray(self, /) -> onp.Array2D[_SCT_co]: ...
def tosparse(self, /) -> _SpArray: ...

0 comments on commit e8e3700

Please sign in to comment.