Automatic Differentiation
 
Loading...
Searching...
No Matches
trace_dot.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_TRACE_DOT_HPP
2#define STAN_MATH_REV_FUN_TRACE_DOT_HPP
3
10
11namespace stan {
12namespace math {
13
33template <typename Mat1, typename Mat2,
34 require_all_matrix_t<Mat1, Mat2>* = nullptr,
35 require_any_rev_matrix_t<Mat1, Mat2>* = nullptr>
36inline var trace_dot(Mat1&& A, Mat2&& B) {
37 check_size_match("trace_dot", "A.cols()", A.cols(), "B.rows()", B.rows());
38 check_size_match("trace_dot", "A.rows()", A.rows(), "B.cols()", B.cols());
39 if constexpr (is_autodiff_v<Mat1> && is_autodiff_v<Mat2>) {
40 arena_t<Mat1> arena_A(std::forward<Mat1>(A));
41 arena_t<Mat2> arena_B(std::forward<Mat2>(B));
42 auto res_val = arena_A.val().cwiseProduct(arena_B.val().transpose()).sum();
43 return make_callback_var(res_val, [arena_A, arena_B](auto&& res) mutable {
44 if constexpr (is_var_matrix<Mat1>::value) {
45 arena_A.adj().noalias() += res.adj() * arena_B.val().transpose();
46 } else {
47 arena_A.adj() += res.adj() * arena_B.val().transpose();
48 }
49 if constexpr (is_var_matrix<Mat2>::value) {
50 arena_B.adj().noalias() += res.adj() * arena_A.val().transpose();
51 } else {
52 arena_B.adj() += res.adj() * arena_A.val().transpose();
53 }
54 });
55 } else if constexpr (is_autodiff_v<Mat2>) {
56 arena_t<Mat1> arena_A(std::forward<Mat1>(A));
57 arena_t<Mat2> arena_B(std::forward<Mat2>(B));
58 auto res_val = arena_A.cwiseProduct(arena_B.val().transpose()).sum();
59 return make_callback_var(res_val, [arena_A, arena_B](auto&& res) mutable {
60 if constexpr (is_var_matrix<Mat2>::value) {
61 arena_B.adj().noalias() += res.adj() * arena_A.transpose();
62 } else {
63 arena_B.adj() += res.adj() * arena_A.transpose();
64 }
65 });
66 } else {
67 arena_t<Mat1> arena_A(std::forward<Mat1>(A));
68 arena_t<Mat2> arena_B(std::forward<Mat2>(B));
69 auto res_val = arena_A.val().cwiseProduct(arena_B.transpose()).sum();
70 return make_callback_var(res_val, [arena_A, arena_B](auto&& res) mutable {
71 if constexpr (is_var_matrix<Mat1>::value) {
72 arena_A.adj().noalias() += res.adj() * arena_B.transpose();
73 } else {
74 arena_A.adj() += res.adj() * arena_B.transpose();
75 }
76 });
77 }
78}
79
80} // namespace math
81} // namespace stan
82#endif
return_type_t< EigMat1, EigMat2 > trace_dot(EigMat1 &&A, EigMat2 &&B)
Compute the trace of the product of two matrices with forward-mode autodiff support.
Definition trace_dot.hpp:29
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Check if a type is a var_value whose value_type is derived from Eigen::EigenBase