Automatic Differentiation
 
Loading...
Searching...
No Matches
gaussian_dlm_obs_lpdf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_PROB_GAUSSIAN_DLM_OBS_LPDF_HPP
2#define STAN_MATH_PRIM_PROB_GAUSSIAN_DLM_OBS_LPDF_HPP
3
21#include <cmath>
22
23/*
24 TODO: time-varying system matrices
25 TODO: use sequential processing even for non-diagonal obs
26 covariance.
27 TODO: add constant terms in observation.
28*/
29namespace stan {
30namespace math {
65template <bool propto, typename T_y, typename T_F, typename T_G, typename T_V,
66 typename T_W, typename T_m0, typename T_C0,
67 require_all_eigen_matrix_dynamic_t<T_y, T_F, T_G, T_V, T_W,
68 T_C0>* = nullptr,
69 require_eigen_col_vector_t<T_m0>* = nullptr>
71 const T_y& y, const T_F& F, const T_G& G, const T_V& V, const T_W& W,
72 const T_m0& m0, const T_C0& C0) {
74 using std::pow;
75 static constexpr const char* function = "gaussian_dlm_obs_lpdf";
76 check_size_match(function, "columns of F", F.cols(), "rows of y", y.rows());
77 check_size_match(function, "rows of F", F.rows(), "rows of G", G.rows());
78 check_size_match(function, "rows of V", V.rows(), "rows of y", y.rows());
79 check_size_match(function, "rows of W", W.rows(), "rows of G", G.rows());
80 check_size_match(function, "size of m0", m0.size(), "rows of G", G.rows());
81 check_size_match(function, "rows of C0", C0.rows(), "rows of G", G.rows());
82 check_square(function, "G", G);
83
84 const auto& y_ref = to_ref(y);
85 const auto& F_ref = to_ref(F);
86 const auto& G_ref = to_ref(G);
87 const auto& V_ref = to_ref(V);
88 const auto& W_ref = to_ref(W);
89 const auto& m0_ref = to_ref(m0);
90 const auto& C0_ref = to_ref(C0);
91
92 check_finite(function, "y", y_ref);
93 check_finite(function, "F", F_ref);
94 check_finite(function, "G", G_ref);
95 // TODO(anyone): incorporate support for infinite V
96 check_finite(function, "V", V_ref);
97 check_pos_semidefinite(function, "V", V_ref);
98 // TODO(anyone): incorporate support for infinite W
99 check_finite(function, "W", W_ref);
100 check_pos_semidefinite(function, "W", W_ref);
101 check_finite(function, "m0", m0_ref);
102 check_pos_semidefinite(function, "C0", C0_ref);
103 check_finite(function, "C0", C0_ref);
104
105 if (size_zero(y)) {
106 return 0;
107 }
108
109 int r = y.rows(); // number of variables
110 int n = G.rows(); // number of states
111
112 T_lp lp(0);
114 lp -= HALF_LOG_TWO_PI * r * y.cols();
115 }
116
118 Eigen::Matrix<T_lp, Eigen::Dynamic, 1> m{m0_ref};
119 Eigen::Matrix<T_lp, Eigen::Dynamic, Eigen::Dynamic> C{C0_ref};
120 Eigen::Matrix<T_lp, Eigen::Dynamic, 1> a(n);
121 Eigen::Matrix<T_lp, Eigen::Dynamic, Eigen::Dynamic> R(n, n);
122 Eigen::Matrix<T_lp, Eigen::Dynamic, 1> f(r);
123 Eigen::Matrix<T_lp, Eigen::Dynamic, Eigen::Dynamic> Q(r, r);
124 Eigen::Matrix<T_lp, Eigen::Dynamic, Eigen::Dynamic> Q_inv(r, r);
125 Eigen::Matrix<T_lp, Eigen::Dynamic, 1> e(r);
126 Eigen::Matrix<T_lp, Eigen::Dynamic, Eigen::Dynamic> A(n, r);
127
128 for (int i = 0; i < y.cols(); i++) {
129 // // Predict state
130 // a_t = G_t m_{t-1}
131 a = multiply(G_ref, m);
132 // R_t = G_t C_{t-1} G_t' + W_t
133 R = quad_form_sym(C, transpose(G_ref)) + W_ref;
134 // // predict observation
135 // f_t = F_t' a_t
136 f = multiply(transpose(F_ref), a);
137 // Q_t = F'_t R_t F_t + V_t
138 Q = quad_form_sym(R, F_ref) + V_ref;
139 Q_inv = inverse_spd(Q);
140 // // filtered state
141 // e_t = y_t - f_t
142 e = y_ref.col(i) - f;
143 // A_t = R_t F_t Q^{-1}_t
144 A = multiply(multiply(R, F_ref), Q_inv);
145 // m_t = a_t + A_t e_t
146 m = a + multiply(A, e);
147 // C = R_t - A_t Q_t A_t'
148 C = R - quad_form_sym(Q, transpose(A));
149 lp -= 0.5 * (log_determinant_spd(Q) + trace_quad_form(Q_inv, e));
150 }
151 }
152 return lp;
153}
154
190template <
191 bool propto, typename T_y, typename T_F, typename T_G, typename T_V,
192 typename T_W, typename T_m0, typename T_C0,
196 const T_y& y, const T_F& F, const T_G& G, const T_V& V, const T_W& W,
197 const T_m0& m0, const T_C0& C0) {
199 using std::log;
200 static constexpr const char* function = "gaussian_dlm_obs_lpdf";
201 check_size_match(function, "columns of F", F.cols(), "rows of y", y.rows());
202 check_size_match(function, "rows of F", F.rows(), "rows of G", G.rows());
203 check_size_match(function, "rows of G", G.rows(), "columns of G", G.cols());
204 check_size_match(function, "size of V", V.size(), "rows of y", y.rows());
205 check_size_match(function, "rows of W", W.rows(), "rows of G", G.rows());
206 check_size_match(function, "size of m0", m0.size(), "rows of G", G.rows());
207 check_size_match(function, "rows of C0", C0.rows(), "rows of G", G.rows());
208
209 const auto& y_ref = to_ref(y);
210 const auto& F_ref = to_ref(F);
211 const auto& G_ref = to_ref(G);
212 const auto& V_ref = to_ref(V);
213 const auto& W_ref = to_ref(W);
214 const auto& m0_ref = to_ref(m0);
215 const auto& C0_ref = to_ref(C0);
216
217 check_finite(function, "y", y_ref);
218 check_finite(function, "F", F_ref);
219 check_finite(function, "G", G_ref);
220 check_nonnegative(function, "V", V_ref);
221 // TODO(anyone): support infinite V
222 check_finite(function, "V", V_ref);
223 check_pos_semidefinite(function, "W", W_ref);
224 // TODO(anyone): support infinite W
225 check_finite(function, "W", W_ref);
226 check_finite(function, "m0", m0_ref);
227 check_pos_semidefinite(function, "C0", C0_ref);
228 check_finite(function, "C0", C0_ref);
229
230 if (y.cols() == 0 || y.rows() == 0) {
231 return 0;
232 }
233
234 int r = y.rows(); // number of variables
235 int n = G.rows(); // number of states
236
237 T_lp lp(0);
238 if (include_summand<propto>::value) {
239 lp -= HALF_LOG_TWO_PI * r * y.cols();
240 }
241
242 if (include_summand<propto, T_y, T_F, T_G, T_V, T_W, T_m0, T_C0>::value) {
243 T_lp f;
244 T_lp Q;
245 T_lp Q_inv;
246 T_lp e;
247 Eigen::Matrix<T_lp, Eigen::Dynamic, 1> A(n);
248 Eigen::Matrix<T_lp, Eigen::Dynamic, 1> Fj(n);
249 Eigen::Matrix<T_lp, Eigen::Dynamic, 1> m{m0_ref};
250 Eigen::Matrix<T_lp, Eigen::Dynamic, Eigen::Dynamic> C{C0_ref};
251
252 for (int i = 0; i < y.cols(); i++) {
253 // Predict state
254 // reuse m and C instead of using a and R
255 m = multiply(G_ref, m);
256 C = quad_form_sym(C, transpose(G_ref)) + W_ref;
257 for (int j = 0; j < y.rows(); ++j) {
258 // predict observation
259 // dim Fj = (n, 1)
260 const auto& Fj = F_ref.col(j);
261 // f_{t, i} = F_{t, i}' m_{t, i-1}
262 f = dot_product(Fj, m);
263 Q = trace_quad_form(C, Fj) + V_ref.coeff(j);
264 if (i == 0)
265 check_positive(function, "Q0", Q);
266 Q_inv = 1.0 / Q;
267 // filtered observation
268 // e_{t, i} = y_{t, i} - f_{t, i}
269 e = y_ref.coeff(j, i) - f;
270 // A_{t, i} = C_{t, i-1} F_{t, i} Q_{t, i}^{-1}
271 A = multiply(multiply(C, Fj), Q_inv);
272 // m_{t, i} = m_{t, i-1} + A_{t, i} e_{t, i}
273 m += multiply(A, e);
274 // c_{t, i} = C_{t, i-1} - Q_{t, i} A_{t, i} A_{t, i}'
275 // tcrossprod throws an error (ambiguous)
276 // C = subtract(C, multiply(Q, tcrossprod(A)));
277 C -= multiply(Q, multiply(A, transpose(A)));
278 C = 0.5 * (C + transpose(C));
279 lp -= 0.5 * (log(Q) + square(e) * Q_inv);
280 }
281 }
282 }
283 return lp;
284}
285
286template <typename T_y, typename T_F, typename T_G, typename T_V, typename T_W,
287 typename T_m0, typename T_C0>
289 const T_y& y, const T_F& F, const T_G& G, const T_V& V, const T_W& W,
290 const T_m0& m0, const T_C0& C0) {
291 return gaussian_dlm_obs_lpdf<false>(y, F, G, V, W, m0, C0);
292}
293
294} // namespace math
295} // namespace stan
296#endif
require_all_t< is_eigen_col_vector< std::decay_t< Types > >... > require_all_eigen_col_vector_t
Require all of the types satisfy is_eigen_col_vector.
require_all_t< is_eigen_matrix_dynamic< std::decay_t< Types > >... > require_all_eigen_matrix_dynamic_t
Require all of the types satisfy is_eigen_matrix_dynamic.
return_type_t< T_y, T_F, T_G, T_V, T_W, T_m0, T_C0 > gaussian_dlm_obs_lpdf(const T_y &y, const T_F &F, const T_G &G, const T_V &V, const T_W &W, const T_m0 &m0, const T_C0 &C0)
The log of a Gaussian dynamic linear model (GDLM).
auto transpose(Arg &&a)
Transposes a kernel generator expression.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition size_zero.hpp:19
void check_pos_semidefinite(const char *function, const char *name, const EigMat &y)
Check if the specified matrix is positive definite.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
value_type_t< EigMat > log_determinant_spd(const EigMat &m)
Returns the log absolute determinant of the specified square matrix.
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
auto multiply(const Mat1 &m1, const Mat2 &m2)
Return the product of the specified matrices.
Definition multiply.hpp:18
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
promote_scalar_t< return_type_t< EigMat1, EigMat2 >, EigMat2 > quad_form_sym(const EigMat1 &A, const EigMat2 &B)
Return the quadratic form of a symmetric matrix.
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
Eigen::Matrix< value_type_t< EigMat >, Eigen::Dynamic, Eigen::Dynamic > inverse_spd(const EigMat &m)
Returns the inverse of the specified symmetric, pos/neg-definite matrix.
return_type_t< EigMat1, EigMat2 > trace_quad_form(const EigMat1 &A, const EigMat2 &B)
static constexpr double HALF_LOG_TWO_PI
The value of half the natural logarithm , .
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
auto dot_product(const T_a &a, const T_b &b)
Returns the dot product of the specified vectors.
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...