Automatic Differentiation
 
Loading...
Searching...
No Matches
softmax.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_SOFTMAX_HPP
2#define STAN_MATH_PRIM_FUN_SOFTMAX_HPP
3
9
10namespace stan {
11namespace math {
12
47template <typename Container, require_st_arithmetic<Container>* = nullptr,
48 require_container_t<Container>* = nullptr,
49 require_not_t<bool_constant<
50 is_eigen<std::decay_t<Container>>::value
51 && !is_eigen_vector<std::decay_t<Container>>::value>>* = nullptr>
52inline auto softmax(Container&& x) {
53 return make_holder(
54 [](auto&& a) {
56 std::forward<decltype(a)>(a),
57 [](auto&& v) -> plain_type_t<decltype(v)> {
58 if (v.size() == 0) {
59 return v;
60 }
61 const auto theta = (v.array() - v.maxCoeff()).exp();
62 return (theta / theta.sum()).matrix();
63 });
64 },
65 to_ref(std::forward<Container>(x)));
66}
67
68} // namespace math
69} // namespace stan
70#endif
auto softmax(T &&x)
Return the softmax of each vector in a container of fvar values.
Definition softmax.hpp:23
auto make_holder(F &&func, Args &&... args)
Calls given function with given arguments.
Definition holder.hpp:481
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
Definition apply.hpp:51
typename plain_type< std::decay_t< T > >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...