1#ifndef STAN_MATH_REV_CONSTRAINT_SIMPLEX_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_SIMPLEX_CONSTRAIN_HPP
30template <
typename T, require_rev_col_vector_t<T>* =
nullptr>
37 Eigen::VectorXd x_val(N + 1);
39 double stick_len(1.0);
40 for (Eigen::Index k = 0; k < N; ++k) {
41 double log_N_minus_k = std::log(N - k);
42 arena_z.coeffRef(k) =
inv_logit(arena_y.val().coeff(k) - log_N_minus_k);
43 x_val.coeffRef(k) = stick_len * arena_z.coeff(k);
44 stick_len -= x_val(k);
46 x_val.coeffRef(N) = stick_len;
51 return ret_type(arena_x);
55 int N = arena_y.size();
56 double stick_len_val = arena_x.val().coeff(N);
57 double stick_len_adj = arena_x.adj().coeff(N);
58 for (Eigen::Index k = N; k-- > 0;) {
59 arena_x.adj().coeffRef(k) -= stick_len_adj;
60 stick_len_val += arena_x.val().coeff(k);
61 stick_len_adj += arena_x.adj().coeff(k) * arena_z.coeff(k);
62 double arena_z_adj = arena_x.adj().coeff(k) * stick_len_val;
63 arena_y.adj().coeffRef(k)
64 += arena_z_adj * arena_z.coeff(k) * (1.0 - arena_z.coeff(k));
68 return ret_type(arena_x);
84template <
typename T, require_rev_col_vector_t<T>* =
nullptr>
91 Eigen::VectorXd x_val(N + 1);
93 double stick_len(1.0);
94 for (Eigen::Index k = 0; k < N; ++k) {
95 double log_N_minus_k = std::log(N - k);
96 double adj_y_k = arena_y.val().coeff(k) - log_N_minus_k;
98 x_val.coeffRef(k) = stick_len * arena_z.coeff(k);
102 stick_len -= x_val(k);
104 x_val.coeffRef(N) = stick_len;
109 return ret_type(arena_x);
113 int N = arena_y.size();
114 double stick_len_val = arena_x.val().coeff(N);
115 double stick_len_adj = arena_x.adj().coeff(N);
116 for (Eigen::Index k = N; k-- > 0;) {
117 arena_x.adj().coeffRef(k) -= stick_len_adj;
118 stick_len_val += arena_x.val().coeff(k);
119 double log_N_minus_k = std::log(N - k);
120 double adj_y_k = arena_y.val().coeff(k) - log_N_minus_k;
121 arena_y.adj().coeffRef(k) -= lp.adj() *
inv_logit(adj_y_k);
122 arena_y.adj().coeffRef(k) += lp.adj() *
inv_logit(-adj_y_k);
123 stick_len_adj += lp.adj() / stick_len_val;
124 stick_len_adj += arena_x.adj().coeff(k) * arena_z.coeff(k);
125 double arena_z_adj = arena_x.adj().coeff(k) * stick_len_val;
126 arena_y.adj().coeffRef(k)
127 += arena_z_adj * arena_z.coeff(k) * (1.0 - arena_z.coeff(k));
131 return ret_type(arena_x);
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
fvar< T > log(const fvar< T > &x)
fvar< T > log1p_exp(const fvar< T > &x)
plain_type_t< Vec > simplex_constrain(const Vec &y)
Return the simplex corresponding to the specified free vector.
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
typename plain_type< T >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...