Automatic Differentiation
 
Loading...
Searching...
No Matches
add_diag.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_REV_ADD_DIAG_HPP
2#define STAN_MATH_OPENCL_REV_ADD_DIAG_HPP
3#ifdef STAN_OPENCL
11
12namespace stan {
13namespace math {
14
27template <typename T_m, typename T_a,
28 require_all_nonscalar_prim_or_rev_kernel_expression_t<T_m>* = nullptr,
29 require_all_prim_or_rev_kernel_expression_t<T_a>* = nullptr,
30 require_any_var_t<T_m, T_a>* = nullptr>
31inline auto add_diag(const T_m& mat, const T_a& to_add) {
32 const arena_t<T_m>& mat_arena = mat;
33 const arena_t<T_a>& to_add_arena = to_add;
34
36
37 reverse_pass_callback([mat_arena, to_add_arena, res]() mutable {
39 adjoint_of(mat_arena) += res.adj();
40 }
43 auto& to_add_adj
44 = forward_as<var_value<matrix_cl<double>>>(to_add_arena).adj();
45 to_add_adj += diagonal(res.adj());
46 } else {
47 auto& to_add_adj = forward_as<var_value<double>>(to_add_arena).adj();
48 to_add_adj += to_add_adj + sum(diagonal(res.adj()));
49 }
50 }
51 });
52 return res;
53}
54} // namespace math
55} // namespace stan
56
57#endif
58#endif
auto diagonal(T &&a)
Diagonal of a kernel generator expression.
Definition diagonal.hpp:136
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
auto & adjoint_of(const T &x)
Returns a reference to a variable's adjoint.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
auto add_diag(T_m &&mat, T_a &&to_add)
Returns a Matrix with values added along the main diagonal.
Definition add_diag.hpp:27
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Checks if decayed type is a var, fvar, or arithmetic.