Loading [MathJax]/extensions/TeX/AMSsymbols.js
Automatic Differentiation
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
simplex_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_SIMPLEX_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_SIMPLEX_CONSTRAIN_HPP
3
10#include <cmath>
11
12namespace stan {
13namespace math {
14
29template <typename T, require_rev_col_vector_t<T>* = nullptr>
30inline auto simplex_constrain(const T& y) {
31 using ret_type = plain_type_t<T>;
32
33 const auto N = y.size();
34 arena_t<T> arena_y = y;
35
36 arena_t<ret_type> arena_x = simplex_constrain(arena_y.val());
37
38 if (unlikely(N == 0)) {
39 return ret_type(arena_x);
40 }
41
42 reverse_pass_callback([arena_y, arena_x]() mutable {
43 auto&& res_val = arena_x.val();
44
45 // backprop for softmax
46 Eigen::VectorXd x_pre_softmax_adj = -res_val * arena_x.adj().dot(res_val)
47 + res_val.cwiseProduct(arena_x.adj());
48
49 // backprop for sum_to_zero_constrain
50 internal::sum_to_zero_vector_backprop(arena_y.adj(), x_pre_softmax_adj);
51 });
52
53 return ret_type(arena_x);
54}
55
70template <typename T, require_rev_col_vector_t<T>* = nullptr>
71inline auto simplex_constrain(const T& y, scalar_type_t<T>& lp) {
72 using ret_type = plain_type_t<T>;
73
74 const auto N = y.size();
75 arena_t<T> arena_y = y;
76
77 double lp_val = 0.0;
78 arena_t<ret_type> arena_x = simplex_constrain(arena_y.val(), lp_val);
79 lp += lp_val;
80
81 if (unlikely(N == 0)) {
82 return ret_type(arena_x);
83 }
84
85 reverse_pass_callback([arena_y, arena_x, lp]() mutable {
86 auto&& res_val = arena_x.val();
87
88 // backprop for log jacobian contribution to log density is equivalent to
89 // arena_x.adj().array() += lp.adj() / res_val.array();
90 // but is folded into the following to avoid needing to modify the adjoints
91 // in-place
92
93 // backprop for softmax
94 Eigen::VectorXd x_pre_softmax_adj
95 = -res_val * (arena_x.adj().dot(res_val) + res_val.size() * lp.adj())
96 + (res_val.cwiseProduct(arena_x.adj()).array() + lp.adj()).matrix();
97
98 // backprop for sum_to_zero_constrain
99 internal::sum_to_zero_vector_backprop(arena_y.adj(), x_pre_softmax_adj);
100 });
101
102 return ret_type(arena_x);
103}
104
105} // namespace math
106} // namespace stan
107#endif
#define unlikely(x)
void sum_to_zero_vector_backprop(T &&y_adj, const Eigen::VectorXd &z_adj)
The reverse pass backprop for the sum_to_zero_constrain on vectors.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
plain_type_t< Vec > simplex_constrain(const Vec &y)
Return the simplex corresponding to the specified free vector.
double dot(const std::vector< double > &x, const std::vector< double > &y)
Definition dot.hpp:11
typename plain_type< std::decay_t< T > >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...