Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse: complete kron[sum] #311

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 82 additions & 19 deletions scipy-stubs/sparse/_construct.pyi
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
from collections.abc import Sequence
from typing import Any, Literal, TypeAlias, TypeVar, overload

import numpy as np
import numpy.typing as npt
import optype.numpy as onp
import optype.typing as opt
from scipy._typing import Seed, Untyped, UntypedCallable
from scipy.sparse import (
bsr_array,
bsr_matrix,
coo_array,
coo_matrix,
csc_array,
csc_matrix,
csr_array,
csr_matrix,
dia_array,
dia_matrix,
dok_array,
dok_matrix,
lil_array,
lil_matrix,
)
from ._base import _spbase
from ._bsr import bsr_array, bsr_matrix
from ._coo import coo_array, coo_matrix
from ._csc import csc_array, csc_matrix
from ._csr import csr_array, csr_matrix
from ._dia import dia_array, dia_matrix
from ._dok import dok_array, dok_matrix
from ._lil import lil_array, lil_matrix
from ._matrix import spmatrix
from ._typing import Float, Scalar, SPFormat, ToDType, ToDTypeBool, ToDTypeComplex, ToDTypeFloat, ToDTypeInt, ToShape, ToShape2D

__all__ = [
Expand All @@ -42,8 +37,12 @@ __all__ = [
]

_SCT = TypeVar("_SCT", bound=Scalar, default=Any)
_SCT1 = TypeVar("_SCT1", bound=Scalar)
_SCT2 = TypeVar("_SCT2", bound=Scalar)
_ShapeT = TypeVar("_ShapeT", bound=tuple[int] | tuple[int, int], default=tuple[int] | tuple[int, int])

_ToMatrix: TypeAlias = spmatrix[_SCT] | Sequence[Sequence[_SCT] | onp.CanArrayND[_SCT]] | onp.CanArrayND[_SCT]

_SpMatrix: TypeAlias = (
bsr_matrix[_SCT]
| coo_matrix[_SCT]
Expand All @@ -53,7 +52,6 @@ _SpMatrix: TypeAlias = (
| dok_matrix[_SCT]
| lil_matrix[_SCT]
)

_SpArray: TypeAlias = (
bsr_array[_SCT]
| coo_array[_SCT, _ShapeT]
Expand All @@ -66,7 +64,36 @@ _SpArray: TypeAlias = (
_SpArray1D: TypeAlias = coo_array[_SCT, tuple[int]] | csr_array[_SCT, tuple[int]] | dok_array[_SCT, tuple[int]]
_SpArray2D: TypeAlias = _SpArray[_SCT, tuple[int, int]]

_BSRArray: TypeAlias = bsr_array[_SCT]
_CSRArray: TypeAlias = csr_array[_SCT, tuple[int, int]]
_NonBSRArray: TypeAlias = (
coo_array[_SCT, tuple[int, int]]
| csc_array[_SCT]
| csr_array[_SCT, tuple[int, int]]
| dia_array[_SCT]
| dok_array[_SCT, tuple[int, int]]
| lil_array[_SCT]
)
_NonCSRArray: TypeAlias = (
bsr_array[_SCT]
| coo_array[_SCT, tuple[int, int]]
| csc_array[_SCT]
| dia_array[_SCT]
| dok_array[_SCT, tuple[int, int]]
| lil_array[_SCT]
)
_NonBSRMatrix: TypeAlias = (
coo_matrix[_SCT] | csr_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)
_NonCSRMatrix: TypeAlias = (
bsr_matrix[_SCT] | coo_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)

_FmtBSR: TypeAlias = Literal["bsr"]
_FmtCSR: TypeAlias = Literal["csr"]
_FmtDIA: TypeAlias = Literal["dia"]
_FmtNonBSR: TypeAlias = Literal["coo", "csc", "csr", "dia", "dok", "lil"]
_FmtNonCSR: TypeAlias = Literal["bsr", "coo", "csc", "dia", "dok", "lil"]
_FmtNonDIA: TypeAlias = Literal["bsr", "coo", "csc", "csr", "dok", "lil"]

###
Expand Down Expand Up @@ -334,8 +361,44 @@ def eye(
) -> _SpMatrix: ...

#
def kron(A: Untyped, B: Untyped, format: SPFormat | None = None) -> _SpArray2D | _SpMatrix: ...
def kronsum(A: Untyped, B: Untyped, format: SPFormat | None = None) -> _SpArray2D | _SpMatrix: ...
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"bsr", None} = ...
def kron(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtBSR | None = None) -> bsr_matrix[_SCT1 | _SCT2]: ...
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
def kron(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtNonBSR) -> _NonBSRMatrix[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: {"bsr", None} = ...
def kron(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtBSR | None = None) -> _BSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: <otherwise>
def kron(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtNonBSR) -> _NonBSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparse, B: sparray, format: {"bsr", None} = ...
def kron(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtBSR | None = None) -> _BSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparse, B: sparray, format: <otherwise>
def kron(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtNonBSR) -> _NonBSRArray[_SCT1 | _SCT2]: ...
@overload # A: unknown array-like, B: unknown array-like (catch-all)
def kron(
A: onp.ToComplex2D | _spbase[_SCT],
B: onp.ToComplex2D | _spbase[_SCT],
format: SPFormat | None = None,
) -> _SpArray2D[_SCT] | _SpMatrix[_SCT]: ...

#
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"csr", None} = ...
def kronsum(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtCSR | None = None) -> csr_matrix[_SCT1 | _SCT2]: ...
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
def kronsum(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtNonCSR) -> _NonCSRMatrix[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: {"csr", None} = ...
def kronsum(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtCSR | None = None) -> _CSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: <otherwise>
def kronsum(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtNonCSR) -> _NonCSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparse, B: sparray, format: {"csr", None} = ...
def kronsum(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtCSR | None = None) -> _CSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparse, B: sparray, format: <otherwise>
def kronsum(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtNonCSR) -> _NonCSRArray[_SCT1 | _SCT2]: ...
@overload # A: unknown array-like, B: unknown array-like (catch-all)
def kronsum(
A: onp.ToComplex2D | _spbase[_SCT],
B: onp.ToComplex2D | _spbase[_SCT],
format: SPFormat | None = None,
) -> _SpArray2D[_SCT] | _SpMatrix[_SCT]: ...

#
def hstack(blocks: Untyped, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray | _SpMatrix: ...
Expand Down
Loading