Automatic Differentiation
 
Loading...
Searching...
No Matches
simplex_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_SIMPLEX_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_SIMPLEX_CONSTRAIN_HPP
3
11#include <cmath>
12#include <tuple>
13#include <vector>
14
15namespace stan {
16namespace math {
17
30template <typename T, require_rev_col_vector_t<T>* = nullptr>
31inline auto simplex_constrain(const T& y) {
32 using ret_type = plain_type_t<T>;
33
34 size_t N = y.size();
35 arena_t<T> arena_y = y;
36 arena_t<Eigen::VectorXd> arena_z(N);
37 Eigen::VectorXd x_val(N + 1);
38
39 double stick_len(1.0);
40 for (Eigen::Index k = 0; k < N; ++k) {
41 double log_N_minus_k = std::log(N - k);
42 arena_z.coeffRef(k) = inv_logit(arena_y.val().coeff(k) - log_N_minus_k);
43 x_val.coeffRef(k) = stick_len * arena_z.coeff(k);
44 stick_len -= x_val(k);
45 }
46 x_val.coeffRef(N) = stick_len;
47
48 arena_t<ret_type> arena_x = x_val;
49
50 if (unlikely(N == 0)) {
51 return ret_type(arena_x);
52 }
53
54 reverse_pass_callback([arena_y, arena_x, arena_z]() mutable {
55 int N = arena_y.size();
56 double stick_len_val = arena_x.val().coeff(N);
57 double stick_len_adj = arena_x.adj().coeff(N);
58 for (Eigen::Index k = N; k-- > 0;) {
59 arena_x.adj().coeffRef(k) -= stick_len_adj;
60 stick_len_val += arena_x.val().coeff(k);
61 stick_len_adj += arena_x.adj().coeff(k) * arena_z.coeff(k);
62 double arena_z_adj = arena_x.adj().coeff(k) * stick_len_val;
63 arena_y.adj().coeffRef(k)
64 += arena_z_adj * arena_z.coeff(k) * (1.0 - arena_z.coeff(k));
65 }
66 });
67
68 return ret_type(arena_x);
69}
70
84template <typename T, require_rev_col_vector_t<T>* = nullptr>
85auto simplex_constrain(const T& y, scalar_type_t<T>& lp) {
86 using ret_type = plain_type_t<T>;
87
88 size_t N = y.size();
89 arena_t<T> arena_y = y;
90 arena_t<Eigen::VectorXd> arena_z(N);
91 Eigen::VectorXd x_val(N + 1);
92
93 double stick_len(1.0);
94 for (Eigen::Index k = 0; k < N; ++k) {
95 double log_N_minus_k = std::log(N - k);
96 double adj_y_k = arena_y.val().coeff(k) - log_N_minus_k;
97 arena_z.coeffRef(k) = inv_logit(adj_y_k);
98 x_val.coeffRef(k) = stick_len * arena_z.coeff(k);
99 lp += log(stick_len);
100 lp -= log1p_exp(-adj_y_k);
101 lp -= log1p_exp(adj_y_k);
102 stick_len -= x_val(k);
103 }
104 x_val.coeffRef(N) = stick_len;
105
106 arena_t<ret_type> arena_x = x_val;
107
108 if (unlikely(N == 0)) {
109 return ret_type(arena_x);
110 }
111
112 reverse_pass_callback([arena_y, arena_x, arena_z, lp]() mutable {
113 int N = arena_y.size();
114 double stick_len_val = arena_x.val().coeff(N);
115 double stick_len_adj = arena_x.adj().coeff(N);
116 for (Eigen::Index k = N; k-- > 0;) {
117 arena_x.adj().coeffRef(k) -= stick_len_adj;
118 stick_len_val += arena_x.val().coeff(k);
119 double log_N_minus_k = std::log(N - k);
120 double adj_y_k = arena_y.val().coeff(k) - log_N_minus_k;
121 arena_y.adj().coeffRef(k) -= lp.adj() * inv_logit(adj_y_k);
122 arena_y.adj().coeffRef(k) += lp.adj() * inv_logit(-adj_y_k);
123 stick_len_adj += lp.adj() / stick_len_val;
124 stick_len_adj += arena_x.adj().coeff(k) * arena_z.coeff(k);
125 double arena_z_adj = arena_x.adj().coeff(k) * stick_len_val;
126 arena_y.adj().coeffRef(k)
127 += arena_z_adj * arena_z.coeff(k) * (1.0 - arena_z.coeff(k));
128 }
129 });
130
131 return ret_type(arena_x);
132}
133
134} // namespace math
135} // namespace stan
136#endif
#define unlikely(x)
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
fvar< T > log1p_exp(const fvar< T > &x)
Definition log1p_exp.hpp:14
plain_type_t< Vec > simplex_constrain(const Vec &y)
Return the simplex corresponding to the specified free vector.
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
Definition inv_logit.hpp:20
typename plain_type< T >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...