Skip to content

Commit

Permalink
sparse.linalg: complete _dsolve (#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham authored Dec 11, 2024
1 parent e8e3700 commit f81e1b5
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 65 deletions.
78 changes: 68 additions & 10 deletions scipy-stubs/sparse/linalg/_dsolve/_superlu.pyi
Original file line number Diff line number Diff line change
@@ -1,21 +1,79 @@
from typing import Any, final
from collections.abc import Callable, Mapping
from typing import Any, Literal, TypeAlias, final, overload

import numpy as np
import optype as op
import optype.numpy as onp
from scipy._typing import Untyped
from scipy.sparse import csc_matrix
from scipy.sparse import csc_matrix, csr_matrix

_Int1D: TypeAlias = onp.Array1D[np.int32]
_Float1D: TypeAlias = onp.Array1D[np.float64]
_Float2D: TypeAlias = onp.Array2D[np.float64]
_Complex1D: TypeAlias = onp.Array1D[np.complex128]
_Complex2D: TypeAlias = onp.Array2D[np.complex128]
_Inexact2D: TypeAlias = onp.Array2D[np.float32 | np.float64 | np.complex64 | np.complex128]

###

@final
class SuperLU:
L: csc_matrix
U: csc_matrix
shape: tuple[int, int]
nnz: int
perm_r: onp.Array1D[np.intp]
perm_c: onp.Array1D[np.intp]
shape: tuple[int, ...]
L: csc_matrix
U: csc_matrix

@overload
def solve(self, /, rhs: onp.Array1D[np.integer[Any] | np.floating[Any]]) -> _Float1D: ...
@overload
def solve(self, /, rhs: onp.Array1D[np.complexfloating[Any, Any]]) -> _Complex1D: ...
@overload
def solve(self, /, rhs: onp.Array2D[np.integer[Any] | np.floating[Any]]) -> _Float2D: ...
@overload
def solve(self, /, rhs: onp.Array2D[np.complexfloating[Any, Any]]) -> _Complex2D: ...
@overload
def solve(self, /, rhs: onp.ArrayND[np.integer[Any] | np.floating[Any]]) -> _Float1D | _Float2D: ...
@overload
def solve(self, /, rhs: onp.ArrayND[np.complexfloating[Any, Any]]) -> _Complex1D | _Complex2D: ...
@overload
def solve(self, /, rhs: onp.ArrayND[np.number[Any]]) -> _Float1D | _Complex1D | _Float2D | _Complex2D: ...

def gssv(
N: op.CanIndex,
nnz: op.CanIndex,
nzvals: _Inexact2D,
colind: _Int1D,
rowptr: _Int1D,
B: _Inexact2D,
csc: onp.ToBool = 0,
options: Mapping[str, object] = ...,
) -> tuple[csc_matrix | csr_matrix, int]: ...

def solve(self, /, rhs: onp.ArrayND[np.number[Any]]) -> onp.ArrayND[np.number[Any]]: ...
#
def gstrf(
N: op.CanIndex,
nnz: op.CanIndex,
nzvals: _Inexact2D,
colind: _Int1D,
rowptr: _Int1D,
csc_construct_func: type[csc_matrix] | Callable[..., csc_matrix],
ilu: onp.ToBool = 0,
options: Mapping[str, object] = ...,
) -> SuperLU: ...

def gssv(*args: Untyped, **kwargs: Untyped) -> Untyped: ...
def gstrf(*args: Untyped, **kwargs: Untyped) -> Untyped: ...
def gstrs(*args: Untyped, **kwargs: Untyped) -> Untyped: ...
#
def gstrs(
trans: Literal["N", "T"],
L_n: op.CanIndex,
L_nnz: op.CanIndex,
L_nzvals: _Inexact2D,
L_rowind: _Int1D,
L_colptr: _Int1D,
U_n: op.CanIndex,
U_nnz: op.CanIndex,
U_nzvals: _Inexact2D,
U_rowind: _Int1D,
U_colptr: _Int1D,
B: _Inexact2D,
) -> tuple[_Float1D | _Complex1D | _Float2D | _Complex2D, int]: ...
104 changes: 79 additions & 25 deletions scipy-stubs/sparse/linalg/_dsolve/linsolve.pyi
Original file line number Diff line number Diff line change
@@ -1,36 +1,90 @@
from scipy._typing import Untyped
from collections.abc import Mapping
from typing import Any, Literal, Protocol, TypeAlias, TypeVar, overload, type_check_only

import numpy as np
import optype.numpy as onp
from scipy.sparse._base import _spbase
from ._superlu import SuperLU

__all__ = ["MatrixRankWarning", "factorized", "spilu", "splu", "spsolve", "spsolve_triangular", "use_solver"]

_SparseT = TypeVar("_SparseT", bound=_spbase)

_PermcSpec: TypeAlias = Literal["COLAMD", "NATURAL", "MMD_ATA", "MMD_AT_PLUS_A"]
_Float1D: TypeAlias = onp.Array1D[np.float64]
_Float2D: TypeAlias = onp.Array2D[np.float64]
_Complex1D: TypeAlias = onp.Array1D[np.complex128]
_Complex2D: TypeAlias = onp.Array2D[np.complex128]

@type_check_only
class _Solve(Protocol):
@overload
def __call__(self, b: onp.Array1D[np.integer[Any] | np.floating[Any]], /) -> _Float1D: ...
@overload
def __call__(self, b: onp.Array1D[np.complexfloating[Any, Any]], /) -> _Complex1D: ...
@overload
def __call__(self, b: onp.Array2D[np.integer[Any] | np.floating[Any]], /) -> _Float2D: ...
@overload
def __call__(self, b: onp.Array2D[np.complexfloating[Any, Any]], /) -> _Complex2D: ...
@overload
def __call__(self, b: onp.ArrayND[np.integer[Any] | np.floating[Any]], /) -> _Float1D | _Float2D: ...
@overload
def __call__(self, b: onp.ArrayND[np.complexfloating[Any, Any]], /) -> _Complex1D | _Complex2D: ...
@overload
def __call__(self, b: onp.ArrayND[np.number[Any]], /) -> _Float1D | _Complex1D | _Float2D | _Complex2D: ...

###

class MatrixRankWarning(UserWarning): ...

def use_solver(*, useUmfpack: bool = ..., assumeSortedIndices: bool = ...) -> None: ...
def spsolve(A: Untyped, b: Untyped, permc_spec: Untyped | None = None, use_umfpack: bool = True) -> Untyped: ...
def splu(
A: Untyped,
permc_spec: Untyped | None = None,
diag_pivot_thresh: Untyped | None = None,
relax: Untyped | None = None,
panel_size: Untyped | None = None,
options: Untyped | None = {},
) -> Untyped: ...
def spilu(
A: Untyped,
drop_tol: Untyped | None = None,
fill_factor: Untyped | None = None,
drop_rule: Untyped | None = None,
permc_spec: Untyped | None = None,
diag_pivot_thresh: Untyped | None = None,
relax: Untyped | None = None,
panel_size: Untyped | None = None,
options: Untyped | None = None,
) -> Untyped: ...
def factorized(A: Untyped) -> Untyped: ...
def factorized(A: _spbase | onp.ToComplex2D) -> _Solve: ...

#
@overload
def spsolve(
A: _spbase | onp.ToComplex2D,
b: _SparseT,
permc_spec: _PermcSpec | None = None,
use_umfpack: bool = True,
) -> _SparseT: ...
@overload
def spsolve(
A: _spbase | onp.ToComplex2D,
b: onp.ToComplex2D | onp.ToComplex1D,
permc_spec: _PermcSpec | None = None,
use_umfpack: bool = True,
) -> _Float1D | _Complex1D | _Float2D | _Complex2D: ...

#
def spsolve_triangular(
A: Untyped,
b: Untyped,
A: _spbase | onp.ToComplex2D,
b: _spbase | onp.ToComplex2D | onp.ToComplex1D,
lower: bool = True,
overwrite_A: bool = False,
overwrite_b: bool = False,
unit_diagonal: bool = False,
) -> Untyped: ...
) -> _Float1D | _Complex1D | _Float2D | _Complex2D: ...

#
def splu(
A: _spbase | onp.ToComplex2D,
permc_spec: _PermcSpec | None = None,
diag_pivot_thresh: onp.ToFloat | None = None,
relax: int | None = None,
panel_size: int | None = None,
options: Mapping[str, object] | None = {},
) -> SuperLU: ...

#
def spilu(
A: _spbase | onp.ToComplex2D,
drop_tol: onp.ToFloat | None = None,
fill_factor: onp.ToFloat | None = None,
drop_rule: str | None = None,
permc_spec: _PermcSpec | None = None,
diag_pivot_thresh: onp.ToFloat | None = None,
relax: int | None = None,
panel_size: int | None = None,
options: Mapping[str, object] | None = None,
) -> SuperLU: ...
30 changes: 0 additions & 30 deletions scipy-stubs/sparse/linalg/_svdp.pyi

This file was deleted.

0 comments on commit f81e1b5

Please sign in to comment.