From 88965d8b2b3dbbd07150bb64486117690a013b5c Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Fri, 22 Mar 2024 19:03:44 -0700 Subject: [PATCH 1/2] add trunc_svd svoler --- src/solvers.jl | 37 +++++++++++++++++++++++++++++++++++++ test/test_linearsolvers.jl | 7 +++++++ 2 files changed, 44 insertions(+) diff --git a/src/solvers.jl b/src/solvers.jl index 54ac29c..96f034d 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -2,6 +2,7 @@ using LinearAlgebra: qr, I, norm using LowRankApprox: pqrfact using IterativeSolvers using .BayesianLinear +using LinearAlgebra: SVD, svd @doc raw""" `struct QR` : linear least squares solver, using standard QR factorisation; @@ -148,3 +149,39 @@ function SKLEARN_ARD(; n_iter = 300, tol = 1e-3, threshold_lambda = 10000) end # solve(solver::SKLEARN_ARD, ...) is implemented in ext/ + +@doc raw""" +`struct Truncated_SVD` : linear least squares solver +```math + θ = \arg\min \| A \theta - y \|^2 + \lambda \| P \theta \|^2 +``` +Constructor +```julia +ACEfit.Truncated_SVD(; lambda = 0.0, P = nothing) +``` +where +* `rtol` : relative tolerance +* `P` : right-preconditioner / tychonov operator +""" +struct Truncated_SVD + rtol::Number + P::Any +end + +Truncated_SVD(; rtol = 1e-9, P = I) = Truncated_SVD(rtol, P) + +function trunc_svd(USV::SVD, Y, rtol) + U, S, V = USV # svd(A) + Ikeep = findall(x -> x > rtol, S ./ maximum(S)) + U1 = @view U[:, Ikeep] + S1 = S[Ikeep] + V1 = @view V[:, Ikeep] + return V1 * (S1 .\ (U1' * Y)) +end + +function solve(solver::Truncated_SVD, A, y) + AP = A / solver.P + θP = trunc_svd(svd(AP), y, solver.rtol) + return Dict{String, Any}("C" => solver.P \ θP) +end + diff --git a/test/test_linearsolvers.jl b/test/test_linearsolvers.jl index dca778b..fca1b4e 100644 --- a/test/test_linearsolvers.jl +++ b/test/test_linearsolvers.jl @@ -79,3 +79,10 @@ results = ACEfit.solve(solver, A, y) C = results["C"] @show norm(A * C - y) @show norm(C) + +@info(" ... Truncated_SVD") +solver = ACEfit.Truncated_SVD() +results = ACEfit.solve(solver, A, y) +C = results["C"] +@show norm(A * C - y) +@show norm(C) From e1d674d86a6a808d7d2089d95665c7059552e1ab Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Fri, 22 Mar 2024 19:49:25 -0700 Subject: [PATCH 2/2] update doc --- src/solvers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solvers.jl b/src/solvers.jl index 96f034d..11c0f5f 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -153,7 +153,7 @@ end @doc raw""" `struct Truncated_SVD` : linear least squares solver ```math - θ = \arg\min \| A \theta - y \|^2 + \lambda \| P \theta \|^2 + θ = \arg\min \| A P^{-1} \theta - y \|^2 ``` Constructor ```julia