Automatic Differentiation
 
Loading...
Searching...
No Matches
fma.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_FMA_HPP
2#define STAN_MATH_REV_FUN_FMA_HPP
3
9
10namespace stan {
11namespace math {
12
31inline var fma(const var& x, const var& y, const var& z) {
32 return make_callback_var(fma(x.val(), y.val(), z.val()), [x, y, z](auto& vi) {
33 x.adj() += vi.adj() * y.val();
34 y.adj() += vi.adj() * x.val();
35 z.adj() += vi.adj();
36 });
37}
38
56template <typename Tc, require_arithmetic_t<Tc>* = nullptr>
57inline var fma(const var& x, const var& y, Tc&& z) {
58 return make_callback_var(fma(x.val(), y.val(), z), [x, y](auto& vi) {
59 x.adj() += vi.adj() * y.val();
60 y.adj() += vi.adj() * x.val();
61 });
62}
63
84template <typename Tb, require_arithmetic_t<Tb>* = nullptr>
85inline var fma(const var& x, Tb&& y, const var& z) {
86 return make_callback_var(fma(x.val(), y, z.val()), [x, y, z](auto& vi) {
87 x.adj() += vi.adj() * y;
88 z.adj() += vi.adj();
89 });
90}
91
113template <typename Tb, typename Tc, require_all_arithmetic_t<Tb, Tc>* = nullptr>
114inline var fma(const var& x, Tb&& y, Tc&& z) {
115 return make_callback_var(fma(x.val(), y, z),
116 [x, y](auto& vi) { x.adj() += vi.adj() * y; });
117}
118
136template <typename Ta, typename Tc, require_all_arithmetic_t<Ta, Tc>* = nullptr>
137inline var fma(Ta&& x, const var& y, Tc&& z) {
138 return make_callback_var(fma(x, y.val(), z),
139 [x, y](auto& vi) { y.adj() += vi.adj() * x; });
140}
141
159template <typename Ta, typename Tb, require_all_arithmetic_t<Ta, Tb>* = nullptr>
160inline var fma(Ta&& x, Tb&& y, const var& z) {
161 return make_callback_var(fma(x, y, z.val()),
162 [z](auto& vi) { z.adj() += vi.adj(); });
163}
164
182template <typename Ta, require_arithmetic_t<Ta>* = nullptr>
183inline var fma(Ta&& x, const var& y, const var& z) {
184 return make_callback_var(fma(x, y.val(), z.val()), [x, y, z](auto& vi) {
185 y.adj() += vi.adj() * x;
186 z.adj() += vi.adj();
187 });
188}
189
190namespace internal {
194template <typename T1, typename T2, typename T3, typename T4,
196inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
197 return [arena_x, arena_y, arena_z, ret]() mutable {
198 if constexpr (is_autodiff_v<T1>) {
199 arena_x.adj().array() += ret.adj().array() * value_of(arena_y).array();
200 }
201 if constexpr (is_autodiff_v<T2>) {
202 arena_y.adj().array() += ret.adj().array() * value_of(arena_x).array();
203 }
204 if constexpr (is_autodiff_v<T3>) {
205 arena_z.adj().array() += ret.adj().array();
206 }
207 };
208}
209
213template <typename T1, typename T2, typename T3, typename T4,
215 require_stan_scalar_t<T1>* = nullptr>
216inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
217 return [arena_x, arena_y, arena_z, ret]() mutable {
218 if constexpr (is_autodiff_v<T1>) {
219 arena_x.adj() += (ret.adj().array() * value_of(arena_y).array()).sum();
220 }
221 if constexpr (is_autodiff_v<T2>) {
222 arena_y.adj().array() += ret.adj().array() * value_of(arena_x);
223 }
224 if constexpr (is_autodiff_v<T3>) {
225 arena_z.adj().array() += ret.adj().array();
226 }
227 };
228}
229
233template <typename T1, typename T2, typename T3, typename T4,
235 require_stan_scalar_t<T2>* = nullptr>
236inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
237 return [arena_x, arena_y, arena_z, ret]() mutable {
238 if constexpr (is_autodiff_v<T1>) {
239 arena_x.adj().array() += ret.adj().array() * value_of(arena_y);
240 }
241 if constexpr (is_autodiff_v<T2>) {
242 arena_y.adj() += (ret.adj().array() * value_of(arena_x).array()).sum();
243 }
244 if constexpr (is_autodiff_v<T3>) {
245 arena_z.adj().array() += ret.adj().array();
246 }
247 };
248}
249
253template <typename T1, typename T2, typename T3, typename T4,
254 require_matrix_t<T3>* = nullptr,
255 require_all_stan_scalar_t<T1, T2>* = nullptr>
256inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
257 return [arena_x, arena_y, arena_z, ret]() mutable {
258 if constexpr (is_autodiff_v<T1>) {
259 arena_x.adj() += (ret.adj().array() * value_of(arena_y)).sum();
260 }
261 if constexpr (is_autodiff_v<T2>) {
262 arena_y.adj() += (ret.adj().array() * value_of(arena_x)).sum();
263 }
264 if constexpr (is_autodiff_v<T3>) {
265 arena_z.adj().array() += ret.adj().array();
266 }
267 };
268}
269
273template <typename T1, typename T2, typename T3, typename T4,
274 require_all_matrix_t<T1, T2>* = nullptr,
275 require_stan_scalar_t<T3>* = nullptr>
276inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
277 return [arena_x, arena_y, arena_z, ret]() mutable {
278 if constexpr (is_autodiff_v<T1>) {
279 arena_x.adj().array() += ret.adj().array() * value_of(arena_y).array();
280 }
281 if constexpr (is_autodiff_v<T2>) {
282 arena_y.adj().array() += ret.adj().array() * value_of(arena_x).array();
283 }
284 if constexpr (is_autodiff_v<T3>) {
285 arena_z.adj() += ret.adj().sum();
286 }
287 };
288}
289
293template <typename T1, typename T2, typename T3, typename T4,
294 require_matrix_t<T2>* = nullptr,
295 require_all_stan_scalar_t<T1, T3>* = nullptr>
296inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
297 return [arena_x, arena_y, arena_z, ret]() mutable {
298 if constexpr (is_autodiff_v<T1>) {
299 arena_x.adj() += (ret.adj().array() * value_of(arena_y).array()).sum();
300 }
301 if constexpr (is_autodiff_v<T2>) {
302 arena_y.adj().array() += ret.adj().array() * value_of(arena_x);
303 }
304 if constexpr (is_autodiff_v<T3>) {
305 arena_z.adj() += ret.adj().sum();
306 }
307 };
308}
309
313template <typename T1, typename T2, typename T3, typename T4,
314 require_matrix_t<T1>* = nullptr,
315 require_all_stan_scalar_t<T2, T3>* = nullptr>
316inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
317 return [arena_x, arena_y, arena_z, ret]() mutable {
318 if constexpr (is_autodiff_v<T1>) {
319 arena_x.adj().array() += ret.adj().array() * value_of(arena_y);
320 }
321 if constexpr (is_autodiff_v<T2>) {
322 arena_y.adj() += (ret.adj().array() * value_of(arena_x).array()).sum();
323 }
324 if constexpr (is_autodiff_v<T3>) {
325 arena_z.adj() += ret.adj().sum();
326 }
327 };
328}
329
330} // namespace internal
331
350template <typename T1, typename T2, typename T3,
351 require_any_matrix_t<T1, T2, T3>* = nullptr,
352 require_var_t<return_type_t<T1, T2, T3>>* = nullptr>
353inline auto fma(const T1& x, const T2& y, const T3& z) {
354 arena_t<T1> arena_x = x;
355 arena_t<T2> arena_y = y;
356 arena_t<T3> arena_z = z;
358 check_matching_dims("fma", "x", arena_x, "y", arena_y);
359 }
361 check_matching_dims("fma", "x", arena_x, "z", arena_z);
362 }
364 check_matching_dims("fma", "y", arena_y, "z", arena_z);
365 }
366 using inner_ret_type
367 = decltype(fma(value_of(arena_x), value_of(arena_y), value_of(arena_z)));
370 = fma(value_of(arena_x), value_of(arena_y), value_of(arena_z));
372 internal::fma_reverse_pass(arena_x, arena_y, arena_z, ret));
373 return ret_type(ret);
374}
375
376} // namespace math
377} // namespace stan
378#endif
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
auto fma_reverse_pass(T1 &arena_x, T2 &arena_y, T3 &arena_z, T4 &ret)
Overload for matrix, matrix, matrix.
Definition fma.hpp:196
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:23
fvar< return_type_t< T1, T2, T3 > > fma(const fvar< T1 > &x1, const fvar< T2 > &x2, const fvar< T3 > &x3)
The fused multiply-add operation (C99).
Definition fma.hpp:60
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Check if a type is derived from Eigen::EigenBase or is a var_value whose value_type is derived from E...
Definition is_matrix.hpp:18