From d65ce27031e49ce2cf75dd9a24f5994a5328937d Mon Sep 17 00:00:00 2001 From: Devon Morris Date: Wed, 21 Jun 2023 10:49:39 -0400 Subject: [PATCH] Fix steepest descent to cache previous param --- .../solver/gradientdescent/steepestdescent.rs | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/argmin/src/solver/gradientdescent/steepestdescent.rs b/argmin/src/solver/gradientdescent/steepestdescent.rs index c7ae87349..dc42a7205 100644 --- a/argmin/src/solver/gradientdescent/steepestdescent.rs +++ b/argmin/src/solver/gradientdescent/steepestdescent.rs @@ -7,7 +7,7 @@ use crate::core::{ ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState, - LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, KV, + LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, State, KV, }; use argmin_math::ArgminMul; #[cfg(feature = "serde1")] @@ -63,15 +63,18 @@ where fn next_iter( &mut self, problem: &mut Problem, - mut state: IterState, + state: IterState, ) -> Result<(IterState, Option), Error> { - let param_new = state.take_param().ok_or_else(argmin_error_closure!( - NotInitialized, - concat!( - "`SteepestDescent` requires an initial parameter vector. ", - "Please provide an initial guess via `Executor`s `configure` method." - ) - ))?; + let param_new = state + .get_param() + .ok_or_else(argmin_error_closure!( + NotInitialized, + concat!( + "`SteepestDescent` requires an initial parameter vector. ", + "Please provide an initial guess via `Executor`s `configure` method." + ) + ))? + .clone(); let new_cost = problem.cost(¶m_new)?; let new_grad = problem.gradient(¶m_new)?; @@ -153,6 +156,20 @@ mod tests { ); } + #[test] + fn test_next_iter_prev_param_not_erased() { + let linesearch: BacktrackingLineSearch, Vec, ArmijoCondition, f64> = + BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap()); + let mut sd = SteepestDescent::new(linesearch); + let (state, _kv) = sd + .next_iter( + &mut Problem::new(TestProblem::new()), + IterState::new().param(vec![1.0, 2.0]), + ) + .unwrap(); + state.prev_param.unwrap(); + } + #[test] fn test_next_iter_regression() { struct SDProblem {}