Loading [MathJax]/extensions/TeX/AMSsymbols.js
Automatic Differentiation
 
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Loading...
Searching...
No Matches
cholesky_corr_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_CONSTRAINT_CHOLESKY_CORR_CONSTRAIN_HPP
2#define STAN_MATH_PRIM_CONSTRAINT_CHOLESKY_CORR_CONSTRAIN_HPP
3
10#include <cmath>
11
12namespace stan {
13namespace math {
14
15template <typename EigVec, require_eigen_col_vector_t<EigVec>* = nullptr>
16inline Eigen::Matrix<value_type_t<EigVec>, Eigen::Dynamic, Eigen::Dynamic>
17cholesky_corr_constrain(const EigVec& y, int K) {
18 using Eigen::Dynamic;
19 using Eigen::Matrix;
20 using std::sqrt;
21 using T_scalar = value_type_t<EigVec>;
22 int k_choose_2 = (K * (K - 1)) / 2;
23 check_size_match("cholesky_corr_constrain", "constrain size", y.size(),
24 "k_choose_2", k_choose_2);
25 Matrix<T_scalar, Dynamic, 1> z = corr_constrain(y);
26 Matrix<T_scalar, Dynamic, Dynamic> x(K, K);
27 if (K == 0) {
28 return x;
29 }
30 x.setZero();
31 x.coeffRef(0, 0) = 1;
32 int k = 0;
33 for (int i = 1; i < K; ++i) {
34 x.coeffRef(i, 0) = z.coeff(k++);
35 T_scalar sum_sqs = square(x.coeff(i, 0));
36 for (int j = 1; j < i; ++j) {
37 x.coeffRef(i, j) = z.coeff(k++) * sqrt(1.0 - sum_sqs);
38 sum_sqs += square(x.coeff(i, j));
39 }
40 x.coeffRef(i, i) = sqrt(1.0 - sum_sqs);
41 }
42 return x;
43}
44
45// FIXME to match above after debugged
46template <typename EigVec, typename Lp,
49inline Eigen::Matrix<value_type_t<EigVec>, Eigen::Dynamic, Eigen::Dynamic>
50cholesky_corr_constrain(const EigVec& y, int K, Lp& lp) {
51 using Eigen::Dynamic;
52 using Eigen::Matrix;
53 using std::sqrt;
54 using T_scalar = value_type_t<EigVec>;
55 int k_choose_2 = (K * (K - 1)) / 2;
56 check_size_match("cholesky_corr_constrain", "y.size()", y.size(),
57 "k_choose_2", k_choose_2);
58 Matrix<T_scalar, Dynamic, 1> z = corr_constrain(y, lp);
59 Matrix<T_scalar, Dynamic, Dynamic> x(K, K);
60 if (K == 0) {
61 return x;
62 }
63 x.setZero();
64 x.coeffRef(0, 0) = 1;
65 int k = 0;
66 for (int i = 1; i < K; ++i) {
67 x.coeffRef(i, 0) = z.coeff(k++);
68 T_scalar sum_sqs = square(x.coeff(i, 0));
69 for (int j = 1; j < i; ++j) {
70 lp += 0.5 * log1m(sum_sqs);
71 x.coeffRef(i, j) = z.coeff(k++) * sqrt(1.0 - sum_sqs);
72 sum_sqs += square(x.coeff(i, j));
73 }
74 x.coeffRef(i, i) = sqrt(1.0 - sum_sqs);
75 }
76 return x;
77}
78
89template <typename T, require_std_vector_t<T>* = nullptr>
90inline auto cholesky_corr_constrain(const T& y, int K) {
92 y, [K](auto&& v) { return cholesky_corr_constrain(v, K); });
93}
94
108template <typename T, typename Lp, require_std_vector_t<T>* = nullptr,
109 require_convertible_t<return_type_t<T>, Lp>* = nullptr>
110inline auto cholesky_corr_constrain(const T& y, int K, Lp& lp) {
112 y, [&lp, K](auto&& v) { return cholesky_corr_constrain(v, K, lp); });
113}
114
133template <bool Jacobian, typename T, typename Lp,
135inline auto cholesky_corr_constrain(const T& y, int K, Lp& lp) {
136 if constexpr (Jacobian) {
137 return cholesky_corr_constrain(y, K, lp);
138 } else {
139 return cholesky_corr_constrain(y, K);
140 }
141}
142
143} // namespace math
144} // namespace stan
145#endif
require_t< std::is_convertible< std::decay_t< T >, std::decay_t< S > > > require_convertible_t
Require types T and S satisfies std::is_convertible.
require_t< is_eigen_vector< std::decay_t< T > > > require_eigen_vector_t
Require type satisfies is_eigen_vector.
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:18
Eigen::Matrix< value_type_t< EigVec >, Eigen::Dynamic, Eigen::Dynamic > cholesky_corr_constrain(const EigVec &y, int K)
plain_type_t< T > corr_constrain(const T &x)
Return the result of transforming the specified scalar or container of values to have a valid correla...
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
fvar< T > log1m(const fvar< T > &x)
Definition log1m.hpp:12
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...