Automatic Differentiation
 
Loading...
Searching...
No Matches
stochastic_row_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_STOCHASTIC_ROW_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_STOCHASTIC_ROW_CONSTRAIN_HPP
3
11#include <cmath>
12#include <tuple>
13#include <vector>
14
15namespace stan {
16namespace math {
17
25template <typename T, require_rev_matrix_t<T>* = nullptr>
27 using ret_type = plain_type_t<T>;
28 const Eigen::Index N = y.rows();
29 const Eigen::Index M = y.cols();
30 arena_t<Eigen::MatrixXd> x_val(N, M + 1);
31 if (unlikely(N == 0 || M == 0)) {
32 return ret_type(x_val);
33 }
34 arena_t<T> arena_y = y;
35 arena_t<Eigen::MatrixXd> arena_z(N, M);
36 Eigen::Array<double, -1, 1> stick_len = Eigen::Array<double, -1, 1>::Ones(N);
37 for (Eigen::Index j = 0; j < M; ++j) {
38 double log_N_minus_k = std::log(M - j);
39 arena_z.col(j).array()
40 = inv_logit((arena_y.col(j).val_op().array() - log_N_minus_k).matrix());
41 x_val.col(j).array() = stick_len * arena_z.col(j).array();
42 stick_len -= x_val.col(j).array();
43 }
44 x_val.col(M).array() = stick_len;
45 arena_t<ret_type> arena_x = x_val;
46 reverse_pass_callback([arena_y, arena_x, arena_z]() mutable {
47 const Eigen::Index M = arena_y.cols();
48 auto arena_y_arr = arena_y.array();
49 auto arena_x_arr = arena_x.array();
50 auto arena_z_arr = arena_z.array();
51 auto stick_len_val_arr = arena_x_arr.col(M).val_op().eval();
52 auto stick_len_adj_arr = arena_x_arr.col(M).adj_op().eval();
53 for (Eigen::Index k = M; k-- > 0;) {
54 arena_x_arr.col(k).adj() -= stick_len_adj_arr;
55 stick_len_val_arr += arena_x_arr.col(k).val_op();
56 stick_len_adj_arr += arena_x_arr.col(k).adj_op() * arena_z_arr.col(k);
57 arena_y_arr.col(k).adj() += arena_x_arr.adj_op().col(k)
58 * stick_len_val_arr * arena_z_arr.col(k)
59 * (1.0 - arena_z_arr.col(k));
60 }
61 });
62 return ret_type(arena_x);
63}
64
78template <typename T, require_rev_matrix_t<T>* = nullptr>
80 scalar_type_t<T>& lp) {
81 using ret_type = plain_type_t<T>;
82 const Eigen::Index N = y.rows();
83 const Eigen::Index M = y.cols();
84 arena_t<Eigen::MatrixXd> x_val(N, M + 1);
85 if (unlikely(N == 0 || M == 0)) {
86 return ret_type(x_val);
87 }
88 arena_t<T> arena_y = y;
89 arena_t<Eigen::MatrixXd> arena_z(N, M);
90 Eigen::Array<double, -1, 1> stick_len = Eigen::Array<double, -1, 1>::Ones(N);
91 for (Eigen::Index j = 0; j < M; ++j) {
92 double log_N_minus_k = std::log(M - j);
93 auto adj_y_k = arena_y.col(j).val_op().array() - log_N_minus_k;
94 arena_z.col(j).array() = inv_logit(adj_y_k);
95 x_val.col(j).array() = stick_len * arena_z.col(j).array();
96 lp += sum(log(stick_len)) - sum(log1p_exp(-adj_y_k))
97 - sum(log1p_exp(adj_y_k));
98 stick_len -= x_val.col(j).array();
99 }
100 x_val.col(M).array() = stick_len;
101 arena_t<ret_type> arena_x = x_val;
102 reverse_pass_callback([arena_y, arena_x, arena_z, lp]() mutable {
103 const Eigen::Index M = arena_y.cols();
104 auto arena_y_arr = arena_y.array();
105 auto arena_x_arr = arena_x.array();
106 auto arena_z_arr = arena_z.array();
107 auto stick_len_val = arena_x_arr.col(M).val_op().eval();
108 auto stick_len_adj = arena_x_arr.col(M).adj_op().eval();
109 for (Eigen::Index k = M; k-- > 0;) {
110 const double log_N_minus_k = std::log(M - k);
111 arena_x_arr.col(k).adj() -= stick_len_adj;
112 stick_len_val += arena_x_arr.col(k).val_op();
113 stick_len_adj += lp.adj() / stick_len_val
114 + arena_x_arr.adj_op().col(k) * arena_z_arr.col(k);
115 auto adj_y_k = arena_y_arr.col(k).val_op() - log_N_minus_k;
116 arena_y_arr.col(k).adj()
117 += -(lp.adj() * inv_logit(adj_y_k)) + lp.adj() * inv_logit(-adj_y_k)
118 + arena_x_arr.col(k).adj_op() * stick_len_val * arena_z_arr.col(k)
119 * (1.0 - arena_z_arr.col(k));
120 }
121 });
122 return ret_type(arena_x);
123}
124
125} // namespace math
126} // namespace stan
127#endif
#define unlikely(x)
plain_type_t< Mat > stochastic_row_constrain(const Mat &y)
Return a row stochastic matrix.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
fvar< T > log1p_exp(const fvar< T > &x)
Definition log1p_exp.hpp:13
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
Definition inv_logit.hpp:20
typename plain_type< T >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...