diff --git a/argmin/src/solver/gradientdescent/steepestdescent.rs b/argmin/src/solver/gradientdescent/steepestdescent.rs index c7ae87349..9afa77b31 100644 --- a/argmin/src/solver/gradientdescent/steepestdescent.rs +++ b/argmin/src/solver/gradientdescent/steepestdescent.rs @@ -65,13 +65,16 @@ where problem: &mut Problem, mut state: IterState, ) -> Result<(IterState, Option), Error> { - let param_new = state.take_param().ok_or_else(argmin_error_closure!( - NotInitialized, - concat!( - "`SteepestDescent` requires an initial parameter vector. ", - "Please provide an initial guess via `Executor`s `configure` method." - ) - ))?; + let param_new = state + .get_param() + .ok_or_else(argmin_error_closure!( + NotInitialized, + concat!( + "`SteepestDescent` requires an initial parameter vector. ", + "Please provide an initial guess via `Executor`s `configure` method." + ) + ))? + .clone(); let new_cost = problem.cost(¶m_new)?; let new_grad = problem.gradient(¶m_new)?; @@ -153,6 +156,20 @@ mod tests { ); } + #[test] + fn test_next_iter_prev_param_not_erased() { + let linesearch: BacktrackingLineSearch, Vec, ArmijoCondition, f64> = + BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap()); + let mut sd = SteepestDescent::new(linesearch); + let (state, _kv) = sd + .next_iter( + &mut Problem::new(TestProblem::new()), + IterState::new().param(vec![1.0, 2.0]), + ) + .unwrap(); + state.prev_param.unwrap(); + } + #[test] fn test_next_iter_regression() { struct SDProblem {}