Skip to content

Commit

Permalink
use last segment info for checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Oct 16, 2024
1 parent 3d416b7 commit 36e6ab9
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ pub use ode_solver::state::{StateRef, StateRefMut};
pub use ode_solver::{
adjoint_equations::AdjointContext, adjoint_equations::AdjointEquations,
adjoint_equations::AdjointInit, adjoint_equations::AdjointRhs, bdf::Bdf, bdf::BdfAdj,
bdf_state::BdfState, builder::OdeBuilder, checkpointing::Checkpointing,
bdf_state::BdfState, builder::OdeBuilder, checkpointing::Checkpointing, checkpointing::HermiteInterpolator,
equations::AugmentedOdeEquations, equations::AugmentedOdeEquationsImplicit, equations::NoAug,
equations::OdeEquations, equations::OdeEquationsAdjoint, equations::OdeEquationsImplicit,
equations::OdeEquationsSens, equations::OdeSolverEquations, method::OdeSolverMethod,
Expand Down
4 changes: 2 additions & 2 deletions src/ode_solver/adjoint_equations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ mod tests {
h: 0.0,
};
let checkpointer =
Checkpointing::new(&problem, solver, 0, vec![state.clone(), state.clone()]);
Checkpointing::new(&problem, solver, 0, vec![state.clone(), state.clone()], None);
let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer)));
let adj_eqn = AdjointEquations::new(&problem.eqn, context.clone(), false);
// F(λ, x, t) = -f^T_x(x, t) λ
Expand Down Expand Up @@ -662,7 +662,7 @@ mod tests {
h: 0.0,
};
let checkpointer =
Checkpointing::new(&problem, solver, 0, vec![state.clone(), state.clone()]);
Checkpointing::new(&problem, solver, 0, vec![state.clone(), state.clone()], None);
let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer)));
let mut adj_eqn = AdjointEquations::new(&problem.eqn, context, true);

Expand Down
16 changes: 9 additions & 7 deletions src/ode_solver/bdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::rc::Rc;
use crate::{
error::{DiffsolError, OdeSolverError},
AdjointContext, AdjointEquations, NoAug, OdeEquationsAdjoint, OdeEquationsSens, SensEquations,
StateRef, StateRefMut,
StateRef, StateRefMut, HermiteInterpolator
};

use num_traits::{abs, One, Pow, Zero};
Expand Down Expand Up @@ -1094,6 +1094,7 @@ where
fn new_adjoint_solver(
&self,
checkpoints: Vec<Self::State>,
last_segment: HermiteInterpolator<Eqn::V>,
include_in_error_control: bool,
) -> Result<Self::AdjointSolver, DiffsolError> {
// construct checkpointing
Expand All @@ -1105,6 +1106,7 @@ where
checkpointer_solver,
checkpoints.len() - 2,
checkpoints,
Some(last_segment),
);

// construct adjoint equations and problem
Expand Down Expand Up @@ -1298,9 +1300,9 @@ mod test {
"###);
insta::assert_yaml_snapshot!(s.problem().as_ref().unwrap().eqn.rhs().statistics(), @r###"
---
number_of_calls: 166
number_of_jac_muls: 8
number_of_matrix_evals: 4
number_of_calls: 84
number_of_jac_muls: 6
number_of_matrix_evals: 3
number_of_jac_adj_muls: 254
"###);
insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###"
Expand Down Expand Up @@ -1328,9 +1330,9 @@ mod test {
"###);
insta::assert_yaml_snapshot!(s.problem().as_ref().unwrap().eqn.rhs().statistics(), @r###"
---
number_of_calls: 210
number_of_jac_muls: 21
number_of_matrix_evals: 7
number_of_calls: 208
number_of_jac_muls: 18
number_of_matrix_evals: 6
number_of_jac_adj_muls: 201
"###);
insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###"
Expand Down
35 changes: 24 additions & 11 deletions src/ode_solver/checkpointing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ impl<V> HermiteInterpolator<V>
where
V: Vector,
{
pub fn new(ys: Vec<V>, ydots: Vec<V>, ts: Vec<V::T>) -> Self {
HermiteInterpolator { ys, ydots, ts }
}
pub fn reset<Eqn, Method, State>(
&mut self,
problem: &OdeSolverProblem<Eqn>,
Expand Down Expand Up @@ -125,22 +128,24 @@ where
mut solver: Method,
start_idx: usize,
checkpoints: Vec<Method::State>,
segment: Option<HermiteInterpolator<Eqn::V>>,
) -> Self {
if checkpoints.len() < 2 {
panic!("Checkpoints must have at least 2 elements");
}
if start_idx >= checkpoints.len() - 1 {
panic!("start_idx must be less than checkpoints.len() - 1");
}
let mut segment = HermiteInterpolator::default();
segment
.reset(
let segment = segment.unwrap_or_else(|| {
let mut segment = HermiteInterpolator::default();
segment.reset(
problem,
&mut solver,
&checkpoints[start_idx],
&checkpoints[start_idx + 1],
)
.unwrap();
).unwrap();
segment
});
let segment = RefCell::new(segment);
let previous_segment = RefCell::new(None);
let solver = RefCell::new(solver);
Expand Down Expand Up @@ -212,7 +217,7 @@ mod tests {
OdeSolverMethod, OdeSolverState, Op, Vector,
};

use super::Checkpointing;
use super::{Checkpointing, HermiteInterpolator};

#[test]
fn test_checkpointing() {
Expand All @@ -224,17 +229,25 @@ mod tests {
solver.set_problem(state0.clone(), &problem).unwrap();
let mut checkpoints = vec![state0];
let mut i = 0;
let mut ys = Vec::new();
let mut ts = Vec::new();
let mut ydots = Vec::new();
while solver.state().unwrap().t < t_final {
ts.push(solver.state().unwrap().t);
ys.push(solver.state().unwrap().y.clone());
ydots.push(solver.state().unwrap().dy.clone());
solver.step().unwrap();
i += 1;
if i % n_steps == 0 {
if i % n_steps == 0 && solver.state().unwrap().t < t_final {
checkpoints.push(solver.checkpoint().unwrap());
ts.clear();
ys.clear();
ydots.clear();
}
}
if i % n_steps != 0 {
checkpoints.push(solver.checkpoint().unwrap());
}
let checkpointer = Checkpointing::new(&problem, solver, checkpoints.len() - 2, checkpoints);
checkpoints.push(solver.checkpoint().unwrap());
let segment = HermiteInterpolator::new(ys, ydots, ts);
let checkpointer = Checkpointing::new(&problem, solver, checkpoints.len() - 2, checkpoints, Some(segment));
let mut y = DVector::zeros(problem.eqn.rhs().nstates());
for point in soln.solution_points.iter().rev() {
checkpointer.interpolate(point.t, &mut y).unwrap();
Expand Down
3 changes: 3 additions & 0 deletions src/ode_solver/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use crate::{
OdeSolverState, Op, SensEquations, StateRef, StateRefMut, VectorViewMut,
};

use super::checkpointing::HermiteInterpolator;

#[derive(Debug, PartialEq)]
pub enum OdeSolverStopReason<T: Scalar> {
InternalTimestep,
Expand Down Expand Up @@ -274,6 +276,7 @@ where
fn new_adjoint_solver(
&self,
checkpoints: Vec<Self::State>,
last_segment: HermiteInterpolator<Eqn::V>,
include_in_error_control: bool,
) -> Result<Self::AdjointSolver, DiffsolError>;
}
21 changes: 16 additions & 5 deletions src/ode_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mod tests {
use std::rc::Rc;

use self::problem::OdeSolverSolution;
use checkpointing::HermiteInterpolator;
use method::{AdjointOdeSolverMethod, SensitivitiesOdeSolverMethod};
use nalgebra::ComplexField;

Expand Down Expand Up @@ -242,13 +243,22 @@ mod tests {
let mut nsteps = 0;
let (rtol, atol) = (solution.rtol, &solution.atol);
let mut checkpoints = vec![method.checkpoint().unwrap()];
let mut ts = Vec::new();
let mut ys = Vec::new();
let mut ydots = Vec::new();
for point in solution.solution_points.iter() {
while method.state().unwrap().t.abs() < point.t.abs() {
ts.push(method.state().unwrap().t);
ys.push(method.state().unwrap().y.clone());
ydots.push(method.state().unwrap().dy.clone());
method.step().unwrap();
nsteps += 1;
if nsteps > 50 {
if nsteps > 50 && method.state().unwrap().t.abs() < t1.abs() {
checkpoints.push(method.checkpoint().unwrap());
nsteps = 0;
ts.clear();
ys.clear();
ydots.clear();
}
}
let soln = method.interpolate_out(point.t).unwrap();
Expand All @@ -264,17 +274,18 @@ mod tests {
point.state
);
}
ts.push(method.state().unwrap().t);
ys.push(method.state().unwrap().y.clone());
ydots.push(method.state().unwrap().dy.clone());
checkpoints.push(method.checkpoint().unwrap());
let mut adjoint_solver = method.new_adjoint_solver(checkpoints, true).unwrap();
let last_segment = HermiteInterpolator::new(ys, ydots, ts);
let mut adjoint_solver = method.new_adjoint_solver(checkpoints, last_segment, true).unwrap();
let y_expect = M::V::from_element(problem.eqn.rhs().nstates(), M::T::zero());
adjoint_solver
.state()
.unwrap()
.y
.assert_eq_st(&y_expect, M::T::from(1e-9));
//for i in 0..problem.eqn.out().unwrap().nout() {
// adjoint_solver.state().unwrap().s[i].assert_eq_st(&y_expect, M::T::from(1e-9));
//}
let g_expect = M::V::from_element(problem.eqn.rhs().nparams(), M::T::zero());
for i in 0..problem.eqn.out().unwrap().nout() {
adjoint_solver.state().unwrap().sg[i].assert_eq_st(&g_expect, M::T::from(1e-9));
Expand Down

0 comments on commit 36e6ab9

Please sign in to comment.