Automatic Differentiation
 
Loading...
Searching...
No Matches
multiply_log.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_MULTIPLY_LOG_HPP
2#define STAN_MATH_REV_FUN_MULTIPLY_LOG_HPP
3
12#include <cmath>
13
14namespace stan {
15namespace math {
16
17namespace internal {
19 public:
21 : op_vv_vari(multiply_log(avi->val_, bvi->val_), avi, bvi) {}
22 void chain() {
23 using std::log;
24 avi_->adj_ += adj_ * log(bvi_->val_);
25 bvi_->adj_ += adj_ * avi_->val_ / bvi_->val_;
26 }
27};
29 public:
30 multiply_log_vd_vari(vari* avi, double b)
31 : op_vd_vari(multiply_log(avi->val_, b), avi, b) {}
32 void chain() {
33 using std::log;
34 avi_->adj_ += adj_ * log(bd_);
35 }
36};
38 public:
39 multiply_log_dv_vari(double a, vari* bvi)
40 : op_dv_vari(multiply_log(a, bvi->val_), a, bvi) {}
41 void chain() { bvi_->adj_ += adj_ * ad_ / bvi_->val_; }
42};
43} // namespace internal
44
56inline var multiply_log(const var& a, const var& b) {
57 return var(new internal::multiply_log_vv_vari(a.vi_, b.vi_));
58}
69inline var multiply_log(const var& a, double b) {
70 return var(new internal::multiply_log_vd_vari(a.vi_, b));
71}
82inline var multiply_log(double a, const var& b) {
83 if (a == 1.0) {
84 return log(b);
85 }
86 return var(new internal::multiply_log_dv_vari(a, b.vi_));
87}
88
101template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
102 require_any_var_matrix_t<T1, T2>* = nullptr>
103inline auto multiply_log(const T1& a, const T2& b) {
104 check_matching_dims("multiply_log", "a", a, "b", b);
108
109 return make_callback_var(
110 multiply_log(arena_a.val(), arena_b.val()),
111 [arena_a, arena_b](const auto& res) mutable {
112 arena_a.adj().array()
113 += res.adj().array() * arena_b.val().array().log();
114 arena_b.adj().array() += res.adj().array() * arena_a.val().array()
115 / arena_b.val().array();
116 });
117 } else if (!is_constant<T1>::value) {
118 arena_t<promote_scalar_t<var, T1>> arena_a = a;
119 arena_t<promote_scalar_t<double, T2>> arena_b = value_of(b);
120
121 return make_callback_var(multiply_log(arena_a.val(), arena_b),
122 [arena_a, arena_b](const auto& res) mutable {
123 arena_a.adj().array()
124 += res.adj().array()
125 * arena_b.val().array().log();
126 });
127 } else {
128 arena_t<promote_scalar_t<double, T1>> arena_a = value_of(a);
129 arena_t<promote_scalar_t<var, T2>> arena_b = b;
130
131 return make_callback_var(multiply_log(arena_a, arena_b.val()),
132 [arena_a, arena_b](const auto& res) mutable {
133 arena_b.adj().array() += res.adj().array()
134 * arena_a.val().array()
135 / arena_b.val().array();
136 });
137 }
138}
139
149template <typename T1, typename T2, require_var_matrix_t<T1>* = nullptr,
150 require_stan_scalar_t<T2>* = nullptr>
151inline auto multiply_log(const T1& a, const T2& b) {
152 using std::log;
153
154 if (!is_constant<T1>::value && !is_constant<T2>::value) {
155 arena_t<promote_scalar_t<var, T1>> arena_a = a;
156 var arena_b = b;
157
158 return make_callback_var(
159 multiply_log(arena_a.val(), arena_b.val()),
160 [arena_a, arena_b](const auto& res) mutable {
161 arena_a.adj().array() += res.adj().array() * log(arena_b.val());
162 arena_b.adj() += (res.adj().array() * arena_a.val().array()).sum()
163 / arena_b.val();
164 });
165 } else if (!is_constant<T1>::value) {
166 arena_t<promote_scalar_t<var, T1>> arena_a = a;
167
168 return make_callback_var(multiply_log(arena_a.val(), value_of(b)),
169 [arena_a, b](const auto& res) mutable {
170 arena_a.adj().array()
171 += res.adj().array() * log(value_of(b));
172 });
173 } else {
174 arena_t<promote_scalar_t<double, T1>> arena_a = value_of(a);
175 var arena_b = b;
176
177 return make_callback_var(
178 multiply_log(arena_a, arena_b.val()),
179 [arena_a, arena_b](const auto& res) mutable {
180 arena_b.adj()
181 += (res.adj().array() * arena_a.array()).sum() / arena_b.val();
182 });
183 }
184}
185
195template <typename T1, typename T2, require_stan_scalar_t<T1>* = nullptr,
196 require_var_matrix_t<T2>* = nullptr>
197inline auto multiply_log(const T1& a, const T2& b) {
198 if (!is_constant<T1>::value && !is_constant<T2>::value) {
199 var arena_a = a;
200 arena_t<promote_scalar_t<var, T2>> arena_b = b;
201
202 return make_callback_var(
203 multiply_log(arena_a.val(), arena_b.val()),
204 [arena_a, arena_b](const auto& res) mutable {
205 arena_a.adj()
206 += (res.adj().array() * arena_b.val().array().log()).sum();
207 arena_b.adj().array()
208 += arena_a.val() * res.adj().array() / arena_b.val().array();
209 });
210 } else if (!is_constant<T1>::value) {
211 var arena_a = a;
212 arena_t<promote_scalar_t<double, T2>> arena_b = value_of(b);
213
214 return make_callback_var(
215 multiply_log(arena_a.val(), arena_b),
216 [arena_a, arena_b](const auto& res) mutable {
217 arena_a.adj()
218 += (res.adj().array() * arena_b.val().array().log()).sum();
219 });
220 } else {
221 arena_t<promote_scalar_t<var, T2>> arena_b = b;
222
223 return make_callback_var(multiply_log(value_of(a), arena_b.val()),
224 [a, arena_b](const auto& res) mutable {
225 arena_b.adj().array() += value_of(a)
226 * res.adj().array()
227 / arena_b.val().array();
228 });
229 }
230}
231
232} // namespace math
233} // namespace stan
234#endif
fvar< T > multiply_log(const fvar< T > &x1, const fvar< T > &x2)
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
var_value< double > var
Definition var.hpp:1187
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...