1#ifndef STAN_MATH_REV_FUN_TRACE_GEN_QUAD_FORM_HPP
2#define STAN_MATH_REV_FUN_TRACE_GEN_QUAD_FORM_HPP
18template <
typename Td,
int Rd,
int Cd,
typename Ta,
int Ra,
int Ca,
typename Tb,
23 const Eigen::Matrix<Ta, Ra, Ca>& A,
24 const Eigen::Matrix<Tb, Rb, Cb>& B)
31 Eigen::Matrix<Td, Rd, Cd>
D_;
32 Eigen::Matrix<Ta, Ra, Ca>
A_;
33 Eigen::Matrix<Tb, Rb, Cb>
B_;
36template <
typename Td,
int Rd,
int Cd,
typename Ta,
int Ra,
int Ca,
typename Tb,
41 const Eigen::Matrix<double, Rd, Cd>& D,
42 const Eigen::Matrix<double, Ra, Ca>& A,
43 const Eigen::Matrix<double, Rb, Cb>& B,
44 Eigen::Matrix<var, Rd, Cd>* varD,
45 Eigen::Matrix<var, Ra, Ca>* varA,
46 Eigen::Matrix<var, Rb, Cb>* varB) {
47 Eigen::Matrix<double, Ca, Cb> AtB;
48 Eigen::Matrix<double, Ra, Cb> BD;
53 AtB.noalias() = A.transpose() * B;
57 (*varB).adj() += adj * (A * BD + AtB * D.transpose());
60 (*varA).adj() += adj * (B * BD.transpose());
63 (*varD).adj() += adj * (B.transpose() * AtB);
75 reinterpret_cast<Eigen::Matrix<var, Rd, Cd>*
>(
76 std::is_same<Td, var>::value ? (&
impl_->D_) : NULL),
77 reinterpret_cast<Eigen::Matrix<var, Ra, Ca>*
>(
78 std::is_same<Ta, var>::value ? (&
impl_->A_) : NULL),
79 reinterpret_cast<Eigen::Matrix<var, Rb, Cb>*
>(
80 std::is_same<Tb, var>::value ? (&
impl_->B_) : NULL));
87template <
typename Td,
typename Ta,
typename Tb,
95 constexpr int Rd = Td::RowsAtCompileTime;
96 constexpr int Cd = Td::ColsAtCompileTime;
97 constexpr int Ra = Ta::RowsAtCompileTime;
98 constexpr int Ca = Ta::ColsAtCompileTime;
99 constexpr int Rb = Tb::RowsAtCompileTime;
100 constexpr int Cb = Tb::ColsAtCompileTime;
108 Ra, Ca, Tb_scal, Rb, Cb>(
113 Tb_scal, Rb, Cb>(baseVari));
135template <
typename Td,
typename Ta,
typename Tb,
150 auto arena_BDT =
to_arena(arena_B.val_op() * arena_D.val_op().transpose());
151 auto arena_AB =
to_arena(arena_A.val_op() * arena_B.val_op());
153 var res = (arena_BDT.transpose() * arena_AB).
trace();
156 [arena_A, arena_B, arena_D, arena_BDT, arena_AB, res]()
mutable {
157 double C_adj = res.adj();
159 arena_A.adj() += C_adj * arena_BDT * arena_B.val_op().transpose();
161 arena_B.adj() += C_adj
162 * (arena_AB * arena_D.val_op()
163 + arena_A.val_op().transpose() * arena_BDT);
165 arena_D.adj() += C_adj * (arena_AB.transpose() * arena_B.val_op());
169 }
else if (!is_constant<Ta>::value && !is_constant<Tb>::value
170 && is_constant<Td>::value) {
171 arena_t<promote_scalar_t<double, Td>> arena_D =
value_of(D);
172 arena_t<promote_scalar_t<var, Ta>> arena_A = A;
173 arena_t<promote_scalar_t<var, Tb>> arena_B = B;
175 auto arena_BDT =
to_arena(arena_B.val_op() * arena_D.transpose());
176 auto arena_AB =
to_arena(arena_A.val_op() * arena_B.val_op());
178 var res = (arena_BDT.transpose() * arena_AB).
trace();
182 double C_adj = res.adj();
184 arena_A.adj() += C_adj * arena_BDT * arena_B.val_op().transpose();
187 * (arena_AB * arena_D + arena_A.val_op().transpose() * arena_BDT);
191 }
else if (!is_constant<Ta>::value && is_constant<Tb>::value
192 && !is_constant<Td>::value) {
193 arena_t<promote_scalar_t<var, Td>> arena_D = D;
194 arena_t<promote_scalar_t<var, Ta>> arena_A = A;
195 arena_t<promote_scalar_t<double, Tb>> arena_B =
value_of(B);
197 auto arena_BDT =
to_arena(arena_B.val_op() * arena_D.val_op().transpose());
198 auto arena_AB =
to_arena(arena_A.val_op() * arena_B.val_op());
200 var res = (arena_BDT.transpose() * arena_A.val_op() * arena_B).
trace();
203 [arena_A, arena_B, arena_D, arena_BDT, arena_AB, res]()
mutable {
204 double C_adj = res.adj();
206 arena_A.adj() += C_adj * arena_BDT * arena_B.transpose();
207 arena_D.adj() += C_adj * arena_AB.transpose() * arena_B;
211 }
else if (!is_constant<Ta>::value && is_constant<Tb>::value
212 && is_constant<Td>::value) {
213 arena_t<promote_scalar_t<double, Td>> arena_D =
value_of(D);
214 arena_t<promote_scalar_t<var, Ta>> arena_A = A;
215 arena_t<promote_scalar_t<double, Tb>> arena_B =
value_of(B);
217 auto arena_BDT =
to_arena(arena_B * arena_D);
219 var res = (arena_BDT.transpose() * arena_A.val_op() * arena_B).
trace();
222 arena_A.adj() += res.adj() * arena_BDT * arena_B.val_op().transpose();
226 }
else if (is_constant<Ta>::value && !is_constant<Tb>::value
227 && !is_constant<Td>::value) {
228 arena_t<promote_scalar_t<var, Td>> arena_D = D;
229 arena_t<promote_scalar_t<double, Ta>> arena_A =
value_of(A);
230 arena_t<promote_scalar_t<var, Tb>> arena_B = B;
232 auto arena_AB =
to_arena(arena_A * arena_B.val_op());
233 auto arena_BDT =
to_arena(arena_B.val_op() * arena_D.val_op());
235 var res = (arena_BDT.transpose() * arena_AB).
trace();
239 double C_adj = res.adj();
243 * (arena_AB * arena_D.val_op() + arena_A.transpose() * arena_BDT);
245 arena_D.adj() += C_adj * (arena_AB.transpose() * arena_B.val_op());
249 }
else if (is_constant<Ta>::value && !is_constant<Tb>::value
250 && is_constant<Td>::value) {
251 arena_t<promote_scalar_t<double, Td>> arena_D =
value_of(D);
252 arena_t<promote_scalar_t<double, Ta>> arena_A =
value_of(A);
253 arena_t<promote_scalar_t<var, Tb>> arena_B = B;
255 auto arena_AB =
to_arena(arena_A * arena_B.val_op());
256 auto arena_BDT =
to_arena(arena_B.val_op() * arena_D.val_op());
258 var res = (arena_BDT.transpose() * arena_AB).
trace();
261 [arena_A, arena_B, arena_D, arena_AB, arena_BDT, res]()
mutable {
262 arena_B.adj() += res.adj()
263 * (arena_AB * arena_D.val_op()
264 + arena_A.val_op().transpose() * arena_BDT);
268 }
else if (is_constant<Ta>::value && is_constant<Tb>::value
269 && !is_constant<Td>::value) {
270 arena_t<promote_scalar_t<var, Td>> arena_D = D;
271 arena_t<promote_scalar_t<double, Ta>> arena_A =
value_of(A);
272 arena_t<promote_scalar_t<double, Tb>> arena_B =
value_of(B);
274 auto arena_AB =
to_arena(arena_A * arena_B);
276 var res = (arena_D.val_op() * arena_B.transpose() * arena_AB).
trace();
279 arena_D.adj() += res.adj() * (arena_AB.transpose() * arena_B);
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
require_all_t< is_eigen< std::decay_t< Types > >... > require_all_eigen_t
Require all of the types satisfy is_eigen.
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
require_any_t< is_var< std::decay_t< Types > >... > require_any_var_t
Require any of the types satisfy is_var.
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
value_type_t< T > trace(const T &m)
Calculates trace (sum of diagonal) of given kernel generator expression.
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
auto trace_gen_quad_form(const TD &D, const TA &A, const TB &B)
Return the trace of D times the quadratic form of B and A.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...