1#ifndef STAN_MATH_REV_FUN_FMA_HPP
2#define STAN_MATH_REV_FUN_FMA_HPP
33 x.adj() += vi.adj() * y.val();
34 y.adj() += vi.adj() * x.val();
56template <
typename Tc, require_arithmetic_t<Tc>* =
nullptr>
59 x.adj() += vi.adj() * y.val();
60 y.adj() += vi.adj() * x.val();
84template <
typename Tb, require_arithmetic_t<Tb>* =
nullptr>
87 x.adj() += vi.adj() * y;
113template <
typename Tb,
typename Tc, require_all_arithmetic_t<Tb, Tc>* =
nullptr>
116 [x, y](
auto& vi) { x.adj() += vi.adj() * y; });
136template <
typename Ta,
typename Tc, require_all_arithmetic_t<Ta, Tc>* =
nullptr>
139 [x, y](
auto& vi) { y.adj() += vi.adj() * x; });
159template <
typename Ta,
typename Tb, require_all_arithmetic_t<Ta, Tb>* =
nullptr>
162 [z](
auto& vi) { z.adj() += vi.adj(); });
182template <
typename Ta, require_arithmetic_t<Ta>* =
nullptr>
185 y.adj() += vi.adj() * x;
194template <
typename T1,
typename T2,
typename T3,
typename T4,
197 return [arena_x, arena_y, arena_z, ret]()
mutable {
202 forward_as<T1_var>(arena_x).adj().array()
203 += ret.adj().array() *
value_of(arena_y).array();
206 forward_as<T2_var>(arena_y).adj().array()
207 += ret.adj().array() *
value_of(arena_x).array();
210 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
218template <
typename T1,
typename T2,
typename T3,
typename T4,
222 return [arena_x, arena_y, arena_z, ret]()
mutable {
227 forward_as<T1_var>(arena_x).adj()
228 += (ret.adj().array() *
value_of(arena_y).array()).
sum();
231 forward_as<T2_var>(arena_y).adj().array()
232 += ret.adj().array() *
value_of(arena_x);
235 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
243template <
typename T1,
typename T2,
typename T3,
typename T4,
246inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
247 return [arena_x, arena_y, arena_z, ret]()
mutable {
252 forward_as<T1_var>(arena_x).adj().array()
253 += ret.adj().array() *
value_of(arena_y);
256 forward_as<T2_var>(arena_y).adj()
257 += (ret.adj().array() *
value_of(arena_x).array()).
sum();
259 if (!is_constant<T3>::value) {
260 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
268template <
typename T1,
typename T2,
typename T3,
typename T4,
269 require_matrix_t<T3>* =
nullptr,
270 require_all_stan_scalar_t<T1, T2>* =
nullptr>
271inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
272 return [arena_x, arena_y, arena_z, ret]()
mutable {
273 using T1_var = arena_t<promote_scalar_t<var, T1>>;
274 using T2_var = arena_t<promote_scalar_t<var, T2>>;
275 using T3_var = arena_t<promote_scalar_t<var, T3>>;
276 if (!is_constant<T1>::value) {
277 forward_as<T1_var>(arena_x).adj()
280 if (!is_constant<T2>::value) {
281 forward_as<T2_var>(arena_y).adj()
284 if (!is_constant<T3>::value) {
285 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
293template <
typename T1,
typename T2,
typename T3,
typename T4,
294 require_all_matrix_t<T1, T2>* =
nullptr,
295 require_stan_scalar_t<T3>* =
nullptr>
296inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
297 return [arena_x, arena_y, arena_z, ret]()
mutable {
298 using T1_var = arena_t<promote_scalar_t<var, T1>>;
299 using T2_var = arena_t<promote_scalar_t<var, T2>>;
300 using T3_var = arena_t<promote_scalar_t<var, T3>>;
301 if (!is_constant<T1>::value) {
302 forward_as<T1_var>(arena_x).adj().array()
303 += ret.adj().array() *
value_of(arena_y).array();
305 if (!is_constant<T2>::value) {
306 forward_as<T2_var>(arena_y).adj().array()
307 += ret.adj().array() *
value_of(arena_x).array();
309 if (!is_constant<T3>::value) {
310 forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
318template <
typename T1,
typename T2,
typename T3,
typename T4,
319 require_matrix_t<T2>* =
nullptr,
320 require_all_stan_scalar_t<T1, T3>* =
nullptr>
321inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
322 return [arena_x, arena_y, arena_z, ret]()
mutable {
323 using T1_var = arena_t<promote_scalar_t<var, T1>>;
324 using T2_var = arena_t<promote_scalar_t<var, T2>>;
325 using T3_var = arena_t<promote_scalar_t<var, T3>>;
326 if (!is_constant<T1>::value) {
327 forward_as<T1_var>(arena_x).adj()
328 += (ret.adj().array() *
value_of(arena_y).array()).
sum();
330 if (!is_constant<T2>::value) {
331 forward_as<T2_var>(arena_y).adj().array()
332 += ret.adj().array() *
value_of(arena_x);
334 if (!is_constant<T3>::value) {
335 forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
343template <
typename T1,
typename T2,
typename T3,
typename T4,
344 require_matrix_t<T1>* =
nullptr,
345 require_all_stan_scalar_t<T2, T3>* =
nullptr>
346inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
347 return [arena_x, arena_y, arena_z, ret]()
mutable {
348 using T1_var = arena_t<promote_scalar_t<var, T1>>;
349 using T2_var = arena_t<promote_scalar_t<var, T2>>;
350 using T3_var = arena_t<promote_scalar_t<var, T3>>;
351 if (!is_constant<T1>::value) {
352 forward_as<T1_var>(arena_x).adj().array()
353 += ret.adj().array() *
value_of(arena_y);
355 if (!is_constant<T2>::value) {
356 forward_as<T2_var>(arena_y).adj()
357 += (ret.adj().array() *
value_of(arena_x).array()).
sum();
359 if (!is_constant<T3>::value) {
360 forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
385template <
typename T1,
typename T2,
typename T3,
386 require_any_matrix_t<T1, T2, T3>* =
nullptr,
387 require_var_t<return_type_t<T1, T2, T3>>* =
nullptr>
388inline auto fma(
const T1& x,
const T2& y,
const T3& z) {
408 return ret_type(ret);
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
auto fma_reverse_pass(T1 &arena_x, T2 &arena_y, T3 &arena_z, T4 &ret)
Overload for matrix, matrix, matrix.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
fvar< return_type_t< T1, T2, T3 > > fma(const fvar< T1 > &x1, const fvar< T2 > &x2, const fvar< T3 > &x3)
The fused multiply-add operation (C99).
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Check if a type is derived from Eigen::EigenBase or is a var_value whose value_type is derived from E...