Automatic Differentiation
 
Loading...
Searching...
No Matches
trace_gen_quad_form.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_TRACE_GEN_QUAD_FORM_HPP
2#define STAN_MATH_PRIM_FUN_TRACE_GEN_QUAD_FORM_HPP
3
11#include <exception>
12
13namespace stan {
14namespace math {
15
32template <typename TD, typename TA, typename TB,
33 typename = require_all_eigen_t<TD, TA, TB>,
34 typename = require_all_not_vt_var<TD, TA, TB>,
35 typename = require_any_not_vt_arithmetic<TD, TA, TB>>
36inline auto trace_gen_quad_form(TD&& D, TA&& A, TB&& B) {
37 check_square("trace_gen_quad_form", "A", A);
38 check_square("trace_gen_quad_form", "D", D);
39 check_multiplicable("trace_gen_quad_form", "A", A, "B", B);
40 check_multiplicable("trace_gen_quad_form", "B", B, "D", D);
41 decltype(auto) B_ref = to_ref(std::forward<TB>(B));
42 return make_holder(
43 [](auto&& D_, auto&& A_, auto&& B_ref_) {
44 return multiply(B_ref_, D_.transpose())
45 .cwiseProduct(multiply(A_, B_ref_))
46 .sum();
47 },
48 std::forward<TD>(D), std::forward<TA>(A),
49 std::forward<decltype(B_ref)>(B_ref));
50}
51
70template <typename EigMatD, typename EigMatA, typename EigMatB,
71 require_all_eigen_vt<std::is_arithmetic, EigMatD, EigMatA,
72 EigMatB>* = nullptr>
73inline double trace_gen_quad_form(const EigMatD& D, const EigMatA& A,
74 const EigMatB& B) {
75 check_square("trace_gen_quad_form", "A", A);
76 check_square("trace_gen_quad_form", "D", D);
77 check_multiplicable("trace_gen_quad_form", "A", A, "B", B);
78 check_multiplicable("trace_gen_quad_form", "B", B, "D", D);
79 const auto& B_ref = to_ref(B);
80 return (B_ref * D.transpose()).cwiseProduct(A * B_ref).sum();
81}
82
83} // namespace math
84} // namespace stan
85
86#endif
require_all_t< container_type_check_base< is_eigen, value_type_t, TypeCheck, Check >... > require_all_eigen_vt
Require all of the types satisfy is_eigen.
Definition is_eigen.hpp:191
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
auto multiply(Mat1 &&m1, Mat2 &&m2)
Return the product of the specified matrices.
Definition multiply.hpp:20
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
auto make_holder(F &&func, Args &&... args)
Calls given function with given arguments.
Definition holder.hpp:481
auto trace_gen_quad_form(TD &&D, TA &&A, TB &&B)
Return the trace of D times the quadratic form of B and A.
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...