1#ifndef STAN_MATH_FWD_FUN_LOG_SOFTMAX_HPP
2#define STAN_MATH_FWD_FUN_LOG_SOFTMAX_HPP
21template <
typename T, require_vector_st<is_fvar, T>* =
nullptr>
24 using T_alpha =
decltype(alpha);
26 using T_fvar_inner =
typename T_fvar::Scalar;
28 const Eigen::Ref<const plain_type_t<T_alpha>>& alpha_ref = alpha;
29 Eigen::Matrix<T_fvar_inner, -1, 1> alpha_t = alpha_ref.val();
30 Eigen::Matrix<T_fvar_inner, -1, 1> softmax_alpha_t =
softmax(alpha_t);
32 Eigen::Matrix<T_fvar, -1, 1> log_softmax_alpha(alpha.size());
34 log_softmax_alpha.d().setZero();
36 for (
int m = 0; m < alpha.size(); ++m) {
37 T_fvar_inner negative_alpha_m_d_times_softmax_alpha_t_m
38 = -alpha_ref.coeff(m).d_ * softmax_alpha_t(m);
39 for (
int k = 0; k < alpha.size(); ++k) {
41 log_softmax_alpha(k).d_
42 += alpha_ref.coeff(m).d_
43 + negative_alpha_m_d_times_softmax_alpha_t_m;
45 log_softmax_alpha(k).d_ += negative_alpha_m_d_times_softmax_alpha_t_m;
50 return log_softmax_alpha;
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
auto softmax(const ColVec &alpha)
auto log_softmax(const T &x)
Return the log softmax of the specified vector or container of vectors.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...