From ec88c8e77ffc13e5574deebcfbb31c2f0d01a807 Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Thu, 27 Apr 2017 16:11:48 +0200 Subject: [PATCH] unify methods for cholesky --- base/linalg/cholesky.jl | 45 +++++++++-------------------------------- 1 file changed, 9 insertions(+), 36 deletions(-) diff --git a/base/linalg/cholesky.jl b/base/linalg/cholesky.jl index f43e9364db459..a3bdff32fe5d5 100644 --- a/base/linalg/cholesky.jl +++ b/base/linalg/cholesky.jl @@ -16,10 +16,7 @@ # supported for the four LAPACK element types. For other types, e.g. BigFloats Val{true} will # give an error. It is required that the input is Hermitian (including real symmetric) either # through the Hermitian and Symmetric views or exact symmetric or Hermitian elements which -# is checked for and an error is thrown if the check fails. The dispatch -# is further complicated by a limitation in the formulation of Unions. The relevant union -# would be Union{Symmetric{T<:Real,S}, Hermitian} but, right now, it doesn't work in Julia -# so we'll have to define methods for the two elements of the union separately. +# is checked for and an error is thrown if the check fails. # FixMe? The dispatch below seems overly complicated. One simplification could be to # merge the two Cholesky types into one. It would remove the need for Val completely but @@ -121,9 +118,7 @@ non_hermitian_error(f) = throw(ArgumentError("matrix is not symmetric/" * # chol!. Destructive methods for computing Cholesky factor of real symmetric or Hermitian # matrix -chol!(A::Hermitian) = - _chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular) -chol!(A::Symmetric{<:Real,<:StridedMatrix}) = +chol!(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) = _chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular) function chol!(A::StridedMatrix) ishermitian(A) || non_hermitian_error("chol!") @@ -134,7 +129,7 @@ end # chol. Non-destructive methods for computing Cholesky factor of a real symmetric or # Hermitian matrix. Promotes elements to a type that is stable under square roots. -function chol(A::Hermitian) +function chol(A::RealHermSymComplexHerm) T = promote_type(typeof(chol(one(eltype(A)))), Float32) AA = similar(A, T, size(A)) if A.uplo == 'U' @@ -144,16 +139,6 @@ function chol(A::Hermitian) end chol!(Hermitian(AA, :U)) end -function chol(A::Symmetric{T,<:AbstractMatrix}) where T<:Real - TT = promote_type(typeof(chol(one(T))), Float32) - AA = similar(A, TT, size(A)) - if A.uplo == 'U' - copy!(AA, A.data) - else - Base.ctranspose!(AA, A.data) - end - chol!(Hermitian(AA, :U)) -end ## for StridedMatrices, check that matrix is symmetric/Hermitian """ @@ -206,14 +191,7 @@ chol(x::Number, args...) = _chol!(x, nothing) # cholfact!. Destructive methods for computing Cholesky factorization of real symmetric # or Hermitian matrix ## No pivoting -function cholfact!(A::Hermitian, ::Type{Val{false}}) - if A.uplo == 'U' - Cholesky(_chol!(A.data, UpperTriangular).data, 'U') - else - Cholesky(_chol!(A.data, LowerTriangular).data, 'L') - end -end -function cholfact!(A::Symmetric{<:Real}, ::Type{Val{false}}) +function cholfact!(A::RealHermSymComplexHerm, ::Type{Val{false}}) if A.uplo == 'U' Cholesky(_chol!(A.data, UpperTriangular).data, 'U') else @@ -248,8 +226,8 @@ function cholfact!(A::StridedMatrix, uplo::Symbol, ::Type{Val{false}}) end ### Default to no pivoting (and storing of upper factor) when not explicit -cholfact!(A::Hermitian) = cholfact!(A, Val{false}) -cholfact!(A::Symmetric{<:Real}) = cholfact!(A, Val{false}) +cholfact!(A::RealHermSymComplexHerm) = cholfact!(A, Val{false}) + #### for StridedMatrices, check that matrix is symmetric/Hermitian function cholfact!(A::StridedMatrix, uplo::Symbol = :U) ishermitian(A) || non_hermitian_error("cholfact!") @@ -288,9 +266,7 @@ end # cholfact. Non-destructive methods for computing Cholesky factorization of real symmetric # or Hermitian matrix ## No pivoting -cholfact(A::Hermitian, ::Type{Val{false}}) = - cholfact!(copy_oftype(A, promote_type(typeof(chol(one(eltype(A)))),Float32)), Val{false}) -cholfact(A::Symmetric{<:Real,<:StridedMatrix}, ::Type{Val{false}}) = +cholfact(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}, ::Type{Val{false}}) = cholfact!(copy_oftype(A, promote_type(typeof(chol(one(eltype(A)))),Float32)), Val{false}) ### for StridedMatrices, check that matrix is symmetric/Hermitian @@ -342,8 +318,8 @@ function cholfact(A::StridedMatrix, uplo::Symbol, ::Type{Val{false}}) end ### Default to no pivoting (and storing of upper factor) when not explicit -cholfact(A::Hermitian) = cholfact(A, Val{false}) -cholfact(A::Symmetric{<:Real,<:StridedMatrix}) = cholfact(A, Val{false}) +cholfact(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) = cholfact(A, Val{false}) + #### for StridedMatrices, check that matrix is symmetric/Hermitian function cholfact(A::StridedMatrix, uplo::Symbol = :U) ishermitian(A) || non_hermitian_error("cholfact") @@ -352,9 +328,6 @@ end ## With pivoting -cholfact(A::Hermitian, ::Type{Val{true}}; tol = 0.0) = - cholfact!(copy_oftype(A, promote_type(typeof(chol(one(eltype(A)))),Float32)), - Val{true}; tol = tol) cholfact(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}, ::Type{Val{true}}; tol = 0.0) = cholfact!(copy_oftype(A, promote_type(typeof(chol(one(eltype(A)))),Float32)), Val{true}; tol = tol)