diff --git a/NEWS.md b/NEWS.md index 2ed6586a110dc..303864055e955 100644 --- a/NEWS.md +++ b/NEWS.md @@ -45,6 +45,7 @@ Standard library changes * The BLAS submodule no longer exports `dot`, which conflicts with that in LinearAlgebra ([#31838]). * `diagm` and `spdiagm` now accept optional `m,n` initial arguments to specify a size ([#31654]). * `Hessenberg` factorizations `H` now support efficient shifted solves `(H+µI) \ b` and determinants, and use a specialized tridiagonal factorization for Hermitian matrices. There is also a new `UpperHessenberg` matrix type ([#31853]). +* Added keyword argument `alg` to `svd` and `svd!` that allows one to switch between different SVD algorithms ([#31057]). #### SparseArrays diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 8547e1c2e5f0f..7f409c95c7063 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -198,8 +198,8 @@ function svd!(M::Bidiagonal{<:BlasReal}; full::Bool = false) d, e, U, Vt, Q, iQ = LAPACK.bdsdc!(M.uplo, 'I', M.dv, M.ev) SVD(U, d, Vt) end -function svd(M::Bidiagonal; full::Bool = false) - svd!(copy(M), full = full) +function svd(M::Bidiagonal; kw...) + svd!(copy(M), kw...) end #################### diff --git a/stdlib/LinearAlgebra/src/svd.jl b/stdlib/LinearAlgebra/src/svd.jl index 89b7501811230..7a33e31d57cbc 100644 --- a/stdlib/LinearAlgebra/src/svd.jl +++ b/stdlib/LinearAlgebra/src/svd.jl @@ -54,6 +54,12 @@ function SVD{T}(U::AbstractArray, S::AbstractVector{Tr}, Vt::AbstractArray) wher convert(AbstractArray{T}, Vt)) end + +abstract type SVDAlgorithm end +struct DivideAndConquer <: SVDAlgorithm end +struct Simple <: SVDAlgorithm end + + # iteration for destructuring into components Base.iterate(S::SVD) = (S.U, Val(:S)) Base.iterate(S::SVD, ::Val{:S}) = (S.S, Val(:V)) @@ -61,7 +67,7 @@ Base.iterate(S::SVD, ::Val{:V}) = (S.V, Val(:done)) Base.iterate(S::SVD, ::Val{:done}) = nothing """ - svd!(A; full::Bool = false) -> SVD + svd!(A; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) -> SVD `svd!` is the same as [`svd`](@ref), but saves space by overwriting the input `A`, instead of creating a copy. @@ -92,18 +98,25 @@ julia> A 0.0 0.0 -2.0 0.0 0.0 ``` """ -function svd!(A::StridedMatrix{T}; full::Bool = false) where T<:BlasFloat +function svd!(A::StridedMatrix{T}; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) where T<:BlasFloat m,n = size(A) if m == 0 || n == 0 u,s,vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T,0)), Matrix{T}(I, n, n)) else - u,s,vt = LAPACK.gesdd!(full ? 'A' : 'S', A) + if typeof(alg) == DivideAndConquer + u,s,vt = LAPACK.gesdd!(full ? 'A' : 'S', A) + elseif typeof(alg) == Simple + c = full ? 'A' : 'S' + u,s,vt = LAPACK.gesvd!(c, c, A) + else + throw(ArgumentError("Unsupported value for `alg` keyword.")) + end end SVD(u,s,vt) end """ - svd(A; full::Bool = false) -> SVD + svd(A; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) -> SVD Compute the singular value decomposition (SVD) of `A` and return an `SVD` object. @@ -120,6 +133,9 @@ and `V` is `N \\times N`, while in the thin factorization `U` is `M \\times K` and `V` is `N \\times K`, where `K = \\min(M,N)` is the number of singular values. +If `alg = DivideAndConquer()` (default) a divide-and-conquer algorithm is used to calculate the SVD. +One can set `alg = Simple()` to use a simple (typically slower) algorithm instead. + # Examples ```jldoctest julia> A = [1. 0. 0. 0. 2.; 0. 0. 3. 0. 0.; 0. 0. 0. 0. 0.; 0. 2. 0. 0. 0.] @@ -144,21 +160,21 @@ julia> u == F.U && s == F.S && v == F.V true ``` """ -function svd(A::StridedVecOrMat{T}; full::Bool = false) where T - svd!(copy_oftype(A, eigtype(T)), full = full) +function svd(A::StridedVecOrMat{T}; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) where T + svd!(copy_oftype(A, eigtype(T)), full = full, alg = alg) end -function svd(x::Number; full::Bool = false) +function svd(x::Number; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) SVD(x == 0 ? fill(one(x), 1, 1) : fill(x/abs(x), 1, 1), [abs(x)], fill(one(x), 1, 1)) end -function svd(x::Integer; full::Bool = false) - svd(float(x), full = full) +function svd(x::Integer; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) + svd(float(x), full = full, alg = alg) end -function svd(A::Adjoint; full::Bool = false) - s = svd(A.parent, full = full) +function svd(A::Adjoint; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) + s = svd(A.parent, full = full, alg = alg) return SVD(s.Vt', s.S, s.U') end -function svd(A::Transpose; full::Bool = false) - s = svd(A.parent, full = full) +function svd(A::Transpose; full::Bool = false, alg::SVDAlgorithm = DivideAndConquer()) + s = svd(A.parent, full = full, alg = alg) return SVD(transpose(s.Vt), s.S, transpose(s.U)) end diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index fc78bf57547ed..e751a4c801d5c 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -2513,7 +2513,7 @@ eigen(A::AbstractTriangular) = Eigen(eigvals(A), eigvecs(A)) # Generic singular systems for func in (:svd, :svd!, :svdvals) @eval begin - ($func)(A::AbstractTriangular) = ($func)(copyto!(similar(parent(A)), A)) + ($func)(A::AbstractTriangular; kwargs...) = ($func)(copyto!(similar(parent(A)), A); kwargs...) end end diff --git a/stdlib/LinearAlgebra/test/svd.jl b/stdlib/LinearAlgebra/test/svd.jl index 7a29d86974a61..0a23720bd6cc5 100644 --- a/stdlib/LinearAlgebra/test/svd.jl +++ b/stdlib/LinearAlgebra/test/svd.jl @@ -139,4 +139,35 @@ aimg = randn(n,n)/2 end end + + +@testset "SVD Algorithms" begin + ≊(x,y) = isapprox(x,y,rtol=1e-15) + + allpos = (v) -> begin + for e in v + e < 0 && return false + end + return true + end + + x = [0.1 0.2; 0.3 0.4] + + for alg in [LinearAlgebra.Simple(), LinearAlgebra.DivideAndConquer()] + sx1 = svd(x, alg = alg) + @test sx1.U * Diagonal(sx1.S) * sx1.Vt ≊ x + @test sx1.V * sx1.Vt ≊ I + @test sx1.U * sx1.U' ≊ I + @test allpos(sx1.S) + + sx2 = svd!(copy(x), alg = alg) + @test sx2.U * Diagonal(sx2.S) * sx2.Vt ≊ x + @test sx2.V * sx2.Vt ≊ I + @test sx2.U * sx2.U' ≊ I + @test allpos(sx2.S) + end +end + + + end # module TestSVD