Skip to content

Commit

Permalink
finished removing op params from solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Oct 16, 2024
1 parent 2b0bebd commit 2fd87e9
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 66 deletions.
8 changes: 3 additions & 5 deletions benches/ode_solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,6 @@ criterion_main!(benches);

mod benchmarks {
use diffsol::matrix::MatrixRef;
use diffsol::op::bdf::BdfCallable;
use diffsol::op::sdirk::SdirkCallable;
use diffsol::vector::VectorRef;
use diffsol::LinearSolver;
use diffsol::{
Expand All @@ -560,7 +558,7 @@ mod benchmarks {
pub fn bdf<Eqn>(
problem: &OdeSolverProblem<Eqn>,
t: Eqn::T,
ls: impl LinearSolver<BdfCallable<Eqn>>,
ls: impl LinearSolver<Eqn::M>,
) where
Eqn: OdeEquationsImplicit,
Eqn::M: Matrix + DefaultSolver,
Expand All @@ -577,7 +575,7 @@ mod benchmarks {
pub fn esdirk34<Eqn>(
problem: &OdeSolverProblem<Eqn>,
t: Eqn::T,
linear_solver: impl LinearSolver<SdirkCallable<Eqn>>,
linear_solver: impl LinearSolver<Eqn::M>,
) where
Eqn: OdeEquationsImplicit,
Eqn::M: Matrix + DefaultSolver,
Expand All @@ -594,7 +592,7 @@ mod benchmarks {
pub fn tr_bdf2<Eqn>(
problem: &OdeSolverProblem<Eqn>,
t: Eqn::T,
linear_solver: impl LinearSolver<SdirkCallable<Eqn>>,
linear_solver: impl LinearSolver<Eqn::M>,
) where
Eqn: OdeEquationsImplicit,
Eqn::M: Matrix + DefaultSolver,
Expand Down
4 changes: 2 additions & 2 deletions src/nonlinear_solver/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

use std::rc::Rc;

use crate::{error::DiffsolError, Matrix, NonLinearOpJacobian};
use crate::{error::DiffsolError, Matrix, NonLinearOpJacobian, NonLinearOp};
use convergence::Convergence;

pub struct NonLinearSolveSolution<V> {
Expand Down Expand Up @@ -36,7 +36,7 @@ pub trait NonLinearSolver<M: Matrix>: Default {
}

/// Solve the problem `F(x) = 0` in place.
fn solve_in_place<C: NonLinearOpJacobian<V=M::V, T=M::T, M=M>>(&mut self, op: &C, x: &mut C::V, t: C::T, error_y: &C::V)
fn solve_in_place<C: NonLinearOp<V=M::V, T=M::T, M=M>>(&mut self, op: &C, x: &mut C::V, t: C::T, error_y: &C::V)
-> Result<(), DiffsolError>;

/// Solve the linearised problem `J * x = b`, where `J` was calculated using [Self::reset_jacobian].
Expand Down
4 changes: 2 additions & 2 deletions src/nonlinear_solver/newton.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::rc::Rc;
use crate::{
error::{DiffsolError, NonLinearSolverError},
non_linear_solver_error, Convergence, ConvergenceStatus, LinearSolver,
NonLinearOpJacobian, NonLinearSolver, Vector, Matrix,
NonLinearOpJacobian, NonLinearSolver, Vector, Matrix, NonLinearOp
};

pub fn newton_iteration<V: Vector>(
Expand Down Expand Up @@ -91,7 +91,7 @@ impl<M: Matrix, Ls: LinearSolver<M>> NonLinearSolver<M> for NewtonNonlinearSolve
self.linear_solver.solve_in_place(x)
}

fn solve_in_place<C: NonLinearOpJacobian<V=M::V, T=M::T, M=M>>(
fn solve_in_place<C: NonLinearOp<V=M::V, T=M::T, M=M>>(
&mut self,
op: &C,
xn: &mut M::V,
Expand Down
45 changes: 23 additions & 22 deletions src/ode_solver/bdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
op::bdf::BdfCallable,
scalar::scale,
vector::DefaultDenseMatrix,
AugmentedOdeEquations, BdfState, Checkpointing, DenseMatrix, IndexType, InitOp, JacobianUpdate,
AugmentedOdeEquations, BdfState, Checkpointing, DenseMatrix, IndexType, JacobianUpdate,
MatrixViewMut, NewtonNonlinearSolver, NonLinearOp, NonLinearSolver, OdeEquationsImplicit,
OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, Op, Scalar,
Vector, VectorRef, VectorView, VectorViewMut,
Expand Down Expand Up @@ -87,7 +87,7 @@ pub struct Bdf<
> {
nonlinear_solver: Nls,
ode_problem: Option<OdeSolverProblem<Eqn>>,
op: Option<Rc<BdfCallable<Eqn>>>,
op: Option<BdfCallable<Eqn>>,
n_equal_steps: usize,
y_delta: Eqn::V,
g_delta: Eqn::V,
Expand Down Expand Up @@ -259,12 +259,12 @@ where
//let y = &self.y_predict;
//let t = self.t_predict;
if self.jacobian_update.check_rhs_jacobian_update(c, &state) {
self.op.unwrap().set_jacobian_is_stale();
self.nonlinear_solver.reset_jacobian(y, t);
self.op.as_mut().unwrap().set_jacobian_is_stale();
self.nonlinear_solver.reset_jacobian(self.op.as_ref().unwrap(), y, t);
self.jacobian_update.update_rhs_jacobian();
self.jacobian_update.update_jacobian(c);
} else if self.jacobian_update.check_jacobian_update(c, &state) {
self.nonlinear_solver.reset_jacobian(y, t);
self.nonlinear_solver.reset_jacobian(self.op.as_ref().unwrap(), y, t);
self.jacobian_update.update_jacobian(c);
}
}
Expand Down Expand Up @@ -297,7 +297,7 @@ where
}
}

self.nonlinear_problem_op().set_c(new_h, self.alpha[order]);
self.op.as_mut().unwrap().set_c(new_h, self.alpha[order]);

self.state.as_mut().unwrap().h = new_h;

Expand Down Expand Up @@ -333,7 +333,7 @@ where
if self.ode_problem.as_ref().unwrap().integrate_out {
let out = self.ode_problem.as_ref().unwrap().eqn.out().unwrap();
out.call_inplace(&self.y_predict, self.t_predict, &mut state.dg);
self.nonlinear_solver.problem().f.integrate_out(
self.op.as_ref().unwrap().integrate_out(
&state.dg,
&state.gdiff,
self.gamma.as_slice(),
Expand Down Expand Up @@ -363,7 +363,7 @@ where
if op.eqn().integrate_out() {
let out = op.eqn().out().unwrap();
out.call_inplace(&state.s[i], self.t_predict, &mut state.dsg[i]);
self.nonlinear_solver.problem().f.integrate_out(
self.op.as_ref().unwrap().integrate_out(
&state.dsg[i],
&state.sgdiff[i],
self.gamma.as_slice(),
Expand Down Expand Up @@ -414,7 +414,7 @@ where
Self::_predict_using_diff(&mut self.y_predict, &state.diff, state.order);

// update psi and c (h, D, y0 has changed)
self.nonlinear_problem_op().set_psi_and_y0(
self.op.as_mut().unwrap().set_psi_and_y0(
&state.diff,
self.gamma.as_slice(),
self.alpha.as_slice(),
Expand Down Expand Up @@ -514,7 +514,7 @@ where

// update for new state
{
let dy_new = self.nonlinear_solver.problem().f.tmp();
let dy_new = self.op.as_ref().unwrap().tmp();
let y_new = &self.y_predict;
Rc::get_mut(op.eqn_mut())
.unwrap()
Expand Down Expand Up @@ -549,7 +549,7 @@ where
let s_new = &mut self.state.as_mut().unwrap().s[i];
s_new.copy_from(&self.s_predict);
self.nonlinear_solver
.solve_other_in_place(&*op, s_new, t_new, &self.s_predict)?;
.solve_in_place(&*op, s_new, t_new, &self.s_predict)?;
self.statistics.number_of_nonlinear_solver_iterations +=
self.nonlinear_solver.convergence().niter();
let s_new = &*s_new;
Expand Down Expand Up @@ -577,7 +577,7 @@ where
Eqn: OdeEquationsImplicit,
AugmentedEqn: AugmentedOdeEquations<Eqn> + OdeEquationsImplicit,
M: DenseMatrix<T = Eqn::T, V = Eqn::V>,
Nls: NonLinearSolver<BdfCallable<Eqn>>,
Nls: NonLinearSolver<Eqn::M>,
for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
{
Expand Down Expand Up @@ -702,15 +702,15 @@ where
state.check_consistent_with_problem(problem)?;

// setup linear solver for first step
let bdf_callable = Rc::new(BdfCallable::new(problem));
let bdf_callable = BdfCallable::new(problem);
bdf_callable.set_c(state.h, self.alpha[state.order]);

let nonlinear_problem = SolverProblem::new_from_ode_problem(bdf_callable, problem);
self.nonlinear_solver.set_problem(&nonlinear_problem);
self.nonlinear_solver.set_problem(&bdf_callable, problem.rtol, problem.atol.clone());
self.nonlinear_solver
.convergence_mut()
.set_max_iter(Self::NEWTON_MAXITER);
self.nonlinear_solver.reset_jacobian(&state.y, state.t);
self.nonlinear_solver.reset_jacobian(&bdf_callable, &state.y, state.t);
self.op = Some(bdf_callable);

// setup root solver
if let Some(root_fn) = problem.eqn.root() {
Expand Down Expand Up @@ -773,6 +773,7 @@ where

// solve BDF equation using y0 as starting point
let mut solve_result = self.nonlinear_solver.solve_in_place(
self.op.as_ref().unwrap(),
&mut self.y_delta,
self.t_predict,
&self.y_predict,
Expand Down Expand Up @@ -880,7 +881,7 @@ where

// update statistics
self.statistics.number_of_linear_solver_setups =
self.nonlinear_problem_op().number_of_jac_evals();
self.op.as_ref().unwrap().number_of_jac_evals();
self.statistics.number_of_steps += 1;
self.jacobian_update.step();

Expand Down Expand Up @@ -1025,7 +1026,7 @@ where
Eqn: OdeEquationsImplicit,
AugmentedEqn: AugmentedOdeEquations<Eqn> + OdeEquationsImplicit,
M: DenseMatrix<T = Eqn::T, V = Eqn::V>,
Nls: NonLinearSolver<BdfCallable<Eqn>>,
Nls: NonLinearSolver<Eqn::M>,
for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
{
Expand Down Expand Up @@ -1070,14 +1071,14 @@ where
Eqn: OdeEquationsAdjoint,
AugmentedEqn: AugmentedOdeEquations<Eqn> + OdeEquationsAdjoint,
M: DenseMatrix<T = Eqn::T, V = Eqn::V>,
Nls: NonLinearSolver<BdfCallable<Eqn>>,
Nls: NonLinearSolver<Eqn::M>,
for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
{
type AdjointSolver = Bdf<
M,
AdjointEquations<Eqn, Self>,
Nls::SelfNewOp<BdfCallable<AdjointEquations<Eqn, Self>>>,
Nls,
AdjointEquations<Eqn, Self>,
>;

Expand Down Expand Up @@ -1114,12 +1115,12 @@ where
// initialise adjoint state
let mut state =
Self::State::new_without_initialise_augmented(&adj_problem, &mut new_augmented_eqn)?;
let mut init_nls = Nls::SelfNewOp::<InitOp<AdjointEquations<Eqn, Self>>>::default();
let mut init_nls = Nls::default();
let new_augmented_eqn =
state.set_consistent_augmented(&adj_problem, new_augmented_eqn, &mut init_nls)?;

// create adjoint solver
let adjoint_nls = Nls::SelfNewOp::<BdfCallable<AdjointEquations<Eqn, Self>>>::default();
let adjoint_nls = Nls::default();
let mut adjoint_solver = Self::AdjointSolver::new(adjoint_nls);

// setup the solver
Expand Down
34 changes: 17 additions & 17 deletions src/ode_solver/sdirk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::SdirkState;
use crate::SensEquations;
use crate::Tableau;
use crate::{
nonlinear_solver::NonLinearSolver, op::sdirk::SdirkCallable, scale, solver::SolverProblem,
nonlinear_solver::NonLinearSolver, op::sdirk::SdirkCallable, scale,
AugmentedOdeEquations, DenseMatrix, JacobianUpdate, NonLinearOp, OdeEquations, OdeSolverMethod,
OdeSolverProblem, OdeSolverState, Op, Scalar, Vector, VectorViewMut, StateRef, StateRefMut,
OdeEquationsSens, OdeEquationsImplicit
Expand All @@ -43,7 +43,7 @@ pub type SdirkAdj<M, Eqn, LS> = Sdirk<
impl<M, Eqn, LS> SensitivitiesOdeSolverMethod<Eqn> for Sdirk<M, Eqn, LS, SensEquations<Eqn>>
where
M: DenseMatrix<T = Eqn::T, V = Eqn::V>,
LS: LinearSolver<SdirkCallable<Eqn>>,
LS: LinearSolver<Eqn::M>,
Eqn: OdeEquationsSens,
for<'a> &'a Eqn::V: VectorRef<Eqn::V>,
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
Expand Down Expand Up @@ -308,7 +308,7 @@ where
if state.t + state.h > tstop + troundoff {
let factor = (tstop - state.t) / state.h;
state.h *= factor;
self.nonlinear_solver.problem().f.set_h(state.h);
self.op.as_mut().unwrap().set_h(state.h);
}
Ok(None)
}
Expand Down Expand Up @@ -350,7 +350,7 @@ where

// solve
let op = self.s_op.as_ref().unwrap();
self.nonlinear_solver.solve_other_in_place(op, ds, t, s0)?;
self.nonlinear_solver.solve_in_place(op, ds, t, s0)?;

self.old_y_sens[j].copy_from(&op.get_last_f_eval());
self.statistics.number_of_nonlinear_solver_iterations +=
Expand Down Expand Up @@ -394,14 +394,14 @@ where

fn _jacobian_updates(&mut self, h: Eqn::T, state: SolverState) {
if self.jacobian_update.check_rhs_jacobian_update(h, &state) {
self.nonlinear_solver.problem().f.set_jacobian_is_stale();
self.op.as_mut().unwrap().set_jacobian_is_stale();
self.nonlinear_solver
.reset_jacobian(&self.old_f, self.state.as_ref().unwrap().t);
.reset_jacobian(self.op.as_ref().unwrap(), &self.old_f, self.state.as_ref().unwrap().t);
self.jacobian_update.update_rhs_jacobian();
self.jacobian_update.update_jacobian(h);
} else if self.jacobian_update.check_jacobian_update(h, &state) {
self.nonlinear_solver
.reset_jacobian(&self.old_f, self.state.as_ref().unwrap().t);
.reset_jacobian(self.op.as_ref().unwrap(), &self.old_f, self.state.as_ref().unwrap().t);
self.jacobian_update.update_jacobian(h);
}
}
Expand All @@ -417,7 +417,7 @@ where
}

// update h for new step size
self.nonlinear_solver.problem().f.set_h(new_h);
self.op.as_mut().unwrap().set_h(new_h);

// update state
self.state.as_mut().unwrap().h = new_h;
Expand Down Expand Up @@ -463,19 +463,18 @@ where
problem: &OdeSolverProblem<Eqn>,
) -> Result<(), DiffsolError> {
// setup linear solver for first step
let callable = Rc::new(SdirkCallable::new(problem, self.gamma));
let callable = SdirkCallable::new(problem, self.gamma);
callable.set_h(state.h);
self.jacobian_update.update_jacobian(state.h);
self.jacobian_update.update_rhs_jacobian();
let nonlinear_problem = SolverProblem::new_from_ode_problem(callable.clone(), problem);
self.nonlinear_solver.set_problem(&nonlinear_problem);
self.op = Some(callable);
self.nonlinear_solver.set_problem(&callable, problem.rtol, problem.atol.clone());

// set max iterations for nonlinear solver
self.nonlinear_solver
.convergence_mut()
.set_max_iter(Self::NEWTON_MAXITER);
self.nonlinear_solver.reset_jacobian(&state.y, state.t);
self.nonlinear_solver.reset_jacobian(&callable, &state.y, state.t);
self.op = Some(callable);

// update statistics
self.statistics = BdfStatistics::default();
Expand Down Expand Up @@ -562,7 +561,7 @@ where

for i in start..self.tableau.s() {
let t = t0 + self.tableau.c()[i] * h;
self.nonlinear_solver.problem().f.set_phi(
self.op.as_mut().unwrap().set_phi(
&self.diff.columns(0, i),
&self.state.as_ref().unwrap().y,
&self.a_rows[i],
Expand All @@ -571,6 +570,7 @@ where
Self::predict_stage(i, &self.diff, &mut self.old_f, &self.tableau);

let mut solve_result = self.nonlinear_solver.solve_in_place(
self.op.as_ref().unwrap(),
&mut self.old_f,
t,
&self.state.as_ref().unwrap().y,
Expand All @@ -582,7 +582,7 @@ where
if solve_result.is_ok() {
// old_y now has the new y soln and old_f has the new dy soln
self.old_y
.copy_from(&self.nonlinear_solver.problem().f.get_last_f_eval());
.copy_from(&self.op.as_ref().unwrap().get_last_f_eval());
if self.s_op.is_some() {
solve_result = self.solve_for_sensitivities(i, t);
}
Expand Down Expand Up @@ -702,7 +702,7 @@ where

// update statistics
self.statistics.number_of_linear_solver_setups =
self.nonlinear_solver.problem().f.number_of_jac_evals();
self.op.as_ref().unwrap().number_of_jac_evals();
self.statistics.number_of_steps += 1;
self.jacobian_update.step();

Expand Down Expand Up @@ -886,7 +886,7 @@ where
impl<M, Eqn, AugmentedEqn, LS> AugmentedOdeSolverMethod<Eqn, AugmentedEqn>
for Sdirk<M, Eqn, LS, AugmentedEqn>
where
LS: LinearSolver<SdirkCallable<Eqn>>,
LS: LinearSolver<Eqn::M>,
M: DenseMatrix<T = Eqn::T, V = Eqn::V>,
Eqn: OdeEquationsImplicit,
AugmentedEqn: AugmentedOdeEquations<Eqn>,
Expand Down
Loading

0 comments on commit 2fd87e9

Please sign in to comment.