Automatic Differentiation
 
Loading...
Searching...
No Matches
multiply.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_MULTIPLY_HPP
2#define STAN_MATH_REV_FUN_MULTIPLY_HPP
3
8#include <type_traits>
9
10namespace stan {
11namespace math {
12
26template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
27 require_return_type_t<is_var, T1, T2>* = nullptr,
28 require_not_row_and_col_vector_t<T1, T2>* = nullptr>
29inline auto multiply(T1&& A, T2&& B) {
30 check_multiplicable("multiply", "A", A, "B", B);
32 arena_t<promote_scalar_t<var, T1>> arena_A(std::forward<T1>(A));
33 arena_t<promote_scalar_t<var, T2>> arena_B(std::forward<T2>(B));
34 auto arena_A_val = to_arena(arena_A.val());
35 auto arena_B_val = to_arena(arena_B.val());
36 using return_t
37 = return_var_matrix_t<decltype(arena_A_val * arena_B_val), T1, T2>;
38 arena_t<return_t> res = arena_A_val * arena_B_val;
39
41 [arena_A, arena_B, arena_A_val, arena_B_val, res]() mutable {
43 arena_A.adj() += res.adj_op() * arena_B_val.transpose();
44 arena_B.adj() += arena_A_val.transpose() * res.adj_op();
45 } else {
46 auto res_adj = res.adj().eval();
47 arena_A.adj() += res_adj * arena_B_val.transpose();
48 arena_B.adj() += arena_A_val.transpose() * res_adj;
49 }
50 });
51 return res;
52 } else if (!is_constant<T2>::value) {
54 arena_t<promote_scalar_t<var, T2>> arena_B(std::forward<T2>(B));
55 using return_t
56 = return_var_matrix_t<decltype(arena_A * value_of(B).eval()), T1, T2>;
57 arena_t<return_t> res = arena_A * arena_B.val_op();
58 reverse_pass_callback([arena_B, arena_A, res]() mutable {
59 arena_B.adj() += arena_A.transpose() * res.adj_op();
60 });
61 return res;
62 } else {
63 arena_t<promote_scalar_t<var, T1>> arena_A(std::forward<T1>(A));
65 using return_t
66 = return_var_matrix_t<decltype(value_of(arena_A).eval() * arena_B), T1,
67 T2>;
68 arena_t<return_t> res = arena_A.val_op() * arena_B;
69 reverse_pass_callback([arena_A, arena_B, res]() mutable {
70 arena_A.adj() += res.adj_op() * arena_B.transpose();
71 });
72
73 return res;
74 }
75}
76
87template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
88 require_return_type_t<is_var, T1, T2>* = nullptr,
89 require_row_and_col_vector_t<T1, T2>* = nullptr>
90inline var multiply(const T1& A, const T2& B) {
91 check_multiplicable("multiply", "A", A, "B", B);
95 arena_t<promote_scalar_t<double, T1>> arena_A_val = value_of(arena_A);
96 arena_t<promote_scalar_t<double, T2>> arena_B_val = value_of(arena_B);
97 var res = arena_A_val.dot(arena_B_val);
98
100 [arena_A, arena_B, arena_A_val, arena_B_val, res]() mutable {
101 auto res_adj = res.adj();
102 arena_A.adj().array() += res_adj * arena_B_val.transpose().array();
103 arena_B.adj().array() += arena_A_val.transpose().array() * res_adj;
104 });
105 return res;
106 } else if (!is_constant<T2>::value) {
109 var res = arena_A_val.dot(value_of(arena_B));
110 reverse_pass_callback([arena_B, arena_A_val, res]() mutable {
111 arena_B.adj().array() += arena_A_val.transpose().array() * res.adj();
112 });
113 return res;
114 } else {
117 var res = value_of(arena_A).dot(arena_B_val);
118 reverse_pass_callback([arena_A, arena_B_val, res]() mutable {
119 arena_A.adj().array() += res.adj() * arena_B_val.transpose().array();
120 });
121 return res;
122 }
123}
124
136template <typename T1, typename T2, require_not_matrix_t<T1>* = nullptr,
137 require_matrix_t<T2>* = nullptr,
138 require_return_type_t<is_var, T1, T2>* = nullptr,
139 require_not_row_and_col_vector_t<T1, T2>* = nullptr>
140inline auto multiply(const T1& a, T2&& B) {
142 arena_t<promote_scalar_t<var, T2>> arena_B(std::forward<T2>(B));
143 using return_t = return_var_matrix_t<T2, T1, T2>;
144 var av = a;
145 auto a_val = value_of(av);
146 arena_t<return_t> res = a_val * arena_B.val().array();
147 reverse_pass_callback([av, a_val, arena_B, res]() mutable {
148 for (Eigen::Index j = 0; j < res.cols(); ++j) {
149 for (Eigen::Index i = 0; i < res.rows(); ++i) {
150 const auto res_adj = res.adj().coeffRef(i, j);
151 av.adj() += res_adj * arena_B.val().coeff(i, j);
152 arena_B.adj().coeffRef(i, j) += a_val * res_adj;
153 }
154 }
155 });
156 return res;
157 } else if (!is_constant<T2>::value) {
158 double val_a = value_of(a);
159 arena_t<promote_scalar_t<var, T2>> arena_B(std::forward<T2>(B));
160 using return_t = return_var_matrix_t<T2, T1, T2>;
161 arena_t<return_t> res = val_a * arena_B.val().array();
162 reverse_pass_callback([val_a, arena_B, res]() mutable {
163 arena_B.adj().array() += val_a * res.adj().array();
164 });
165 return res;
166 } else {
167 var av = a;
169 using return_t = return_var_matrix_t<T2, T1, T2>;
170 arena_t<return_t> res = av.val() * arena_B.array();
171 reverse_pass_callback([av, arena_B, res]() mutable {
172 av.adj() += (res.adj().array() * arena_B.array()).sum();
173 });
174 return res;
175 }
176}
177
189template <typename T1, typename T2, require_matrix_t<T1>* = nullptr,
190 require_not_matrix_t<T2>* = nullptr,
191 require_any_st_var<T1, T2>* = nullptr,
192 require_not_complex_t<value_type_t<T1>>* = nullptr,
193 require_not_complex_t<value_type_t<T2>>* = nullptr,
194 require_not_row_and_col_vector_t<T1, T2>* = nullptr>
195inline auto multiply(T1&& A, T2&& B) {
196 return multiply(std::forward<T2>(B), std::forward<T1>(A));
197}
198
209template <typename T1, typename T2, require_any_var_matrix_t<T1, T2>* = nullptr>
210inline auto operator*(T1&& a, T2&& b) {
211 return multiply(std::forward<T1>(a), std::forward<T2>(b));
212}
213
214} // namespace math
215} // namespace stan
216#endif
fvar< T > operator*(const fvar< T > &x, const fvar< T > &y)
Return the product of the two arguments.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
auto multiply(const Mat1 &m1, const Mat2 &m2)
Return the product of the specified matrices.
Definition multiply.hpp:19
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Check if a type is a var_value whose value_type is derived from Eigen::EigenBase