1#ifndef STAN_MATH_REV_CONSTRAINT_STOCHASTIC_COLUMN_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_STOCHASTIC_COLUMN_CONSTRAIN_HPP
26template <
typename T, require_rev_matrix_t<T>* =
nullptr>
29 const Eigen::Index N = y.rows();
30 const Eigen::Index M = y.cols();
31 using eigen_mat_rowmajor
32 = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
35 return ret_type(x_val);
39 using arr_vec = Eigen::Array<double, 1, -1>;
40 arr_vec stick_len = arr_vec::Constant(M, 1.0);
41 for (Eigen::Index k = 0; k < N; ++k) {
42 const double log_N_minus_k = std::log(N - k);
44 =
inv_logit(arena_y.array().row(k).val_op() - log_N_minus_k).matrix();
45 x_val.row(k) = stick_len.array() * arena_z.array().row(k);
46 stick_len -= x_val.array().row(k);
48 x_val.row(N) = stick_len;
51 const Eigen::Index N = arena_y.rows();
52 auto arena_x_arr = arena_x.array();
53 auto arena_y_arr = arena_y.array();
54 auto arena_z_arr = arena_z.array();
55 auto stick_len_val = arena_x.array().
row(N).val().
eval();
56 auto stick_len_adj = arena_x.array().
row(N).adj().
eval();
57 for (Eigen::Index k = N; k-- > 0;) {
58 arena_x_arr.row(k).adj() -= stick_len_adj;
59 stick_len_val += arena_x_arr.row(k).val();
60 stick_len_adj += arena_x_arr.row(k).adj() * arena_z_arr.row(k);
61 auto arena_z_adj = arena_x_arr.row(k).adj() * stick_len_val;
62 arena_y_arr.row(k).adj()
63 += arena_z_adj * arena_z_arr.row(k) * (1.0 - arena_z_arr.row(k));
66 return ret_type(arena_x);
82template <
typename T, require_rev_matrix_t<T>* =
nullptr>
86 const Eigen::Index N = y.rows();
87 const Eigen::Index M = y.cols();
88 using eigen_mat_rowmajor
89 = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
92 return ret_type(x_val);
96 using arr_vec = Eigen::Array<double, 1, -1>;
97 arr_vec stick_len = arr_vec::Constant(M, 1.0);
99 for (Eigen::Index k = 0; k < N; ++k) {
100 double log_N_minus_k = std::log(N - k);
101 adj_y_k = arena_y.array().row(k).val() - log_N_minus_k;
102 arena_z.array().row(k) =
inv_logit(adj_y_k);
103 x_val.array().row(k) = stick_len * arena_z.array().row(k);
106 stick_len -= x_val.array().row(k);
108 x_val.array().row(N) = stick_len;
111 const Eigen::Index N = arena_y.rows();
112 auto arena_x_arr = arena_x.array();
113 auto arena_y_arr = arena_y.array();
114 auto arena_z_arr = arena_z.array();
115 auto stick_len_val = arena_x.array().
row(N).val().
eval();
116 auto stick_len_adj = arena_x.array().
row(N).adj().
eval();
117 for (Eigen::Index k = N; k-- > 0;) {
118 const double log_N_minus_k = std::log(N - k);
119 arena_x_arr.row(k).adj() -= stick_len_adj;
120 stick_len_val += arena_x_arr.row(k).val();
121 stick_len_adj += lp.adj() / stick_len_val
122 + arena_x_arr.row(k).adj() * arena_z_arr.row(k);
123 auto adj_y_k = arena_y_arr.row(k).val() - log_N_minus_k;
124 auto arena_z_adj = arena_x_arr.row(k).adj() * stick_len_val;
125 arena_y_arr.row(k).adj()
127 + arena_z_adj * arena_z_arr.row(k) * (1.0 - arena_z_arr.row(k));
130 return ret_type(arena_x);
auto row(T_x &&x, size_t j)
Return the specified row of the specified kernel generator expression using start-at-1 indexing.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
fvar< T > log(const fvar< T > &x)
plain_type_t< Mat > stochastic_column_constrain(const Mat &y)
Return a column stochastic matrix.
fvar< T > log1p_exp(const fvar< T > &x)
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
typename plain_type< T >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...