1#ifndef STAN_MATH_REV_FUN_DIAG_PRE_MULTIPLY_HPP
2#define STAN_MATH_REV_FUN_DIAG_PRE_MULTIPLY_HPP
23template <
typename T1,
typename T2, require_vector_t<T1>* =
nullptr,
24 require_matrix_t<T2>* =
nullptr,
25 require_any_st_var<T1, T2>* =
nullptr>
30 using ret_type = return_var_matrix_t<inner_ret_type, T1, T2>;
31 if (!is_constant<T1>::value && !is_constant<T2>::value) {
32 arena_t<promote_scalar_t<var, T1>> arena_m1 = m1;
33 arena_t<promote_scalar_t<var, T2>> arena_m2 = m2;
34 arena_t<ret_type> ret(arena_m1.val().asDiagonal() * arena_m2.val());
36 arena_m1.adj() += arena_m2.val().cwiseProduct(ret.adj()).rowwise().sum();
37 arena_m2.adj() += arena_m1.val().asDiagonal() * ret.adj();
40 }
else if (!is_constant<T1>::value) {
41 arena_t<promote_scalar_t<var, T1>> arena_m1 = m1;
42 arena_t<promote_scalar_t<double, T2>> arena_m2 =
value_of(m2);
43 arena_t<ret_type> ret(arena_m1.val().asDiagonal() * arena_m2);
45 arena_m1.adj() += arena_m2.val().cwiseProduct(ret.adj()).rowwise().sum();
48 }
else if (!is_constant<T2>::value) {
49 arena_t<promote_scalar_t<double, T1>> arena_m1 =
value_of(m1);
50 arena_t<promote_scalar_t<var, T2>> arena_m2 = m2;
51 arena_t<ret_type> ret(arena_m1.asDiagonal() * arena_m2.val());
53 arena_m2.adj() += arena_m1.val().asDiagonal() * ret.adj();
auto diag_pre_multiply(const T1 &m1, const T2 &m2)
Return the product of the diagonal matrix formed from the vector or row_vector and a matrix.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...