Automatic Differentiation
 
Loading...
Searching...
No Matches
operator_addition.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CORE_OPERATOR_ADDITION_HPP
2#define STAN_MATH_REV_CORE_OPERATOR_ADDITION_HPP
3
11
12namespace stan {
13namespace math {
14
53inline var operator+(const var& a, const var& b) {
54 return make_callback_vari(a.vi_->val_ + b.vi_->val_,
55 [avi = a.vi_, bvi = b.vi_](const auto& vi) mutable {
56 avi->adj_ += vi.adj_;
57 bvi->adj_ += vi.adj_;
58 });
59}
60
73template <typename Arith, require_arithmetic_t<Arith>* = nullptr>
74inline var operator+(const var& a, Arith b) {
75 if (unlikely(b == 0.0)) {
76 return a;
77 }
78 return make_callback_vari(
79 a.vi_->val_ + b,
80 [avi = a.vi_](const auto& vi) mutable { avi->adj_ += vi.adj_; });
81}
82
95template <typename Arith, require_arithmetic_t<Arith>* = nullptr>
96inline var operator+(Arith a, const var& b) {
97 return b + a; // by symmetry
98}
99
109template <typename VarMat1, typename VarMat2,
111inline auto add(VarMat1&& a, VarMat2&& b) {
112 check_matching_dims("add", "a", a, "b", b);
113 using op_ret_type = decltype(a.val() + b.val());
115 arena_t<VarMat1> arena_a(std::forward<VarMat1>(a));
116 arena_t<VarMat2> arena_b(std::forward<VarMat2>(b));
117 arena_t<ret_type> ret(arena_a.val() + arena_b.val());
118 reverse_pass_callback([ret, arena_a, arena_b]() mutable {
119 for (Eigen::Index j = 0; j < ret.cols(); ++j) {
120 for (Eigen::Index i = 0; i < ret.rows(); ++i) {
121 const auto ref_adj = ret.adj().coeffRef(i, j);
122 arena_a.adj().coeffRef(i, j) += ref_adj;
123 arena_b.adj().coeffRef(i, j) += ref_adj;
124 }
125 }
126 });
127 return ret;
128}
129
139template <typename Arith, typename VarMat,
142inline auto add(VarMat&& a, const Arith& b) {
144 check_matching_dims("add", "a", a, "b", b);
145 }
146 using op_ret_type
147 = decltype((a.val().array() + as_array_or_scalar(b)).matrix());
149 arena_t<VarMat> arena_a(std::forward<VarMat>(a));
150 arena_t<ret_type> ret(arena_a.val().array() + as_array_or_scalar(b));
152 [ret, arena_a]() mutable { arena_a.adj() += ret.adj_op(); });
153 return ret;
154}
155
165template <typename Arith, typename VarMat,
168inline auto add(const Arith& a, VarMat&& b) {
169 return add(std::forward<VarMat>(b), a);
170}
171
181template <typename Var, typename EigMat,
184inline auto add(const Var& a, const EigMat& b) {
185 using ret_type = return_var_matrix_t<EigMat>;
186 arena_t<ret_type> ret(a.val() + b.array());
187 reverse_pass_callback([ret, a]() mutable { a.adj() += ret.adj().sum(); });
188 return ret;
189}
190
200template <typename EigMat, typename Var,
203inline auto add(const EigMat& a, const Var& b) {
204 return add(b, a);
205}
206
217template <typename Var, typename VarMat,
220inline auto add(const Var& a, VarMat&& b) {
221 using ret_type = return_var_matrix_t<VarMat>;
222 arena_t<VarMat> arena_b(std::forward<VarMat>(b));
223 arena_t<ret_type> ret(a.val() + arena_b.val().array());
224 reverse_pass_callback([ret, a, arena_b]() mutable {
225 for (Eigen::Index j = 0; j < ret.cols(); ++j) {
226 for (Eigen::Index i = 0; i < ret.rows(); ++i) {
227 const auto ret_adj = ret.adj().coeffRef(i, j);
228 a.adj() += ret_adj;
229 arena_b.adj().coeffRef(i, j) += ret_adj;
230 }
231 }
232 });
233 return ret;
234}
235
246template <typename Var, typename VarMat,
249inline auto add(VarMat&& a, const Var& b) {
250 return add(b, std::forward<VarMat>(a));
251}
252
253template <typename T1, typename T2,
256inline auto add(const T1& a, const T2& b) {
257 return a + b;
258}
259
260template <typename T1, typename T2,
262inline auto add(const T1& a, const T2& b) {
263 return a + b;
264}
265
275template <typename VarMat1, typename VarMat2,
277inline auto operator+(VarMat1&& a, VarMat2&& b) {
278 return add(std::forward<VarMat1>(a), std::forward<VarMat2>(b));
279}
280
281} // namespace math
282} // namespace stan
283#endif
#define unlikely(x)
require_any_t< std::is_arithmetic< std::decay_t< Types > >... > require_any_arithmetic_t
Require any of the types satisfy std::is_arithmetic.
require_t< std::is_arithmetic< scalar_type_t< std::decay_t< T > > > > require_st_arithmetic
Require scalar type satisfies std::is_arithmetic.
require_t< container_type_check_base< is_eigen, value_type_t, TypeCheck, Check... > > require_eigen_vt
Require type satisfies is_eigen.
Definition is_eigen.hpp:97
addition_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > add(T_a &&a, T_b &&b)
require_t< is_rev_matrix< std::decay_t< T > > > require_rev_matrix_t
Require type satisfies is_rev_matrix.
require_all_t< is_rev_matrix< std::decay_t< Types > >... > require_all_rev_matrix_t
Require all of the types satisfy is_rev_matrix.
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
require_t< container_type_check_base< is_var, value_type_t, TypeCheck, Check... > > require_var_vt
Require type satisfies is_var.
Definition is_var.hpp:64
require_any_t< container_type_check_base< is_var, value_type_t, TypeCheck, Check >... > require_any_var_vt
Require any of the types satisfy is_var.
Definition is_var.hpp:73
require_all_t< container_type_check_base< is_var, value_type_t, TypeCheck, Check >... > require_all_var_vt
Require all of the types satisfy is_var.
Definition is_var.hpp:82
T as_array_or_scalar(T &&v)
Returns specified input value.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
fvar< T > operator+(const fvar< T > &x1, const fvar< T > &x2)
Return the sum of the specified forward mode addends.
internal::callback_vari< plain_type_t< T >, F > * make_callback_vari(T &&value, F &&functor)
Creates a new vari with given value and a callback that implements the reverse pass (chain).
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Check if type derives from EigenBase
Definition is_eigen.hpp:21