Automatic Differentiation
 
Loading...
Searching...
No Matches
stochastic_column_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_STOCHASTIC_COLUMN_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_STOCHASTIC_COLUMN_CONSTRAIN_HPP
3
11#include <cmath>
12#include <tuple>
13#include <vector>
14
15namespace stan {
16namespace math {
17
26template <typename T, require_rev_matrix_t<T>* = nullptr>
28 using ret_type = plain_type_t<T>;
29 const Eigen::Index N = y.rows();
30 const Eigen::Index M = y.cols();
31 using eigen_mat_rowmajor
32 = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
33 arena_t<eigen_mat_rowmajor> x_val(N + 1, M);
34 if (unlikely(N == 0 || M == 0)) {
35 return ret_type(x_val);
36 }
38 arena_t<eigen_mat_rowmajor> arena_z(N, M);
39 using arr_vec = Eigen::Array<double, 1, -1>;
40 arr_vec stick_len = arr_vec::Constant(M, 1.0);
41 for (Eigen::Index k = 0; k < N; ++k) {
42 const double log_N_minus_k = std::log(N - k);
43 arena_z.row(k)
44 = inv_logit(arena_y.array().row(k).val_op() - log_N_minus_k).matrix();
45 x_val.row(k) = stick_len.array() * arena_z.array().row(k);
46 stick_len -= x_val.array().row(k);
47 }
48 x_val.row(N) = stick_len;
49 arena_t<ret_type> arena_x = x_val;
50 reverse_pass_callback([arena_y, arena_x, arena_z]() mutable {
51 const Eigen::Index N = arena_y.rows();
52 auto arena_x_arr = arena_x.array();
53 auto arena_y_arr = arena_y.array();
54 auto arena_z_arr = arena_z.array();
55 auto stick_len_val = arena_x.array().row(N).val().eval();
56 auto stick_len_adj = arena_x.array().row(N).adj().eval();
57 for (Eigen::Index k = N; k-- > 0;) {
58 arena_x_arr.row(k).adj() -= stick_len_adj;
59 stick_len_val += arena_x_arr.row(k).val();
60 stick_len_adj += arena_x_arr.row(k).adj() * arena_z_arr.row(k);
61 auto arena_z_adj = arena_x_arr.row(k).adj() * stick_len_val;
62 arena_y_arr.row(k).adj()
63 += arena_z_adj * arena_z_arr.row(k) * (1.0 - arena_z_arr.row(k));
64 }
65 });
66 return ret_type(arena_x);
67}
68
82template <typename T, require_rev_matrix_t<T>* = nullptr>
84 scalar_type_t<T>& lp) {
85 using ret_type = plain_type_t<T>;
86 const Eigen::Index N = y.rows();
87 const Eigen::Index M = y.cols();
88 using eigen_mat_rowmajor
89 = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
90 arena_t<eigen_mat_rowmajor> x_val(N + 1, M);
91 if (unlikely(N == 0 || M == 0)) {
92 return ret_type(x_val);
93 }
95 arena_t<eigen_mat_rowmajor> arena_z(N, M);
96 using arr_vec = Eigen::Array<double, 1, -1>;
97 arr_vec stick_len = arr_vec::Constant(M, 1.0);
98 arr_vec adj_y_k(N);
99 for (Eigen::Index k = 0; k < N; ++k) {
100 double log_N_minus_k = std::log(N - k);
101 adj_y_k = arena_y.array().row(k).val() - log_N_minus_k;
102 arena_z.array().row(k) = inv_logit(adj_y_k);
103 x_val.array().row(k) = stick_len * arena_z.array().row(k);
104 lp += sum(log(stick_len)) - sum(log1p_exp(-adj_y_k))
105 - sum(log1p_exp(adj_y_k));
106 stick_len -= x_val.array().row(k);
107 }
108 x_val.array().row(N) = stick_len;
109 arena_t<ret_type> arena_x = x_val;
110 reverse_pass_callback([arena_y, arena_x, arena_z, lp]() mutable {
111 const Eigen::Index N = arena_y.rows();
112 auto arena_x_arr = arena_x.array();
113 auto arena_y_arr = arena_y.array();
114 auto arena_z_arr = arena_z.array();
115 auto stick_len_val = arena_x.array().row(N).val().eval();
116 auto stick_len_adj = arena_x.array().row(N).adj().eval();
117 for (Eigen::Index k = N; k-- > 0;) {
118 const double log_N_minus_k = std::log(N - k);
119 arena_x_arr.row(k).adj() -= stick_len_adj;
120 stick_len_val += arena_x_arr.row(k).val();
121 stick_len_adj += lp.adj() / stick_len_val
122 + arena_x_arr.row(k).adj() * arena_z_arr.row(k);
123 auto adj_y_k = arena_y_arr.row(k).val() - log_N_minus_k;
124 auto arena_z_adj = arena_x_arr.row(k).adj() * stick_len_val;
125 arena_y_arr.row(k).adj()
126 += -(lp.adj() * inv_logit(adj_y_k)) + lp.adj() * inv_logit(-adj_y_k)
127 + arena_z_adj * arena_z_arr.row(k) * (1.0 - arena_z_arr.row(k));
128 }
129 });
130 return ret_type(arena_x);
131}
132
133} // namespace math
134} // namespace stan
135#endif
#define unlikely(x)
auto row(T_x &&x, size_t j)
Return the specified row of the specified kernel generator expression using start-at-1 indexing.
Definition row.hpp:23
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
plain_type_t< Mat > stochastic_column_constrain(const Mat &y)
Return a column stochastic matrix.
fvar< T > log1p_exp(const fvar< T > &x)
Definition log1p_exp.hpp:14
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:23
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
Definition inv_logit.hpp:20
typename plain_type< T >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...