Skip to content

Commit

Permalink
sparse: complete {h,v}stack (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham authored Dec 15, 2024
1 parent cc91d80 commit aac20bc
Showing 1 changed file with 123 additions and 4 deletions.
127 changes: 123 additions & 4 deletions scipy-stubs/sparse/_construct.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Sequence
from collections.abc import Sequence as Seq
from typing import Any, Literal, TypeAlias, TypeVar, overload

import numpy as np
Expand Down Expand Up @@ -41,7 +41,7 @@ _SCT1 = TypeVar("_SCT1", bound=Scalar)
_SCT2 = TypeVar("_SCT2", bound=Scalar)
_ShapeT = TypeVar("_ShapeT", bound=tuple[int] | tuple[int, int], default=tuple[int] | tuple[int, int])

_ToMatrix: TypeAlias = spmatrix[_SCT] | Sequence[Sequence[_SCT] | onp.CanArrayND[_SCT]] | onp.CanArrayND[_SCT]
_ToMatrix: TypeAlias = spmatrix[_SCT] | Seq[Seq[_SCT] | onp.CanArrayND[_SCT]] | onp.CanArrayND[_SCT]

_SpMatrix: TypeAlias = (
bsr_matrix[_SCT]
Expand Down Expand Up @@ -89,12 +89,19 @@ _NonCSRMatrix: TypeAlias = (
bsr_matrix[_SCT] | coo_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)

_SpMatrixOut: TypeAlias = coo_matrix[_SCT] | csc_matrix[_SCT] | csr_matrix[_SCT]
_SpMatrixNonOut: TypeAlias = bsr_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
_SpArrayOut: TypeAlias = coo_array[_SCT, _ShapeT] | csc_array[_SCT] | csr_array[_SCT, _ShapeT]
_SpArrayNonOut: TypeAlias = bsr_array[_SCT] | dia_array[_SCT] | dok_array[_SCT, tuple[int, int]] | lil_array[_SCT]

_FmtBSR: TypeAlias = Literal["bsr"]
_FmtCSR: TypeAlias = Literal["csr"]
_FmtDIA: TypeAlias = Literal["dia"]
_FmtOut: TypeAlias = Literal["coo", "csc", "csr"]
_FmtNonBSR: TypeAlias = Literal["coo", "csc", "csr", "dia", "dok", "lil"]
_FmtNonCSR: TypeAlias = Literal["bsr", "coo", "csc", "dia", "dok", "lil"]
_FmtNonDIA: TypeAlias = Literal["bsr", "coo", "csc", "csr", "dok", "lil"]
_FmtNonOut: TypeAlias = Literal["bsr", "dia", "dok", "lil"]

###

Expand Down Expand Up @@ -400,9 +407,121 @@ def kronsum(
format: SPFormat | None = None,
) -> _SpArray2D[_SCT] | _SpMatrix[_SCT]: ...

# NOTE: hstack and vstack have identical signatures
@overload # sparray, format: <default>, dtype: None
def hstack(blocks: Seq[_SpArray[_SCT]], format: _FmtOut | None = None, dtype: None = None) -> _SpArrayOut[_SCT]: ...
@overload # sparray, format: <non-default>, dtype: None
def hstack(blocks: Seq[_SpArray[_SCT]], format: _FmtNonOut, dtype: None = None) -> _SpArrayNonOut[_SCT]: ...
@overload # sparray, format: <default>, dtype: <int>
def hstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: ToDTypeBool) -> _SpArrayOut[np.bool_]: ...
@overload # sparray, format: <non-default>, dtype: <int>
def hstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: ToDTypeBool) -> _SpArrayNonOut[np.bool_]: ...
@overload # sparray, format: <default>, dtype: <int>
def hstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: ToDTypeInt) -> _SpArrayOut[np.int_]: ...
@overload # sparray, format: <non-default>, dtype: <int>
def hstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: ToDTypeInt) -> _SpArrayNonOut[np.int_]: ...
@overload # sparray, format: <default>, dtype: <float>
def hstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: ToDTypeFloat) -> _SpArrayOut[np.float64]: ...
@overload # sparray, format: <non-default>, dtype: <float>
def hstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: ToDTypeFloat) -> _SpArrayNonOut[np.float64]: ...
@overload # sparray, format: <default>, dtype: <complex>
def hstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: ToDTypeComplex) -> _SpArrayOut[np.complex128]: ...
@overload # sparray, format: <non-default>, dtype: <complex>
def hstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: ToDTypeComplex) -> _SpArrayNonOut[np.complex128]: ...
@overload # sparray, format: <default>, dtype: <known>
def hstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: ToDType[_SCT]) -> _SpArrayOut[_SCT]: ...
@overload # sparray, format: <non-default>, dtype: <known>
def hstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: ToDType[_SCT]) -> _SpArrayNonOut[_SCT]: ...
@overload # sparray, format: <default>, dtype: <unknown>
def hstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: npt.DTypeLike) -> _SpArrayOut: ...
@overload # sparray, format: <non-default>, dtype: <unknown>
def hstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: npt.DTypeLike) -> _SpArrayNonOut: ...
@overload # spmatrix, format: <default>, dtype: None
def hstack(blocks: Seq[spmatrix[_SCT]], format: _FmtOut | None = None, dtype: None = None) -> _SpMatrixOut[_SCT]: ...
@overload # spmatrix, format: <non-default>, dtype: None
def hstack(blocks: Seq[spmatrix[_SCT]], format: _FmtNonOut, dtype: None = None) -> _SpMatrixNonOut[_SCT]: ...
@overload # spmatrix, format: <default>, dtype: <int>
def hstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: ToDTypeBool) -> _SpMatrixOut[np.bool_]: ...
@overload # spmatrix, format: <non-default>, dtype: <int>
def hstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: ToDTypeBool) -> _SpMatrixNonOut[np.bool_]: ...
@overload # spmatrix, format: <default>, dtype: <int>
def hstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: ToDTypeInt) -> _SpMatrixOut[np.int_]: ...
@overload # spmatrix, format: <non-default>, dtype: <int>
def hstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: ToDTypeInt) -> _SpMatrixNonOut[np.int_]: ...
@overload # spmatrix, format: <default>, dtype: <float>
def hstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: ToDTypeFloat) -> _SpMatrixOut[np.float64]: ...
@overload # spmatrix, format: <non-default>, dtype: <float>
def hstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: ToDTypeFloat) -> _SpMatrixNonOut[np.float64]: ...
@overload # spmatrix, format: <default>, dtype: <complex>
def hstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: ToDTypeComplex) -> _SpMatrixOut[np.complex128]: ...
@overload # spmatrix, format: <non-default>, dtype: <complex>
def hstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: ToDTypeComplex) -> _SpMatrixNonOut[np.complex128]: ...
@overload # spmatrix, format: <default>, dtype: <known>
def hstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: ToDType[_SCT]) -> _SpMatrixOut[_SCT]: ...
@overload # spmatrix, format: <non-default>, dtype: <known>
def hstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: ToDType[_SCT]) -> _SpMatrixNonOut[_SCT]: ...
@overload # spmatrix, format: <default>, dtype: <unknown>
def hstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: npt.DTypeLike) -> _SpMatrixOut: ...
@overload # spmatrix, format: <non-default>, dtype: <unknown>
def hstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: npt.DTypeLike) -> _SpMatrixNonOut: ...

#
def hstack(blocks: Untyped, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray | _SpMatrix: ...
def vstack(blocks: Untyped, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray | _SpMatrix: ...
@overload # sparray, format: <default>, dtype: None
def vstack(blocks: Seq[_SpArray[_SCT]], format: _FmtOut | None = None, dtype: None = None) -> _SpArrayOut[_SCT]: ...
@overload # sparray, format: <non-default>, dtype: None
def vstack(blocks: Seq[_SpArray[_SCT]], format: _FmtNonOut, dtype: None = None) -> _SpArrayNonOut[_SCT]: ...
@overload # sparray, format: <default>, dtype: <int>
def vstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: ToDTypeBool) -> _SpArrayOut[np.bool_]: ...
@overload # sparray, format: <non-default>, dtype: <int>
def vstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: ToDTypeBool) -> _SpArrayNonOut[np.bool_]: ...
@overload # sparray, format: <default>, dtype: <int>
def vstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: ToDTypeInt) -> _SpArrayOut[np.int_]: ...
@overload # sparray, format: <non-default>, dtype: <int>
def vstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: ToDTypeInt) -> _SpArrayNonOut[np.int_]: ...
@overload # sparray, format: <default>, dtype: <float>
def vstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: ToDTypeFloat) -> _SpArrayOut[np.float64]: ...
@overload # sparray, format: <non-default>, dtype: <float>
def vstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: ToDTypeFloat) -> _SpArrayNonOut[np.float64]: ...
@overload # sparray, format: <default>, dtype: <complex>
def vstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: ToDTypeComplex) -> _SpArrayOut[np.complex128]: ...
@overload # sparray, format: <non-default>, dtype: <complex>
def vstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: ToDTypeComplex) -> _SpArrayNonOut[np.complex128]: ...
@overload # sparray, format: <default>, dtype: <known>
def vstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: ToDType[_SCT]) -> _SpArrayOut[_SCT]: ...
@overload # sparray, format: <non-default>, dtype: <known>
def vstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: ToDType[_SCT]) -> _SpArrayNonOut[_SCT]: ...
@overload # sparray, format: <default>, dtype: <unknown>
def vstack(blocks: Seq[_SpArray], format: _FmtOut | None = None, *, dtype: npt.DTypeLike) -> _SpArrayOut: ...
@overload # sparray, format: <non-default>, dtype: <unknown>
def vstack(blocks: Seq[_SpArray], format: _FmtNonOut, dtype: npt.DTypeLike) -> _SpArrayNonOut: ...
@overload # spmatrix, format: <default>, dtype: None
def vstack(blocks: Seq[spmatrix[_SCT]], format: _FmtOut | None = None, dtype: None = None) -> _SpMatrixOut[_SCT]: ...
@overload # spmatrix, format: <non-default>, dtype: None
def vstack(blocks: Seq[spmatrix[_SCT]], format: _FmtNonOut, dtype: None = None) -> _SpMatrixNonOut[_SCT]: ...
@overload # spmatrix, format: <default>, dtype: <int>
def vstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: ToDTypeBool) -> _SpMatrixOut[np.bool_]: ...
@overload # spmatrix, format: <non-default>, dtype: <int>
def vstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: ToDTypeBool) -> _SpMatrixNonOut[np.bool_]: ...
@overload # spmatrix, format: <default>, dtype: <int>
def vstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: ToDTypeInt) -> _SpMatrixOut[np.int_]: ...
@overload # spmatrix, format: <non-default>, dtype: <int>
def vstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: ToDTypeInt) -> _SpMatrixNonOut[np.int_]: ...
@overload # spmatrix, format: <default>, dtype: <float>
def vstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: ToDTypeFloat) -> _SpMatrixOut[np.float64]: ...
@overload # spmatrix, format: <non-default>, dtype: <float>
def vstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: ToDTypeFloat) -> _SpMatrixNonOut[np.float64]: ...
@overload # spmatrix, format: <default>, dtype: <complex>
def vstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: ToDTypeComplex) -> _SpMatrixOut[np.complex128]: ...
@overload # spmatrix, format: <non-default>, dtype: <complex>
def vstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: ToDTypeComplex) -> _SpMatrixNonOut[np.complex128]: ...
@overload # spmatrix, format: <default>, dtype: <known>
def vstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: ToDType[_SCT]) -> _SpMatrixOut[_SCT]: ...
@overload # spmatrix, format: <non-default>, dtype: <known>
def vstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: ToDType[_SCT]) -> _SpMatrixNonOut[_SCT]: ...
@overload # spmatrix, format: <default>, dtype: <unknown>
def vstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: npt.DTypeLike) -> _SpMatrixOut: ...
@overload # spmatrix, format: <non-default>, dtype: <unknown>
def vstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: npt.DTypeLike) -> _SpMatrixNonOut: ...

#
def bmat(blocks: Untyped, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray | _SpMatrix: ...
Expand Down

0 comments on commit aac20bc

Please sign in to comment.