Skip to content

Commit

Permalink
Fix steepest descent to cache previous param
Browse files Browse the repository at this point in the history
  • Loading branch information
DevonMorris committed Jun 21, 2023
1 parent 8388f38 commit 4566bcd
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions argmin/src/solver/gradientdescent/steepestdescent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,16 @@ where
problem: &mut Problem<O>,
mut state: IterState<P, G, (), (), F>,
) -> Result<(IterState<P, G, (), (), F>, Option<KV>), Error> {
let param_new = state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`SteepestDescent` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?;
let param_new = state
.get_param()
.ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`SteepestDescent` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?
.clone();
let new_cost = problem.cost(&param_new)?;
let new_grad = problem.gradient(&param_new)?;

Expand Down Expand Up @@ -153,6 +156,20 @@ mod tests {
);
}

#[test]
fn test_next_iter_prev_param_not_erased() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let mut sd = SteepestDescent::new(linesearch);
let (state, _kv) = sd
.next_iter(
&mut Problem::new(TestProblem::new()),
IterState::new().param(vec![1.0, 2.0]),
)
.unwrap();
state.prev_param.unwrap();
}

#[test]
fn test_next_iter_regression() {
struct SDProblem {}
Expand Down

0 comments on commit 4566bcd

Please sign in to comment.