diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index d849640a351f1..e4d4615c1a2c5 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -316,7 +316,7 @@ function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta return C end -#TODO: many of /, \ related function has no size check and singular check +#TODO: many of /, \ related function has no singular check (/)(A::AbstractVecOrMat, D::Diagonal) = rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D) (/)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag) @@ -345,6 +345,7 @@ end (\)(D::Diagonal, b::AbstractVector) = D.diag .\ b (\)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .\ Db.diag) +#TODO: we should check size(x,2) == size(b,2) ldiv!(x::AbstractVecOrMat, A::Diagonal, b::AbstractVecOrMat) = (x .= A.diag .\ b) function ldiv!(D::Diagonal, B::AbstractVecOrMat) @@ -548,11 +549,17 @@ function svd(D::Diagonal{<:Number}) return SVD(Up, S[piv], copy(Vp')) end -# disambiguation methods: * of Diagonal and Adj/Trans AbsVec -*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x))) -*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x))) -*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y) -*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y) +# disambiguation methods: * and / of Diagonal and Adj/Trans AbsVec +*(x::AdjointAbsVec, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x))) +*(x::TransposeAbsVec, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x))) +*(x::AdjointAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y) +*(x::TransposeAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y) +/(u::AdjointAbsVec, D::Diagonal) = adjoint(adjoint(D) \ u.parent) +/(u::TransposeAbsVec, D::Diagonal) = transpose(transpose(D) \ u.parent) +# disambiguation methods: Call unoptimized version for user defined AbstractTriangular. +*(A::AbstractTriangular, D::Diagonal) = Base.@invoke *(A::AbstractMatrix, D::Diagonal) +*(D::Diagonal, A::AbstractTriangular) = Base.@invoke *(D::Diagonal, A::AbstractMatrix) + dot(x::AbstractVector, D::Diagonal, y::AbstractVector) = _mapreduce_prod(dot, x, D, y) dot(A::Diagonal, B::Diagonal) = dot(A.diag, B.diag) @@ -620,4 +627,4 @@ end function Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal) Diagonal(A.diag .* B.diag .+ z.diag) -end +end \ No newline at end of file