Loading [MathJax]/extensions/tex2jax.js
Automatic Differentiation
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
sum_to_zero_free.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_CONSTRAINT_SUM_TO_ZERO_FREE_HPP
2#define STAN_MATH_PRIM_CONSTRAINT_SUM_TO_ZERO_FREE_HPP
3
11#include <cmath>
12
13namespace stan {
14namespace math {
15
38template <typename Vec, require_eigen_vector_t<Vec>* = nullptr>
39inline plain_type_t<Vec> sum_to_zero_free(const Vec& z) {
40 const auto& z_ref = to_ref(z);
41 check_sum_to_zero("stan::math::sum_to_zero_free", "sum_to_zero variable",
42 z_ref);
43
44 const auto N = z.size() - 1;
45
46 plain_type_t<Vec> y = Eigen::VectorXd::Zero(N);
47 if (unlikely(N == 0)) {
48 return y;
49 }
50
51 y.coeffRef(N - 1) = -z_ref.coeff(N) * sqrt(N * (N + 1)) / N;
52
53 value_type_t<Vec> sum_w(0);
54
55 for (int i = N - 2; i >= 0; --i) {
56 double n = static_cast<double>(i + 1);
57 auto w = y.coeff(i + 1) / sqrt((n + 1) * (n + 2));
58 sum_w += w;
59 y.coeffRef(i) = (sum_w - z_ref.coeff(i + 1)) * sqrt(n * (n + 1)) / n;
60 }
61
62 return y;
63}
64
73template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr>
74inline plain_type_t<Mat> sum_to_zero_free(const Mat& z) {
75 const auto& z_ref = to_ref(z);
76 check_sum_to_zero("stan::math::sum_to_zero_free", "sum_to_zero variable",
77 z_ref);
78
79 const auto N = z_ref.rows() - 1;
80 const auto M = z_ref.cols() - 1;
81
82 plain_type_t<Mat> x = Eigen::MatrixXd::Zero(N, M);
83 if (unlikely(N == 0 || M == 0)) {
84 return x;
85 }
86
87 Eigen::Matrix<value_type_t<Mat>, -1, 1> beta = Eigen::VectorXd::Zero(N);
88
89 for (int j = M - 1; j >= 0; --j) {
90 value_type_t<Mat> ax_previous(0);
91
92 double a_j = inv_sqrt((j + 1.0) * (j + 2.0));
93 double b_j = (j + 1.0) * a_j;
94
95 for (int i = N - 1; i >= 0; --i) {
96 double a_i = inv_sqrt((i + 1.0) * (i + 2.0));
97 double b_i = (i + 1.0) * a_i;
98
99 auto alpha_plus_beta = z_ref.coeff(i, j) + beta.coeff(i);
100
101 x.coeffRef(i, j) = (alpha_plus_beta + b_j * ax_previous) / (b_j * b_i);
102 beta.coeffRef(i) += a_j * (b_i * x.coeff(i, j) - ax_previous);
103 ax_previous += a_i * x.coeff(i, j);
104 }
105 }
106
107 return x;
108}
109
117template <typename T, require_std_vector_t<T>* = nullptr>
118inline auto sum_to_zero_free(const T& z) {
120 z, [](auto&& v) { return sum_to_zero_free(v); });
121}
122
123} // namespace math
124} // namespace stan
125
126#endif
#define unlikely(x)
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
plain_type_t< Vec > sum_to_zero_free(const Vec &z)
Return an unconstrained vector.
void check_sum_to_zero(const char *function, const char *name, const T &theta)
Throw an exception if the specified vector does not sum to 0.
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:18
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
Definition beta.hpp:51
fvar< T > inv_sqrt(const fvar< T > &x)
Definition inv_sqrt.hpp:14
typename plain_type< T >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...