Loading [MathJax]/extensions/TeX/AMSsymbols.js
Automatic Differentiation
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
simplex_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_CONSTRAINT_SIMPLEX_CONSTRAIN_HPP
2#define STAN_MATH_PRIM_CONSTRAINT_SIMPLEX_CONSTRAIN_HPP
3
10#include <cmath>
11
12namespace stan {
13namespace math {
14
27template <typename Vec, require_eigen_vector_t<Vec>* = nullptr,
28 require_not_st_var<Vec>* = nullptr>
30 // cut & paste simplex_constrain(Eigen::Matrix, T) w/o Jacobian
31 using std::log;
32 using T = value_type_t<Vec>;
33
34 int Km1 = y.size();
35 plain_type_t<Vec> x(Km1 + 1);
36 T stick_len(1.0);
37 for (Eigen::Index k = 0; k < Km1; ++k) {
38 T z_k = inv_logit(y.coeff(k) - log(Km1 - k));
39 x.coeffRef(k) = stick_len * z_k;
40 stick_len -= x.coeff(k);
41 }
42 x.coeffRef(Km1) = stick_len;
43 return x;
44}
45
61template <typename Vec, typename Lp, require_eigen_vector_t<Vec>* = nullptr,
62 require_not_st_var<Vec>* = nullptr,
63 require_convertible_t<value_type_t<Vec>, Lp>* = nullptr>
64inline plain_type_t<Vec> simplex_constrain(const Vec& y, Lp& lp) {
65 using Eigen::Dynamic;
66 using Eigen::Matrix;
67 using std::log;
68 using T = value_type_t<Vec>;
69
70 int Km1 = y.size(); // K = Km1 + 1
71 plain_type_t<Vec> x(Km1 + 1);
72 T stick_len(1.0);
73 for (Eigen::Index k = 0; k < Km1; ++k) {
74 double eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k));
75 T adj_y_k = y.coeff(k) + eq_share;
76 T z_k = inv_logit(adj_y_k);
77 x.coeffRef(k) = stick_len * z_k;
78 lp += log(stick_len);
79 lp -= log1p_exp(-adj_y_k);
80 lp -= log1p_exp(adj_y_k);
81 stick_len -= x.coeff(k); // equivalently *= (1 - z_k);
82 }
83 x.coeffRef(Km1) = stick_len; // no Jacobian contrib for last dim
84 return x;
85}
86
97template <typename T, require_std_vector_t<T>* = nullptr>
98inline auto simplex_constrain(const T& y) {
100 y, [](auto&& v) { return simplex_constrain(v); });
101}
102
116template <typename T, typename Lp, require_std_vector_t<T>* = nullptr,
117 require_convertible_t<return_type_t<T>, Lp>* = nullptr>
118inline auto simplex_constrain(const T& y, Lp& lp) {
120 y, [&lp](auto&& v) { return simplex_constrain(v, lp); });
121}
122
141template <bool Jacobian, typename Vec, typename Lp,
143inline plain_type_t<Vec> simplex_constrain(const Vec& y, Lp& lp) {
144 if constexpr (Jacobian) {
145 return simplex_constrain(y, lp);
146 } else {
147 return simplex_constrain(y);
148 }
149}
150
151} // namespace math
152} // namespace stan
153
154#endif
require_t< std::is_convertible< std::decay_t< T >, std::decay_t< S > > > require_convertible_t
Require types T and S satisfies std::is_convertible.
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
fvar< T > log1p_exp(const fvar< T > &x)
Definition log1p_exp.hpp:14
plain_type_t< Vec > simplex_constrain(const Vec &y)
Return the simplex corresponding to the specified free vector.
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
Definition inv_logit.hpp:20
typename plain_type< T >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...