Skip to content

Commit

Permalink
Merge pull request #11 from N5N3/Tri_Diag_Unify
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
N5N3 authored Sep 23, 2021
2 parents 7e0be45 + b519518 commit 272278b
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta
return C
end

#TODO: many of /, \ related function has no size check and singular check
#TODO: many of /, \ related function has no singular check
(/)(A::AbstractVecOrMat, D::Diagonal) =
rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D)
(/)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag)
Expand Down Expand Up @@ -345,6 +345,7 @@ end
(\)(D::Diagonal, b::AbstractVector) = D.diag .\ b
(\)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .\ Db.diag)

#TODO: we should check size(x,2) == size(b,2)
ldiv!(x::AbstractVecOrMat, A::Diagonal, b::AbstractVecOrMat) = (x .= A.diag .\ b)

function ldiv!(D::Diagonal, B::AbstractVecOrMat)
Expand Down Expand Up @@ -548,11 +549,17 @@ function svd(D::Diagonal{<:Number})
return SVD(Up, S[piv], copy(Vp'))
end

# disambiguation methods: * of Diagonal and Adj/Trans AbsVec
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
# disambiguation methods: * and / of Diagonal and Adj/Trans AbsVec
*(x::AdjointAbsVec, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
*(x::TransposeAbsVec, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
*(x::AdjointAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
*(x::TransposeAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
/(u::AdjointAbsVec, D::Diagonal) = adjoint(adjoint(D) \ u.parent)
/(u::TransposeAbsVec, D::Diagonal) = transpose(transpose(D) \ u.parent)
# disambiguation methods: Call unoptimized version for user defined AbstractTriangular.
*(A::AbstractTriangular, D::Diagonal) = Base.@invoke *(A::AbstractMatrix, D::Diagonal)
*(D::Diagonal, A::AbstractTriangular) = Base.@invoke *(D::Diagonal, A::AbstractMatrix)

dot(x::AbstractVector, D::Diagonal, y::AbstractVector) = _mapreduce_prod(dot, x, D, y)

dot(A::Diagonal, B::Diagonal) = dot(A.diag, B.diag)
Expand Down Expand Up @@ -620,4 +627,4 @@ end

function Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal)
Diagonal(A.diag .* B.diag .+ z.diag)
end
end

0 comments on commit 272278b

Please sign in to comment.