From 0b635925272c59aa93bc3e9828717e2aa37221e4 Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu Date: Sun, 15 Dec 2024 02:51:18 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20`sparse`:=20complete=20`kron[sum]`?= =?UTF-8?q?=20(#311)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scipy-stubs/sparse/_construct.pyi | 101 ++++++++++++++++++++++++------ 1 file changed, 82 insertions(+), 19 deletions(-) diff --git a/scipy-stubs/sparse/_construct.pyi b/scipy-stubs/sparse/_construct.pyi index 9574a9d9..c217ed59 100644 --- a/scipy-stubs/sparse/_construct.pyi +++ b/scipy-stubs/sparse/_construct.pyi @@ -1,25 +1,20 @@ +from collections.abc import Sequence from typing import Any, Literal, TypeAlias, TypeVar, overload import numpy as np import numpy.typing as npt +import optype.numpy as onp import optype.typing as opt from scipy._typing import Seed, Untyped, UntypedCallable -from scipy.sparse import ( - bsr_array, - bsr_matrix, - coo_array, - coo_matrix, - csc_array, - csc_matrix, - csr_array, - csr_matrix, - dia_array, - dia_matrix, - dok_array, - dok_matrix, - lil_array, - lil_matrix, -) +from ._base import _spbase +from ._bsr import bsr_array, bsr_matrix +from ._coo import coo_array, coo_matrix +from ._csc import csc_array, csc_matrix +from ._csr import csr_array, csr_matrix +from ._dia import dia_array, dia_matrix +from ._dok import dok_array, dok_matrix +from ._lil import lil_array, lil_matrix +from ._matrix import spmatrix from ._typing import Float, Scalar, SPFormat, ToDType, ToDTypeBool, ToDTypeComplex, ToDTypeFloat, ToDTypeInt, ToShape, ToShape2D __all__ = [ @@ -42,8 +37,12 @@ __all__ = [ ] _SCT = TypeVar("_SCT", bound=Scalar, default=Any) +_SCT1 = TypeVar("_SCT1", bound=Scalar) +_SCT2 = TypeVar("_SCT2", bound=Scalar) _ShapeT = TypeVar("_ShapeT", bound=tuple[int] | tuple[int, int], default=tuple[int] | tuple[int, int]) +_ToMatrix: TypeAlias = spmatrix[_SCT] | Sequence[Sequence[_SCT] | onp.CanArrayND[_SCT]] | onp.CanArrayND[_SCT] + _SpMatrix: TypeAlias = ( bsr_matrix[_SCT] | coo_matrix[_SCT] @@ -53,7 +52,6 @@ _SpMatrix: TypeAlias = ( | dok_matrix[_SCT] | lil_matrix[_SCT] ) - _SpArray: TypeAlias = ( bsr_array[_SCT] | coo_array[_SCT, _ShapeT] @@ -66,7 +64,36 @@ _SpArray: TypeAlias = ( _SpArray1D: TypeAlias = coo_array[_SCT, tuple[int]] | csr_array[_SCT, tuple[int]] | dok_array[_SCT, tuple[int]] _SpArray2D: TypeAlias = _SpArray[_SCT, tuple[int, int]] +_BSRArray: TypeAlias = bsr_array[_SCT] +_CSRArray: TypeAlias = csr_array[_SCT, tuple[int, int]] +_NonBSRArray: TypeAlias = ( + coo_array[_SCT, tuple[int, int]] + | csc_array[_SCT] + | csr_array[_SCT, tuple[int, int]] + | dia_array[_SCT] + | dok_array[_SCT, tuple[int, int]] + | lil_array[_SCT] +) +_NonCSRArray: TypeAlias = ( + bsr_array[_SCT] + | coo_array[_SCT, tuple[int, int]] + | csc_array[_SCT] + | dia_array[_SCT] + | dok_array[_SCT, tuple[int, int]] + | lil_array[_SCT] +) +_NonBSRMatrix: TypeAlias = ( + coo_matrix[_SCT] | csr_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT] +) +_NonCSRMatrix: TypeAlias = ( + bsr_matrix[_SCT] | coo_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT] +) + +_FmtBSR: TypeAlias = Literal["bsr"] +_FmtCSR: TypeAlias = Literal["csr"] _FmtDIA: TypeAlias = Literal["dia"] +_FmtNonBSR: TypeAlias = Literal["coo", "csc", "csr", "dia", "dok", "lil"] +_FmtNonCSR: TypeAlias = Literal["bsr", "coo", "csc", "dia", "dok", "lil"] _FmtNonDIA: TypeAlias = Literal["bsr", "coo", "csc", "csr", "dok", "lil"] ### @@ -334,8 +361,44 @@ def eye( ) -> _SpMatrix: ... # -def kron(A: Untyped, B: Untyped, format: SPFormat | None = None) -> _SpArray2D | _SpMatrix: ... -def kronsum(A: Untyped, B: Untyped, format: SPFormat | None = None) -> _SpArray2D | _SpMatrix: ... +@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"bsr", None} = ... +def kron(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtBSR | None = None) -> bsr_matrix[_SCT1 | _SCT2]: ... +@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: +def kron(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtNonBSR) -> _NonBSRMatrix[_SCT1 | _SCT2]: ... +@overload # A: sparray, B: sparse, format: {"bsr", None} = ... +def kron(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtBSR | None = None) -> _BSRArray[_SCT1 | _SCT2]: ... +@overload # A: sparray, B: sparse, format: +def kron(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtNonBSR) -> _NonBSRArray[_SCT1 | _SCT2]: ... +@overload # A: sparse, B: sparray, format: {"bsr", None} = ... +def kron(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtBSR | None = None) -> _BSRArray[_SCT1 | _SCT2]: ... +@overload # A: sparse, B: sparray, format: +def kron(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtNonBSR) -> _NonBSRArray[_SCT1 | _SCT2]: ... +@overload # A: unknown array-like, B: unknown array-like (catch-all) +def kron( + A: onp.ToComplex2D | _spbase[_SCT], + B: onp.ToComplex2D | _spbase[_SCT], + format: SPFormat | None = None, +) -> _SpArray2D[_SCT] | _SpMatrix[_SCT]: ... + +# +@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"csr", None} = ... +def kronsum(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtCSR | None = None) -> csr_matrix[_SCT1 | _SCT2]: ... +@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: +def kronsum(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtNonCSR) -> _NonCSRMatrix[_SCT1 | _SCT2]: ... +@overload # A: sparray, B: sparse, format: {"csr", None} = ... +def kronsum(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtCSR | None = None) -> _CSRArray[_SCT1 | _SCT2]: ... +@overload # A: sparray, B: sparse, format: +def kronsum(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtNonCSR) -> _NonCSRArray[_SCT1 | _SCT2]: ... +@overload # A: sparse, B: sparray, format: {"csr", None} = ... +def kronsum(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtCSR | None = None) -> _CSRArray[_SCT1 | _SCT2]: ... +@overload # A: sparse, B: sparray, format: +def kronsum(A: _spbase[_SCT1], B: _SpArray[_SCT2], format: _FmtNonCSR) -> _NonCSRArray[_SCT1 | _SCT2]: ... +@overload # A: unknown array-like, B: unknown array-like (catch-all) +def kronsum( + A: onp.ToComplex2D | _spbase[_SCT], + B: onp.ToComplex2D | _spbase[_SCT], + format: SPFormat | None = None, +) -> _SpArray2D[_SCT] | _SpMatrix[_SCT]: ... # def hstack(blocks: Untyped, format: SPFormat | None = None, dtype: npt.DTypeLike | None = None) -> _SpArray | _SpMatrix: ...