Loading [MathJax]/extensions/TeX/AMSsymbols.js
Automatic Differentiation
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
stochastic_row_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_STOCHASTIC_ROW_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_STOCHASTIC_ROW_CONSTRAIN_HPP
3
11#include <cmath>
12
13namespace stan {
14namespace math {
15
23template <typename T, require_rev_matrix_t<T>* = nullptr>
24inline auto stochastic_row_constrain(const T& y) {
25 using ret_type = plain_type_t<T>;
26
27 const auto N = y.rows();
28 const auto M = y.cols();
29 arena_t<T> arena_y = y;
30
31 arena_t<ret_type> arena_x = stochastic_row_constrain(arena_y.val_op());
32
33 if (unlikely(N == 0 || M == 0)) {
34 return arena_x;
35 }
36
37 reverse_pass_callback([arena_y, arena_x]() mutable {
38 const auto N = arena_y.rows();
39
40 auto&& x_val = arena_x.val_op();
41 auto&& x_adj = arena_x.adj_op();
42
43 Eigen::VectorXd x_pre_softmax_adj(x_val.cols());
44 for (Eigen::Index i = 0; i < N; ++i) {
45 // backprop for softmax
46 x_pre_softmax_adj.noalias()
47 = -x_val.row(i) * x_adj.row(i).dot(x_val.row(i))
48 + x_val.row(i).cwiseProduct(x_adj.row(i));
49
50 // backprop for sum_to_zero_constrain
51 internal::sum_to_zero_vector_backprop(arena_y.row(i).adj(),
52 x_pre_softmax_adj);
53 }
54 });
55
56 return arena_x;
57}
58
72template <typename T, require_rev_matrix_t<T>* = nullptr>
74 scalar_type_t<T>& lp) {
75 using ret_type = plain_type_t<T>;
76
77 const auto N = y.rows();
78 const auto M = y.cols();
79 arena_t<T> arena_y = y;
80
81 double lp_val = 0;
82 arena_t<ret_type> arena_x
83 = stochastic_row_constrain(arena_y.val_op(), lp_val);
84 lp += lp_val;
85
86 if (unlikely(N == 0 || M == 0)) {
87 return arena_x;
88 }
89
90 reverse_pass_callback([arena_y, arena_x, lp]() mutable {
91 const auto N = arena_y.rows();
92
93 auto&& x_val = arena_x.val_op();
94 auto&& x_adj = arena_x.adj_op();
95
96 const auto x_val_cols = x_val.cols();
97
98 Eigen::VectorXd x_pre_softmax_adj(x_val.cols());
99 for (Eigen::Index i = 0; i < N; ++i) {
100 // backprop for softmax
101 x_pre_softmax_adj.noalias()
102 = -x_val.row(i)
103 * (x_adj.row(i).dot(x_val.row(i)) + lp.adj() * x_val_cols)
104 + (x_val.row(i).cwiseProduct(x_adj.row(i)).array() + lp.adj())
105 .matrix();
106
107 // backprop for sum_to_zero_constrain
108 internal::sum_to_zero_vector_backprop(arena_y.row(i).adj(),
109 x_pre_softmax_adj);
110 }
111 });
112
113 return arena_x;
114}
115
116} // namespace math
117} // namespace stan
118#endif
#define unlikely(x)
void sum_to_zero_vector_backprop(T &&y_adj, const Eigen::VectorXd &z_adj)
The reverse pass backprop for the sum_to_zero_constrain on vectors.
plain_type_t< Mat > stochastic_row_constrain(const Mat &y)
Return a row stochastic matrix.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
typename plain_type< std::decay_t< T > >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...