Skip to content

Commit

Permalink
Update diagonal.jl
Browse files Browse the repository at this point in the history
replace all `AbstractArray` with `AbstractVecOrMat`
  • Loading branch information
N5N3 authored Sep 22, 2021
1 parent eb4b99a commit 7e0be45
Showing 1 changed file with 27 additions and 29 deletions.
56 changes: 27 additions & 29 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,11 @@ function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta
return C
end

#TODO: many of /, \ related function has no size check and singular check
(/)(A::AbstractVecOrMat, D::Diagonal) =
rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D)
(/)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag)

ldiv!(x::AbstractArray, A::Diagonal, b::AbstractArray) = (x .= A.diag .\ b)

function rdiv!(A::AbstractMatrix, D::Diagonal)
require_one_based_indexing(A)
dd = D.diag
Expand All @@ -339,8 +340,30 @@ function rdiv!(A::AbstractMatrix, D::Diagonal)
A
end

(/)(A::AbstractArray, D::Diagonal) =
rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D)
(\)(D::Diagonal, A::AbstractMatrix) =
ldiv!(D, (typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A))
(\)(D::Diagonal, b::AbstractVector) = D.diag .\ b
(\)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .\ Db.diag)

ldiv!(x::AbstractVecOrMat, A::Diagonal, b::AbstractVecOrMat) = (x .= A.diag .\ b)

function ldiv!(D::Diagonal, B::AbstractVecOrMat)
m, n = size(B, 1), size(B, 2)
if m != length(D.diag)
throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $m rows"))
end
(m == 0 || n == 0) && return B
for j = 1:n
for i = 1:m
di = D.diag[i]
if di == 0
throw(SingularException(i))
end
B[i,j] = di \ B[i,j]
end
end
return B
end

# (l/r)mul!, l/rdiv!, *, / and \ Optimization for AbstractTriangular.
# These functions are generally more efficient if we calculate the whole data field.
Expand Down Expand Up @@ -469,30 +492,6 @@ for f in (:exp, :cis, :log, :sqrt,
@eval $f(D::Diagonal) = Diagonal($f.(D.diag))
end

(\)(D::Diagonal, A::AbstractMatrix) =
ldiv!(D, (typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A))

(\)(D::Diagonal, b::AbstractVector) = D.diag .\ b
(\)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .\ Db.diag)

function ldiv!(D::Diagonal, B::AbstractVecOrMat)
m, n = size(B, 1), size(B, 2)
if m != length(D.diag)
throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $m rows"))
end
(m == 0 || n == 0) && return B
for j = 1:n
for i = 1:m
di = D.diag[i]
if di == 0
throw(SingularException(i))
end
B[i,j] = di \ B[i,j]
end
end
return B
end

function inv(D::Diagonal{T}) where T
Di = similar(D.diag, typeof(inv(zero(T))))
for i = 1:length(D.diag)
Expand Down Expand Up @@ -572,7 +571,6 @@ function _mapreduce_prod(f, x, D::Diagonal, y)
end
end


function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true)
info = 0
for (i, di) in enumerate(A.diag)
Expand Down

0 comments on commit 7e0be45

Please sign in to comment.