Skip to content

Commit

Permalink
add tests for integrate out with solve and solve_dense
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Oct 19, 2024
1 parent 2b47142 commit cd23c33
Showing 1 changed file with 88 additions and 22 deletions.
110 changes: 88 additions & 22 deletions src/ode_solver/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ where
.ceil() as usize,
);
let mut ret_y = <<Eqn::V as DefaultDenseMatrix>::M as Matrix>::zeros(nstates, ntimes_guess);
let mut write_out = |t: Eqn::T, y: &Eqn::V| {
let mut write_out = |t: Eqn::T, y: &Eqn::V, g: &Eqn::V| {
ret_t.push(t);
let mut y_i = {
let max_i = ret_y.ncols();
Expand All @@ -144,21 +144,37 @@ where
};
match problem.eqn.out() {
Some(out) => {
y_i.copy_from(&out.call(y, t))
if problem.integrate_out {
y_i.copy_from(g);
} else {
y_i.copy_from(&out.call(y, t))
}
}
None => y_i.copy_from(y),
}
};

// do the main loop
write_out(self.state().unwrap().t, self.state().unwrap().y);
write_out(
self.state().unwrap().t,
self.state().unwrap().y,
self.state().unwrap().g,
);
self.set_stop_time(final_time)?;
while self.step()? != OdeSolverStopReason::TstopReached {
write_out(self.state().unwrap().t, self.state().unwrap().y);
write_out(
self.state().unwrap().t,
self.state().unwrap().y,
self.state().unwrap().g,
);
}

// store the final step
write_out(self.state().unwrap().t, self.state().unwrap().y);
write_out(
self.state().unwrap().t,
self.state().unwrap().y,
self.state().unwrap().g,
);
Ok((ret_y, ret_t))
}

Expand Down Expand Up @@ -186,6 +202,18 @@ where
return Err(ode_solver_error!(InvalidTEval));
}

let mut write_out = |i: usize, y: &Eqn::V, g: Option<&Eqn::V>| {
let mut y_out = ret.column_mut(i);
if let Some(g) = g {
y_out.copy_from(g);
} else {
match problem.eqn.out() {
Some(out) => y_out.copy_from(&out.call(y, t_eval[i])),
None => y_out.copy_from(y),
}
}
};

// do loop
self.set_stop_time(t_eval[t_eval.len() - 1])?;
let mut step_reason = OdeSolverStopReason::InternalTimestep;
Expand All @@ -194,25 +222,26 @@ where
step_reason = self.step()?;
}
let y = self.interpolate(*t)?;
let mut y_out = ret.column_mut(i);
match problem.eqn.out() {
Some(out) => y_out.copy_from(&out.call(&y, *t)),
None => y_out.copy_from(&y),
if problem.integrate_out {
let g = self.interpolate_out(*t)?;
write_out(i, &y, Some(&g));
} else {
write_out(i, &y, None);
}
}

// do final step
while step_reason != OdeSolverStopReason::TstopReached {
step_reason = self.step()?;
}
{
let mut y_out = ret.column_mut(t_eval.len() - 1);
match problem.eqn.out() {
Some(out) => {
y_out.copy_from(&out.call(self.state().unwrap().y, self.state().unwrap().t))
}
None => y_out.copy_from(self.state().unwrap().y),
}
if problem.integrate_out {
write_out(
t_eval.len() - 1,
self.state().unwrap().y,
Some(self.state().unwrap().g),
);
} else {
write_out(t_eval.len() - 1, self.state().unwrap().y, None);
}
Ok(ret)
}
Expand Down Expand Up @@ -313,10 +342,13 @@ where
}
}


#[cfg(test)]
mod test {
use crate::{ode_solver::test_models::exponential_decay::exponential_decay_problem, scale, Bdf, OdeSolverMethod, OdeSolverState, Vector};
use crate::{
ode_solver::test_models::exponential_decay::exponential_decay_problem,
ode_solver::test_models::exponential_decay::exponential_decay_problem_adjoint, scale, Bdf,
OdeSolverMethod, OdeSolverState, Vector,
};

#[test]
fn test_solve() {
Expand All @@ -325,9 +357,7 @@ mod test {

let k = 0.1;
let y0 = nalgebra::DVector::from_vec(vec![1.0, 1.0]);
let expect = |t: f64| {
&y0 * scale(f64::exp(-k * t))
};
let expect = |t: f64| &y0 * scale(f64::exp(-k * t));
let state = OdeSolverState::new(&problem, &s).unwrap();
let (y, t) = s.solve(&problem, state, 10.0).unwrap();
assert!((t[0] - 0.0).abs() < 1e-10);
Expand All @@ -338,6 +368,29 @@ mod test {
}
}

#[test]
fn test_solve_integrate_out() {
let mut s = Bdf::default();
let (problem, _soln) = exponential_decay_problem_adjoint::<nalgebra::DMatrix<f64>>();

let k = 0.1;
let y0 = nalgebra::DVector::from_vec(vec![1.0, 1.0]);
let t0 = 0.0;
let expect = |t: f64| {
let g = &y0 * scale((f64::exp(-k * t0) - f64::exp(-k * t)) / k);
nalgebra::DVector::<f64>::from_vec(vec![
1.0 * g[0] + 2.0 * g[1],
3.0 * g[0] + 4.0 * g[1],
])
};
let state = OdeSolverState::new(&problem, &s).unwrap();
let (y, t) = s.solve(&problem, state, 10.0).unwrap();
for (i, t_i) in t.iter().enumerate() {
let y_i = y.column(i).into_owned();
y_i.assert_eq_norm(&expect(*t_i), problem.atol.as_ref(), problem.rtol, 15.0);
}
}

#[test]
fn test_dense_solve() {
let mut s = Bdf::default();
Expand All @@ -352,4 +405,17 @@ mod test {
}
}

#[test]
fn test_dense_solve_integrate_out() {
let mut s = Bdf::default();
let (problem, soln) = exponential_decay_problem_adjoint::<nalgebra::DMatrix<f64>>();

let state = OdeSolverState::new(&problem, &s).unwrap();
let t_eval = soln.solution_points.iter().map(|p| p.t).collect::<Vec<_>>();
let y = s.solve_dense(&problem, state, t_eval.as_slice()).unwrap();
for (i, soln_pt) in soln.solution_points.iter().enumerate() {
let y_i = y.column(i).into_owned();
y_i.assert_eq_norm(&soln_pt.state, problem.atol.as_ref(), problem.rtol, 15.0);
}
}
}

0 comments on commit cd23c33

Please sign in to comment.