Skip to content

Commit

Permalink
alg keyword for LinearAlgebra.svd
Browse files Browse the repository at this point in the history
  • Loading branch information
carstenbauer committed Jun 24, 2019
1 parent f6049d6 commit aae8959
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 16 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Standard library changes
* The BLAS submodule no longer exports `dot`, which conflicts with that in LinearAlgebra ([#31838]).
* `diagm` and `spdiagm` now accept optional `m,n` initial arguments to specify a size ([#31654]).
* `Hessenberg` factorizations `H` now support efficient shifted solves `(H+µI) \ b` and determinants, and use a specialized tridiagonal factorization for Hermitian matrices. There is also a new `UpperHessenberg` matrix type ([#31853]).
* Added keyword argument `alg` to `svd` and `svd!` that allows one to switch between different SVD algorithms ([#31057]).

#### SparseArrays

Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ function svd!(M::Bidiagonal{<:BlasReal}; full::Bool = false)
d, e, U, Vt, Q, iQ = LAPACK.bdsdc!(M.uplo, 'I', M.dv, M.ev)
SVD(U, d, Vt)
end
function svd(M::Bidiagonal; full::Bool = false)
svd!(copy(M), full = full)
function svd(M::Bidiagonal; kw...)
svd!(copy(M), kw...)
end

####################
Expand Down
42 changes: 29 additions & 13 deletions stdlib/LinearAlgebra/src/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,20 @@ function SVD{T}(U::AbstractArray, S::AbstractVector{Tr}, Vt::AbstractArray) wher
convert(AbstractArray{T}, Vt))
end


abstract type SVDAlgorithm end
struct DivideAndConquer <: SVDAlgorithm end
struct Simple <: SVDAlgorithm end


# iteration for destructuring into components
Base.iterate(S::SVD) = (S.U, Val(:S))
Base.iterate(S::SVD, ::Val{:S}) = (S.S, Val(:V))
Base.iterate(S::SVD, ::Val{:V}) = (S.V, Val(:done))
Base.iterate(S::SVD, ::Val{:done}) = nothing

"""
svd!(A; full::Bool = false) -> SVD
svd!(A; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) -> SVD
`svd!` is the same as [`svd`](@ref), but saves space by
overwriting the input `A`, instead of creating a copy.
Expand Down Expand Up @@ -92,18 +98,25 @@ julia> A
0.0 0.0 -2.0 0.0 0.0
```
"""
function svd!(A::StridedMatrix{T}; full::Bool = false) where T<:BlasFloat
function svd!(A::StridedMatrix{T}; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) where T<:BlasFloat
m,n = size(A)
if m == 0 || n == 0
u,s,vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T,0)), Matrix{T}(I, n, n))
else
u,s,vt = LAPACK.gesdd!(full ? 'A' : 'S', A)
if typeof(alg) == DivideAndConquer
u,s,vt = LAPACK.gesdd!(full ? 'A' : 'S', A)
elseif typeof(alg) == Simple
c = full ? 'A' : 'S'
u,s,vt = LAPACK.gesvd!(c, c, A)
else
throw(ArgumentError("Unsupported value for `alg` keyword."))
end
end
SVD(u,s,vt)
end

"""
svd(A; full::Bool = false) -> SVD
svd(A; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) -> SVD
Compute the singular value decomposition (SVD) of `A` and return an `SVD` object.
Expand All @@ -120,6 +133,9 @@ and `V` is `N \\times N`, while in the thin factorization `U` is `M
\\times K` and `V` is `N \\times K`, where `K = \\min(M,N)` is the
number of singular values.
If `alg = DivideAndConquer()` (default) a divide-and-conquer algorithm is used to calculate the SVD.
One can set `alg = Simple()` to use a simple (typically slower) algorithm instead.
# Examples
```jldoctest
julia> A = [1. 0. 0. 0. 2.; 0. 0. 3. 0. 0.; 0. 0. 0. 0. 0.; 0. 2. 0. 0. 0.]
Expand All @@ -144,21 +160,21 @@ julia> u == F.U && s == F.S && v == F.V
true
```
"""
function svd(A::StridedVecOrMat{T}; full::Bool = false) where T
svd!(copy_oftype(A, eigtype(T)), full = full)
function svd(A::StridedVecOrMat{T}; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) where T
svd!(copy_oftype(A, eigtype(T)), full = full, alg = alg)
end
function svd(x::Number; full::Bool = false)
function svd(x::Number; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer())
SVD(x == 0 ? fill(one(x), 1, 1) : fill(x/abs(x), 1, 1), [abs(x)], fill(one(x), 1, 1))
end
function svd(x::Integer; full::Bool = false)
svd(float(x), full = full)
function svd(x::Integer; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer())
svd(float(x), full = full, alg = alg)
end
function svd(A::Adjoint; full::Bool = false)
s = svd(A.parent, full = full)
function svd(A::Adjoint; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer())
s = svd(A.parent, full = full, alg = alg)
return SVD(s.Vt', s.S, s.U')
end
function svd(A::Transpose; full::Bool = false)
s = svd(A.parent, full = full)
function svd(A::Transpose; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer())
s = svd(A.parent, full = full, alg = alg)
return SVD(transpose(s.Vt), s.S, transpose(s.U))
end

Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2513,7 +2513,7 @@ eigen(A::AbstractTriangular) = Eigen(eigvals(A), eigvecs(A))
# Generic singular systems
for func in (:svd, :svd!, :svdvals)
@eval begin
($func)(A::AbstractTriangular) = ($func)(copyto!(similar(parent(A)), A))
($func)(A::AbstractTriangular; kwargs...) = ($func)(copyto!(similar(parent(A)), A); kwargs...)
end
end

Expand Down
31 changes: 31 additions & 0 deletions stdlib/LinearAlgebra/test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,35 @@ aimg = randn(n,n)/2
end
end



@testset "SVD Algorithms" begin
(x,y) = isapprox(x,y,rtol=1e-15)

allpos = (v) -> begin
for e in v
e < 0 && return false
end
return true
end

x = [0.1 0.2; 0.3 0.4]

for alg in [LinearAlgebra.Simple(), LinearAlgebra.DivideAndConquer()]
sx1 = svd(x, alg = alg)
@test sx1.U * Diagonal(sx1.S) * sx1.Vt x
@test sx1.V * sx1.Vt I
@test sx1.U * sx1.U' I
@test allpos(sx1.S)

sx2 = svd!(copy(x), alg = alg)
@test sx2.U * Diagonal(sx2.S) * sx2.Vt x
@test sx2.V * sx2.Vt I
@test sx2.U * sx2.U' I
@test allpos(sx2.S)
end
end



end # module TestSVD

0 comments on commit aae8959

Please sign in to comment.