Automatic Differentiation
 
Loading...
Searching...
No Matches
log_sum_exp.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_LOG_SUM_EXP_HPP
2#define STAN_MATH_REV_FUN_LOG_SUM_EXP_HPP
3
12#include <cmath>
13#include <vector>
14
15namespace stan {
16namespace math {
17namespace internal {
18
20 public:
22 : op_vv_vari(log_sum_exp(avi->val_, bvi->val_), avi, bvi) {}
23 void chain() {
24 avi_->adj_ += adj_ * inv_logit(avi_->val_ - bvi_->val_);
25 bvi_->adj_ += adj_ * inv_logit(bvi_->val_ - avi_->val_);
26 }
27};
29 public:
30 log_sum_exp_vd_vari(vari* avi, double b)
31 : op_vd_vari(log_sum_exp(avi->val_, b), avi, b) {}
32 void chain() {
33 if (val_ == NEGATIVE_INFTY) {
34 avi_->adj_ += adj_;
35 } else {
36 avi_->adj_ += adj_ * inv_logit(avi_->val_ - bd_);
37 }
38 }
39};
40} // namespace internal
41
45inline var log_sum_exp(const var& a, const var& b) {
46 return var(new internal::log_sum_exp_vv_vari(a.vi_, b.vi_));
47}
51inline var log_sum_exp(const var& a, double b) {
52 return var(new internal::log_sum_exp_vd_vari(a.vi_, b));
53}
57inline var log_sum_exp(double a, const var& b) {
58 return var(new internal::log_sum_exp_vd_vari(b.vi_, a));
59}
60
67template <typename T, require_eigen_st<is_var, T>* = nullptr,
68 require_not_var_matrix_t<T>* = nullptr>
69inline var log_sum_exp(const T& v) {
70 arena_t<decltype(v)> arena_v = v;
71 arena_t<decltype(v.val())> arena_v_val = arena_v.val();
72 var res = log_sum_exp(arena_v_val);
73
74 reverse_pass_callback([arena_v, arena_v_val, res]() mutable {
75 arena_v.adj()
76 += res.adj() * (arena_v_val.array().val() - res.val()).exp().matrix();
77 });
78
79 return res;
80}
81
88template <typename T, require_var_matrix_t<T>* = nullptr>
89inline var log_sum_exp(const T& x) {
90 return make_callback_vari(log_sum_exp(x.val()), [x](const auto& res) mutable {
91 x.adj() += res.adj() * (x.val().array().val() - res.val()).exp().matrix();
92 });
93}
94
101template <typename T, require_std_vector_st<is_var, T>* = nullptr>
102inline auto log_sum_exp(const T& x) {
103 return apply_vector_unary<T>::reduce(
104 x, [](const auto& v) { return log_sum_exp(v); });
105}
106
107} // namespace math
108} // namespace stan
109#endif
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
var_value< double > var
Definition var.hpp:1187
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
Definition inv_logit.hpp:20
internal::callback_vari< plain_type_t< T >, F > * make_callback_vari(T &&value, F &&functor)
Creates a new vari with given value and a callback that implements the reverse pass (chain).
fvar< T > log_sum_exp(const fvar< T > &x1, const fvar< T > &x2)
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9