1#ifndef STAN_MATH_REV_FUN_FMA_HPP
2#define STAN_MATH_REV_FUN_FMA_HPP
33 x.adj() += vi.adj() * y.val();
34 y.adj() += vi.adj() * x.val();
56template <
typename Tc, require_arithmetic_t<Tc>* =
nullptr>
59 x.adj() += vi.adj() * y.val();
60 y.adj() += vi.adj() * x.val();
84template <
typename Tb, require_arithmetic_t<Tb>* =
nullptr>
87 x.adj() += vi.adj() * y;
113template <
typename Tb,
typename Tc, require_all_arithmetic_t<Tb, Tc>* =
nullptr>
116 [x, y](
auto& vi) { x.adj() += vi.adj() * y; });
136template <
typename Ta,
typename Tc, require_all_arithmetic_t<Ta, Tc>* =
nullptr>
139 [x, y](
auto& vi) { y.adj() += vi.adj() * x; });
159template <
typename Ta,
typename Tb, require_all_arithmetic_t<Ta, Tb>* =
nullptr>
162 [z](
auto& vi) { z.adj() += vi.adj(); });
182template <
typename Ta, require_arithmetic_t<Ta>* =
nullptr>
185 y.adj() += vi.adj() * x;
194template <
typename T1,
typename T2,
typename T3,
typename T4,
197 return [arena_x, arena_y, arena_z, ret]()
mutable {
198 if constexpr (is_autodiff_v<T1>) {
199 arena_x.adj().array() += ret.adj().array() *
value_of(arena_y).array();
201 if constexpr (is_autodiff_v<T2>) {
202 arena_y.adj().array() += ret.adj().array() *
value_of(arena_x).array();
204 if constexpr (is_autodiff_v<T3>) {
205 arena_z.adj().array() += ret.adj().array();
213template <
typename T1,
typename T2,
typename T3,
typename T4,
217 return [arena_x, arena_y, arena_z, ret]()
mutable {
218 if constexpr (is_autodiff_v<T1>) {
219 arena_x.adj() += (ret.adj().array() *
value_of(arena_y).array()).
sum();
221 if constexpr (is_autodiff_v<T2>) {
222 arena_y.adj().array() += ret.adj().array() *
value_of(arena_x);
224 if constexpr (is_autodiff_v<T3>) {
225 arena_z.adj().array() += ret.adj().array();
233template <
typename T1,
typename T2,
typename T3,
typename T4,
236inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
237 return [arena_x, arena_y, arena_z, ret]()
mutable {
238 if constexpr (is_autodiff_v<T1>) {
239 arena_x.adj().array() += ret.adj().array() *
value_of(arena_y);
241 if constexpr (is_autodiff_v<T2>) {
242 arena_y.adj() += (ret.adj().array() *
value_of(arena_x).array()).
sum();
244 if constexpr (is_autodiff_v<T3>) {
245 arena_z.adj().array() += ret.adj().array();
253template <
typename T1,
typename T2,
typename T3,
typename T4,
254 require_matrix_t<T3>* =
nullptr,
255 require_all_stan_scalar_t<T1, T2>* =
nullptr>
256inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
257 return [arena_x, arena_y, arena_z, ret]()
mutable {
258 if constexpr (is_autodiff_v<T1>) {
259 arena_x.adj() += (ret.adj().array() *
value_of(arena_y)).
sum();
261 if constexpr (is_autodiff_v<T2>) {
262 arena_y.adj() += (ret.adj().array() *
value_of(arena_x)).
sum();
264 if constexpr (is_autodiff_v<T3>) {
265 arena_z.adj().array() += ret.adj().array();
273template <
typename T1,
typename T2,
typename T3,
typename T4,
274 require_all_matrix_t<T1, T2>* =
nullptr,
275 require_stan_scalar_t<T3>* =
nullptr>
276inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
277 return [arena_x, arena_y, arena_z, ret]()
mutable {
278 if constexpr (is_autodiff_v<T1>) {
279 arena_x.adj().array() += ret.adj().array() *
value_of(arena_y).array();
281 if constexpr (is_autodiff_v<T2>) {
282 arena_y.adj().array() += ret.adj().array() *
value_of(arena_x).array();
284 if constexpr (is_autodiff_v<T3>) {
285 arena_z.adj() += ret.adj().sum();
293template <
typename T1,
typename T2,
typename T3,
typename T4,
294 require_matrix_t<T2>* =
nullptr,
295 require_all_stan_scalar_t<T1, T3>* =
nullptr>
296inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
297 return [arena_x, arena_y, arena_z, ret]()
mutable {
298 if constexpr (is_autodiff_v<T1>) {
299 arena_x.adj() += (ret.adj().array() *
value_of(arena_y).array()).
sum();
301 if constexpr (is_autodiff_v<T2>) {
302 arena_y.adj().array() += ret.adj().array() *
value_of(arena_x);
304 if constexpr (is_autodiff_v<T3>) {
305 arena_z.adj() += ret.adj().sum();
313template <
typename T1,
typename T2,
typename T3,
typename T4,
314 require_matrix_t<T1>* =
nullptr,
315 require_all_stan_scalar_t<T2, T3>* =
nullptr>
316inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
317 return [arena_x, arena_y, arena_z, ret]()
mutable {
318 if constexpr (is_autodiff_v<T1>) {
319 arena_x.adj().array() += ret.adj().array() *
value_of(arena_y);
321 if constexpr (is_autodiff_v<T2>) {
322 arena_y.adj() += (ret.adj().array() *
value_of(arena_x).array()).
sum();
324 if constexpr (is_autodiff_v<T3>) {
325 arena_z.adj() += ret.adj().sum();
350template <
typename T1,
typename T2,
typename T3,
351 require_any_matrix_t<T1, T2, T3>* =
nullptr,
352 require_var_t<return_type_t<T1, T2, T3>>* =
nullptr>
353inline auto fma(
const T1& x,
const T2& y,
const T3& z) {
373 return ret_type(ret);
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
auto fma_reverse_pass(T1 &arena_x, T2 &arena_y, T3 &arena_z, T4 &ret)
Overload for matrix, matrix, matrix.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
fvar< return_type_t< T1, T2, T3 > > fma(const fvar< T1 > &x1, const fvar< T2 > &x2, const fvar< T3 > &x3)
The fused multiply-add operation (C99).
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Check if a type is derived from Eigen::EigenBase or is a var_value whose value_type is derived from E...