Skip to content

Commit

Permalink
sparse: complete bmat and block_{array,diag} (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham authored Dec 15, 2024
1 parent f101692 commit cae0291
Showing 1 changed file with 117 additions and 6 deletions.
123 changes: 117 additions & 6 deletions scipy-stubs/sparse/_construct.pyi
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections.abc import Callable, Sequence as Seq
from collections.abc import Callable, Iterable, Sequence as Seq
from typing import Any, Literal, Protocol, TypeAlias, TypeVar, overload, type_check_only

import numpy as np
import numpy.typing as npt
import optype.numpy as onp
import optype.typing as opt
from scipy._typing import Seed, Untyped
from ._base import _spbase
from scipy._typing import Seed
from ._base import _spbase, sparray
from ._bsr import bsr_array, bsr_matrix
from ._coo import coo_array, coo_matrix
from ._csc import csc_array, csc_matrix
Expand Down Expand Up @@ -76,6 +76,14 @@ _NonBSRArray: TypeAlias = (
| dok_array[_SCT, tuple[int, int]]
| lil_array[_SCT]
)
_NonCOOArray: TypeAlias = (
bsr_array[_SCT]
| csc_array[_SCT]
| csr_array[_SCT, tuple[int, int]]
| dia_array[_SCT]
| dok_array[_SCT, tuple[int, int]]
| lil_array[_SCT]
)
_NonCSRArray: TypeAlias = (
bsr_array[_SCT]
| coo_array[_SCT, tuple[int, int]]
Expand All @@ -95,6 +103,9 @@ _NonDIAArray: TypeAlias = (
_NonBSRMatrix: TypeAlias = (
coo_matrix[_SCT] | csr_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)
_NonCOOMatrix: TypeAlias = (
bsr_matrix[_SCT] | csc_matrix[_SCT] | csr_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)
_NonCSRMatrix: TypeAlias = (
bsr_matrix[_SCT] | coo_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)
Expand All @@ -108,16 +119,20 @@ _SpArrayOut: TypeAlias = coo_array[_SCT, _ShapeT] | csc_array[_SCT] | csr_array[
_SpArrayNonOut: TypeAlias = bsr_array[_SCT] | dia_array[_SCT] | dok_array[_SCT, tuple[int, int]] | lil_array[_SCT]

_FmtBSR: TypeAlias = Literal["bsr"]
_FmtCOO: TypeAlias = Literal["coo"]
_FmtCSR: TypeAlias = Literal["csr"]
_FmtDIA: TypeAlias = Literal["dia"]
_FmtOut: TypeAlias = Literal["coo", "csc", "csr"]
_FmtNonBSR: TypeAlias = Literal["coo", "csc", "csr", "dia", "dok", "lil"]
_FmtNonCOO: TypeAlias = Literal["bsr", "csc", "csr", "dia", "dok", "lil"]
_FmtNonCSR: TypeAlias = Literal["bsr", "coo", "csc", "dia", "dok", "lil"]
_FmtNonDIA: TypeAlias = Literal["bsr", "coo", "csc", "csr", "dok", "lil"]
_FmtNonOut: TypeAlias = Literal["bsr", "dia", "dok", "lil"]

_DataRVS: TypeAlias = Callable[[int], onp.ArrayND[Scalar]]

_ToBlocks: TypeAlias = Seq[Seq[_spbase]] | onp.ArrayND[np.object_]

@type_check_only
class _DataSampler(Protocol):
def __call__(self, /, *, size: int) -> onp.ArrayND[Scalar]: ...
Expand Down Expand Up @@ -773,10 +788,106 @@ def vstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: npt.D
@overload # spmatrix, format: <non-default>, dtype: <unknown>
def vstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: npt.DTypeLike) -> _SpMatrixNonOut: ...

_COOArray2D: TypeAlias = coo_array[_SCT, tuple[int, int]]

#
def bmat(blocks: Untyped, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray | _SpMatrix: ...
def block_array(blocks: Untyped, *, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray: ...
def block_diag(mats: Untyped, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray | _SpMatrix: ...
@overload # blocks: <known dtype>, format: <default>, dtype: <default>
def block_array(blocks: Seq[Seq[_spbase[_SCT]]], *, format: _FmtCOO | None = None, dtype: None = None) -> _COOArray2D[_SCT]: ...
@overload # blocks: <unknown dtype>, format: <default>, dtype: <known>
def block_array(blocks: _ToBlocks, *, format: _FmtCOO | None = None, dtype: ToDType[_SCT]) -> _COOArray2D[_SCT]: ...
@overload # blocks: <unknown dtype>, format: <default>, dtype: <unknown>
def block_array(blocks: _ToBlocks, *, format: _FmtCOO | None = None, dtype: npt.DTypeLike) -> _COOArray2D: ...
@overload # blocks: <known dtype>, format: <otherwise>, dtype: <default>
def block_array(blocks: Seq[Seq[_spbase[_SCT]]], *, format: _FmtNonCOO, dtype: None = None) -> _NonCOOArray[_SCT]: ...
@overload # blocks: <unknown dtype>, format: <otherwise>, dtype: <known>
def block_array(blocks: _ToBlocks, *, format: _FmtNonCOO, dtype: ToDType[_SCT]) -> _NonCOOArray[_SCT]: ...
@overload # blocks: <unknown dtype>, format: <otherwise>, dtype: <unknown>
def block_array(blocks: _ToBlocks, *, format: _FmtNonCOO, dtype: npt.DTypeLike) -> _NonCOOArray: ...

#
@overload # blocks: <array, known dtype>, format: <default>, dtype: <default>
def bmat(blocks: Seq[Seq[_SpArray[_SCT]]], format: _FmtCOO | None = None, dtype: None = None) -> _COOArray2D[_SCT]: ...
@overload # blocks: <matrix, known dtype>, format: <default>, dtype: <default>
def bmat(blocks: Seq[Seq[spmatrix[_SCT]]], format: _FmtCOO | None = None, dtype: None = None) -> coo_matrix[_SCT]: ...
@overload # sparray, blocks: <unknown, unknown dtype>, format: <default>, dtype: <known> (positional)
def bmat(blocks: _ToBlocks, format: _FmtCOO | None, dtype: ToDType[_SCT]) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
@overload # sparray, blocks: <unknown, unknown dtype>, format: <default>, dtype: <known> (keyword)
def bmat(blocks: _ToBlocks, format: _FmtCOO | None = None, *, dtype: ToDType[_SCT]) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
@overload # sparray, blocks: <unknown, unknown dtype>, format: <default>, dtype: <unknown>
def bmat(blocks: _ToBlocks, format: _FmtCOO | None = None, dtype: npt.DTypeLike | None = None) -> _COOArray2D | coo_matrix: ...
@overload # sparray, blocks: <array, known dtype>, format: <otherwise>, dtype: <default>
def bmat(blocks: Seq[Seq[_SpArray[_SCT]]], format: _FmtNonCOO, dtype: None = None) -> _NonCOOArray[_SCT]: ...
@overload # sparray, blocks: <matrix, known dtype>, format: <otherwise>, dtype: <default>
def bmat(blocks: Seq[Seq[spmatrix[_SCT]]], format: _FmtNonCOO, dtype: None = None) -> _NonCOOMatrix[_SCT]: ...
@overload # sparray, blocks: <unknown, unknown dtype>, format: <otherwise>, dtype: <known>
def bmat(blocks: _ToBlocks, format: _FmtNonCOO, dtype: ToDType[_SCT]) -> _NonCOOArray[_SCT] | _NonCOOMatrix[_SCT]: ...
@overload # sparray, blocks: <unknown, unknown dtype>, format: <otherwise>, dtype: <unknown>
def bmat(blocks: _ToBlocks, format: _FmtNonCOO, dtype: npt.DTypeLike) -> _NonCOOArray | _NonCOOMatrix: ...

#
@overload # mats: <array, known dtype>
def block_diag(
mats: Iterable[_SpArray[_SCT]],
format: SPFormat | None = None,
dtype: None = None,
) -> _SpArray[_SCT, tuple[int, int]]: ...
@overload # mats: <matrix, known dtype>
def block_diag(
mats: Iterable[spmatrix[_SCT]],
format: SPFormat | None = None,
dtype: None = None,
) -> _SpMatrix[_SCT]: ...
@overload # mats: <unknown, known dtype>
def block_diag(
mats: Iterable[_spbase[_SCT] | onp.ArrayND[_SCT]],
format: SPFormat | None = None,
dtype: None = None,
) -> _SpArray[_SCT, tuple[int, int]] | _SpMatrix[_SCT]: ...
@overload # mats: <array, unknown dtype>, dtype: <known> (positional)
def block_diag(
mats: Iterable[sparray],
format: SPFormat | None,
dtype: ToDType[_SCT],
) -> _SpArray[_SCT, tuple[int, int]]: ...
@overload # mats: <array, unknown dtype>, dtype: <known> (keyword)
def block_diag(
mats: Iterable[sparray],
format: SPFormat | None = None,
*,
dtype: ToDType[_SCT],
) -> _SpArray[_SCT, tuple[int, int]]: ...
@overload # mats: <matrix, unknown dtype>, dtype: <known> (positional)
def block_diag(
mats: Iterable[spmatrix | onp.ArrayND[Scalar] | complex | list[onp.ToComplex] | list[onp.ToComplex1D]],
format: SPFormat | None,
dtype: ToDType[_SCT],
) -> _SpMatrix[_SCT]: ...
@overload # mats: <matrix, unknown dtype>, dtype: <known> (keyword)
def block_diag(
mats: Iterable[spmatrix | onp.ArrayND[Scalar] | complex | list[onp.ToComplex] | list[onp.ToComplex1D]],
format: SPFormat | None = None,
*,
dtype: ToDType[_SCT],
) -> _SpMatrix[_SCT]: ...
@overload # mats: <unknown, unknown dtype>, dtype: <known> (positional)
def block_diag(
mats: Iterable[_spbase | onp.ArrayND[Scalar] | complex | list[onp.ToComplex] | list[onp.ToComplex1D]],
format: SPFormat | None,
dtype: ToDType[_SCT],
) -> _SpArray[_SCT, tuple[int, int]] | _SpMatrix[_SCT]: ...
@overload # mats: <unknown, unknown dtype>, dtype: <known> (keyword)
def block_diag(
mats: Iterable[_spbase | onp.ArrayND[Scalar] | complex | list[onp.ToComplex] | list[onp.ToComplex1D]],
format: SPFormat | None = None,
*,
dtype: ToDType[_SCT],
) -> _SpArray[_SCT, tuple[int, int]] | _SpMatrix[_SCT]: ...
@overload # catch-all
def block_diag(
mats: Iterable[_spbase | onp.ArrayND[Scalar] | complex | list[onp.ToComplex] | list[onp.ToComplex1D]],
format: SPFormat | None = None,
dtype: npt.DTypeLike | None = None,
) -> _SpArray[Any, tuple[int, int]] | _SpMatrix[Any]: ...

#
@overload # shape: 1d, dtype: <default>
Expand Down

0 comments on commit cae0291

Please sign in to comment.