Skip to content

Commit

Permalink
sparse: complete [sp]diags[_array] (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham authored Dec 15, 2024
1 parent aac20bc commit d9e34eb
Showing 1 changed file with 267 additions and 23 deletions.
290 changes: 267 additions & 23 deletions scipy-stubs/sparse/_construct.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ _SCT1 = TypeVar("_SCT1", bound=Scalar)
_SCT2 = TypeVar("_SCT2", bound=Scalar)
_ShapeT = TypeVar("_ShapeT", bound=tuple[int] | tuple[int, int], default=tuple[int] | tuple[int, int])

_ToMatrix: TypeAlias = spmatrix[_SCT] | Seq[Seq[_SCT] | onp.CanArrayND[_SCT]] | onp.CanArrayND[_SCT]
_ToArray1D: TypeAlias = Seq[_SCT] | onp.CanArrayND[_SCT]
_ToArray2D: TypeAlias = Seq[Seq[_SCT] | onp.CanArrayND[_SCT]] | onp.CanArrayND[_SCT]
_ToSpMatrix: TypeAlias = spmatrix[_SCT] | _ToArray2D[_SCT]

_SpMatrix: TypeAlias = (
bsr_matrix[_SCT]
Expand Down Expand Up @@ -82,12 +84,23 @@ _NonCSRArray: TypeAlias = (
| dok_array[_SCT, tuple[int, int]]
| lil_array[_SCT]
)
_NonDIAArray: TypeAlias = (
bsr_array[_SCT]
| coo_array[_SCT, tuple[int, int]]
| csc_array[_SCT]
| csr_array[_SCT, tuple[int, int]]
| dok_array[_SCT, tuple[int, int]]
| lil_array[_SCT]
)
_NonBSRMatrix: TypeAlias = (
coo_matrix[_SCT] | csr_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)
_NonCSRMatrix: TypeAlias = (
bsr_matrix[_SCT] | coo_matrix[_SCT] | csc_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)
_NonDIAMatrix: TypeAlias = (
bsr_matrix[_SCT] | coo_matrix[_SCT] | csc_matrix[_SCT] | csr_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
)

_SpMatrixOut: TypeAlias = coo_matrix[_SCT] | csc_matrix[_SCT] | csr_matrix[_SCT]
_SpMatrixNonOut: TypeAlias = bsr_matrix[_SCT] | dia_matrix[_SCT] | dok_matrix[_SCT] | lil_matrix[_SCT]
Expand All @@ -106,32 +119,263 @@ _FmtNonOut: TypeAlias = Literal["bsr", "dia", "dok", "lil"]
###

#
def spdiags(
data: Untyped,
diags: Untyped,
m: Untyped | None = None,
n: Untyped | None = None,
format: SPFormat | None = None,
) -> Untyped: ...

#
@overload # diagonals: <known>, dtype: None = ..., format: {"dia", None} = ...
def diags_array(
diagonals: Untyped,
diagonals: _ToArray1D[_SCT] | _ToArray2D[_SCT],
/,
*,
offsets: int = 0,
shape: ToShape | None = None,
format: SPFormat | None = None,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtDIA | None = None,
dtype: None = None,
) -> dia_array[_SCT]: ...
@overload # diagonals: <known>, dtype: None = ..., format: <otherwise>
def diags_array(
diagonals: _ToArray1D[_SCT] | _ToArray2D[_SCT],
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtNonDIA,
dtype: None = None,
) -> _NonDIAArray[_SCT]: ...
@overload # diagonals: <unknown>, format: {"dia", None} = ..., dtype: int
def diags_array(
diagonals: onp.ToFloat1D | onp.ToFloat2D,
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtDIA | None = None,
dtype: ToDTypeInt,
) -> dia_array[np.int_]: ...
@overload # diagonals: <unknown>, format: <otherwise>, dtype: int
def diags_array(
diagonals: onp.ToFloat1D | onp.ToFloat2D,
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtNonDIA,
dtype: ToDTypeInt,
) -> _NonDIAArray[np.int_]: ...
@overload # diagonals: <unknown>, format: {"dia", None} = ..., dtype: float
def diags_array(
diagonals: onp.ToFloat1D | onp.ToFloat2D,
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtDIA | None = None,
dtype: ToDTypeFloat,
) -> dia_array[np.float64]: ...
@overload # diagonals: <unknown>, format: <otherwise>, dtype: float
def diags_array(
diagonals: onp.ToFloat1D | onp.ToFloat2D,
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtNonDIA,
dtype: ToDTypeFloat,
) -> _NonDIAArray[np.float64]: ...
@overload # diagonals: <unknown>, format: {"dia", None} = ..., dtype: complex
def diags_array(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtDIA | None = None,
dtype: ToDTypeComplex,
) -> dia_array[np.complex128]: ...
@overload # diagonals: <unknown>, format: <otherwise>, dtype: complex
def diags_array(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtNonDIA,
dtype: ToDTypeComplex,
) -> _NonDIAArray[np.complex128]: ...
@overload # diagonals: <unknown>, format: {"dia", None} = ..., dtype: <known>
def diags_array(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtDIA | None = None,
dtype: ToDType[_SCT],
) -> dia_array[_SCT]: ...
@overload # diagonals: <unknown>, format: <otherwise>, dtype: <known>
def diags_array(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtNonDIA,
dtype: ToDType[_SCT],
) -> _NonDIAArray[_SCT]: ...
@overload # diagonals: <unknown>, format: {"dia", None} = ..., dtype: <unknown>
def diags_array(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtDIA | None = None,
dtype: npt.DTypeLike | None = None,
) -> _SpArray1D | _SpArray2D: ...
) -> dia_array: ...
@overload # diagonals: <unknown>, format: <otherwise>, dtype: <unknown>
def diags_array(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
/,
*,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtNonDIA,
dtype: npt.DTypeLike | None = None,
) -> _NonDIAArray: ...

#
# NOTE: `diags_array` should be prefered over `diags`
@overload # diagonals: <known>, format: {"dia", None} = ...
def diags(
diagonals: Untyped,
offsets: int = 0,
diagonals: _ToArray1D[_SCT] | _ToArray2D[_SCT],
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: SPFormat | None = None,
format: _FmtDIA | None = None,
dtype: ToDType[_SCT] | None = None,
) -> dia_array[_SCT]: ...
@overload # diagonals: <known>, format: <otherwise> (positional)
def diags(
diagonals: _ToArray1D[_SCT] | _ToArray2D[_SCT],
offsets: onp.ToInt | onp.ToInt1D,
shape: ToShape2D | None,
format: _FmtNonDIA,
dtype: ToDType[_SCT] | None = None,
) -> _NonDIAArray[_SCT]: ...
@overload # diagonals: <known>, format: <otherwise> (keyword)
def diags(
diagonals: _ToArray1D[_SCT] | _ToArray2D[_SCT],
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
*,
format: _FmtNonDIA,
dtype: ToDType[_SCT] | None = None,
) -> _NonDIAArray[_SCT]: ...
@overload # diagonals: <unknown>, format: {"dia", None} = ..., dtype: <known> (positional)
def diags(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
offsets: onp.ToInt | onp.ToInt1D,
shape: ToShape2D | None,
format: _FmtDIA | None,
dtype: ToDType[_SCT],
) -> dia_array[_SCT]: ...
@overload # diagonals: <unknown>, format: {"dia", None} = ..., dtype: <known> (keyword)
def diags(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtDIA | None = None,
*,
dtype: ToDType[_SCT],
) -> dia_array[_SCT]: ...
@overload # diagonals: <unknown>, format: <otherwise> (positional), dtype: <known>
def diags(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
offsets: onp.ToInt | onp.ToInt1D,
shape: ToShape2D | None,
format: _FmtNonDIA,
dtype: ToDType[_SCT],
) -> _NonDIAArray[_SCT]: ...
@overload # diagonals: <unknown>, format: <otherwise> (keyword), dtype: <known>
def diags(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
*,
format: _FmtNonDIA,
dtype: ToDType[_SCT],
) -> _NonDIAArray[_SCT]: ...
@overload # diagonals: <unknown>, format: {"dia", None} = ..., dtype: <unknown>
def diags(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
format: _FmtDIA | None = None,
dtype: npt.DTypeLike | None = None,
) -> dia_array: ...
@overload # diagonals: <unknown>, format: <otherwise> (positional), dtype: <unknown>
def diags(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
offsets: onp.ToInt | onp.ToInt1D,
shape: ToShape2D | None,
format: _FmtNonDIA,
dtype: npt.DTypeLike | None = None,
) -> _NonDIAArray: ...
@overload # diagonals: <unknown>, format: <otherwise> (keyword), dtype: <unknown>
def diags(
diagonals: onp.ToComplex1D | onp.ToComplex2D,
offsets: onp.ToInt | onp.ToInt1D = 0,
shape: ToShape2D | None = None,
*,
format: _FmtNonDIA,
dtype: npt.DTypeLike | None = None,
) -> _NonDIAArray: ...

# NOTE: `diags_array` should be prefered over `spdiags`
@overload
def spdiags(
data: _ToArray1D[_SCT] | _ToArray2D[_SCT],
diags: onp.ToInt | onp.ToInt1D,
m: onp.ToJustInt,
n: onp.ToJustInt,
format: _FmtDIA | None = None,
) -> dia_matrix[_SCT]: ...
@overload
def spdiags(
data: _ToArray1D[_SCT] | _ToArray2D[_SCT],
diags: onp.ToInt | onp.ToInt1D,
m: tuple[onp.ToJustInt, onp.ToJustInt] | None = None,
n: None = None,
format: _FmtDIA | None = None,
) -> dia_matrix[_SCT]: ...
@overload
def spdiags(
data: _ToArray1D[_SCT] | _ToArray2D[_SCT],
diags: onp.ToInt | onp.ToInt1D,
m: onp.ToJustInt,
n: onp.ToJustInt,
format: _FmtNonDIA,
) -> _NonDIAMatrix[_SCT]: ...
@overload
def spdiags(
data: _ToArray1D[_SCT] | _ToArray2D[_SCT],
diags: onp.ToInt | onp.ToInt1D,
m: tuple[onp.ToJustInt, onp.ToJustInt] | None = None,
n: None = None,
*,
format: _FmtNonDIA,
) -> _NonDIAMatrix[_SCT]: ...
@overload
def spdiags(
data: Seq[complex] | Seq[Seq[complex]],
diags: onp.ToInt | onp.ToInt1D,
m: onp.ToJustInt,
n: onp.ToJustInt,
format: SPFormat | None = None,
) -> _SpMatrix: ...
@overload
def spdiags(
data: Seq[complex] | Seq[Seq[complex]],
diags: onp.ToInt | onp.ToInt1D,
m: tuple[onp.ToJustInt, onp.ToJustInt] | None = None,
n: None = None,
*,
format: SPFormat | None = None,
) -> _SpMatrix: ...

#
Expand Down Expand Up @@ -369,9 +613,9 @@ def eye(

#
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"bsr", None} = ...
def kron(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtBSR | None = None) -> bsr_matrix[_SCT1 | _SCT2]: ...
def kron(A: _ToSpMatrix[_SCT1], B: _ToSpMatrix[_SCT2], format: _FmtBSR | None = None) -> bsr_matrix[_SCT1 | _SCT2]: ...
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
def kron(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtNonBSR) -> _NonBSRMatrix[_SCT1 | _SCT2]: ...
def kron(A: _ToSpMatrix[_SCT1], B: _ToSpMatrix[_SCT2], format: _FmtNonBSR) -> _NonBSRMatrix[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: {"bsr", None} = ...
def kron(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtBSR | None = None) -> _BSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: <otherwise>
Expand All @@ -389,9 +633,9 @@ def kron(

#
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"csr", None} = ...
def kronsum(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtCSR | None = None) -> csr_matrix[_SCT1 | _SCT2]: ...
def kronsum(A: _ToSpMatrix[_SCT1], B: _ToSpMatrix[_SCT2], format: _FmtCSR | None = None) -> csr_matrix[_SCT1 | _SCT2]: ...
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
def kronsum(A: _ToMatrix[_SCT1], B: _ToMatrix[_SCT2], format: _FmtNonCSR) -> _NonCSRMatrix[_SCT1 | _SCT2]: ...
def kronsum(A: _ToSpMatrix[_SCT1], B: _ToSpMatrix[_SCT2], format: _FmtNonCSR) -> _NonCSRMatrix[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: {"csr", None} = ...
def kronsum(A: _SpArray[_SCT1], B: _spbase[_SCT2], format: _FmtCSR | None = None) -> _CSRArray[_SCT1 | _SCT2]: ...
@overload # A: sparray, B: sparse, format: <otherwise>
Expand Down

0 comments on commit d9e34eb

Please sign in to comment.