Automatic Differentiation
 
Loading...
Searching...
No Matches
lmultiply.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_LMULTPLY_HPP
2#define STAN_MATH_REV_FUN_LMULTPLY_HPP
3
13#include <cmath>
14
15namespace stan {
16namespace math {
17
18namespace internal {
20 public:
22 : op_vv_vari(lmultiply(avi->val_, bvi->val_), avi, bvi) {}
23 void chain() {
24 using std::log;
25 avi_->adj_ += adj_ * log(bvi_->val_);
26 bvi_->adj_ += adj_ * avi_->val_ / bvi_->val_;
27 }
28};
30 public:
31 lmultiply_vd_vari(vari* avi, double b)
32 : op_vd_vari(lmultiply(avi->val_, b), avi, b) {}
33 void chain() {
34 using std::log;
35 avi_->adj_ += adj_ * log(bd_);
36 }
37};
39 public:
40 lmultiply_dv_vari(double a, vari* bvi)
41 : op_dv_vari(lmultiply(a, bvi->val_), a, bvi) {}
42 void chain() { bvi_->adj_ += adj_ * ad_ / bvi_->val_; }
43};
44} // namespace internal
45
57inline var lmultiply(const var& a, const var& b) {
58 return var(new internal::lmultiply_vv_vari(a.vi_, b.vi_));
59}
70inline var lmultiply(const var& a, double b) {
71 return var(new internal::lmultiply_vd_vari(a.vi_, b));
72}
83inline var lmultiply(double a, const var& b) {
84 if (a == 1.0) {
85 return log(b);
86 }
87 return var(new internal::lmultiply_dv_vari(a, b.vi_));
88}
89
102template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
103 require_any_var_matrix_t<T1, T2>* = nullptr>
104inline auto lmultiply(const T1& a, const T2& b) {
105 check_matching_dims("lmultiply", "a", a, "b", b);
109
110 return make_callback_var(
111 lmultiply(arena_a.val(), arena_b.val()),
112 [arena_a, arena_b](const auto& res) mutable {
113 arena_a.adj().array()
114 += res.adj().array() * arena_b.val().array().log();
115 arena_b.adj().array() += res.adj().array() * arena_a.val().array()
116 / arena_b.val().array();
117 });
118 } else if (!is_constant<T1>::value) {
119 arena_t<promote_scalar_t<var, T1>> arena_a = a;
120 arena_t<promote_scalar_t<double, T2>> arena_b = value_of(b);
121
122 return make_callback_var(lmultiply(arena_a.val(), arena_b),
123 [arena_a, arena_b](const auto& res) mutable {
124 arena_a.adj().array()
125 += res.adj().array()
126 * arena_b.val().array().log();
127 });
128 } else {
129 arena_t<promote_scalar_t<double, T1>> arena_a = value_of(a);
130 arena_t<promote_scalar_t<var, T2>> arena_b = b;
131
132 return make_callback_var(lmultiply(arena_a, arena_b.val()),
133 [arena_a, arena_b](const auto& res) mutable {
134 arena_b.adj().array() += res.adj().array()
135 * arena_a.val().array()
136 / arena_b.val().array();
137 });
138 }
139}
140
150template <typename T1, typename T2, require_var_matrix_t<T1>* = nullptr,
151 require_stan_scalar_t<T2>* = nullptr>
152inline auto lmultiply(const T1& a, const T2& b) {
153 using std::log;
154
155 if (!is_constant<T1>::value && !is_constant<T2>::value) {
156 arena_t<promote_scalar_t<var, T1>> arena_a = a;
157 var arena_b = b;
158
159 return make_callback_var(
160 lmultiply(arena_a.val(), arena_b.val()),
161 [arena_a, arena_b](const auto& res) mutable {
162 arena_a.adj().array() += res.adj().array() * log(arena_b.val());
163 arena_b.adj() += (res.adj().array() * arena_a.val().array()).sum()
164 / arena_b.val();
165 });
166 } else if (!is_constant<T1>::value) {
167 arena_t<promote_scalar_t<var, T1>> arena_a = a;
168
169 return make_callback_var(lmultiply(arena_a.val(), value_of(b)),
170 [arena_a, b](const auto& res) mutable {
171 arena_a.adj().array()
172 += res.adj().array() * log(value_of(b));
173 });
174 } else {
175 arena_t<promote_scalar_t<double, T1>> arena_a = value_of(a);
176 var arena_b = b;
177
178 return make_callback_var(
179 lmultiply(arena_a, arena_b.val()),
180 [arena_a, arena_b](const auto& res) mutable {
181 arena_b.adj()
182 += (res.adj().array() * arena_a.array()).sum() / arena_b.val();
183 });
184 }
185}
186
196template <typename T1, typename T2, require_stan_scalar_t<T1>* = nullptr,
197 require_var_matrix_t<T2>* = nullptr>
198inline auto lmultiply(const T1& a, const T2& b) {
199 if (!is_constant<T1>::value && !is_constant<T2>::value) {
200 var arena_a = a;
201 arena_t<promote_scalar_t<var, T2>> arena_b = b;
202
203 return make_callback_var(
204 lmultiply(arena_a.val(), arena_b.val()),
205 [arena_a, arena_b](const auto& res) mutable {
206 arena_a.adj()
207 += (res.adj().array() * arena_b.val().array().log()).sum();
208 arena_b.adj().array()
209 += arena_a.val() * res.adj().array() / arena_b.val().array();
210 });
211 } else if (!is_constant<T1>::value) {
212 var arena_a = a;
213 arena_t<promote_scalar_t<double, T2>> arena_b = value_of(b);
214
215 return make_callback_var(
216 lmultiply(arena_a.val(), arena_b),
217 [arena_a, arena_b](const auto& res) mutable {
218 arena_a.adj()
219 += (res.adj().array() * arena_b.val().array().log()).sum();
220 });
221 } else {
222 arena_t<promote_scalar_t<var, T2>> arena_b = b;
223
224 return make_callback_var(lmultiply(value_of(a), arena_b.val()),
225 [a, arena_b](const auto& res) mutable {
226 arena_b.adj().array() += value_of(a)
227 * res.adj().array()
228 / arena_b.val().array();
229 });
230 }
231}
232
233} // namespace math
234} // namespace stan
235#endif
lmultiply_vv_vari(vari *avi, vari *bvi)
Definition lmultiply.hpp:21
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
fvar< T > lmultiply(const fvar< T > &x1, const fvar< T > &x2)
Definition lmultiply.hpp:13
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
var_value< double > var
Definition var.hpp:1187
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...