Automatic Differentiation
 
Loading...
Searching...
No Matches
matrix_normal_prec_rng.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_PROB_MATRIX_NORMAL_PREC_RNG_HPP
2#define STAN_MATH_PRIM_PROB_MATRIX_NORMAL_PREC_RNG_HPP
3
7#include <boost/random/normal_distribution.hpp>
8#include <boost/random/variate_generator.hpp>
9
10namespace stan {
11namespace math {
12
31template <class RNG>
32inline Eigen::MatrixXd matrix_normal_prec_rng(const Eigen::MatrixXd &Mu,
33 const Eigen::MatrixXd &Sigma,
34 const Eigen::MatrixXd &D,
35 RNG &rng) {
36 using boost::normal_distribution;
37 using boost::variate_generator;
38 static constexpr const char *function = "matrix_normal_prec_rng";
39 check_positive(function, "Sigma rows", Sigma.rows());
40 check_finite(function, "Sigma", Sigma);
41 check_symmetric(function, "Sigma", Sigma);
42 check_positive(function, "D rows", D.rows());
43 check_finite(function, "D", D);
44 check_symmetric(function, "D", D);
45 check_size_match(function, "Rows of location parameter", Mu.rows(),
46 "Rows of Sigma", Sigma.rows());
47 check_size_match(function, "Columns of location parameter", Mu.cols(),
48 "Rows of D", D.rows());
49 check_finite(function, "Location parameter", Mu);
50
51 Eigen::LDLT<Eigen::MatrixXd> Sigma_ldlt(Sigma);
52 // Sigma = PS^T LS DS LS^T PS
53 // PS a permutation matrix.
54 // LS lower triangular with unit diagonal.
55 // DS diagonal.
56 Eigen::LDLT<Eigen::MatrixXd> D_ldlt(D);
57 // D = PD^T LD DD LD^T PD
58
59 check_pos_semidefinite(function, "Sigma", Sigma_ldlt);
60 check_pos_semidefinite(function, "D", D_ldlt);
61
62 // If
63 // C ~ N[0, I, I]
64 // Then
65 // A C B ~ N[0, A A^T, B^T B]
66 // So to get
67 // Y - Mu ~ N[0, Sigma^(-1), D^(-1)]
68 // We need to do
69 // Y - Mu = Q^T^(-1) C R^(-1)
70 // Where Q^T^(-1) and R^(-1) are such that
71 // Q^(-1) Q^(-1)^T = Sigma^(-1)
72 // R^(-1)^T R^(-1) = D^(-1)
73 // We choose:
74 // Q^(-1)^T = PS^T LS^T^(-1) sqrt[DS]^(-1)
75 // R^(-1) = sqrt[DD]^(-1) LD^(-1) PD
76 // And therefore
77 // Y - Mu = (PS^T LS^T^(-1) sqrt[DS]^(-1)) C (sqrt[DD]^(-1) LD^(-1) PD)
78
79 int m = Sigma.rows();
80 int n = D.rows();
81
82 variate_generator<RNG &, normal_distribution<>> std_normal_rng(
83 rng, normal_distribution<>(0, 1));
84
85 // X = sqrt[DS]^(-1) C sqrt[DD]^(-1)
86 // X ~ N[0, DS, DD]
87 Eigen::MatrixXd X(m, n);
88 Eigen::VectorXd row_stddev
89 = Sigma_ldlt.vectorD().array().inverse().sqrt().matrix();
90 Eigen::VectorXd col_stddev
91 = D_ldlt.vectorD().array().inverse().sqrt().matrix();
92 for (int col = 0; col < n; ++col) {
93 for (int row = 0; row < m; ++row) {
94 double stddev = row_stddev(row) * col_stddev(col);
95 // C(row, col) = std_normal_rng();
96 X(row, col) = stddev * std_normal_rng();
97 }
98 }
99
100 // Y - Mu = PS^T (LS^T^(-1) X LD^(-1)) PD
101 // Y' = LS^T^(-1) X LD^(-1)
102 // Y' = LS^T.solve(X) LD^(-1)
103 // Y' = (LD^(-1)^T (LS^T.solve(X))^T)^T
104 // Y' = (LD^T.solve((LS^T.solve(X))^T))^T
105 // Y = Mu + PS^T Y' PD
106 Eigen::MatrixXd Y = Mu
107 + (Sigma_ldlt.transpositionsP().transpose()
108 * (D_ldlt.matrixU().solve(
109 (Sigma_ldlt.matrixU().solve(X)).transpose()))
110 .transpose()
111 * D_ldlt.transpositionsP());
112
113 return Y;
114}
115} // namespace math
116} // namespace stan
117#endif
void check_symmetric(const char *function, const char *name, const matrix_cl< T > &y)
Check if the matrix_cl is symmetric.
Eigen::MatrixXd matrix_normal_prec_rng(const Eigen::MatrixXd &Mu, const Eigen::MatrixXd &Sigma, const Eigen::MatrixXd &D, RNG &rng)
Sample from the the matrix normal distribution for the given Mu, Sigma and D where Sigma and D are gi...
auto transpose(Arg &&a)
Transposes a kernel generator expression.
auto col(T_x &&x, size_t j)
Return the specified column of the specified kernel generator expression using start-at-1 indexing.
Definition col.hpp:23
auto row(T_x &&x, size_t j)
Return the specified row of the specified kernel generator expression using start-at-1 indexing.
Definition row.hpp:23
double std_normal_rng(RNG &rng)
Return a standard Normal random variate using the specified random number generator.
void check_pos_semidefinite(const char *function, const char *name, const EigMat &y)
Check if the specified matrix is positive definite.
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9