Automatic Differentiation
 
Loading...
Searching...
No Matches
operator_division.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CORE_OPERATOR_DIVISION_HPP
2#define STAN_MATH_REV_CORE_OPERATOR_DIVISION_HPP
3
17#include <complex>
18#include <type_traits>
19
20namespace stan {
21namespace math {
22
61inline var operator/(const var& dividend, const var& divisor) {
62 return make_callback_var(
63 dividend.val() / divisor.val(), [dividend, divisor](auto&& vi) {
64 dividend.adj() += vi.adj() / divisor.val();
65 divisor.adj()
66 -= vi.adj() * dividend.val() / (divisor.val() * divisor.val());
67 });
68}
69
83template <typename Arith, require_arithmetic_t<Arith>* = nullptr>
84inline var operator/(const var& dividend, Arith divisor) {
85 if (divisor == 1.0) {
86 return dividend;
87 }
88 return make_callback_var(
89 dividend.val() / divisor,
90 [dividend, divisor](auto&& vi) { dividend.adj() += vi.adj() / divisor; });
91}
92
105template <typename Arith, require_arithmetic_t<Arith>* = nullptr>
106inline var operator/(Arith dividend, const var& divisor) {
107 return make_callback_var(
108 dividend / divisor.val(), [dividend, divisor](auto&& vi) {
109 divisor.adj() -= vi.adj() * dividend / (divisor.val() * divisor.val());
110 });
111}
112
122template <typename Scalar, typename Mat, require_matrix_t<Mat>* = nullptr,
123 require_stan_scalar_t<Scalar>* = nullptr,
124 require_all_st_var_or_arithmetic<Scalar, Mat>* = nullptr,
125 require_any_st_var<Scalar, Mat>* = nullptr>
126inline auto divide(const Mat& m, Scalar c) {
129 var arena_c = c;
130 auto inv_c = (1.0 / arena_c.val());
131 arena_t<promote_scalar_t<var, Mat>> res = inv_c * arena_m.val();
132 reverse_pass_callback([arena_c, inv_c, arena_m, res]() mutable {
133 auto inv_times_adj = (inv_c * res.adj().array()).eval();
134 arena_c.adj() -= (inv_times_adj * res.val().array()).sum();
135 arena_m.adj().array() += inv_times_adj;
136 });
137 return promote_scalar_t<var, Mat>(res);
138 } else if (!is_constant<Mat>::value) {
140 auto inv_c = (1.0 / value_of(c));
141 arena_t<promote_scalar_t<var, Mat>> res = inv_c * arena_m.val();
142 reverse_pass_callback([inv_c, arena_m, res]() mutable {
143 arena_m.adj().array() += inv_c * res.adj_op().array();
144 });
145 return promote_scalar_t<var, Mat>(res);
146 } else {
147 var arena_c = c;
148 auto inv_c = (1.0 / arena_c.val());
149 arena_t<promote_scalar_t<var, Mat>> res = inv_c * value_of(m).array();
150 reverse_pass_callback([arena_c, inv_c, res]() mutable {
151 arena_c.adj() -= inv_c * (res.adj().array() * res.val().array()).sum();
152 });
153 return promote_scalar_t<var, Mat>(res);
154 }
155}
156
166template <typename Scalar, typename Mat, require_matrix_t<Mat>* = nullptr,
167 require_stan_scalar_t<Scalar>* = nullptr,
168 require_all_st_var_or_arithmetic<Scalar, Mat>* = nullptr,
169 require_any_st_var<Scalar, Mat>* = nullptr>
170inline auto divide(Scalar c, const Mat& m) {
173 auto inv_m = to_arena(arena_m.val().array().inverse());
174 var arena_c = c;
175 arena_t<promote_scalar_t<var, Mat>> res = arena_c.val() * inv_m;
176 reverse_pass_callback([arena_c, inv_m, arena_m, res]() mutable {
177 auto inv_times_res = (inv_m * res.adj().array()).eval();
178 arena_m.adj().array() -= inv_times_res * res.val().array();
179 arena_c.adj() += (inv_times_res).sum();
180 });
181 return promote_scalar_t<var, Mat>(res);
182 } else if (!is_constant<Mat>::value) {
184 auto inv_m = to_arena(arena_m.val().array().inverse());
186 reverse_pass_callback([inv_m, arena_m, res]() mutable {
187 arena_m.adj().array() -= inv_m * res.adj().array() * res.val().array();
188 });
189 return promote_scalar_t<var, Mat>(res);
190 } else {
191 auto inv_m = to_arena(value_of(m).array().inverse());
192 var arena_c = c;
193 arena_t<promote_scalar_t<var, Mat>> res = arena_c.val() * inv_m;
194 reverse_pass_callback([arena_c, inv_m, res]() mutable {
195 arena_c.adj() += (inv_m * res.adj().array()).sum();
196 });
197 return promote_scalar_t<var, Mat>(res);
198 }
199}
200
202
212template <typename Mat1, typename Mat2,
215inline auto divide(const Mat1& m1, const Mat2& m2) {
219 auto inv_m2 = to_arena(arena_m2.val().array().inverse());
220 using val_ret = decltype((inv_m2 * arena_m1.val().array()).matrix().eval());
222 arena_t<ret_type> res = (inv_m2.array() * arena_m1.val().array()).matrix();
223 reverse_pass_callback([inv_m2, arena_m1, arena_m2, res]() mutable {
224 auto inv_times_res = (inv_m2 * res.adj().array()).eval();
225 arena_m1.adj().array() += inv_times_res;
226 arena_m2.adj().array() -= inv_times_res * res.val().array();
227 });
228 return ret_type(res);
229 } else if (!is_constant<Mat2>::value) {
232 auto inv_m2 = to_arena(arena_m2.val().array().inverse());
233 using val_ret = decltype((inv_m2 * arena_m1.array()).matrix().eval());
235 arena_t<ret_type> res = (inv_m2.array() * arena_m1.array()).matrix();
236 reverse_pass_callback([inv_m2, arena_m1, arena_m2, res]() mutable {
237 arena_m2.adj().array() -= inv_m2 * res.adj().array() * res.val().array();
238 });
239 return ret_type(res);
240 } else {
243 auto inv_m2 = to_arena(arena_m2.array().inverse());
244 using val_ret = decltype((inv_m2 * arena_m1.val().array()).matrix().eval());
246 arena_t<ret_type> res = (inv_m2.array() * arena_m1.val().array()).matrix();
247 reverse_pass_callback([inv_m2, arena_m1, arena_m2, res]() mutable {
248 arena_m1.adj().array() += inv_m2 * res.adj().array();
249 });
250 return ret_type(res);
251 }
252}
253
254template <typename T1, typename T2, require_any_var_matrix_t<T1, T2>* = nullptr>
255inline auto operator/(const T1& dividend, const T2& divisor) {
256 return divide(dividend, divisor);
257}
258
259inline std::complex<var> operator/(const std::complex<var>& x1,
260 const std::complex<var>& x2) {
261 return internal::complex_divide(x1, x2);
262}
263
264} // namespace math
265} // namespace stan
266#endif
require_any_t< container_type_check_base< is_matrix, scalar_type_t, TypeCheck, Check >... > require_any_matrix_st
Require any of the types satisfy is_matrix.
Definition is_matrix.hpp:64
require_all_t< container_type_check_base< is_matrix, scalar_type_t, TypeCheck, Check >... > require_all_matrix_st
Require all of the types does not satisfy is_matrix.
Definition is_matrix.hpp:73
auto divide(T_a &&a, double d)
Returns the elementwise division of the kernel generator expression.
Definition divide.hpp:20
complex_return_t< U, V > complex_divide(const U &lhs, const V &rhs)
Return the quotient of the specified arguments.
fvar< T > operator/(const fvar< T > &x1, const fvar< T > &x2)
Return the result of dividing the first argument by the second.
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
Eigen::Matrix< value_type_t< EigMat >, EigMat::RowsAtCompileTime, EigMat::ColsAtCompileTime > inverse(const EigMat &m)
Forward mode specialization of calculating the inverse of the matrix.
Definition inverse.hpp:29
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...