Automatic Differentiation
 
Loading...
Searching...
No Matches
log_sum_exp.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_FWD_FUN_LOG_SUM_EXP_HPP
2#define STAN_MATH_FWD_FUN_LOG_SUM_EXP_HPP
3
11#include <cmath>
12#include <vector>
13
14namespace stan {
15namespace math {
16
17template <typename T>
18inline fvar<T> log_sum_exp(const fvar<T>& x1, const fvar<T>& x2) {
19 using std::exp;
20 return fvar<T>(log_sum_exp(x1.val_, x2.val_),
21 x1.d_ * inv_logit(-(x2.val_ - x1.val_))
22 + x2.d_ * inv_logit(-(x1.val_ - x2.val_)));
23}
24
25template <typename T>
26inline fvar<T> log_sum_exp(double x1, const fvar<T>& x2) {
27 using std::exp;
28 if (x1 == NEGATIVE_INFTY) {
29 return fvar<T>(x2.val_, x2.d_);
30 }
31 return fvar<T>(log_sum_exp(x1, x2.val_), x2.d_ * inv_logit(-(x1 - x2.val_)));
32}
33
34template <typename T>
35inline fvar<T> log_sum_exp(const fvar<T>& x1, double x2) {
36 return log_sum_exp(x2, x1);
37}
38
54template <typename T, require_container_st<is_fvar, T>* = nullptr>
55inline auto log_sum_exp(const T& x) {
57 to_ref(x), [&](const auto& v) {
58 using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
59 using mat_type = Eigen::Matrix<T_fvar_inner, -1, -1>;
60 mat_type vals = v.val();
61 mat_type exp_vals = vals.array().exp();
62
63 return fvar<T_fvar_inner>(
64 log_sum_exp(vals),
65 v.d().cwiseProduct(exp_vals).sum() / exp_vals.sum());
66 });
67}
68
69} // namespace math
70} // namespace stan
71#endif
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
Definition inv_logit.hpp:20
fvar< T > log_sum_exp(const fvar< T > &x1, const fvar< T > &x2)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Scalar val_
The value of this variable.
Definition fvar.hpp:49
Scalar d_
The tangent (derivative) of this variable.
Definition fvar.hpp:61
This template class represents scalars used in forward-mode automatic differentiation,...
Definition fvar.hpp:40