Automatic Differentiation
 
Loading...
Searching...
No Matches
ode_adjoint.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUNCTOR_ODE_ADJOINT_HPP
2#define STAN_MATH_REV_FUNCTOR_ODE_ADJOINT_HPP
3
7#include <ostream>
8#include <vector>
9
10namespace stan {
11namespace math {
12
66template <typename F, typename T_y0, typename T_t0, typename T_ts,
67 typename T_abs_tol_fwd, typename T_abs_tol_bwd, typename... T_Args,
68 require_all_eigen_col_vector_t<T_y0, T_abs_tol_fwd,
69 T_abs_tol_bwd>* = nullptr,
70 require_any_not_st_arithmetic<T_y0, T_t0, T_ts, T_Args...>* = nullptr>
72 const char* function_name, F&& f, const T_y0& y0, const T_t0& t0,
73 const std::vector<T_ts>& ts, double relative_tolerance_forward,
74 const T_abs_tol_fwd& absolute_tolerance_forward,
75 double relative_tolerance_backward,
76 const T_abs_tol_bwd& absolute_tolerance_backward,
77 double relative_tolerance_quadrature, double absolute_tolerance_quadrature,
78 long int max_num_steps, // NOLINT(runtime/int)
79 long int num_steps_between_checkpoints, // NOLINT(runtime/int)
80 int interpolation_polynomial, int solver_forward, int solver_backward,
81 std::ostream* msgs, const T_Args&... args) {
82 using integrator_vari
85 auto integrator = new integrator_vari(
86 function_name, std::forward<F>(f), eval(y0), t0, ts,
87 relative_tolerance_forward, absolute_tolerance_forward,
88 relative_tolerance_backward, absolute_tolerance_backward,
89 relative_tolerance_quadrature, absolute_tolerance_quadrature,
90 max_num_steps, num_steps_between_checkpoints, interpolation_polynomial,
91 solver_forward, solver_backward, msgs, eval(args)...);
92 return integrator->solution();
93}
94
150template <typename F, typename T_y0, typename T_t0, typename T_ts,
151 typename T_abs_tol_fwd, typename T_abs_tol_bwd, typename... T_Args,
152 require_all_eigen_col_vector_t<T_y0, T_abs_tol_fwd,
153 T_abs_tol_bwd>* = nullptr,
154 require_all_st_arithmetic<T_y0, T_t0, T_ts, T_Args...>* = nullptr>
155std::vector<Eigen::VectorXd> ode_adjoint_impl(
156 const char* function_name, F&& f, const T_y0& y0, const T_t0& t0,
157 const std::vector<T_ts>& ts, double relative_tolerance_forward,
158 const T_abs_tol_fwd& absolute_tolerance_forward,
159 double relative_tolerance_backward,
160 const T_abs_tol_bwd& absolute_tolerance_backward,
161 double relative_tolerance_quadrature, double absolute_tolerance_quadrature,
162 long int max_num_steps, // NOLINT(runtime/int)
163 long int num_steps_between_checkpoints, // NOLINT(runtime/int)
164 int interpolation_polynomial, int solver_forward, int solver_backward,
165 std::ostream* msgs, const T_Args&... args) {
166 std::vector<Eigen::VectorXd> ode_solution;
167 {
168 nested_rev_autodiff nested;
169
170 using integrator_vari
173
174 auto integrator = new integrator_vari(
175 function_name, std::forward<F>(f), eval(y0), t0, ts,
176 relative_tolerance_forward, absolute_tolerance_forward,
177 relative_tolerance_backward, absolute_tolerance_backward,
178 relative_tolerance_quadrature, absolute_tolerance_quadrature,
179 max_num_steps, num_steps_between_checkpoints, interpolation_polynomial,
180 solver_forward, solver_backward, msgs, eval(args)...);
181
182 ode_solution = integrator->solution();
183 }
184 return ode_solution;
185}
186
239template <typename F, typename T_y0, typename T_t0, typename T_ts,
240 typename T_abs_tol_fwd, typename T_abs_tol_bwd, typename... T_Args,
241 require_all_eigen_col_vector_t<T_y0, T_abs_tol_fwd,
242 T_abs_tol_bwd>* = nullptr>
244 F&& f, const T_y0& y0, const T_t0& t0, const std::vector<T_ts>& ts,
245 double relative_tolerance_forward,
246 const T_abs_tol_fwd& absolute_tolerance_forward,
247 double relative_tolerance_backward,
248 const T_abs_tol_bwd& absolute_tolerance_backward,
249 double relative_tolerance_quadrature, double absolute_tolerance_quadrature,
250 long int max_num_steps, // NOLINT(runtime/int)
251 long int num_steps_between_checkpoints, // NOLINT(runtime/int)
252 int interpolation_polynomial, int solver_forward, int solver_backward,
253 std::ostream* msgs, const T_Args&... args) {
254 return ode_adjoint_impl(
255 "ode_adjoint_tol_ctl", std::forward<F>(f), y0, t0, ts,
256 relative_tolerance_forward, absolute_tolerance_forward,
257 relative_tolerance_backward, absolute_tolerance_backward,
258 relative_tolerance_quadrature, absolute_tolerance_quadrature,
259 max_num_steps, num_steps_between_checkpoints, interpolation_polynomial,
260 solver_forward, solver_backward, msgs, args...);
261}
262
263} // namespace math
264} // namespace stan
265#endif
Integrator interface for CVODES' adjoint ODE solvers (Adams & BDF methods).
A class following the RAII idiom to start and recover nested autodiff scopes.
require_all_t< std::is_arithmetic< scalar_type_t< std::decay_t< Types > > >... > require_all_st_arithmetic
Require all of the scalar types satisfy std::is_arithmetic.
require_any_not_t< std::is_arithmetic< scalar_type_t< std::decay_t< Types > > >... > require_any_not_st_arithmetic
Any of the scalar types do not satisfy std::is_arithmetic.
require_all_t< is_eigen_col_vector< std::decay_t< Types > >... > require_all_eigen_col_vector_t
Require all of the types satisfy is_eigen_col_vector.
auto ode_adjoint_tol_ctl(F &&f, const T_y0 &y0, const T_t0 &t0, const std::vector< T_ts > &ts, double relative_tolerance_forward, const T_abs_tol_fwd &absolute_tolerance_forward, double relative_tolerance_backward, const T_abs_tol_bwd &absolute_tolerance_backward, double relative_tolerance_quadrature, double absolute_tolerance_quadrature, long int max_num_steps, long int num_steps_between_checkpoints, int interpolation_polynomial, int solver_forward, int solver_backward, std::ostream *msgs, const T_Args &... args)
Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of times, { t1,...
auto ode_adjoint_impl(const char *function_name, F &&f, const T_y0 &y0, const T_t0 &t0, const std::vector< T_ts > &ts, double relative_tolerance_forward, const T_abs_tol_fwd &absolute_tolerance_forward, double relative_tolerance_backward, const T_abs_tol_bwd &absolute_tolerance_backward, double relative_tolerance_quadrature, double absolute_tolerance_quadrature, long int max_num_steps, long int num_steps_between_checkpoints, int interpolation_polynomial, int solver_forward, int solver_backward, std::ostream *msgs, const T_Args &... args)
Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of times, { t1,...
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
typename plain_type< T >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...