Automatic Differentiation
 
Loading...
Searching...
No Matches
log_sum_exp.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_FWD_FUN_LOG_SUM_EXP_HPP
2#define STAN_MATH_FWD_FUN_LOG_SUM_EXP_HPP
3
11#include <cmath>
12#include <vector>
13
14namespace stan {
15namespace math {
16
17template <typename T>
18inline fvar<T> log_sum_exp(const fvar<T>& x1, const fvar<T>& x2) {
19 return fvar<T>(log_sum_exp(x1.val_, x2.val_),
20 x1.d_ * inv_logit(-(x2.val_ - x1.val_))
21 + x2.d_ * inv_logit(-(x1.val_ - x2.val_)));
22}
23
24template <typename T>
25inline fvar<T> log_sum_exp(double x1, const fvar<T>& x2) {
26 if (x1 == NEGATIVE_INFTY) {
27 return fvar<T>(x2.val_, x2.d_);
28 }
29 return fvar<T>(log_sum_exp(x1, x2.val_), x2.d_ * inv_logit(-(x1 - x2.val_)));
30}
31
32template <typename T>
33inline fvar<T> log_sum_exp(const fvar<T>& x1, double x2) {
34 return log_sum_exp(x2, x1);
35}
36
52template <typename T, require_container_st<is_fvar, T>* = nullptr>
53inline auto log_sum_exp(const T& x) {
55 to_ref(x), [&](const auto& v) {
56 using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
57 using mat_type = Eigen::Matrix<T_fvar_inner, -1, -1>;
58 mat_type vals = v.val();
59 mat_type exp_vals = vals.array().exp();
60
61 return fvar<T_fvar_inner>(
62 log_sum_exp(vals),
63 v.d().cwiseProduct(exp_vals).sum() / exp_vals.sum());
64 });
65}
66
67} // namespace math
68} // namespace stan
69#endif
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
Definition inv_logit.hpp:20
fvar< T > log_sum_exp(const fvar< T > &x1, const fvar< T > &x2)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Scalar val_
The value of this variable.
Definition fvar.hpp:49
Scalar d_
The tangent (derivative) of this variable.
Definition fvar.hpp:61
This template class represents scalars used in forward-mode automatic differentiation,...
Definition fvar.hpp:40