From cae0291bddee246c46ef3af0b26eb40283097685 Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu Date: Sun, 15 Dec 2024 20:17:51 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20`sparse`:=20complete=20`bmat`=20and?= =?UTF-8?q?=20`block=5F{array,diag}`=20(#320)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scipy-stubs/sparse/_construct.pyi | 123 ++++++++++++++++++++++++++++-- 1 file changed, 117 insertions(+), 6 deletions(-) diff --git a/scipy-stubs/sparse/_construct.pyi b/scipy-stubs/sparse/_construct.pyi index f67fc4d4..ab5903ab 100644 --- a/scipy-stubs/sparse/_construct.pyi +++ b/scipy-stubs/sparse/_construct.pyi @@ -1,12 +1,12 @@ -from collections.abc import Callable, Sequence as Seq +from collections.abc import Callable, Iterable, Sequence as Seq from typing import Any, Literal, Protocol, TypeAlias, TypeVar, overload, type_check_only import numpy as np import numpy.typing as npt import optype.numpy as onp import optype.typing as opt -from scipy._typing import Seed, Untyped -from ._base import _spbase +from scipy._typing import Seed +from ._base import _spbase, sparray from ._bsr import bsr_array, bsr_matrix from ._coo import coo_array, coo_matrix from ._csc import csc_array, csc_matrix @@ -76,6 +76,14 @@ _NonBSRArray: TypeAlias = ( | dok_array[_SCT, tuple[int, int]] | lil_array[_SCT] ) +_NonCOOArray: TypeAlias = ( + bsr_array[_SCT] + | csc_array[_SCT] + | csr_array[_SCT, tuple[int, int]] + | dia_array[_SCT] + | dok_array[_SCT, tuple[int, int]] + | lil_array[_SCT] +) _NonCSRArray: TypeAlias = ( bsr_array[_SCT] | coo_array[_SCT, tuple[int, int]] @@ -95,6 +103,9 @@ _NonDIAArray: TypeAlias = ( _NonBSRMatrix: TypeAlias = ( coo_matrix[_SCT] | csr_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT] ) +_NonCOOMatrix: TypeAlias = ( + bsr_matrix[_SCT] | csc_matrix[_SCT] | csr_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT] +) _NonCSRMatrix: TypeAlias = ( bsr_matrix[_SCT] | coo_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT] ) @@ -108,16 +119,20 @@ _SpArrayOut: TypeAlias = coo_array[_SCT, _ShapeT] | csc_array[_SCT] | csr_array[ _SpArrayNonOut: TypeAlias = bsr_array[_SCT] | dia_array[_SCT] | dok_array[_SCT, tuple[int, int]] | lil_array[_SCT] _FmtBSR: TypeAlias = Literal["bsr"] +_FmtCOO: TypeAlias = Literal["coo"] _FmtCSR: TypeAlias = Literal["csr"] _FmtDIA: TypeAlias = Literal["dia"] _FmtOut: TypeAlias = Literal["coo", "csc", "csr"] _FmtNonBSR: TypeAlias = Literal["coo", "csc", "csr", "dia", "dok", "lil"] +_FmtNonCOO: TypeAlias = Literal["bsr", "csc", "csr", "dia", "dok", "lil"] _FmtNonCSR: TypeAlias = Literal["bsr", "coo", "csc", "dia", "dok", "lil"] _FmtNonDIA: TypeAlias = Literal["bsr", "coo", "csc", "csr", "dok", "lil"] _FmtNonOut: TypeAlias = Literal["bsr", "dia", "dok", "lil"] _DataRVS: TypeAlias = Callable[[int], onp.ArrayND[Scalar]] +_ToBlocks: TypeAlias = Seq[Seq[_spbase]] | onp.ArrayND[np.object_] + @type_check_only class _DataSampler(Protocol): def __call__(self, /, *, size: int) -> onp.ArrayND[Scalar]: ... @@ -773,10 +788,106 @@ def vstack(blocks: Seq[spmatrix], format: _FmtOut | None = None, *, dtype: npt.D @overload # spmatrix, format: , dtype: def vstack(blocks: Seq[spmatrix], format: _FmtNonOut, dtype: npt.DTypeLike) -> _SpMatrixNonOut: ... +_COOArray2D: TypeAlias = coo_array[_SCT, tuple[int, int]] + # -def bmat(blocks: Untyped, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray | _SpMatrix: ... -def block_array(blocks: Untyped, *, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray: ... -def block_diag(mats: Untyped, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray | _SpMatrix: ... +@overload # blocks: , format: , dtype: +def block_array(blocks: Seq[Seq[_spbase[_SCT]]], *, format: _FmtCOO | None = None, dtype: None = None) -> _COOArray2D[_SCT]: ... +@overload # blocks: , format: , dtype: +def block_array(blocks: _ToBlocks, *, format: _FmtCOO | None = None, dtype: ToDType[_SCT]) -> _COOArray2D[_SCT]: ... +@overload # blocks: , format: , dtype: +def block_array(blocks: _ToBlocks, *, format: _FmtCOO | None = None, dtype: npt.DTypeLike) -> _COOArray2D: ... +@overload # blocks: , format: , dtype: +def block_array(blocks: Seq[Seq[_spbase[_SCT]]], *, format: _FmtNonCOO, dtype: None = None) -> _NonCOOArray[_SCT]: ... +@overload # blocks: , format: , dtype: +def block_array(blocks: _ToBlocks, *, format: _FmtNonCOO, dtype: ToDType[_SCT]) -> _NonCOOArray[_SCT]: ... +@overload # blocks: , format: , dtype: +def block_array(blocks: _ToBlocks, *, format: _FmtNonCOO, dtype: npt.DTypeLike) -> _NonCOOArray: ... + +# +@overload # blocks: , format: , dtype: +def bmat(blocks: Seq[Seq[_SpArray[_SCT]]], format: _FmtCOO | None = None, dtype: None = None) -> _COOArray2D[_SCT]: ... +@overload # blocks: , format: , dtype: +def bmat(blocks: Seq[Seq[spmatrix[_SCT]]], format: _FmtCOO | None = None, dtype: None = None) -> coo_matrix[_SCT]: ... +@overload # sparray, blocks: , format: , dtype: (positional) +def bmat(blocks: _ToBlocks, format: _FmtCOO | None, dtype: ToDType[_SCT]) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ... +@overload # sparray, blocks: , format: , dtype: (keyword) +def bmat(blocks: _ToBlocks, format: _FmtCOO | None = None, *, dtype: ToDType[_SCT]) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ... +@overload # sparray, blocks: , format: , dtype: +def bmat(blocks: _ToBlocks, format: _FmtCOO | None = None, dtype: npt.DTypeLike | None = None) -> _COOArray2D | coo_matrix: ... +@overload # sparray, blocks: , format: , dtype: +def bmat(blocks: Seq[Seq[_SpArray[_SCT]]], format: _FmtNonCOO, dtype: None = None) -> _NonCOOArray[_SCT]: ... +@overload # sparray, blocks: , format: , dtype: +def bmat(blocks: Seq[Seq[spmatrix[_SCT]]], format: _FmtNonCOO, dtype: None = None) -> _NonCOOMatrix[_SCT]: ... +@overload # sparray, blocks: , format: , dtype: +def bmat(blocks: _ToBlocks, format: _FmtNonCOO, dtype: ToDType[_SCT]) -> _NonCOOArray[_SCT] | _NonCOOMatrix[_SCT]: ... +@overload # sparray, blocks: , format: , dtype: +def bmat(blocks: _ToBlocks, format: _FmtNonCOO, dtype: npt.DTypeLike) -> _NonCOOArray | _NonCOOMatrix: ... + +# +@overload # mats: +def block_diag( + mats: Iterable[_SpArray[_SCT]], + format: SPFormat | None = None, + dtype: None = None, +) -> _SpArray[_SCT, tuple[int, int]]: ... +@overload # mats: +def block_diag( + mats: Iterable[spmatrix[_SCT]], + format: SPFormat | None = None, + dtype: None = None, +) -> _SpMatrix[_SCT]: ... +@overload # mats: +def block_diag( + mats: Iterable[_spbase[_SCT] | onp.ArrayND[_SCT]], + format: SPFormat | None = None, + dtype: None = None, +) -> _SpArray[_SCT, tuple[int, int]] | _SpMatrix[_SCT]: ... +@overload # mats: , dtype: (positional) +def block_diag( + mats: Iterable[sparray], + format: SPFormat | None, + dtype: ToDType[_SCT], +) -> _SpArray[_SCT, tuple[int, int]]: ... +@overload # mats: , dtype: (keyword) +def block_diag( + mats: Iterable[sparray], + format: SPFormat | None = None, + *, + dtype: ToDType[_SCT], +) -> _SpArray[_SCT, tuple[int, int]]: ... +@overload # mats: , dtype: (positional) +def block_diag( + mats: Iterable[spmatrix | onp.ArrayND[Scalar] | complex | list[onp.ToComplex] | list[onp.ToComplex1D]], + format: SPFormat | None, + dtype: ToDType[_SCT], +) -> _SpMatrix[_SCT]: ... +@overload # mats: , dtype: (keyword) +def block_diag( + mats: Iterable[spmatrix | onp.ArrayND[Scalar] | complex | list[onp.ToComplex] | list[onp.ToComplex1D]], + format: SPFormat | None = None, + *, + dtype: ToDType[_SCT], +) -> _SpMatrix[_SCT]: ... +@overload # mats: , dtype: (positional) +def block_diag( + mats: Iterable[_spbase | onp.ArrayND[Scalar] | complex | list[onp.ToComplex] | list[onp.ToComplex1D]], + format: SPFormat | None, + dtype: ToDType[_SCT], +) -> _SpArray[_SCT, tuple[int, int]] | _SpMatrix[_SCT]: ... +@overload # mats: , dtype: (keyword) +def block_diag( + mats: Iterable[_spbase | onp.ArrayND[Scalar] | complex | list[onp.ToComplex] | list[onp.ToComplex1D]], + format: SPFormat | None = None, + *, + dtype: ToDType[_SCT], +) -> _SpArray[_SCT, tuple[int, int]] | _SpMatrix[_SCT]: ... +@overload # catch-all +def block_diag( + mats: Iterable[_spbase | onp.ArrayND[Scalar] | complex | list[onp.ToComplex] | list[onp.ToComplex1D]], + format: SPFormat | None = None, + dtype: npt.DTypeLike | None = None, +) -> _SpArray[Any, tuple[int, int]] | _SpMatrix[Any]: ... # @overload # shape: 1d, dtype: