Loading [MathJax]/extensions/tex2jax.js
Automatic Differentiation
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
stochastic_column_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_STOCHASTIC_COLUMN_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_STOCHASTIC_COLUMN_CONSTRAIN_HPP
3
11#include <cmath>
12
13namespace stan {
14namespace math {
15
25template <typename T, require_rev_matrix_t<T>* = nullptr>
27 using ret_type = plain_type_t<T>;
28
29 const auto N = y.rows();
30 const auto M = y.cols();
31 arena_t<T> arena_y = y;
32
33 arena_t<ret_type> arena_x = stochastic_column_constrain(arena_y.val_op());
34
35 if (unlikely(N == 0 || M == 0)) {
36 return arena_x;
37 }
38
39 reverse_pass_callback([arena_y, arena_x]() mutable {
40 const auto M = arena_y.cols();
41
42 auto&& x_val = arena_x.val_op();
43 auto&& x_adj = arena_x.adj_op();
44
45 Eigen::VectorXd x_pre_softmax_adj(x_val.rows());
46 for (Eigen::Index i = 0; i < M; ++i) {
47 // backprop for softmax
48 x_pre_softmax_adj.noalias()
49 = -x_val.col(i) * x_adj.col(i).dot(x_val.col(i))
50 + x_val.col(i).cwiseProduct(x_adj.col(i));
51
52 // backprop for sum_to_zero_constrain
53 internal::sum_to_zero_vector_backprop(arena_y.col(i).adj(),
54 x_pre_softmax_adj);
55 }
56 });
57
58 return arena_x;
59}
60
74template <typename T, require_rev_matrix_t<T>* = nullptr>
76 scalar_type_t<T>& lp) {
77 using ret_type = plain_type_t<T>;
78
79 const auto N = y.rows();
80 const auto M = y.cols();
81 arena_t<T> arena_y = y;
82
83 double lp_val = 0;
84 arena_t<ret_type> arena_x
85 = stochastic_column_constrain(arena_y.val_op(), lp_val);
86 lp += lp_val;
87
88 if (unlikely(N == 0 || M == 0)) {
89 return arena_x;
90 }
91
92 reverse_pass_callback([arena_y, arena_x, lp]() mutable {
93 const auto M = arena_y.cols();
94
95 auto&& x_val = arena_x.val_op();
96 auto&& x_adj = arena_x.adj_op();
97
98 const auto x_val_rows = x_val.rows();
99
100 Eigen::VectorXd x_pre_softmax_adj(x_val.rows());
101 for (Eigen::Index i = 0; i < M; ++i) {
102 // backprop for softmax
103 x_pre_softmax_adj.noalias()
104 = -x_val.col(i)
105 * (x_adj.col(i).dot(x_val.col(i)) + lp.adj() * x_val_rows)
106 + (x_val.col(i).cwiseProduct(x_adj.col(i)).array() + lp.adj())
107 .matrix();
108
109 // backprop for sum_to_zero_constrain
110 internal::sum_to_zero_vector_backprop(arena_y.col(i).adj(),
111 x_pre_softmax_adj);
112 }
113 });
114
115 return arena_x;
116}
117
118} // namespace math
119} // namespace stan
120#endif
#define unlikely(x)
void sum_to_zero_vector_backprop(T &&y_adj, const Eigen::VectorXd &z_adj)
The reverse pass backprop for the sum_to_zero_constrain on vectors.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
plain_type_t< Mat > stochastic_column_constrain(const Mat &y)
Return a column stochastic matrix.
typename plain_type< std::decay_t< T > >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...