Skip to content

Commit

Permalink
Automatically construct the jacobian for FastLM
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 7, 2023
1 parent 513eef6 commit 9e27c1a
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 34 deletions.
4 changes: 3 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pages = ["index.md",
"api/minpack.md",
"api/nlsolve.md",
"api/sundials.md",
"api/steadystatediffeq.md"],
"api/steadystatediffeq.md",
"api/leastsquaresoptim.md",
"api/fastlevenbergmarquardt.md"],
"Release Notes" => "release_notes.md",
]
69 changes: 56 additions & 13 deletions ext/NonlinearSolveFastLevenbergMarquardtExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ module NonlinearSolveFastLevenbergMarquardtExt
using ArrayInterface, NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import FastLevenbergMarquardt as FastLM
import FiniteDiff, ForwardDiff

function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
if linsolve == :cholesky
if linsolve === :cholesky

Check warning on line 9 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L9

Added line #L9 was not covered by tests
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
elseif linsolve == :qr
elseif linsolve === :qr

Check warning on line 11 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L11

Added line #L11 was not covered by tests
return FastLM.QRSolver(eltype(x), length(x))
else
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))
Expand All @@ -33,23 +34,65 @@ end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = 1e-8,
reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
reltol = 1e-8, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)
u0 = alias_u0 ? prob.u0 : deepcopy(prob.u0)

@assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!"
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
fu = NonlinearSolve.evaluate_f(prob, u)

Check warning on line 40 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L39-L40

Added lines #L39 - L40 were not covered by tests

f! = InplaceFunction{iip}(prob.f)
J! = InplaceFunction{iip}(prob.f.jac)

resid_prototype = prob.f.resid_prototype === nothing ?
(!iip ? prob.f(u0, prob.p) : zeros(u0)) :
prob.f.resid_prototype
if prob.f.jac === nothing
use_forward_diff = if alg.autodiff === nothing
ForwardDiff.can_dual(eltype(u))

Check warning on line 46 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L44-L46

Added lines #L44 - L46 were not covered by tests
else
alg.autodiff isa AutoForwardDiff

Check warning on line 48 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L48

Added line #L48 was not covered by tests
end
uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p)
if use_forward_diff
cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) :

Check warning on line 52 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L50-L52

Added lines #L50 - L52 were not covered by tests
ForwardDiff.JacobianConfig(uf, u)
else
cache = FiniteDiff.JacobianCache(u, fu)

Check warning on line 55 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L55

Added line #L55 was not covered by tests
end
J! = if iip
if use_forward_diff
fu_cache = similar(fu)
function (J, x, p)
uf.p = p
ForwardDiff.jacobian!(J, uf, fu_cache, x, cache)
return J

Check warning on line 63 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L57-L63

Added lines #L57 - L63 were not covered by tests
end
else
function (J, x, p)
uf.p = p
FiniteDiff.finite_difference_jacobian!(J, uf, x, cache)
return J

Check warning on line 69 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L66-L69

Added lines #L66 - L69 were not covered by tests
end
end
else
if use_forward_diff
function (J, x, p)
uf.p = p
ForwardDiff.jacobian!(J, uf, x, cache)
return J

Check warning on line 77 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L73-L77

Added lines #L73 - L77 were not covered by tests
end
else
function (J, x, p)
uf.p = p
J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache)
copyto!(J, J_)
return J

Check warning on line 84 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L80-L84

Added lines #L80 - L84 were not covered by tests
end
end
end
else
J! = InplaceFunction{iip}(prob.f.jac)

Check warning on line 89 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L89

Added line #L89 was not covered by tests
end

J = similar(u0, length(resid_prototype), length(u0))
J = similar(u, length(fu), length(u))

Check warning on line 92 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L92

Added line #L92 was not covered by tests

solver = _fast_lm_solver(alg, u0)
LM = FastLM.LMWorkspace(u0, resid_prototype, J)
solver = _fast_lm_solver(alg, u)
LM = FastLM.LMWorkspace(u, fu, J)

Check warning on line 95 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L94-L95

Added lines #L94 - L95 were not covered by tests

return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver,
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
Expand Down
10 changes: 5 additions & 5 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import ConcreteStructs: @concrete
import LeastSquaresOptim as LSO

function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve}
ls = linsolve == :qr ? LSO.QR() :
(linsolve == :cholesky ? LSO.Cholesky() :
(linsolve == :lsmr ? LSO.LSMR() : nothing))
if alg == :lm
ls = linsolve === :qr ? LSO.QR() :

Check warning on line 8 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L8

Added line #L8 was not covered by tests
(linsolve === :cholesky ? LSO.Cholesky() :
(linsolve === :lsmr ? LSO.LSMR() : nothing))
if alg === :lm

Check warning on line 11 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L11

Added line #L11 was not covered by tests
return LSO.LevenbergMarquardt(ls)
elseif alg == :dogleg
elseif alg === :dogleg

Check warning on line 13 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L13

Added line #L13 was not covered by tests
return LSO.Dogleg(ls)
else
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg"))
Expand Down
12 changes: 8 additions & 4 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,10 @@ function FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothi
autodiff = nothing) where {JAC, SA}
if JAC
if SA
algs = (SimpleNewtonRaphson(; autodiff),
SimpleTrustRegion(; autodiff),
algs = (SimpleNewtonRaphson(;

Check warning on line 247 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L246-L247

Added lines #L246 - L247 were not covered by tests
autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)),
SimpleTrustRegion(;
autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)),
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
autodiff),
TrustRegion(; concrete_jac, linsolve, precs,
Expand All @@ -263,8 +265,10 @@ function FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothi
algs = (SimpleBroyden(),

Check warning on line 265 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L264-L265

Added lines #L264 - L265 were not covered by tests
Broyden(; init_jacobian = Val(:true_jacobian)),
SimpleKlement(),
SimpleNewtonRaphson(; autodiff),
SimpleTrustRegion(; autodiff),
SimpleNewtonRaphson(;
autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)),
SimpleTrustRegion(;
autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)),
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
autodiff),
TrustRegion(; concrete_jac, linsolve, precs,
Expand Down
18 changes: 11 additions & 7 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function LeastSquaresOptimJL(alg = :lm; linsolve = nothing, autodiff::Symbol = :
end

"""
FastLevenbergMarquardtJL(linsolve = :cholesky)
FastLevenbergMarquardtJL(linsolve = :cholesky; autodiff = nothing)
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl)
for solving `NonlinearLeastSquaresProblem`.
Expand All @@ -46,19 +46,20 @@ for solving `NonlinearLeastSquaresProblem`.
This is not really the fastest solver. It is called that since the original package
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.
!!! warning
This algorithm requires the jacobian function to be provided!
## Arguments:
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
`nothing` which means that a default is selected according to the problem specification!
Valid choices are `nothing`, `AutoForwardDiff` or `AutoFiniteDiff`.
!!! note
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
"""
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm
autodiff
factor
factoraccept
factorreject
Expand All @@ -71,14 +72,17 @@ end

function FastLevenbergMarquardtJL(linsolve::Symbol = :cholesky; factor = 1e-6,
factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt,
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32)
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32,
autodiff = nothing)
@assert linsolve in (:qr, :cholesky)
@assert factorupdate in (:marquardt, :nielson)
@assert autodiff === nothing || autodiff isa AutoFiniteDiff ||

Check warning on line 79 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L79

Added line #L79 was not covered by tests
autodiff isa AutoForwardDiff

if Base.get_extension(@__MODULE__, :NonlinearSolveFastLevenbergMarquardtExt) === nothing
error("LeastSquaresOptimJL requires FastLevenbergMarquardt.jl to be loaded")
end

return FastLevenbergMarquardtJL{linsolve}(factor, factoraccept, factorreject,
return FastLevenbergMarquardtJL{linsolve}(autodiff, factor, factoraccept, factorreject,

Check warning on line 86 in src/extension_algs.jl

View check run for this annotation

Codecov / codecov/patch

src/extension_algs.jl#L86

Added line #L86 was not covered by tests
factorupdate, minscale, maxscale, minfactor, maxfactor)
end
31 changes: 27 additions & 4 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,35 @@ function jac!(J, θ, p)
return J
end

prob = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target), jac = jac!), θ_init, x)
jac(θ, p) = ForwardDiff.jacobian-> loss_function(θ, p), θ)

solvers = [FastLevenbergMarquardtJL(:cholesky), FastLevenbergMarquardtJL(:qr)]
probs = [
NonlinearLeastSquaresProblem(NonlinearFunction{true}(loss_function;
resid_prototype = zero(y_target), jac = jac!), θ_init, x),
NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function;
resid_prototype = zero(y_target), jac = jac), θ_init, x),
NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function; jac), θ_init, x),
]

solvers = [FastLevenbergMarquardtJL(linsolve) for linsolve in (:cholesky, :qr)]

for solver in solvers, prob in probs
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test norm(sol.resid) < 1e-6
end

probs = [
NonlinearLeastSquaresProblem(NonlinearFunction{true}(loss_function;
resid_prototype = zero(y_target)), θ_init, x),
NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function;
resid_prototype = zero(y_target)), θ_init, x),
NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function), θ_init, x),
]

solvers = [FastLevenbergMarquardtJL(linsolve; autodiff) for linsolve in (:cholesky, :qr),
autodiff in (nothing, AutoForwardDiff(), AutoFiniteDiff())]

for solver in solvers
for solver in solvers, prob in probs
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test norm(sol.resid) < 1e-6
end

0 comments on commit 9e27c1a

Please sign in to comment.