Automatic Differentiation
 
Loading...
Searching...
No Matches
log_softmax.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_LOG_SOFTMAX_HPP
2#define STAN_MATH_REV_FUN_LOG_SOFTMAX_HPP
3
13#include <cmath>
14#include <vector>
15
16namespace stan {
17namespace math {
18
19namespace internal {
20
21class log_softmax_elt_vari : public vari {
22 private:
24 const double* softmax_alpha_;
25 const int size_; // array sizes
26 const int idx_; // in in softmax output
27
28 public:
29 log_softmax_elt_vari(double val, vari** alpha, const double* softmax_alpha,
30 int size, int idx)
31 : vari(val),
32 alpha_(alpha),
33 softmax_alpha_(softmax_alpha),
34 size_(size),
35 idx_(idx) {}
36 void chain() {
37 for (int m = 0; m < size_; ++m) {
38 if (m == idx_) {
39 alpha_[m]->adj_ += adj_ * (1 - softmax_alpha_[m]);
40 } else {
41 alpha_[m]->adj_ -= adj_ * softmax_alpha_[m];
42 }
43 }
44 }
45};
46} // namespace internal
47
56template <typename T, require_eigen_st<is_var, T>* = nullptr>
57auto log_softmax(const T& x) {
58 const int a_size = x.size();
59
60 check_nonzero_size("log_softmax", "x", x);
61
62 const auto& x_ref = to_ref(x);
63
64 vari** x_vi_array
66 Eigen::Map<vector_vi>(x_vi_array, a_size) = x_ref.vi();
67
68 vector_d x_d = x_ref.val();
69
70 // fold logic of math::softmax() and math::log_softmax()
71 // to save computations
72
73 vector_d diff = (x_d.array() - x_d.maxCoeff());
74 vector_d softmax_x_d = diff.array().exp();
75 double sum = softmax_x_d.sum();
76 vector_d log_softmax_x_d = diff.array() - std::log(sum);
77
78 // end fold
79 double* softmax_x_d_array
81 Eigen::Map<vector_d>(softmax_x_d_array, a_size) = softmax_x_d.array() / sum;
82
83 plain_type_t<T> log_softmax_x(a_size);
84 for (int k = 0; k < a_size; ++k) {
85 log_softmax_x(k) = var(new internal::log_softmax_elt_vari(
86 log_softmax_x_d[k], x_vi_array, softmax_x_d_array, a_size, k));
87 }
88 return log_softmax_x;
89}
90
99template <typename T, require_var_matrix_t<T>* = nullptr>
100inline auto log_softmax(const T& x) {
101 check_nonzero_size("log_softmax", "x", x);
102
103 const auto& theta = (x.val().array() - x.val().maxCoeff()).eval();
104
105 return make_callback_var(
106 (theta.array() - log(theta.exp().sum())).matrix(),
107 [x](const auto& res) mutable {
108 x.adj().noalias()
109 += res.adj() - (res.adj().sum() * res.val().array().exp()).matrix();
110 });
111}
112
122template <typename T, require_std_vector_st<is_var, T>* = nullptr>
123inline auto log_softmax(const T& x) {
124 return apply_vector_unary<T>::apply(
125 x, [](const auto& alpha) { return log_softmax(alpha); });
126}
127
128} // namespace math
129} // namespace stan
130#endif
log_softmax_elt_vari(double val, vari **alpha, const double *softmax_alpha, int size, int idx)
T * alloc_array(size_t n)
Allocate an array on the arena of the specified size to hold values of the specified template paramet...
size_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:18
Eigen::Matrix< double, Eigen::Dynamic, 1 > vector_d
Type for (column) vector of double values.
Definition typedefs.hpp:24
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
vari_value< double > vari
Definition vari.hpp:197
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
var_value< double > var
Definition var.hpp:1187
void check_nonzero_size(const char *function, const char *name, const T_y &y)
Check if the specified matrix/vector is of non-zero size.
auto log_softmax(const T &x)
Return the log softmax of the specified vector or container of vectors.
typename plain_type< T >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
static thread_local AutodiffStackStorage * instance_