Skip to content

Commit

Permalink
🎨 sparse: various linalg.LinearOperator tweaks and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Dec 15, 2024
1 parent 8bbae07 commit 40fe461
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 71 deletions.
3 changes: 2 additions & 1 deletion scipy-stubs/sparse/linalg/_expm_multiply.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ from typing import Any, TypeVar, overload
import numpy as np
import optype as op
import optype.numpy as onp
from scipy.sparse._typing import Scalar
from scipy.sparse.linalg._interface import LinearOperator

__all__ = ["expm_multiply"]

_SCT = TypeVar("_SCT", bound=np.inexact[Any])
_SCT = TypeVar("_SCT", bound=Scalar)

###

Expand Down
110 changes: 43 additions & 67 deletions scipy-stubs/sparse/linalg/_interface.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# mypy: disable-error-code="override"
# pyright: reportInconsistentConstructor=false
# pyright: reportIncompatibleMethodOverride=false
# pyright: reportIncompatibleVariableOverride=false
# pyright: reportUnannotatedClassAttribute=false
# pyright: reportOverlappingOverload=false

from collections.abc import Callable, Iterable
from typing import Any, ClassVar, Final, Generic, Literal, Protocol, TypeAlias, final, overload, type_check_only
Expand All @@ -11,28 +9,30 @@ from typing_extensions import Self, TypeVar, override
import numpy as np
import optype as op
import optype.numpy as onp
import optype.typing as opt
from scipy.sparse._base import _spbase
from scipy.sparse._typing import Complex, Float, Int, Matrix, Scalar, ToDTypeComplex, ToDTypeFloat, ToDTypeInt

__all__ = ["LinearOperator", "aslinearoperator"]

_NumberT = TypeVar("_NumberT", bound=np.number[Any])
_Matrix: TypeAlias = np.matrix[Any, np.dtype[_NumberT]]
_Real: TypeAlias = np.bool_ | Int | Float
_Inexact: TypeAlias = Float | Complex
_Number: TypeAlias = Int | _Inexact

_SCT = TypeVar("_SCT", bound=Scalar)
_SCT_co = TypeVar("_SCT_co", bound=Scalar, default=_Inexact, covariant=True)
_SCT1_co = TypeVar("_SCT1_co", bound=Scalar, default=_Inexact, covariant=True)
_SCT2_co = TypeVar("_SCT2_co", bound=Scalar, default=_SCT1_co, covariant=True)
_FunMatVecT_co = TypeVar("_FunMatVecT_co", bound=_FunMatVec, default=_FunMatVec, covariant=True)

_InexactT = TypeVar("_InexactT", bound=_Inexact)

_ToShape: TypeAlias = Iterable[op.CanIndex]
_ToDType: TypeAlias = type[_SCT] | np.dtype[_SCT] | onp.HasDType[np.dtype[_SCT]]

_JustFloat: TypeAlias = opt.Just[float]
_JustComplex: TypeAlias = opt.Just[complex]
_FunMatVec: TypeAlias = Callable[[onp.Array1D[_Number] | onp.Array2D[_Number]], onp.ToComplex1D | onp.ToComplex2D]
_FunMatMat: TypeAlias = Callable[[onp.Array2D[_Number]], onp.ToComplex2D]

_FunMatVec: TypeAlias = Callable[[onp.Array1D[np.number[Any]] | onp.Array2D[np.number[Any]]], onp.ToComplex1D | onp.ToComplex2D]
_FunMatMat: TypeAlias = Callable[[onp.Array2D[np.number[Any]]], onp.ToComplex2D]

_SCT = TypeVar("_SCT", bound=np.number[Any])
_SCT_co = TypeVar("_SCT_co", bound=np.number[Any], default=np.inexact[Any], covariant=True)
_SCT1_co = TypeVar("_SCT1_co", bound=np.number[Any], default=np.inexact[Any], covariant=True)
_SCT2_co = TypeVar("_SCT2_co", bound=np.number[Any], default=_SCT1_co, covariant=True)
_FunMatVecT_co = TypeVar("_FunMatVecT_co", bound=_FunMatVec, default=_FunMatVec, covariant=True)
_Array1D2D: TypeAlias = onp.Array1D[_SCT] | onp.Array2D[_SCT]

###

Expand All @@ -58,46 +58,31 @@ class LinearOperator(Generic[_SCT_co]):
@overload
def __init__(self, /, dtype: _ToDType[_SCT_co], shape: _ToShape) -> None: ...
@overload
def __init__(
self: LinearOperator[np.intp],
/,
dtype: onp.AnyIntPDType | type[opt.JustInt],
shape: _ToShape,
) -> None: ...
def __init__(self: LinearOperator[np.int_], /, dtype: ToDTypeInt, shape: _ToShape) -> None: ...
@overload
def __init__(
self: LinearOperator[np.float64],
/,
dtype: onp.AnyFloat64DType | type[_JustFloat],
shape: _ToShape,
) -> None: ...
def __init__(self: LinearOperator[np.float64], /, dtype: ToDTypeFloat, shape: _ToShape) -> None: ...
@overload
def __init__(
self: LinearOperator[np.complex128],
/,
dtype: onp.AnyComplex128DType | type[_JustComplex],
shape: _ToShape,
) -> None: ...
def __init__(self: LinearOperator[np.complex128], /, dtype: ToDTypeComplex, shape: _ToShape) -> None: ...
@overload
def __init__(self, /, dtype: onp.AnyInexactDType | None, shape: _ToShape) -> None: ...

#
@overload # float array 1d
def matvec(self, /, x: onp.ToFloatStrict1D) -> onp.Array1D[_SCT_co]: ...
@overload # float matrix
def matvec(self, /, x: _Matrix[np.floating[Any] | np.integer[Any]]) -> _Matrix[_SCT_co]: ...
def matvec(self, /, x: Matrix[_Real]) -> Matrix[_SCT_co]: ...
@overload # float array 2d
def matvec(self, /, x: onp.ToFloatStrict2D) -> onp.Array2D[_SCT_co]: ...
@overload # complex array 1d
def matvec(self, /, x: onp.ToComplexStrict1D) -> onp.Array1D[_SCT_co | np.complex128]: ...
@overload # complex matrix
def matvec(self, /, x: _Matrix[np.number[Any]]) -> _Matrix[_SCT_co | np.complex128]: ...
def matvec(self, /, x: Matrix[_Number]) -> Matrix[_SCT_co | np.complex128]: ...
@overload # complex array 2d
def matvec(self, /, x: onp.ToComplexStrict2D) -> onp.Array2D[_SCT_co | np.complex128]: ...
@overload # float array
def matvec(self, /, x: onp.ToFloat2D) -> onp.Array1D[_SCT_co] | onp.Array2D[_SCT_co]: ...
@overload # complex array
def matvec(self, /, x: onp.ToComplex2D) -> onp.Array1D[_SCT_co | np.complex128] | onp.Array2D[_SCT_co | np.complex128]: ...
def matvec(self, /, x: onp.ToComplex2D) -> _Array1D2D[_SCT_co | np.complex128]: ...
rmatvec = matvec

#
Expand All @@ -122,7 +107,7 @@ class LinearOperator(Generic[_SCT_co]):
@overload
def dot(self, /, x: onp.ToFloatND) -> onp.Array1D[_SCT_co] | onp.Array2D[_SCT_co]: ...
@overload
def dot(self, /, x: onp.ToComplexND) -> onp.Array1D[_SCT_co | np.complex128] | onp.Array2D[_SCT_co | np.complex128]: ...
def dot(self, /, x: onp.ToComplexND) -> _Array1D2D[_SCT_co | np.complex128]: ...
__mul__ = dot
__rmul__ = dot
__call__ = dot
Expand All @@ -141,11 +126,7 @@ class LinearOperator(Generic[_SCT_co]):
@overload
def __matmul__(self, /, x: onp.ToFloatND) -> onp.Array1D[_SCT_co] | onp.Array2D[_SCT_co]: ...
@overload
def __matmul__(
self,
/,
x: onp.ToComplexND,
) -> onp.Array1D[_SCT_co | np.complex128] | onp.Array2D[_SCT_co | np.complex128]: ...
def __matmul__(self, /, x: onp.ToComplexND) -> _Array1D2D[_SCT_co | np.complex128]: ...
__rmatmul__ = __matmul__

#
Expand Down Expand Up @@ -206,7 +187,7 @@ class _CustomLinearOperator(LinearOperator[_SCT_co], Generic[_SCT_co, _FunMatVec
matvec: _FunMatVec,
rmatvec: _FunMatVec | None,
matmat: _FunMatMat | None,
dtype: onp.AnyFloat64DType | type[float],
dtype: ToDTypeFloat,
rmatmat: _FunMatMat | None = None,
) -> None: ...
@overload # dtype-like float64 (keyword)
Expand All @@ -218,7 +199,7 @@ class _CustomLinearOperator(LinearOperator[_SCT_co], Generic[_SCT_co, _FunMatVec
rmatvec: _FunMatVec | None = None,
matmat: _FunMatMat | None = None,
*,
dtype: onp.AnyFloat64DType | type[float],
dtype: ToDTypeFloat,
rmatmat: _FunMatMat | None = None,
) -> None: ...
@overload # dtype-like complex128 (positional)
Expand All @@ -229,7 +210,7 @@ class _CustomLinearOperator(LinearOperator[_SCT_co], Generic[_SCT_co, _FunMatVec
matvec: _FunMatVec,
rmatvec: _FunMatVec | None,
matmat: _FunMatMat | None,
dtype: onp.AnyComplex128DType | type[opt.Just[complex]],
dtype: ToDTypeComplex,
rmatmat: _FunMatMat | None = None,
) -> None: ...
@overload # dtype-like complex128 (keyword)
Expand All @@ -241,14 +222,15 @@ class _CustomLinearOperator(LinearOperator[_SCT_co], Generic[_SCT_co, _FunMatVec
rmatvec: _FunMatVec | None = None,
matmat: _FunMatMat | None = None,
*,
dtype: onp.AnyComplex128DType | type[opt.Just[complex]],
dtype: ToDTypeComplex,
rmatmat: _FunMatMat | None = None,
) -> None: ...

@type_check_only
class _UnaryLinearOperator(LinearOperator[_SCT_co], Generic[_SCT_co]):
A: LinearOperator[_SCT_co]
args: tuple[LinearOperator[_SCT_co]]

def __init__(self, /, A: LinearOperator[_SCT_co]) -> None: ...

@final
Expand All @@ -260,11 +242,13 @@ class _TransposedLinearOperator(_UnaryLinearOperator[_SCT_co], Generic[_SCT_co])
@final
class _SumLinearOperator(LinearOperator[_SCT1_co | _SCT2_co], Generic[_SCT1_co, _SCT2_co]):
args: tuple[LinearOperator[_SCT1_co], LinearOperator[_SCT2_co]]

def __init__(self, /, A: LinearOperator[_SCT1_co], B: LinearOperator[_SCT2_co]) -> None: ...

@final
class _ProductLinearOperator(LinearOperator[_SCT1_co | _SCT2_co], Generic[_SCT1_co, _SCT2_co]):
args: tuple[LinearOperator[_SCT1_co], LinearOperator[_SCT2_co]]

def __init__(self, /, A: LinearOperator[_SCT1_co], B: LinearOperator[_SCT2_co]) -> None: ...

@final
Expand All @@ -273,45 +257,37 @@ class _ScaledLinearOperator(LinearOperator[_SCT_co], Generic[_SCT_co]):
@overload
def __init__(self, /, A: LinearOperator[_SCT_co], alpha: _SCT_co | complex) -> None: ...
@overload
def __init__(self: _ScaledLinearOperator[np.float64], /, A: LinearOperator[np.floating[Any]], alpha: float) -> None: ...
def __init__(self: _ScaledLinearOperator[np.float64], /, A: LinearOperator[Float], alpha: float) -> None: ...
@overload
def __init__(self: _ScaledLinearOperator[np.complex128], /, A: LinearOperator, alpha: complex) -> None: ...

@final
class _PowerLinearOperator(LinearOperator[_SCT_co], Generic[_SCT_co]):
args: tuple[LinearOperator[_SCT_co], op.CanIndex]

def __init__(self, /, A: LinearOperator[_SCT_co], p: op.CanIndex) -> None: ...

class MatrixLinearOperator(LinearOperator[_SCT_co], Generic[_SCT_co]):
A: _spbase | onp.Array2D[_SCT_co]
args: tuple[_spbase | onp.Array2D[_SCT_co]]

def __init__(self, /, A: _spbase | onp.ArrayND[_SCT_co]) -> None: ...

@final
class _AdjointMatrixOperator(MatrixLinearOperator[_SCT_co], Generic[_SCT_co]):
args: tuple[MatrixLinearOperator[_SCT_co]] # type: ignore[assignment]
args: tuple[MatrixLinearOperator[_SCT_co]] # type: ignore[assignment] # pyright: ignore[reportIncompatibleVariableOverride]
@property
@override
def dtype(self, /) -> np.dtype[_SCT_co]: ...
def dtype(self, /) -> np.dtype[_SCT_co]: ... # pyright: ignore[reportIncompatibleVariableOverride]
def __init__(self, /, adjoint: LinearOperator) -> None: ...

class IdentityOperator(LinearOperator[_SCT_co], Generic[_SCT_co]):
@overload
def __init__(self, /, shape: _ToShape, dtype: _ToDType[_SCT_co]) -> None: ...
@overload
def __init__(
self: IdentityOperator[np.float64],
/,
shape: _ToShape,
dtype: onp.AnyFloat64DType | type[float] | None = None,
) -> None: ...
def __init__(self: IdentityOperator[np.float64], /, shape: _ToShape, dtype: ToDTypeFloat | None = None) -> None: ...
@overload
def __init__(
self: IdentityOperator[np.complex128],
/,
shape: _ToShape,
dtype: onp.AnyComplex128DType | type[_JustComplex],
) -> None: ...
def __init__(self: IdentityOperator[np.complex128], /, shape: _ToShape, dtype: ToDTypeComplex) -> None: ...
@overload
def __init__(self, /, shape: _ToShape, dtype: onp.AnyInexactDType) -> None: ...

Expand All @@ -338,12 +314,12 @@ class _HasShapeAndDTypeAndMatVec(Protocol[_SCT_co]):
def matvec(self, /, x: onp.CanArray2D[np.float64] | onp.CanArray2D[np.complex128]) -> onp.ToComplex2D: ...

@overload
def aslinearoperator(A: onp.CanArrayND[_SCT_co]) -> MatrixLinearOperator[_SCT_co]: ...
def aslinearoperator(A: onp.CanArrayND[_InexactT]) -> MatrixLinearOperator[_InexactT]: ...
@overload
def aslinearoperator(A: _spbase) -> MatrixLinearOperator: ...
def aslinearoperator(A: _spbase[_InexactT]) -> MatrixLinearOperator[_InexactT]: ...
@overload
def aslinearoperator(A: onp.ArrayND[np.integer[Any] | np.bool_]) -> MatrixLinearOperator[np.float64]: ...
def aslinearoperator(A: onp.ArrayND[np.bool_ | Int] | _spbase[np.bool_ | Int]) -> MatrixLinearOperator[np.float64]: ...
@overload
def aslinearoperator(A: _HasShapeAndDTypeAndMatVec[_SCT_co]) -> MatrixLinearOperator[_SCT_co]: ...
def aslinearoperator(A: _HasShapeAndDTypeAndMatVec[_InexactT]) -> MatrixLinearOperator[_InexactT]: ...
@overload
def aslinearoperator(A: _HasShapeAndMatVec[_SCT_co]) -> MatrixLinearOperator[_SCT_co]: ...
def aslinearoperator(A: _HasShapeAndMatVec[_InexactT]) -> MatrixLinearOperator[_InexactT]: ...
6 changes: 3 additions & 3 deletions scipy-stubs/sparse/linalg/_matfuncs.pyi
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# mypy: disable-error-code="override"
# pyright: reportIncompatibleMethodOverride=false

from typing import Any, Final, Generic, Literal, TypeAlias
from typing import Final, Generic, Literal, TypeAlias
from typing_extensions import Self, TypeVar, override

import numpy as np
import optype as op
import optype.numpy as onp
from scipy.sparse._base import _spbase
from scipy.sparse._typing import Complex, Float, Scalar
from ._interface import LinearOperator

__all__ = ["expm", "inv", "matrix_power"]

_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.inexact[Any], default=np.inexact[Any])
_SCT_co = TypeVar("_SCT_co", covariant=True, bound=Scalar, default=Float | Complex)
_SparseT = TypeVar("_SparseT", bound=_spbase)

_Structure: TypeAlias = Literal["upper_triangular"]
Expand Down

0 comments on commit 40fe461

Please sign in to comment.