Loading [MathJax]/extensions/TeX/AMSsymbols.js
Automatic Differentiation
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
simplex_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_CONSTRAINT_SIMPLEX_CONSTRAIN_HPP
2#define STAN_MATH_PRIM_CONSTRAINT_SIMPLEX_CONSTRAIN_HPP
3
11#include <cmath>
12
13namespace stan {
14namespace math {
15
32template <typename Vec, require_eigen_vector_t<Vec>* = nullptr,
33 require_not_st_var<Vec>* = nullptr>
35 using T = value_type_t<Vec>;
36 const auto N = y.size();
37
38 plain_type_t<Vec> z = Eigen::VectorXd::Zero(N + 1);
39 if (unlikely(N == 0)) {
40 z.coeffRef(0) = 1;
41 return z;
42 }
43
44 auto&& y_ref = to_ref(y);
45 T sum_w(0);
46
47 T d(0); // sum of exponentials
48 T max_val(0);
49 T max_val_old(negative_infinity());
50
51 for (int i = N; i > 0; --i) {
52 double n = static_cast<double>(i);
53 auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
54 sum_w += w;
55
56 z.coeffRef(i - 1) += sum_w;
57 z.coeffRef(i) -= w * n;
58
59 max_val = fmax(max_val_old, z.coeff(i));
60 d = d * exp(max_val_old - max_val) + exp(z.coeff(i) - max_val);
61 max_val_old = max_val;
62 }
63
64 // above loop doesn't reach i==0
65 max_val = fmax(max_val_old, z.coeff(0));
66 d = d * exp(max_val_old - max_val) + exp(z.coeff(0) - max_val);
67
68 z.array() = (z.array() - max_val).exp() / d;
69
70 return z;
71}
72
91template <typename Vec, typename Lp, require_eigen_vector_t<Vec>* = nullptr,
92 require_not_st_var<Vec>* = nullptr,
93 require_convertible_t<value_type_t<Vec>, Lp>* = nullptr>
94inline plain_type_t<Vec> simplex_constrain(const Vec& y, Lp& lp) {
95 using std::log;
96 using T = value_type_t<Vec>;
97 const auto N = y.size();
98
99 plain_type_t<Vec> z = Eigen::VectorXd::Zero(N + 1);
100 if (unlikely(N == 0)) {
101 z.coeffRef(0) = 1;
102 return z;
103 }
104
105 auto&& y_ref = to_ref(y);
106 T sum_w(0);
107
108 T d(0); // sum of exponentials
109 T max_val(0);
110 T max_val_old(negative_infinity());
111
112 for (int i = N; i > 0; --i) {
113 double n = static_cast<double>(i);
114 auto w = y_ref(i - 1) * inv_sqrt(n * (n + 1));
115 sum_w += w;
116
117 z.coeffRef(i - 1) += sum_w;
118 z.coeffRef(i) -= w * n;
119
120 max_val = fmax(max_val_old, z.coeff(i));
121 d = d * exp(max_val_old - max_val) + exp(z.coeff(i) - max_val);
122 max_val_old = max_val;
123 }
124
125 // above loop doesn't reach i==0
126 max_val = fmax(max_val_old, z.coeff(0));
127 d = d * exp(max_val_old - max_val) + exp(z.coeff(0) - max_val);
128
129 z.array() = (z.array() - max_val).exp() / d;
130
131 // equivalent to z.log().sum() + 0.5 * log(N + 1)
132 lp += -(N + 1) * (max_val + log(d)) + 0.5 * log(N + 1);
133
134 return z;
135}
136
147template <typename T, require_std_vector_t<T>* = nullptr>
148inline auto simplex_constrain(const T& y) {
150 y, [](auto&& v) { return simplex_constrain(v); });
151}
152
166template <typename T, typename Lp, require_std_vector_t<T>* = nullptr,
167 require_convertible_t<return_type_t<T>, Lp>* = nullptr>
168inline auto simplex_constrain(const T& y, Lp& lp) {
170 y, [&lp](auto&& v) { return simplex_constrain(v, lp); });
171}
172
191template <bool Jacobian, typename Vec, typename Lp,
193inline plain_type_t<Vec> simplex_constrain(const Vec& y, Lp& lp) {
194 if constexpr (Jacobian) {
195 return simplex_constrain(y, lp);
196 } else {
197 return simplex_constrain(y);
198 }
199}
200
201} // namespace math
202} // namespace stan
203
204#endif
#define unlikely(x)
require_t< std::is_convertible< std::decay_t< T >, std::decay_t< S > > > require_convertible_t
Require types T and S satisfies std::is_convertible.
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
static constexpr double negative_infinity()
Return negative infinity.
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
fvar< T > fmax(const fvar< T > &x1, const fvar< T > &x2)
Return the greater of the two specified arguments.
Definition fmax.hpp:23
plain_type_t< Vec > simplex_constrain(const Vec &y)
Return the simplex corresponding to the specified free vector.
fvar< T > inv_sqrt(const fvar< T > &x)
Definition inv_sqrt.hpp:14
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
typename plain_type< T >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...