Automatic Differentiation
 
Loading...
Searching...
No Matches
elt_divide.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_ELT_DIVIDE_HPP
2#define STAN_MATH_REV_FUN_ELT_DIVIDE_HPP
3
9
10namespace stan {
11namespace math {
12
24template <typename Mat1, typename Mat2,
25 require_all_matrix_t<Mat1, Mat2>* = nullptr,
26 require_any_rev_matrix_t<Mat1, Mat2>* = nullptr>
27auto elt_divide(const Mat1& m1, const Mat2& m2) {
28 check_matching_dims("elt_divide", "m1", m1, "m2", m2);
29 using inner_ret_type
30 = decltype((value_of(m1).array() / value_of(m2).array()).matrix());
31 using ret_type = return_var_matrix_t<inner_ret_type, Mat1, Mat2>;
32 if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
33 arena_t<promote_scalar_t<var, Mat1>> arena_m1 = m1;
34 arena_t<promote_scalar_t<var, Mat2>> arena_m2 = m2;
35 arena_t<ret_type> ret(arena_m1.val().array() / arena_m2.val().array());
36 reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
37 for (Eigen::Index j = 0; j < arena_m2.cols(); ++j) {
38 for (Eigen::Index i = 0; i < arena_m2.rows(); ++i) {
39 const auto ret_div
40 = ret.adj().coeff(i, j) / arena_m2.val().coeff(i, j);
41 arena_m1.adj().coeffRef(i, j) += ret_div;
42 arena_m2.adj().coeffRef(i, j) -= ret.val().coeff(i, j) * ret_div;
43 }
44 }
45 });
46 return ret_type(ret);
47 } else if (!is_constant<Mat1>::value) {
48 arena_t<promote_scalar_t<var, Mat1>> arena_m1 = m1;
49 arena_t<promote_scalar_t<double, Mat2>> arena_m2 = value_of(m2);
50 arena_t<ret_type> ret(arena_m1.val().array() / arena_m2.array());
51 reverse_pass_callback([ret, arena_m1, arena_m2]() mutable {
52 arena_m1.adj().array() += ret.adj().array() / arena_m2.array();
53 });
54 return ret_type(ret);
55 } else if (!is_constant<Mat2>::value) {
56 arena_t<promote_scalar_t<double, Mat1>> arena_m1 = value_of(m1);
57 arena_t<promote_scalar_t<var, Mat2>> arena_m2 = m2;
58 arena_t<ret_type> ret(arena_m1.array() / arena_m2.val().array());
59 reverse_pass_callback([ret, arena_m2, arena_m1]() mutable {
60 arena_m2.adj().array()
61 -= ret.val().array() * ret.adj().array() / arena_m2.val().array();
62 });
63 return ret_type(ret);
64 }
65}
66
78template <typename Scal, typename Mat, require_stan_scalar_t<Scal>* = nullptr,
79 require_var_matrix_t<Mat>* = nullptr>
80auto elt_divide(Scal s, const Mat& m) {
81 plain_type_t<Mat> res = value_of(s) / m.val().array();
82
83 reverse_pass_callback([m, s, res]() mutable {
84 m.adj().array() -= res.val().array() * res.adj().array() / m.val().array();
85 if (!is_constant<Scal>::value)
86 forward_as<var>(s).adj() += (res.adj().array() / m.val().array()).sum();
87 });
88
89 return res;
90}
91
92} // namespace math
93} // namespace stan
94
95#endif
elt_divide_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_divide(T_a &&a, T_b &&b)
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...