Skip to content

Commit

Permalink
sparse: complete kron[sum] (#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham authored Dec 15, 2024
1 parent e3ab140 commit 0b63592
Showing 1 changed file with 82 additions and 19 deletions.
101 changes: 82 additions & 19 deletions scipy-stubs/sparse/_construct.pyi
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
from collections.abc import Sequence
from typing import Any, Literal, TypeAlias, TypeVar, overload

import numpy as np
import numpy.typing as npt
import optype.numpy as onp
import optype.typing as opt
from scipy._typing import Seed, Untyped, UntypedCallable
from scipy.sparse import (
bsr_array,
bsr_matrix,
coo_array,
coo_matrix,
csc_array,
csc_matrix,
csr_array,
csr_matrix,
dia_array,
dia_matrix,
dok_array,
dok_matrix,
lil_array,
lil_matrix,
)
from ._base import _spbase
from ._bsr import bsr_array, bsr_matrix
from ._coo import coo_array, coo_matrix
from ._csc import csc_array, csc_matrix
from ._csr import csr_array, csr_matrix
from ._dia import dia_array, dia_matrix
from ._dok import dok_array, dok_matrix
from ._lil import lil_array, lil_matrix
from ._matrix import spmatrix
from ._typing import Float, Scalar, SPFormat, ToDType, ToDTypeBool, ToDTypeComplex, ToDTypeFloat, ToDTypeInt, ToShape, ToShape2D

__all__ = [
Expand All @@ -42,8 +37,12 @@ __all__ = [
]

_SCT = TypeVar("_SCT", bound=Scalar, default=Any)
_SCT1 = TypeVar("_SCT1", bound=Scalar)
_SCT2 = TypeVar("_SCT2", bound=Scalar)
_ShapeT = TypeVar("_ShapeT", bound=tuple[int] | tuple[int, int], default=tuple[int] | tuple[int, int])

_ToMatrix: TypeAlias = spmatrix[_SCT] | Sequence[Sequence[_SCT] | onp.CanArrayND[_SCT]] | onp.CanArrayND[_SCT]

_SpMatrix: TypeAlias = (
bsr_matrix[_SCT]
| coo_matrix[_SCT]
Expand All @@ -53,7 +52,6 @@ _SpMatrix: TypeAlias = (
| dok_matrix[_SCT]
| lil_matrix[_SCT]
)

_SpArray: TypeAlias = (
bsr_array[_SCT]
| coo_array[_SCT, _ShapeT]
Expand All @@ -66,7 +64,36 @@ _SpArray: TypeAlias = (
_SpArray1D: TypeAlias = coo_array[_SCT, tuple[int]] | csr_array[_SCT, tuple[int]] | dok_array[_SCT, tuple[int]]
_SpArray2D: TypeAlias = _SpArray[_SCT, tuple[int, int]]

_BSRArray: TypeAlias = bsr_array[_SCT]
_CSRArray: TypeAlias = csr_array[_SCT, tuple[int, int]]
_NonBSRArray: TypeAlias = (
coo_array[_SCT, tuple[int, int]]
| csc_array[_SCT]
| csr_array[_SCT, tuple[int, int]]
| dia_array[_SCT]
| dok_array[_SCT, tuple[int, int]]
| lil_array[_SCT]
)
_NonCSRArray: TypeAlias = (
bsr_array[_SCT]
| coo_array[_SCT, tuple[int, int]]
| csc_array[_SCT]
| dia_array[_SCT]
| dok_array[_SCT, tuple[int, int]]
| lil_array[_SCT]
)
_NonBSRMatrix: TypeAlias = (
coo_matrix[_SCT] | csr_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)
_NonCSRMatrix: TypeAlias = (
bsr_matrix[_SCT] | coo_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)

_FmtBSR: TypeAlias = Literal["bsr"]
_FmtCSR: TypeAlias = Literal["csr"]
_FmtDIA: TypeAlias = Literal["dia"]
_FmtNonBSR: TypeAlias = Literal["coo", "csc", "csr", "dia", "dok", "lil"]
_FmtNonCSR: TypeAlias = Literal["bsr", "coo", "csc", "dia", "dok", "lil"]
_FmtNonDIA: TypeAlias = Literal["bsr", "coo", "csc", "csr", "dok", "lil"]

###
Expand Down Expand Up @@ -334,8 +361,44 @@ def eye(
) -> _SpMatrix: ...

#
def kron(A: Untyped, B: Untyped, format: SPFormat | None = None) -> _SpArray2D | _SpMatrix: ...
def kronsum(A: Untyped, B: Untyped, format: SPFormat | None = None) -> _SpArray2D | _SpMatrix: ...
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"bsr", None} = ...
def kron(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtBSR | None = None) -> bsr_matrix[_SCT1 | _SCT2]: ...
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
def kron(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtNonBSR) -> _NonBSRMatrix[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: {"bsr", None} = ...
def kron(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtBSR | None = None) -> _BSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: <otherwise>
def kron(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtNonBSR) -> _NonBSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparse, B: sparray, format: {"bsr", None} = ...
def kron(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtBSR | None = None) -> _BSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparse, B: sparray, format: <otherwise>
def kron(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtNonBSR) -> _NonBSRArray[_SCT1 | _SCT2]: ...
@overload # A: unknown array-like, B: unknown array-like (catch-all)
def kron(
A: onp.ToComplex2D | _spbase[_SCT],
B: onp.ToComplex2D | _spbase[_SCT],
format: SPFormat | None = None,
) -> _SpArray2D[_SCT] | _SpMatrix[_SCT]: ...

#
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"csr", None} = ...
def kronsum(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtCSR | None = None) -> csr_matrix[_SCT1 | _SCT2]: ...
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
def kronsum(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtNonCSR) -> _NonCSRMatrix[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: {"csr", None} = ...
def kronsum(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtCSR | None = None) -> _CSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: <otherwise>
def kronsum(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtNonCSR) -> _NonCSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparse, B: sparray, format: {"csr", None} = ...
def kronsum(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtCSR | None = None) -> _CSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparse, B: sparray, format: <otherwise>
def kronsum(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtNonCSR) -> _NonCSRArray[_SCT1 | _SCT2]: ...
@overload # A: unknown array-like, B: unknown array-like (catch-all)
def kronsum(
A: onp.ToComplex2D | _spbase[_SCT],
B: onp.ToComplex2D | _spbase[_SCT],
format: SPFormat | None = None,
) -> _SpArray2D[_SCT] | _SpMatrix[_SCT]: ...

#
def hstack(blocks: Untyped, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray | _SpMatrix: ...
Expand Down

0 comments on commit 0b63592

Please sign in to comment.