Automatic Differentiation
 
Loading...
Searching...
No Matches
sum_to_zero_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_SUM_TO_ZERO_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_SUM_TO_ZERO_CONSTRAIN_HPP
3
11#include <cmath>
12
13namespace stan {
14namespace math {
15
16namespace internal {
17
27template <typename T>
28inline void sum_to_zero_vector_backprop(T&& y_adj,
29 const Eigen::VectorXd& z_adj) {
30 const auto N = y_adj.size();
31
32 double sum_u_adj = 0;
33 for (int i = 0; i < N; ++i) {
34 double n = static_cast<double>(i + 1);
35
36 // adjoint of the reverse cumulative sum computed in the forward mode
37 sum_u_adj += z_adj.coeff(i);
38
39 // adjoint of the offset subtraction
40 double v_adj = -z_adj.coeff(i + 1) * n;
41
42 double w_adj = v_adj + sum_u_adj;
43
44 y_adj.coeffRef(i) += w_adj / sqrt(n * (n + 1));
45 }
46}
47
48} // namespace internal
49
74template <typename T, require_rev_col_vector_t<T>* = nullptr>
75inline auto sum_to_zero_constrain(T&& y) {
76 using ret_type = plain_type_t<T>;
77 if (unlikely(y.size() == 0)) {
78 return arena_t<ret_type>(Eigen::VectorXd{{0}});
79 }
80 auto arena_y = to_arena(std::forward<T>(y));
81 arena_t<ret_type> arena_z = sum_to_zero_constrain(arena_y.val());
82
83 reverse_pass_callback([arena_y, arena_z]() mutable {
84 internal::sum_to_zero_vector_backprop(arena_y.adj(), arena_z.adj());
85 });
86
87 return arena_z;
88}
89
100template <typename T, require_rev_matrix_t<T>* = nullptr,
101 require_not_t<is_rev_vector<T>>* = nullptr>
102inline auto sum_to_zero_constrain(T&& x) {
103 using ret_type = plain_type_t<T>;
104 if (unlikely(x.size() == 0)) {
105 return arena_t<ret_type>(Eigen::MatrixXd{{0}});
106 }
107 auto arena_x = to_arena(std::forward<T>(x));
108 arena_t<ret_type> arena_z = sum_to_zero_constrain(arena_x.val());
109
110 reverse_pass_callback([arena_x, arena_z]() mutable {
111 const auto Nf = arena_x.val().rows();
112 const auto Mf = arena_x.val().cols();
113
114 Eigen::VectorXd d_beta = Eigen::VectorXd::Zero(Nf);
115
116 for (int j = 0; j < Mf; ++j) {
117 double a_j = inv_sqrt((j + 1.0) * (j + 2.0));
118 double b_j = (j + 1.0) * a_j;
119
120 double d_ax = 0.0;
121
122 for (int i = 0; i < Nf; ++i) {
123 double a_i = inv_sqrt((i + 1.0) * (i + 2.0));
124 double b_i = (i + 1.0) * a_i;
125
126 double dY = arena_z.adj().coeff(i, j) - arena_z.adj().coeff(Nf, j)
127 + arena_z.adj().coeff(Nf, Mf) - arena_z.adj().coeff(i, Mf);
128 double dI_from_beta = a_j * d_beta.coeff(i);
129 d_beta.coeffRef(i) += -dY;
130
131 double dI_from_alpha = b_j * dY;
132 double dI = dI_from_alpha + dI_from_beta;
133 arena_x.adj().coeffRef(i, j) += b_i * dI + a_i * d_ax;
134 d_ax -= dI;
135 }
136 }
137 });
138
139 return arena_z;
140}
141
167template <typename T, typename Lp, require_t<is_rev_matrix<T>>* = nullptr>
168inline auto sum_to_zero_constrain(T&& y, Lp& lp) {
169 return sum_to_zero_constrain(std::forward<T>(y));
170}
171
172} // namespace math
173} // namespace stan
174#endif
#define unlikely(x)
int64_t cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
Definition cols.hpp:21
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:22
void sum_to_zero_vector_backprop(T &&y_adj, const Eigen::VectorXd &z_adj)
The reverse pass backprop for the sum_to_zero_constrain on vectors.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:18
plain_type_t< Vec > sum_to_zero_constrain(const Vec &y)
Return a vector with sum zero corresponding to the specified free vector.
fvar< T > inv_sqrt(const fvar< T > &x)
Definition inv_sqrt.hpp:14
typename plain_type< std::decay_t< T > >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...