Automatic Differentiation
 
Loading...
Searching...
No Matches
diag_post_multiply.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_DIAG_POST_MULTIPLY_HPP
2#define STAN_MATH_REV_FUN_DIAG_POST_MULTIPLY_HPP
3
8
9namespace stan {
10namespace math {
11
24template <typename T1, typename T2, require_matrix_t<T1>* = nullptr,
25 require_vector_t<T2>* = nullptr,
26 require_any_st_var<T1, T2>* = nullptr>
27inline auto diag_post_multiply(const T1& m1, const T2& m2) {
28 check_size_match("diag_post_multiply", "m2.size()", m2.size(), "m1.cols()",
29 m1.cols());
30 using inner_ret_type = decltype(value_of(m1) * value_of(m2).asDiagonal());
32
33 if constexpr (is_autodiff_v<T1> && is_autodiff_v<T2>) {
36 arena_t<ret_type> ret(arena_m1.val() * arena_m2.val().asDiagonal());
37 reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
38 arena_m2.adj() += arena_m1.val().cwiseProduct(ret.adj()).colwise().sum();
39 arena_m1.adj() += ret.adj() * arena_m2.val().asDiagonal();
40 });
41 return ret_type(ret);
42 } else if constexpr (is_autodiff_v<T1>) {
45 arena_t<ret_type> ret(arena_m1.val() * arena_m2.asDiagonal());
46 reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
47 arena_m1.adj() += ret.adj() * arena_m2.val().asDiagonal();
48 });
49 return ret_type(ret);
50 } else if constexpr (is_autodiff_v<T2>) {
53 arena_t<ret_type> ret(arena_m1 * arena_m2.val().asDiagonal());
54 reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
55 arena_m2.adj() += arena_m1.val().cwiseProduct(ret.adj()).colwise().sum();
56 });
57 return ret_type(ret);
58 }
59}
60
61} // namespace math
62} // namespace stan
63
64#endif
auto diag_post_multiply(const T1 &m1, const T2 &m2)
Return the product of a matrix and the diagonal matrix formed from the vector or row_vector.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...