Automatic Differentiation
 
Loading...
Searching...
No Matches
operator_subtraction.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CORE_OPERATOR_SUBTRACTION_HPP
2#define STAN_MATH_REV_CORE_OPERATOR_SUBTRACTION_HPP
3
12
13namespace stan {
14namespace math {
15
56inline var operator-(const var& a, const var& b) {
57 return make_callback_vari(a.vi_->val_ - b.vi_->val_,
58 [avi = a.vi_, bvi = b.vi_](const auto& vi) mutable {
59 avi->adj_ += vi.adj_;
60 bvi->adj_ -= vi.adj_;
61 });
62}
63
77template <typename Arith, require_arithmetic_t<Arith>* = nullptr>
78inline var operator-(const var& a, Arith b) {
79 if (unlikely(b == 0.0)) {
80 return a;
81 }
82 return make_callback_vari(
83 a.vi_->val_ - b,
84 [avi = a.vi_](const auto& vi) mutable { avi->adj_ += vi.adj_; });
85}
86
100template <typename Arith, require_arithmetic_t<Arith>* = nullptr>
101inline var operator-(Arith a, const var& b) {
102 return make_callback_vari(
103 a - b.vi_->val_,
104 [bvi = b.vi_](const auto& vi) mutable { bvi->adj_ -= vi.adj_; });
105}
106
116template <typename VarMat1, typename VarMat2,
118inline auto subtract(const VarMat1& a, const VarMat2& b) {
119 check_matching_dims("subtract", "a", a, "b", b);
120 using op_ret_type = decltype(a.val() - b.val());
122 arena_t<VarMat1> arena_a = a;
123 arena_t<VarMat2> arena_b = b;
124 arena_t<ret_type> ret((arena_a.val() - arena_b.val()));
125 reverse_pass_callback([ret, arena_a, arena_b]() mutable {
126 for (Eigen::Index j = 0; j < ret.cols(); ++j) {
127 for (Eigen::Index i = 0; i < ret.rows(); ++i) {
128 const auto ref_adj = ret.adj().coeffRef(i, j);
129 arena_a.adj().coeffRef(i, j) += ref_adj;
130 arena_b.adj().coeffRef(i, j) -= ref_adj;
131 }
132 }
133 });
134 return ret_type(ret);
135}
136
146template <typename Arith, typename VarMat,
149inline auto subtract(const VarMat& a, const Arith& b) {
151 check_matching_dims("subtract", "a", a, "b", b);
152 }
153 using op_ret_type = plain_type_t<decltype(
154 (a.val().array() - as_array_or_scalar(b)).matrix())>;
156 arena_t<VarMat> arena_a = a;
157 arena_t<ret_type> ret(arena_a.val().array() - as_array_or_scalar(b));
159 [ret, arena_a]() mutable { arena_a.adj() += ret.adj(); });
160 return ret_type(ret);
161}
162
172template <typename Arith, typename VarMat,
175inline auto subtract(const Arith& a, const VarMat& b) {
177 check_matching_dims("subtract", "a", a, "b", b);
178 }
179 using op_ret_type = plain_type_t<decltype(
180 (as_array_or_scalar(a) - b.val().array()).matrix())>;
182 arena_t<VarMat> arena_b = b;
183 arena_t<ret_type> ret(as_array_or_scalar(a) - arena_b.val().array());
185 [ret, arena_b]() mutable { arena_b.adj() -= ret.adj_op(); });
186 return ret_type(ret);
187}
188
198template <typename Var, typename EigMat,
201inline auto subtract(const Var& a, const EigMat& b) {
202 using ret_type = return_var_matrix_t<EigMat>;
203 arena_t<ret_type> ret(a.val() - b.array());
204 reverse_pass_callback([ret, a]() mutable { a.adj() += ret.adj().sum(); });
205 return ret_type(ret);
206}
207
217template <typename EigMat, typename Var,
220inline auto subtract(const EigMat& a, const Var& b) {
221 using ret_type = return_var_matrix_t<EigMat>;
222 arena_t<ret_type> ret(a.array() - b.val());
223 reverse_pass_callback([ret, b]() mutable { b.adj() -= ret.adj().sum(); });
224 return ret_type(ret);
225}
226
237template <typename Var, typename VarMat,
240inline auto subtract(const Var& a, const VarMat& b) {
241 using ret_type = return_var_matrix_t<VarMat>;
242 arena_t<VarMat> arena_b(b);
243 arena_t<ret_type> ret(a.val() - arena_b.val().array());
244 reverse_pass_callback([ret, a, arena_b]() mutable {
245 for (Eigen::Index j = 0; j < ret.cols(); ++j) {
246 for (Eigen::Index i = 0; i < ret.rows(); ++i) {
247 auto ret_adj = ret.adj().coeff(i, j);
248 a.adj() += ret_adj;
249 arena_b.adj().coeffRef(i, j) -= ret_adj;
250 }
251 }
252 });
253 return ret_type(ret);
254}
255
266template <typename Var, typename VarMat,
269inline auto subtract(const VarMat& a, const Var& b) {
270 using ret_type = return_var_matrix_t<VarMat>;
271 arena_t<VarMat> arena_a(a);
272 arena_t<ret_type> ret(arena_a.val().array() - b.val());
273 reverse_pass_callback([ret, b, arena_a]() mutable {
274 for (Eigen::Index j = 0; j < ret.cols(); ++j) {
275 for (Eigen::Index i = 0; i < ret.rows(); ++i) {
276 const auto ret_adj = ret.adj().coeff(i, j);
277 arena_a.adj().coeffRef(i, j) += ret_adj;
278 b.adj() -= ret_adj;
279 }
280 }
281 });
282 return ret_type(ret);
283}
284
285template <typename T1, typename T2,
288inline auto subtract(const T1& a, const T2& b) {
289 return a - b;
290}
291
292template <typename T1, typename T2,
294inline auto subtract(const T1& a, const T2& b) {
295 return a - b;
296}
297
307template <typename VarMat1, typename VarMat2,
309inline auto operator-(const VarMat1& a, const VarMat2& b) {
310 return subtract(a, b);
311}
312
313} // namespace math
314} // namespace stan
315#endif
#define unlikely(x)
require_any_t< std::is_arithmetic< std::decay_t< Types > >... > require_any_arithmetic_t
Require any of the types satisfy std::is_arithmetic.
require_t< std::is_arithmetic< scalar_type_t< std::decay_t< T > > > > require_st_arithmetic
Require scalar type satisfies std::is_arithmetic.
require_t< container_type_check_base< is_eigen, value_type_t, TypeCheck, Check... > > require_eigen_vt
Require type satisfies is_eigen.
Definition is_eigen.hpp:152
subtraction_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > subtract(T_a &&a, T_b &&b)
require_t< is_rev_matrix< std::decay_t< T > > > require_rev_matrix_t
Require type satisfies is_rev_matrix.
require_all_t< is_rev_matrix< std::decay_t< Types > >... > require_all_rev_matrix_t
Require all of the types satisfy is_rev_matrix.
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
require_t< container_type_check_base< is_var, value_type_t, TypeCheck, Check... > > require_var_vt
Require type satisfies is_var.
Definition is_var.hpp:64
require_any_t< container_type_check_base< is_var, value_type_t, TypeCheck, Check >... > require_any_var_vt
Require any of the types satisfy is_var.
Definition is_var.hpp:73
require_all_t< container_type_check_base< is_var, value_type_t, TypeCheck, Check >... > require_all_var_vt
Require all of the types satisfy is_var.
Definition is_var.hpp:82
fvar< T > operator-(const fvar< T > &x1, const fvar< T > &x2)
Return the difference of the specified arguments.
T as_array_or_scalar(T &&v)
Returns specified input value.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
internal::callback_vari< plain_type_t< T >, F > * make_callback_vari(T &&value, F &&functor)
Creates a new vari with given value and a callback that implements the reverse pass (chain).
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Check if type derives from EigenBase
Definition is_eigen.hpp:21