Skip to content

Commit

Permalink
Generalize Diagonal * AdjOrTransAbsMat to arbitrary element types (#5…
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored Dec 5, 2023
1 parent 8d0eec9 commit c1ca0d3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
9 changes: 0 additions & 9 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,6 @@ end
rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)

function (*)(A::AdjOrTransAbsMat, D::Diagonal)
Ac = copy_similar(A, promote_op(*, eltype(A), eltype(D.diag)))
rmul!(Ac, D)
end
function (*)(D::Diagonal, A::AdjOrTransAbsMat)
Ac = copy_similar(A, promote_op(*, eltype(A), eltype(D.diag)))
lmul!(D, Ac)
end

function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
require_one_based_indexing(out, B)
alpha, beta = _add.alpha, _add.beta
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ using .Main.InfiniteArrays
isdefined(Main, :FillArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "FillArrays.jl"))
using .Main.FillArrays

isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl"))
using .Main.SizedArrays

const n=12 # Size of matrix problem to test
Random.seed!(1)

Expand Down Expand Up @@ -778,6 +781,11 @@ end
D = Diagonal(fill(M, n))
@test D == Matrix{eltype(D)}(D)
end

S = SizedArray{(2,3)}(reshape([1:6;],2,3))
D = Diagonal(fill(S,3))
@test D * fill(S,2,3)' == fill(S * S', 3, 2)
@test fill(S,3,2)' * D == fill(S' * S, 2, 3)
end

@testset "Eigensystem for block diagonal (issue #30681)" begin
Expand Down
13 changes: 11 additions & 2 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ module SizedArrays

import Base: +, *, ==

using LinearAlgebra

export SizedArray

struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
Expand All @@ -31,9 +33,16 @@ Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data
function *(S1::SizedArray, S2::SizedArray)

const SizedArrayLike = Union{SizedArray, Transpose{<:Any, <:SizedArray}, Adjoint{<:Any, <:SizedArray}}

_data(S::SizedArray) = S.data
_data(T::Transpose{<:Any, <:SizedArray}) = transpose(_data(parent(T)))
_data(T::Adjoint{<:Any, <:SizedArray}) = adjoint(_data(parent(T)))

function *(S1::SizedArrayLike, S2::SizedArrayLike)
0 < ndims(S1) < 3 && 0 < ndims(S2) < 3 && size(S1, 2) == size(S2, 1) || throw(ArgumentError("size mismatch!"))
data = S1.data * S2.data
data = _data(S1) * _data(S2)
SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2))
SizedArray{SZ}(data)
end
Expand Down

0 comments on commit c1ca0d3

Please sign in to comment.