Skip to content

Commit

Permalink
Add *(::Diagonal, ::Diagonal, ::Diagonal) (#49005) (#49007)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty authored Mar 16, 2023
1 parent 669d6ca commit c37fc27
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
return broadcast(*, Da.diag, A, permutedims(Db.diag))
end

function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal)
_muldiag_size_check(Da, Db)
_muldiag_size_check(Db, Dc)
return Diagonal(Da.diag .* Db.diag .* Dc.diag)
end

# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
@inline mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
_muldiag!(out, D, V, alpha, beta)
Expand Down
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1155,4 +1155,15 @@ end
@test all(isone, diag(D))
end

@testset "diagonal triple multiplication (#49005)" begin
n = 10
@test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n))) isa Diagonal
@test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n+1))))
@test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n+1), Diagonal(ones(n+1))))
@test_throws DimensionMismatch (*(Diagonal(ones(n+1)), Diagonal(1:n), Diagonal(ones(n))))

# currently falls back to two-term *
@test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n)), Diagonal(1:n)) isa Diagonal
end

end # module TestDiagonal

0 comments on commit c37fc27

Please sign in to comment.