Automatic Differentiation
 
Loading...
Searching...
No Matches
log_sum_exp.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_LOG_SUM_EXP_HPP
2#define STAN_MATH_OPENCL_PRIM_LOG_SUM_EXP_HPP
3#ifdef STAN_OPENCL
10
11namespace stan {
12namespace math {
13
29template <typename T,
30 require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
31inline double log_sum_exp(const T& a) {
32 using std::log;
33 if (a.size() == 0) {
34 return NEGATIVE_INFTY;
35 }
37 double a_max = from_matrix_cl(max_2d(a)).maxCoeff();
38 if (!std::isfinite(a_max)) {
39 return a_max;
40 }
41 return a_max + log(sum(exp(a - a_max)));
42 } else {
43 matrix_cl<double> a_eval;
44 matrix_cl<double> a_max_cl;
45 results(a_eval, a_max_cl) = expressions(a, max_2d(a));
46 double a_max = from_matrix_cl(a_max_cl).maxCoeff();
47 if (!std::isfinite(a_max)) {
48 return a_max;
49 }
50 return a_max + log(sum(exp(a_eval - a_max)));
51 }
52}
53
54} // namespace math
55} // namespace stan
56
57#endif
58#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
results_cl< T_results... > results(T_results &&... results)
Deduces types for constructing results_cl object.
expressions_cl< T_expressions... > expressions(T_expressions &&... expressions)
Deduces types for constructing expressions_cl object.
auto max_2d(T &&a)
Two dimensional max - reduction of a kernel generator expression.
auto from_matrix_cl(const T &src)
Copies the source matrix that is stored on the OpenCL device to the destination Eigen matrix.
Definition copy.hpp:61
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
fvar< T > log_sum_exp(const fvar< T > &x1, const fvar< T > &x2)
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9