Skip to content

Commit

Permalink
Fix performance issue with diagonal multiplication
Browse files Browse the repository at this point in the history
Co-authored-by: Dilum Aluthge <dilum@aluthge.com>
  • Loading branch information
dkarrasch and DilumAluthge committed Mar 18, 2022
1 parent 1e64682 commit 0c620d8
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 39 deletions.
111 changes: 74 additions & 37 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,38 +276,91 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix})
lmul!(D, At)
end

@inline function __muldiag!(out, D::Diagonal, B, alpha, beta)
if iszero(beta)
out .= (D.diag .* B) .*ₛ alpha
# in __muldiag! below we unroll the loops manually, since broadcasting may be unable to
# prove that they are vectorizable
function __muldiag!(out, D::Diagonal, B, alpha, beta)
# TODO: check if this code can be replaced by a single line
# out .= (D.diag .* B) .*ₛ alpha .+ out .*ₛ beta
require_one_based_indexing(out)
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
out .= (D.diag .* B) .*ₛ alpha .+ out .* beta
if iszero(beta)
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
out[i,j] = D.diag[i] * B[i,j] * alpha
end
end
else
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta
end
end
end
end
return out
end

@inline function __muldiag!(out, A, D::Diagonal, alpha, beta)
if iszero(beta)
out .= (A .* permutedims(D.diag)) .*ₛ alpha
function __muldiag!(out, A, D::Diagonal, alpha, beta)
# TODO: check if this code can be replaced by a single line
# out .= (B .* permutedims(D.diag)) .*ₛ alpha .+ out .*ₛ beta
require_one_based_indexing(out)
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta
if iszero(beta)
@inbounds for j in axes(A, 2)
dja = D.diag[j] * alpha
@simd for i in axes(A, 1)
out[i,j] = A[i,j] * dja
end
end
else
@inbounds for j in axes(A, 2)
dja = D.diag[j] * alpha
@simd for i in axes(A, 1)
out[i,j] = A[i,j] * dja + out[i,j] * beta
end
end
end
end
return out
end

@inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta)
if iszero(beta)
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta)
# TODO: check if this code can be replaced by a single line
# out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .*ₛ beta
d1 = D1.diag
d2 = D2.diag
if iszero(alpha)
_rmul_or_fill!(out.diag, beta)
else
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta
if iszero(beta)
@inbounds @simd for i in eachindex(out.diag)
out.diag[i] = d1[i] * d2[i] * alpha
end
else
@inbounds @simd for i in eachindex(out.diag)
out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta
end
end
end
return out
end
function __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta)
require_one_based_indexing(out)
mA = size(D1, 1)
d1 = D1.diag
d2 = D2.diag
_rmul_or_fill!(out, beta)
if !iszero(alpha)
@inbounds @simd for i in 1:mA
out[i,i] += d1[i] * d2[i] * alpha
end
end
return out
end

# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments
@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) =
mul!(out, D1, D2, alpha, beta)

@inline function _muldiag!(out, A, B, alpha, beta)
function _muldiag!(out, A, B, alpha, beta)
_muldiag_size_check(out, A, B)
__muldiag!(out, A, B, alpha, beta)
return out
Expand All @@ -332,24 +385,8 @@ end
@inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
_muldiag!(C, Da, Db, alpha, beta)

function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number)
_muldiag_size_check(C, Da, Db)
require_one_based_indexing(C)
mA = size(Da, 1)
da = Da.diag
db = Db.diag
_rmul_or_fill!(C, beta)
if iszero(beta)
@inbounds @simd for i in 1:mA
C[i,i] = Ref(da[i] * db[i]) .*ₛ alpha
end
else
@inbounds @simd for i in 1:mA
C[i,i] += Ref(da[i] * db[i]) .*ₛ alpha
end
end
return C
end
mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
_muldiag!(C, Da, Db, alpha, beta)

_init(op, A::AbstractArray{<:Number}, B::AbstractArray{<:Number}) =
(_ -> zero(typeof(op(oneunit(eltype(A)), oneunit(eltype(B))))))
Expand Down
3 changes: 1 addition & 2 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
# inside this function.
function *end
Broadcast.broadcasted(::typeof(*ₛ), out, beta) =
iszero(beta::Number) ? false :
isone(beta::Number) ? broadcasted(identity, out) : broadcasted(*, out, beta)
iszero(beta::Number) ? false : broadcasted(*, out, beta)

"""
MulAddMul(alpha, beta)
Expand Down

0 comments on commit 0c620d8

Please sign in to comment.