Skip to content

Commit

Permalink
Fix steepest descent to cache previous param
Browse files Browse the repository at this point in the history
  • Loading branch information
DevonMorris authored and stefan-k committed Jun 21, 2023
1 parent f2f921a commit d65ce27
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions argmin/src/solver/gradientdescent/steepestdescent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

use crate::core::{
ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState,
LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, KV,
LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, State, KV,
};
use argmin_math::ArgminMul;
#[cfg(feature = "serde1")]
Expand Down Expand Up @@ -63,15 +63,18 @@ where
fn next_iter(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, G, (), (), F>,
state: IterState<P, G, (), (), F>,
) -> Result<(IterState<P, G, (), (), F>, Option<KV>), Error> {
let param_new = state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`SteepestDescent` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?;
let param_new = state
.get_param()
.ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`SteepestDescent` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?
.clone();
let new_cost = problem.cost(&param_new)?;
let new_grad = problem.gradient(&param_new)?;

Expand Down Expand Up @@ -153,6 +156,20 @@ mod tests {
);
}

#[test]
fn test_next_iter_prev_param_not_erased() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let mut sd = SteepestDescent::new(linesearch);
let (state, _kv) = sd
.next_iter(
&mut Problem::new(TestProblem::new()),
IterState::new().param(vec![1.0, 2.0]),
)
.unwrap();
state.prev_param.unwrap();
}

#[test]
fn test_next_iter_regression() {
struct SDProblem {}
Expand Down

0 comments on commit d65ce27

Please sign in to comment.