1#ifndef STAN_MATH_REV_CONSTRAINT_STOCHASTIC_ROW_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_STOCHASTIC_ROW_CONSTRAIN_HPP
25template <
typename T, require_rev_matrix_t<T>* =
nullptr>
28 const Eigen::Index N = y.rows();
29 const Eigen::Index M = y.cols();
32 return ret_type(x_val);
36 Eigen::Array<double, -1, 1> stick_len = Eigen::Array<double, -1, 1>::Ones(N);
37 for (Eigen::Index j = 0; j < M; ++j) {
38 double log_N_minus_k = std::log(M - j);
39 arena_z.col(j).array()
40 =
inv_logit((arena_y.col(j).val_op().array() - log_N_minus_k).matrix());
41 x_val.col(j).array() = stick_len * arena_z.col(j).array();
42 stick_len -= x_val.col(j).array();
44 x_val.col(M).array() = stick_len;
47 const Eigen::Index M = arena_y.cols();
48 auto arena_y_arr = arena_y.array();
49 auto arena_x_arr = arena_x.array();
50 auto arena_z_arr = arena_z.array();
51 auto stick_len_val_arr = arena_x_arr.col(M).val_op().
eval();
52 auto stick_len_adj_arr = arena_x_arr.col(M).adj_op().
eval();
53 for (Eigen::Index k = M; k-- > 0;) {
54 arena_x_arr.col(k).adj() -= stick_len_adj_arr;
55 stick_len_val_arr += arena_x_arr.col(k).val_op();
56 stick_len_adj_arr += arena_x_arr.col(k).adj_op() * arena_z_arr.col(k);
57 arena_y_arr.col(k).adj() += arena_x_arr.adj_op().col(k)
58 * stick_len_val_arr * arena_z_arr.col(k)
59 * (1.0 - arena_z_arr.col(k));
62 return ret_type(arena_x);
78template <
typename T, require_rev_matrix_t<T>* =
nullptr>
82 const Eigen::Index N = y.rows();
83 const Eigen::Index M = y.cols();
86 return ret_type(x_val);
90 Eigen::Array<double, -1, 1> stick_len = Eigen::Array<double, -1, 1>::Ones(N);
91 for (Eigen::Index j = 0; j < M; ++j) {
92 double log_N_minus_k = std::log(M - j);
93 auto adj_y_k = arena_y.col(j).val_op().array() - log_N_minus_k;
94 arena_z.col(j).array() =
inv_logit(adj_y_k);
95 x_val.col(j).array() = stick_len * arena_z.col(j).array();
98 stick_len -= x_val.col(j).array();
100 x_val.col(M).array() = stick_len;
103 const Eigen::Index M = arena_y.cols();
104 auto arena_y_arr = arena_y.array();
105 auto arena_x_arr = arena_x.array();
106 auto arena_z_arr = arena_z.array();
107 auto stick_len_val = arena_x_arr.col(M).val_op().
eval();
108 auto stick_len_adj = arena_x_arr.col(M).adj_op().
eval();
109 for (Eigen::Index k = M; k-- > 0;) {
110 const double log_N_minus_k = std::log(M - k);
111 arena_x_arr.col(k).adj() -= stick_len_adj;
112 stick_len_val += arena_x_arr.col(k).val_op();
113 stick_len_adj += lp.adj() / stick_len_val
114 + arena_x_arr.adj_op().col(k) * arena_z_arr.col(k);
115 auto adj_y_k = arena_y_arr.col(k).val_op() - log_N_minus_k;
116 arena_y_arr.col(k).adj()
118 + arena_x_arr.col(k).adj_op() * stick_len_val * arena_z_arr.col(k)
119 * (1.0 - arena_z_arr.col(k));
122 return ret_type(arena_x);
plain_type_t< Mat > stochastic_row_constrain(const Mat &y)
Return a row stochastic matrix.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
fvar< T > log(const fvar< T > &x)
fvar< T > log1p_exp(const fvar< T > &x)
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
typename plain_type< T >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...