Automatic Differentiation
 
Loading...
Searching...
No Matches
fma.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_FMA_HPP
2#define STAN_MATH_REV_FUN_FMA_HPP
3
9
10namespace stan {
11namespace math {
12
31inline var fma(const var& x, const var& y, const var& z) {
32 return make_callback_var(fma(x.val(), y.val(), z.val()), [x, y, z](auto& vi) {
33 x.adj() += vi.adj() * y.val();
34 y.adj() += vi.adj() * x.val();
35 z.adj() += vi.adj();
36 });
37}
38
56template <typename Tc, require_arithmetic_t<Tc>* = nullptr>
57inline var fma(const var& x, const var& y, Tc&& z) {
58 return make_callback_var(fma(x.val(), y.val(), z), [x, y](auto& vi) {
59 x.adj() += vi.adj() * y.val();
60 y.adj() += vi.adj() * x.val();
61 });
62}
63
84template <typename Tb, require_arithmetic_t<Tb>* = nullptr>
85inline var fma(const var& x, Tb&& y, const var& z) {
86 return make_callback_var(fma(x.val(), y, z.val()), [x, y, z](auto& vi) {
87 x.adj() += vi.adj() * y;
88 z.adj() += vi.adj();
89 });
90}
91
113template <typename Tb, typename Tc, require_all_arithmetic_t<Tb, Tc>* = nullptr>
114inline var fma(const var& x, Tb&& y, Tc&& z) {
115 return make_callback_var(fma(x.val(), y, z),
116 [x, y](auto& vi) { x.adj() += vi.adj() * y; });
117}
118
136template <typename Ta, typename Tc, require_all_arithmetic_t<Ta, Tc>* = nullptr>
137inline var fma(Ta&& x, const var& y, Tc&& z) {
138 return make_callback_var(fma(x, y.val(), z),
139 [x, y](auto& vi) { y.adj() += vi.adj() * x; });
140}
141
159template <typename Ta, typename Tb, require_all_arithmetic_t<Ta, Tb>* = nullptr>
160inline var fma(Ta&& x, Tb&& y, const var& z) {
161 return make_callback_var(fma(x, y, z.val()),
162 [z](auto& vi) { z.adj() += vi.adj(); });
163}
164
182template <typename Ta, require_arithmetic_t<Ta>* = nullptr>
183inline var fma(Ta&& x, const var& y, const var& z) {
184 return make_callback_var(fma(x, y.val(), z.val()), [x, y, z](auto& vi) {
185 y.adj() += vi.adj() * x;
186 z.adj() += vi.adj();
187 });
188}
189
190namespace internal {
194template <typename T1, typename T2, typename T3, typename T4,
196inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
197 return [arena_x, arena_y, arena_z, ret]() mutable {
202 forward_as<T1_var>(arena_x).adj().array()
203 += ret.adj().array() * value_of(arena_y).array();
204 }
206 forward_as<T2_var>(arena_y).adj().array()
207 += ret.adj().array() * value_of(arena_x).array();
208 }
210 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
211 }
212 };
213}
214
218template <typename T1, typename T2, typename T3, typename T4,
220 require_stan_scalar_t<T1>* = nullptr>
221inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
222 return [arena_x, arena_y, arena_z, ret]() mutable {
223 using T1_var = arena_t<promote_scalar_t<var, T1>>;
224 using T2_var = arena_t<promote_scalar_t<var, T2>>;
225 using T3_var = arena_t<promote_scalar_t<var, T3>>;
227 forward_as<T1_var>(arena_x).adj()
228 += (ret.adj().array() * value_of(arena_y).array()).sum();
229 }
231 forward_as<T2_var>(arena_y).adj().array()
232 += ret.adj().array() * value_of(arena_x);
233 }
235 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
236 }
237 };
238}
239
243template <typename T1, typename T2, typename T3, typename T4,
245 require_stan_scalar_t<T2>* = nullptr>
246inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
247 return [arena_x, arena_y, arena_z, ret]() mutable {
248 using T1_var = arena_t<promote_scalar_t<var, T1>>;
249 using T2_var = arena_t<promote_scalar_t<var, T2>>;
250 using T3_var = arena_t<promote_scalar_t<var, T3>>;
252 forward_as<T1_var>(arena_x).adj().array()
253 += ret.adj().array() * value_of(arena_y);
254 }
256 forward_as<T2_var>(arena_y).adj()
257 += (ret.adj().array() * value_of(arena_x).array()).sum();
258 }
259 if (!is_constant<T3>::value) {
260 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
261 }
262 };
263}
264
268template <typename T1, typename T2, typename T3, typename T4,
269 require_matrix_t<T3>* = nullptr,
270 require_all_stan_scalar_t<T1, T2>* = nullptr>
271inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
272 return [arena_x, arena_y, arena_z, ret]() mutable {
273 using T1_var = arena_t<promote_scalar_t<var, T1>>;
274 using T2_var = arena_t<promote_scalar_t<var, T2>>;
275 using T3_var = arena_t<promote_scalar_t<var, T3>>;
276 if (!is_constant<T1>::value) {
277 forward_as<T1_var>(arena_x).adj()
278 += (ret.adj().array() * value_of(arena_y)).sum();
279 }
280 if (!is_constant<T2>::value) {
281 forward_as<T2_var>(arena_y).adj()
282 += (ret.adj().array() * value_of(arena_x)).sum();
283 }
284 if (!is_constant<T3>::value) {
285 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
286 }
287 };
288}
289
293template <typename T1, typename T2, typename T3, typename T4,
294 require_all_matrix_t<T1, T2>* = nullptr,
295 require_stan_scalar_t<T3>* = nullptr>
296inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
297 return [arena_x, arena_y, arena_z, ret]() mutable {
298 using T1_var = arena_t<promote_scalar_t<var, T1>>;
299 using T2_var = arena_t<promote_scalar_t<var, T2>>;
300 using T3_var = arena_t<promote_scalar_t<var, T3>>;
301 if (!is_constant<T1>::value) {
302 forward_as<T1_var>(arena_x).adj().array()
303 += ret.adj().array() * value_of(arena_y).array();
304 }
305 if (!is_constant<T2>::value) {
306 forward_as<T2_var>(arena_y).adj().array()
307 += ret.adj().array() * value_of(arena_x).array();
308 }
309 if (!is_constant<T3>::value) {
310 forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
311 }
312 };
313}
314
318template <typename T1, typename T2, typename T3, typename T4,
319 require_matrix_t<T2>* = nullptr,
320 require_all_stan_scalar_t<T1, T3>* = nullptr>
321inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
322 return [arena_x, arena_y, arena_z, ret]() mutable {
323 using T1_var = arena_t<promote_scalar_t<var, T1>>;
324 using T2_var = arena_t<promote_scalar_t<var, T2>>;
325 using T3_var = arena_t<promote_scalar_t<var, T3>>;
326 if (!is_constant<T1>::value) {
327 forward_as<T1_var>(arena_x).adj()
328 += (ret.adj().array() * value_of(arena_y).array()).sum();
329 }
330 if (!is_constant<T2>::value) {
331 forward_as<T2_var>(arena_y).adj().array()
332 += ret.adj().array() * value_of(arena_x);
333 }
334 if (!is_constant<T3>::value) {
335 forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
336 }
337 };
338}
339
343template <typename T1, typename T2, typename T3, typename T4,
344 require_matrix_t<T1>* = nullptr,
345 require_all_stan_scalar_t<T2, T3>* = nullptr>
346inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
347 return [arena_x, arena_y, arena_z, ret]() mutable {
348 using T1_var = arena_t<promote_scalar_t<var, T1>>;
349 using T2_var = arena_t<promote_scalar_t<var, T2>>;
350 using T3_var = arena_t<promote_scalar_t<var, T3>>;
351 if (!is_constant<T1>::value) {
352 forward_as<T1_var>(arena_x).adj().array()
353 += ret.adj().array() * value_of(arena_y);
354 }
355 if (!is_constant<T2>::value) {
356 forward_as<T2_var>(arena_y).adj()
357 += (ret.adj().array() * value_of(arena_x).array()).sum();
358 }
359 if (!is_constant<T3>::value) {
360 forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
361 }
362 };
363}
364
365} // namespace internal
366
385template <typename T1, typename T2, typename T3,
386 require_any_matrix_t<T1, T2, T3>* = nullptr,
387 require_var_t<return_type_t<T1, T2, T3>>* = nullptr>
388inline auto fma(const T1& x, const T2& y, const T3& z) {
389 arena_t<T1> arena_x = x;
390 arena_t<T2> arena_y = y;
391 arena_t<T3> arena_z = z;
393 check_matching_dims("fma", "x", arena_x, "y", arena_y);
394 }
396 check_matching_dims("fma", "x", arena_x, "z", arena_z);
397 }
399 check_matching_dims("fma", "y", arena_y, "z", arena_z);
400 }
401 using inner_ret_type
402 = decltype(fma(value_of(arena_x), value_of(arena_y), value_of(arena_z)));
405 = fma(value_of(arena_x), value_of(arena_y), value_of(arena_z));
407 internal::fma_reverse_pass(arena_x, arena_y, arena_z, ret));
408 return ret_type(ret);
409}
410
411} // namespace math
412} // namespace stan
413#endif
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
auto fma_reverse_pass(T1 &arena_x, T2 &arena_y, T3 &arena_z, T4 &ret)
Overload for matrix, matrix, matrix.
Definition fma.hpp:196
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:23
fvar< return_type_t< T1, T2, T3 > > fma(const fvar< T1 > &x1, const fvar< T2 > &x2, const fvar< T3 > &x3)
The fused multiply-add operation (C99).
Definition fma.hpp:60
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Check if a type is derived from Eigen::EigenBase or is a var_value whose value_type is derived from E...
Definition is_matrix.hpp:18