Automatic Differentiation
 
Loading...
Searching...
No Matches
log_sum_exp.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_LOG_SUM_EXP_HPP
2#define STAN_MATH_PRIM_FUN_LOG_SUM_EXP_HPP
3
9#include <cmath>
10#include <vector>
11
12namespace stan {
13namespace math {
14
51template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr,
52 require_all_stan_scalar_t<T1, T2>* = nullptr>
53inline return_type_t<T1, T2> log_sum_exp(const T2& a, const T1& b) {
54 if (a == NEGATIVE_INFTY) {
55 return b;
56 }
57 if (a == INFTY && b == INFTY) {
58 return INFTY;
59 }
60 if (a > b) {
61 return a + log1p_exp(b - a);
62 }
63 return b + log1p_exp(a - b);
64}
65
81template <typename T, require_container_st<std::is_arithmetic, T>* = nullptr>
82inline auto log_sum_exp(T&& x) {
83 return apply_vector_unary<T>::reduce(std::forward<T>(x), [](auto&& v) {
84 if (v.size() == 0) {
85 return NEGATIVE_INFTY;
86 }
87 const auto& v_ref = to_ref(v);
88 const double max = v_ref.maxCoeff();
89 if (!std::isfinite(max)) {
90 return max;
91 }
92 return max + std::log((v_ref.array() - max).exp().sum());
93 });
94}
95
106template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
107inline auto log_sum_exp(T1&& a, T2&& b) {
108 return apply_scalar_binary(
109 [](auto&& c, auto&& d) {
110 return log_sum_exp(std::forward<decltype(c)>(c),
111 std::forward<decltype(d)>(d));
112 },
113 std::forward<T1>(a), std::forward<T2>(b));
114}
115
116} // namespace math
117} // namespace stan
118
119#endif
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
auto apply_scalar_binary(F &&f, T1 &&x, T2 &&y)
Base template function for vectorization of binary scalar functions defined by applying a functor to ...
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
auto max(T1 x, T2 y)
Returns the maximum value of the two specified scalar arguments.
Definition max.hpp:25
fvar< T > log1p_exp(const fvar< T > &x)
Definition log1p_exp.hpp:14
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
static constexpr double INFTY
Positive infinity.
Definition constants.hpp:46
fvar< T > log_sum_exp(const fvar< T > &x1, const fvar< T > &x2)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...