1#ifndef STAN_MATH_REV_FUN_QUAD_FORM_HPP
2#define STAN_MATH_REV_FUN_QUAD_FORM_HPP
18template <
typename Ta,
int Ra,
int Ca,
typename Tb,
int Rb,
int Cb>
21 inline void compute(
const Eigen::Matrix<double, Ra, Ca>& A,
22 const Eigen::Matrix<double, Rb, Cb>& B) {
25 matrix_d M = 0.5 * (Cd + Cd.transpose());
28 for (
int j = 0; j <
C_.cols(); j++) {
29 for (
int i = 0; i <
C_.rows(); i++) {
30 C_(i, j) =
var(
new vari(Cd(i, j),
false));
37 const Eigen::Matrix<Tb, Rb, Cb>& B,
38 bool symmetric =
false)
43 Eigen::Matrix<Ta, Ra, Ca>
A_;
44 Eigen::Matrix<Tb, Rb, Cb>
B_;
45 Eigen::Matrix<var, Cb, Cb>
C_;
49template <
typename Ta,
int Ra,
int Ca,
typename Tb,
int Rb,
int Cb>
52 inline void chainA(Eigen::Matrix<double, Ra, Ca>& A,
53 const Eigen::Matrix<double, Rb, Cb>& Bd,
54 const Eigen::Matrix<double, Cb, Cb>& adjC) {}
55 inline void chainB(Eigen::Matrix<double, Rb, Cb>& B,
56 const Eigen::Matrix<double, Ra, Ca>& Ad,
57 const Eigen::Matrix<double, Rb, Cb>& Bd,
58 const Eigen::Matrix<double, Cb, Cb>& adjC) {}
60 inline void chainA(Eigen::Matrix<var, Ra, Ca>& A,
61 const Eigen::Matrix<double, Rb, Cb>& Bd,
62 const Eigen::Matrix<double, Cb, Cb>& adjC) {
63 A.adj() += Bd * adjC * Bd.transpose();
65 inline void chainB(Eigen::Matrix<var, Rb, Cb>& B,
66 const Eigen::Matrix<double, Ra, Ca>& Ad,
67 const Eigen::Matrix<double, Rb, Cb>& Bd,
68 const Eigen::Matrix<double, Cb, Cb>& adjC) {
69 B.adj() += Ad * Bd * adjC.transpose() + Ad.transpose() * Bd * adjC;
72 inline void chainAB(Eigen::Matrix<Ta, Ra, Ca>& A,
73 Eigen::Matrix<Tb, Rb, Cb>& B,
74 const Eigen::Matrix<double, Ra, Ca>& Ad,
75 const Eigen::Matrix<double, Rb, Cb>& Bd,
76 const Eigen::Matrix<double, Cb, Cb>& adjC) {
83 const Eigen::Matrix<Tb, Rb, Cb>& B,
bool symmetric =
false)
114template <
typename Mat1,
typename Mat2,
137 arena_res += arena_res.transpose().eval();
140 return_t res = arena_res;
146 arena_A.adj().noalias() +=
value_of(arena_B) * C_adj_B_t;
148 arena_A.adj() +=
value_of(arena_B) * C_adj_B_t;
152 arena_B.adj().noalias()
153 +=
value_of(arena_A) * C_adj_B_t.transpose()
157 +=
value_of(arena_A) * C_adj_B_t.transpose()
174 arena_res += arena_res.transpose().eval();
177 return_t res = arena_res;
183 arena_B.adj().noalias()
184 += arena_A * C_adj_B_t.transpose()
185 + arena_A.transpose() *
value_of(arena_B) * res.adj();
187 arena_B.adj() += arena_A * C_adj_B_t.transpose()
188 + arena_A.transpose() *
value_of(arena_B) * res.adj();
204 arena_res += arena_res.transpose().eval();
207 return_t res = arena_res;
210 auto C_adj_B_t = (res.adj() * arena_B.transpose());
213 arena_A.adj().noalias() += arena_B * C_adj_B_t;
215 arena_A.adj() += arena_B * C_adj_B_t;
240template <
typename EigMat1,
typename EigMat2,
246 bool symmetric =
false) {
253 EigMat2::RowsAtCompileTime, EigMat2::ColsAtCompileTime>(A, B, symmetric);
255 return baseVari->
impl_->C_;
271template <
typename EigMat,
typename ColVec, require_eigen_t<EigMat>* =
nullptr,
272 require_eigen_col_vector_t<ColVec>* =
nullptr,
273 require_any_vt_var<EigMat, ColVec>* =
nullptr>
274inline var quad_form(
const EigMat& A,
const ColVec& B,
bool symmetric =
false) {
281 ColVec::RowsAtCompileTime, 1>(A, B, symmetric);
283 return baseVari->
impl_->C_(0, 0);
307template <
typename Mat1,
typename Mat2,
311inline auto quad_form(
const Mat1& A,
const Mat2& B,
bool symmetric =
false) {
333template <
typename Mat,
typename Vec, require_matrix_t<Mat>* =
nullptr,
334 require_col_vector_t<Vec>* =
nullptr,
335 require_any_var_matrix_t<Mat, Vec>* =
nullptr>
336inline var quad_form(
const Mat& A,
const Vec& B,
bool symmetric =
false) {
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
require_not_t< is_col_vector< std::decay_t< T > > > require_not_col_vector_t
Require type does not satisfy is_col_vector.
require_not_t< is_eigen_col_vector< std::decay_t< T > > > require_not_eigen_col_vector_t
Require type does not satisfy is_eigen_col_vector.
require_all_t< is_eigen< std::decay_t< Types > >... > require_all_eigen_t
Require all of the types satisfy is_eigen.
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
auto transpose(Arg &&a)
Transposes a kernel generator expression.
int64_t cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
require_any_t< is_var< value_type_t< std::decay_t< Types > > >... > require_any_vt_var
Require any of the value_types satisfy is_var.
auto quad_form_impl(const Mat1 &A, const Mat2 &B, bool symmetric)
Return the quadratic form .
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
vari_value< double > vari
promote_scalar_t< return_type_t< EigMat1, EigMat2 >, EigMat2 > quad_form(const EigMat1 &A, const EigMat2 &B)
Return the quadratic form .
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
void check_not_nan(const char *function, const char *name, const T_y &y)
Check if y is not NaN.
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > matrix_d
Type for matrix of double values.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Check if a type is a var_value whose value_type is derived from Eigen::EigenBase