1#ifndef STAN_MATH_PRIM_FUN_SOFTMAX_HPP
2#define STAN_MATH_PRIM_FUN_SOFTMAX_HPP
46template <
typename Vec,
47 require_eigen_vector_vt<std::is_arithmetic, Vec>* =
nullptr>
52 decltype(
auto) v_ref =
to_ref(std::forward<Vec>(v));
53 const auto theta = (v_ref.array() - v_ref.maxCoeff()).
exp();
54 return (theta / theta.sum()).matrix();
64template <
typename T, require_std_vector_st<std::is_arithmetic, T>* =
nullptr>
66 return apply_vector_unary<T>::apply(std::forward<T>(x), [](
auto&& v) {
67 return softmax(std::forward<
decltype(v)>(v));
auto softmax(T &&x)
Return the softmax of each vector in a container of fvar values.
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
fvar< T > exp(const fvar< T > &x)
typename plain_type< std::decay_t< T > >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...