Skip to content

Commit

Permalink
Merge pull request #1072 from stan-dev/feature/issue-1071-ode-time-va…
Browse files Browse the repository at this point in the history
…rying

Feature/issue 1071 ode time varying
  • Loading branch information
wds15 authored May 10, 2019
2 parents 0784a82 + 45c9a65 commit ec52b8b
Show file tree
Hide file tree
Showing 17 changed files with 1,051 additions and 343 deletions.
137 changes: 124 additions & 13 deletions stan/math/prim/arr/functor/coupled_ode_observer.hpp
Original file line number Diff line number Diff line change
@@ -1,38 +1,149 @@
#ifndef STAN_MATH_PRIM_ARR_FUNCTOR_COUPLED_ODE_OBSERVER_HPP
#define STAN_MATH_PRIM_ARR_FUNCTOR_COUPLED_ODE_OBSERVER_HPP

#include <stan/math/prim/scal/err/check_size_match.hpp>
#include <stan/math/prim/scal/err/check_less.hpp>
#include <stan/math/prim/scal/meta/is_constant_struct.hpp>
#include <stan/math/prim/scal/meta/operands_and_partials.hpp>
#include <stan/math/prim/scal/meta/broadcast_array.hpp>
#include <stan/math/prim/arr/fun/sum.hpp>
#include <stan/math/prim/scal/meta/return_type.hpp>

#include <vector>

namespace stan {
namespace math {

/**
* Observer for the coupled states. Holds a reference to
* an externally defined vector of vectors passed in at
* construction time.
* Observer for the coupled states. Holds a reference to an
* externally defined vector of vectors passed in at construction time
* which holds the final result on the AD stack. Thus, whenever any of
* the inputs is varying, then the output will be varying as well. The
* sensitivities of the initials and the parameters are taken from the
* coupled state in the order as defined by the
* coupled_ode_system. The sensitivities at each time-point is simply
* the ODE RHS evaluated at that time point.
*
* The output of this class is for all time-points in the ts vector
* which does not contain the initial time-point by the convention
* used in stan-math.
*
*/
template <typename F, typename T1, typename T2, typename T_t0, typename T_ts>
struct coupled_ode_observer {
std::vector<std::vector<double> >& y_coupled_;
int n_;
typedef typename stan::return_type<T1, T2, T_t0, T_ts>::type return_t;

typedef operands_and_partials<std::vector<T1>, std::vector<T2>, T_t0, T_ts>
ops_partials_t;

const F& f_;
const std::vector<T1>& y0_;
const T_t0& t0_;
const std::vector<T_ts>& ts_;
const std::vector<T2>& theta_;
const std::vector<double>& x_;
const std::vector<int>& x_int_;
std::ostream* msgs_;
std::vector<std::vector<return_t>>& y_;
const std::size_t N_;
const std::size_t M_;
const std::size_t index_offset_theta_;
int next_ts_index_;

/**
* Construct a coupled ODE observer from the specified coupled
* Construct a coupled ODE observer for the specified coupled
* vector.
*
* @param y_coupled reference to a vector of vector of doubles.
* @tparam F type of ODE system function.
* @tparam T1 type of scalars for initial values.
* @tparam T2 type of scalars for parameters.
* @tparam T_t0 type of scalar of initial time point.
* @tparam T_ts type of time-points where ODE solution is returned.
* @param[in] f functor for the base ordinary differential equation.
* @param[in] y0 initial state.
* @param[in] theta parameter vector for the ODE.
* @param[in] t0 initial time.
* @param[in] ts times of the desired solutions, in strictly
* increasing order, all greater than the initial time.
* @param[in] x continuous data vector for the ODE.
* @param[in] x_int integer data vector for the ODE.
* @param[out] msgs the print stream for warning messages.
* @param[out] y reference to a vector of vector of the final return
*/
explicit coupled_ode_observer(std::vector<std::vector<double> >& y_coupled)
: y_coupled_(y_coupled), n_(0) {}
coupled_ode_observer(const F& f, const std::vector<T1>& y0,
const std::vector<T2>& theta, const T_t0& t0,
const std::vector<T_ts>& ts,
const std::vector<double>& x,
const std::vector<int>& x_int, std::ostream* msgs,
std::vector<std::vector<return_t>>& y)
: f_(f),
y0_(y0),
t0_(t0),
ts_(ts),
theta_(theta),
x_(x),
x_int_(x_int),
msgs_(msgs),
y_(y),
N_(y0.size()),
M_(theta.size()),
index_offset_theta_(is_constant_struct<T1>::value ? 0 : N_ * N_),
next_ts_index_(0) {}

/**
* Callback function for Boost's ODE solver to record values.
* Callback function for ODE solvers to record values. The coupled
* state returned from the solver is added directly to the AD tree.
*
* The coupled state follows the convention as defined in the
* coupled_ode_system. In brief, the coupled state consists of {f,
* df/dy0, df/dtheta}. Here df/dy0 and df/dtheta are only present if
* their respective sensitivites have been requested.
*
* @param coupled_state solution at the specified time.
* @param t time of solution.
* @param t time of solution. The time must correspond to the
* element ts[next_ts_index_]
*/
void operator()(const std::vector<double>& coupled_state, double t) {
y_coupled_[n_] = coupled_state;
n_++;
check_less("coupled_ode_observer", "time-state number", next_ts_index_,
ts_.size());

std::vector<return_t> yt;
yt.reserve(N_);

ops_partials_t ops_partials(y0_, theta_, t0_, ts_[next_ts_index_]);

std::vector<double> dy_dt;
if (!is_constant_struct<T_ts>::value) {
std::vector<double> y_dbl(coupled_state.begin(),
coupled_state.begin() + N_);
dy_dt = f_(value_of(ts_[next_ts_index_]), y_dbl, value_of(theta_), x_,
x_int_, msgs_);
check_size_match("coupled_ode_observer", "dy_dt", dy_dt.size(), "states",
N_);
}

for (size_t j = 0; j < N_; j++) {
// iterate over parameters for each equation
if (!is_constant_struct<T1>::value) {
for (std::size_t k = 0; k < N_; k++)
ops_partials.edge1_.partials_[k] = coupled_state[N_ + N_ * k + j];
}

if (!is_constant_struct<T2>::value) {
for (std::size_t k = 0; k < M_; k++)
ops_partials.edge2_.partials_[k]
= coupled_state[N_ + index_offset_theta_ + N_ * k + j];
}

if (!is_constant_struct<T_ts>::value) {
ops_partials.edge4_.partials_[0] = dy_dt[j];
}

yt.emplace_back(ops_partials.build(coupled_state[j]));
}

y_.emplace_back(yt);
next_ts_index_++;
}
};

Expand Down
16 changes: 0 additions & 16 deletions stan/math/prim/arr/functor/coupled_ode_system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,6 @@ class coupled_ode_system<F, double, double> {
state[n] = y0_dbl_[n];
return state;
}

/**
* Returns the states of the base system provided on construction.
*
* <p>In this class's implementation, the coupled system is
* equivalent to the base system.
*
* @param y the vector of coupled states after solving the ode. Each
* inner vector is size <code>size()</code>.
* @return the states of the base ode system corresponding to
* <code>y</code>. Each inner vector is size <code>N</code>.
*/
std::vector<std::vector<double> > decouple_states(
const std::vector<std::vector<double> >& y) const {
return y;
}
};

} // namespace math
Expand Down
62 changes: 40 additions & 22 deletions stan/math/prim/arr/functor/integrate_ode_rk45.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
#include <boost/serialization/array_wrapper.hpp>
#endif
#include <boost/numeric/odeint.hpp>
#include <algorithm>
#include <ostream>
#include <functional>
#include <iterator>
#include <vector>

namespace stan {
Expand Down Expand Up @@ -47,6 +50,8 @@ namespace math {
* @tparam F type of ODE system function.
* @tparam T1 type of scalars for initial values.
* @tparam T2 type of scalars for parameters.
* @tparam T_t0 type of scalar of initial time point.
* @tparam T_ts type of time-points where ODE solution is returned.
* @param[in] f functor for the base ordinary differential equation.
* @param[in] y0 initial state.
* @param[in] t0 initial time.
Expand All @@ -65,10 +70,10 @@ namespace math {
* @return a vector of states, each state being a vector of the
* same size as the state variable, corresponding to a time in ts.
*/
template <typename F, typename T1, typename T2>
std::vector<std::vector<typename stan::return_type<T1, T2>::type> >
integrate_ode_rk45(const F& f, const std::vector<T1>& y0, double t0,
const std::vector<double>& ts, const std::vector<T2>& theta,
template <typename F, typename T1, typename T2, typename T_t0, typename T_ts>
std::vector<std::vector<typename stan::return_type<T1, T2, T_t0, T_ts>::type>>
integrate_ode_rk45(const F& f, const std::vector<T1>& y0, const T_t0& t0,
const std::vector<T_ts>& ts, const std::vector<T2>& theta,
const std::vector<double>& x, const std::vector<int>& x_int,
std::ostream* msgs = nullptr,
double relative_tolerance = 1e-6,
Expand All @@ -78,16 +83,19 @@ integrate_ode_rk45(const F& f, const std::vector<T1>& y0, double t0,
using boost::numeric::odeint::max_step_checker;
using boost::numeric::odeint::runge_kutta_dopri5;

const double t0_dbl = value_of(t0);
const std::vector<double> ts_dbl = value_of(ts);

check_finite("integrate_ode_rk45", "initial state", y0);
check_finite("integrate_ode_rk45", "initial time", t0);
check_finite("integrate_ode_rk45", "times", ts);
check_finite("integrate_ode_rk45", "initial time", t0_dbl);
check_finite("integrate_ode_rk45", "times", ts_dbl);
check_finite("integrate_ode_rk45", "parameter vector", theta);
check_finite("integrate_ode_rk45", "continuous data", x);

check_nonzero_size("integrate_ode_rk45", "times", ts);
check_nonzero_size("integrate_ode_rk45", "initial state", y0);
check_ordered("integrate_ode_rk45", "times", ts);
check_less("integrate_ode_rk45", "initial time", t0, ts[0]);
check_nonzero_size("integrate_ode_rk45", "times", ts_dbl);
check_ordered("integrate_ode_rk45", "times", ts_dbl);
check_less("integrate_ode_rk45", "initial time", t0_dbl, ts_dbl[0]);

if (relative_tolerance <= 0)
invalid_argument("integrate_ode_rk45", "relative_tolerance,",
Expand All @@ -104,12 +112,25 @@ integrate_ode_rk45(const F& f, const std::vector<T1>& y0, double t0,

// first time in the vector must be time of initial state
std::vector<double> ts_vec(ts.size() + 1);
ts_vec[0] = t0;
for (size_t n = 0; n < ts.size(); n++)
ts_vec[n + 1] = ts[n];

std::vector<std::vector<double> > y_coupled(ts_vec.size());
coupled_ode_observer observer(y_coupled);
ts_vec[0] = t0_dbl;
std::copy(ts_dbl.begin(), ts_dbl.end(), ts_vec.begin() + 1);

std::vector<std::vector<typename stan::return_type<T1, T2, T_t0, T_ts>::type>>
y;
coupled_ode_observer<F, T1, T2, T_t0, T_ts> observer(f, y0, theta, t0, ts, x,
x_int, msgs, y);
bool observer_initial_recorded = false;

// avoid recording of the initial state which is included by the
// conventions of odeint in the output
auto filtered_observer
= [&](const std::vector<double>& coupled_state, double t) -> void {
if (!observer_initial_recorded) {
observer_initial_recorded = true;
return;
}
observer(coupled_state, t);
};

// the coupled system creates the coupled initial state
std::vector<double> initial_coupled_state = coupled_system.initial_state();
Expand All @@ -119,14 +140,11 @@ integrate_ode_rk45(const F& f, const std::vector<T1>& y0, double t0,
make_dense_output(absolute_tolerance, relative_tolerance,
runge_kutta_dopri5<std::vector<double>, double,
std::vector<double>, double>()),
boost::ref(coupled_system), initial_coupled_state, boost::begin(ts_vec),
boost::end(ts_vec), step_size, observer, max_step_checker(max_num_steps));

// remove the first state corresponding to the initial value
y_coupled.erase(y_coupled.begin());
std::ref(coupled_system), initial_coupled_state, std::begin(ts_vec),
std::end(ts_vec), step_size, filtered_observer,
max_step_checker(max_num_steps));

// the coupled system also encapsulates the decoupling operation
return coupled_system.decouple_states(y_coupled);
return y;
}

} // namespace math
Expand Down
Loading

0 comments on commit ec52b8b

Please sign in to comment.