Automatic Differentiation
 
Loading...
Searching...
No Matches
softmax.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_SOFTMAX_HPP
2#define STAN_MATH_PRIM_FUN_SOFTMAX_HPP
3
8#include <cmath>
9
10namespace stan {
11namespace math {
12
46template <typename Vec,
47 require_eigen_vector_vt<std::is_arithmetic, Vec>* = nullptr>
48inline plain_type_t<Vec> softmax(Vec&& v) {
49 if (v.size() == 0) {
50 return v;
51 }
52 decltype(auto) v_ref = to_ref(std::forward<Vec>(v));
53 const auto theta = (v_ref.array() - v_ref.maxCoeff()).exp();
54 return (theta / theta.sum()).matrix();
55}
56
64template <typename T, require_std_vector_st<std::is_arithmetic, T>* = nullptr>
65inline auto softmax(T&& x) {
66 return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& v) {
67 return softmax(std::forward<decltype(v)>(v));
68 });
69}
70
71} // namespace math
72} // namespace stan
73
74#endif
auto softmax(T &&x)
Return the softmax of each vector in a container of fvar values.
Definition softmax.hpp:22
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
typename plain_type< std::decay_t< T > >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...