Automatic Differentiation
 
Loading...
Searching...
No Matches
idas_service.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUNCTOR_IDAS_SERVICE_HPP
2#define STAN_MATH_REV_FUNCTOR_IDAS_SERVICE_HPP
3
8#include <idas/idas.h>
9#include <nvector/nvector_serial.h>
10#include <sunmatrix/sunmatrix_dense.h>
11#include <sunlinsol/sunlinsol_dense.h>
12#include <sundials/sundials_context.h>
13#include <ostream>
14#include <vector>
15#include <algorithm>
16#include <type_traits>
17
18namespace stan {
19namespace math {
20
31template <typename dae_type>
33 sundials::Context sundials_context_;
34 int ns;
35 N_Vector nv_yy;
36 N_Vector nv_yp;
37 N_Vector* nv_yys;
38 N_Vector* nv_yps;
39 void* mem;
40 SUNMatrix A;
41 SUNLinearSolver LS;
42
49 idas_service(double t0, dae_type& dae)
51 ns(dae.ns),
52 nv_yy(N_VNew_Serial(dae.N, sundials_context_)),
53 nv_yp(N_VNew_Serial(dae.N, sundials_context_)),
54 nv_yys(nullptr),
55 nv_yps(nullptr),
56 mem(IDACreate(sundials_context_)),
57 A(SUNDenseMatrix(dae.N, dae.N, sundials_context_)),
58 LS(SUNLinSol_Dense(nv_yy, A, sundials_context_)) {
59 const int n = dae.N;
60 for (auto i = 0; i < n; ++i) {
61 NV_Ith_S(nv_yy, i) = dae.dbl_yy[i];
62 NV_Ith_S(nv_yp, i) = dae.dbl_yp[i];
63 }
64
65 CHECK_IDAS_CALL(IDAInit(mem, dae_type::idas_res, t0, nv_yy, nv_yp));
66 CHECK_IDAS_CALL(IDASetUserData(mem, static_cast<void*>(&dae)));
67 CHECK_IDAS_CALL(IDASetLinearSolver(mem, LS, A));
68
70 }
71
73 SUNLinSolFree(LS);
74 SUNMatDestroy(A);
75 IDAFree(&mem);
76 N_VDestroy(nv_yy);
77 N_VDestroy(nv_yp);
78 if (dae_type::use_fwd_sens) {
79 N_VDestroyVectorArray(nv_yys, ns);
80 N_VDestroyVectorArray(nv_yps, ns);
81 }
82 }
83
84 template <typename dae_t = dae_type,
85 std::enable_if_t<!dae_t::use_fwd_sens>* = nullptr>
86 void idas_sens_init(N_Vector* yys, N_Vector* yps, int ns, int n) {}
87
88 template <typename dae_t = dae_type,
89 std::enable_if_t<dae_t::use_fwd_sens>* = nullptr>
90 void idas_sens_init(N_Vector*& yys, N_Vector*& yps, int ns, int n) {
91 yys = N_VCloneVectorArray(ns, nv_yy);
92 yps = N_VCloneVectorArray(ns, nv_yp);
93 for (size_t is = 0; is < ns; ++is) {
94 N_VConst(RCONST(0.0), yys[is]);
95 N_VConst(RCONST(0.0), yps[is]);
96 }
97 set_init_sens(yys, yps, n);
99 IDASensInit(mem, ns, IDA_STAGGERED, dae_type::idas_sens_res, yys, yps));
100 }
101
102 template <typename dae_t = dae_type,
103 std::enable_if_t<dae_t::is_var_yy0 && dae_t::is_var_yp0>* = nullptr>
104 void set_init_sens(N_Vector*& yys, N_Vector*& yps, int n) {
105 for (size_t i = 0; i < n; ++i) {
106 NV_Ith_S(yys[i], i) = 1.0;
107 }
108 for (size_t i = 0; i < n; ++i) {
109 NV_Ith_S(yps[i + n], i) = 1.0;
110 }
111 }
112
113 template <
114 typename dae_t = dae_type,
115 std::enable_if_t<dae_t::is_var_yy0 && (!dae_t::is_var_yp0)>* = nullptr>
116 void set_init_sens(N_Vector*& yys, N_Vector*& yps, int n) {
117 for (size_t i = 0; i < n; ++i) {
118 NV_Ith_S(yys[i], i) = 1.0;
119 }
120 }
121
122 template <
123 typename dae_t = dae_type,
124 std::enable_if_t<(!dae_t::is_var_yy0) && dae_t::is_var_yp0>* = nullptr>
125 void set_init_sens(N_Vector*& yys, N_Vector*& yps, int n) {
126 for (size_t i = 0; i < n; ++i) {
127 NV_Ith_S(yps[i], i) = 1.0;
128 }
129 }
130
131 template <
132 typename dae_t = dae_type,
133 std::enable_if_t<(!dae_t::is_var_yy0) && (!dae_t::is_var_yp0)>* = nullptr>
134 void set_init_sens(N_Vector*& yys, N_Vector*& yps, int n) {}
135};
136} // namespace math
137} // namespace stan
138
139#endif
#define CHECK_IDAS_CALL(call)
std::vector< Eigen::Matrix< stan::return_type_t< T_yy, T_yp, T_Args... >, -1, 1 > > dae(const F &f, const T_yy &yy0, const T_yp &yp0, double t0, const std::vector< double > &ts, std::ostream *msgs, const T_Args &... args)
Solve the DAE initial value problem f(t, y, y')=0, y(t0) = yy0, y'(t0)=yp0 at a set of times,...
Definition dae.hpp:167
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
void idas_sens_init(N_Vector *yys, N_Vector *yps, int ns, int n)
void set_init_sens(N_Vector *&yys, N_Vector *&yps, int n)
sundials::Context sundials_context_
void idas_sens_init(N_Vector *&yys, N_Vector *&yps, int ns, int n)
idas_service(double t0, dae_type &dae)
Construct IDAS ODE mem & workspace.
For each type of Ode(with different rhs functor F and senstivity parameters), we allocate mem and wor...