Automatic Differentiation
 
Loading...
Searching...
No Matches
sum_to_zero_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_SUM_TO_ZERO_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_SUM_TO_ZERO_CONSTRAIN_HPP
3
10#include <cmath>
11#include <tuple>
12#include <vector>
13
14namespace stan {
15namespace math {
16
41template <typename T, require_rev_col_vector_t<T>* = nullptr>
42inline auto sum_to_zero_constrain(T&& y) {
43 using ret_type = plain_type_t<T>;
44 if (unlikely(y.size() == 0)) {
45 return arena_t<ret_type>(Eigen::VectorXd{{0}});
46 }
47 auto arena_y = to_arena(std::forward<T>(y));
48 arena_t<ret_type> arena_z = sum_to_zero_constrain(arena_y.val());
49
50 reverse_pass_callback([arena_y, arena_z]() mutable {
51 const auto N = arena_y.size();
52
53 double sum_u_adj = 0;
54 for (int i = 0; i < N; ++i) {
55 double n = static_cast<double>(i + 1);
56
57 // adjoint of the reverse cumulative sum computed in the forward mode
58 sum_u_adj += arena_z.adj()(i);
59
60 // adjoint of the offset subtraction
61 double v_adj = -arena_z.adj()(i + 1) * n;
62
63 double w_adj = v_adj + sum_u_adj;
64
65 arena_y.adj()(i) += w_adj / sqrt(n * (n + 1));
66 }
67 });
68
69 return arena_z;
70}
71
97template <typename T, require_rev_col_vector_t<T>* = nullptr>
98inline auto sum_to_zero_constrain(T&& y, scalar_type_t<T>& lp) {
99 return sum_to_zero_constrain(std::forward<T>(y));
100}
101
102} // namespace math
103} // namespace stan
104#endif
#define unlikely(x)
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:17
plain_type_t< Vec > sum_to_zero_constrain(const Vec &y)
Return a vector with sum zero corresponding to the specified free vector.
typename plain_type< T >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...