Automatic Differentiation
 
Loading...
Searching...
No Matches
hmm_marginal.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_HMM_MARGINAL_LPDF_HPP
2#define STAN_MATH_REV_FUN_HMM_MARGINAL_LPDF_HPP
3
14#include <vector>
15
16namespace stan {
17namespace math {
18
19template <typename T_omega, typename T_Gamma, typename T_rho, typename T_alpha>
20inline auto hmm_marginal_val(
21 const Eigen::Matrix<T_omega, Eigen::Dynamic, Eigen::Dynamic>& omegas,
22 const T_Gamma& Gamma_val, const T_rho& rho_val,
23 Eigen::Matrix<T_alpha, Eigen::Dynamic, Eigen::Dynamic>& alphas,
24 Eigen::Matrix<T_alpha, Eigen::Dynamic, 1>& alpha_log_norms,
25 T_alpha& norm_norm) {
26 const int n_transitions = omegas.cols() - 1;
27 alphas.col(0) = omegas.col(0).cwiseProduct(rho_val);
28
29 const auto norm = alphas.col(0).maxCoeff();
30 alphas.col(0) /= norm;
31 alpha_log_norms(0) = log(norm);
32
33 auto Gamma_val_transpose = Gamma_val.transpose().eval();
34 for (int n = 1; n <= n_transitions; ++n) {
35 alphas.col(n)
36 = omegas.col(n).cwiseProduct(Gamma_val_transpose * alphas.col(n - 1));
37 const auto col_norm = alphas.col(n).maxCoeff();
38 alphas.col(n) /= col_norm;
39 alpha_log_norms(n) = log(col_norm) + alpha_log_norms(n - 1);
40 }
41 norm_norm = alpha_log_norms(n_transitions);
42 return log(alphas.col(n_transitions).sum()) + norm_norm;
43}
44
73template <typename T_omega, typename T_Gamma, typename T_rho,
76inline auto hmm_marginal(const T_omega& log_omegas, const T_Gamma& Gamma,
77 const T_rho& rho) {
78 using T_partial_type = partials_return_t<T_omega, T_Gamma, T_rho>;
79 using eig_matrix_partial
80 = Eigen::Matrix<T_partial_type, Eigen::Dynamic, Eigen::Dynamic>;
81 using eig_vector_partial = Eigen::Matrix<T_partial_type, Eigen::Dynamic, 1>;
82 using T_omega_ref = ref_type_if_not_constant_t<T_omega>;
83 using T_Gamma_ref = ref_type_if_not_constant_t<T_Gamma>;
84 using T_rho_ref = ref_type_if_not_constant_t<T_rho>;
85 int n_states = log_omegas.rows();
86 int n_transitions = log_omegas.cols() - 1;
87
88 T_omega_ref log_omegas_ref = log_omegas;
89 T_Gamma_ref Gamma_ref = Gamma;
90 T_rho_ref rho_ref = rho;
91
92 const auto& Gamma_val = to_ref(value_of(Gamma_ref));
93 const auto& rho_val = to_ref(value_of(rho_ref));
94 hmm_check(log_omegas, Gamma_val, rho_val, "hmm_marginal");
95
96 auto ops_partials
97 = make_partials_propagator(log_omegas_ref, Gamma_ref, rho_ref);
98
99 eig_matrix_partial alphas(n_states, n_transitions + 1);
100 eig_vector_partial alpha_log_norms(n_transitions + 1);
101 // compute the density using the forward algorithm.
102 eig_matrix_partial omegas = value_of(log_omegas_ref).array().exp();
103 T_partial_type norm_norm;
104 auto log_marginal_density = hmm_marginal_val(
105 omegas, Gamma_val, rho_val, alphas, alpha_log_norms, norm_norm);
106
107 // Variables required for all three Jacobian-adjoint products.
108 auto unnormed_marginal = alphas.col(n_transitions).sum();
109
110 std::vector<eig_vector_partial> kappa(n_transitions);
111 eig_vector_partial kappa_log_norms(n_transitions);
112 std::vector<T_partial_type> grad_corr(n_transitions, 0);
113
114 if (n_transitions > 0) {
115 kappa[n_transitions - 1] = Eigen::VectorXd::Ones(n_states);
116 kappa_log_norms(n_transitions - 1) = 0;
117 grad_corr[n_transitions - 1]
118 = exp(alpha_log_norms(n_transitions - 1) - norm_norm);
119 }
120
121 for (int n = n_transitions - 1; n-- > 0;) {
122 kappa[n] = Gamma_val * (omegas.col(n + 2).cwiseProduct(kappa[n + 1]));
123
124 auto norm = kappa[n].maxCoeff();
125 kappa[n] /= norm;
126 kappa_log_norms[n] = log(norm) + kappa_log_norms[n + 1];
127 grad_corr[n] = exp(alpha_log_norms[n] + kappa_log_norms[n] - norm_norm);
128 }
129
131 for (int n = n_transitions - 1; n >= 0; --n) {
132 edge<1>(ops_partials).partials_
133 += grad_corr[n] * alphas.col(n)
134 * kappa[n].cwiseProduct(omegas.col(n + 1)).transpose()
135 / unnormed_marginal;
136 }
137 }
138
140 // Boundary terms
141 if (n_transitions == 0) {
143 edge<0>(ops_partials).partials_
144 = omegas.cwiseProduct(rho_val) / exp(log_marginal_density);
145 }
146
148 edge<2>(ops_partials).partials_
149 = omegas.col(0) / exp(log_marginal_density);
150 }
151 return ops_partials.build(log_marginal_density);
152 } else {
153 auto grad_corr_boundary = exp(kappa_log_norms(0) - norm_norm);
154 eig_vector_partial C = Gamma_val * omegas.col(1).cwiseProduct(kappa[0]);
155
157 eig_matrix_partial log_omega_jacad
158 = Eigen::MatrixXd::Zero(n_states, n_transitions + 1);
159
160 for (int n = n_transitions - 1; n >= 0; --n) {
161 log_omega_jacad.col(n + 1)
162 = grad_corr[n]
163 * kappa[n].cwiseProduct(Gamma_val.transpose() * alphas.col(n));
164 }
165
166 log_omega_jacad.col(0) = grad_corr_boundary * C.cwiseProduct(rho_val);
167 edge<0>(ops_partials).partials_
168 = log_omega_jacad.cwiseProduct(omegas / unnormed_marginal);
169 }
170
172 partials<2>(ops_partials) = grad_corr_boundary
173 * C.cwiseProduct(omegas.col(0))
174 / unnormed_marginal;
175 }
176 }
177 }
178
179 return ops_partials.build(log_marginal_density);
180}
181
182} // namespace math
183} // namespace stan
184#endif
require_t< is_eigen_col_vector< std::decay_t< T > > > require_eigen_col_vector_t
Require type satisfies is_eigen_col_vector.
Definition is_vector.hpp:98
require_all_t< is_eigen< std::decay_t< Types > >... > require_all_eigen_t
Require all of the types satisfy is_eigen.
Definition is_eigen.hpp:120
fvar< T > norm(const std::complex< fvar< T > > &z)
Return the squared magnitude of the complex argument.
Definition norm.hpp:20
auto hmm_marginal(const T_omega &log_omegas, const T_Gamma &Gamma, const T_rho &rho)
For a Hidden Markov Model with observation y, hidden state x, and parameters theta,...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
void hmm_check(const T_omega &log_omegas, const T_Gamma &Gamma, const T_rho &rho, const char *function)
Check arguments for hidden Markov model functions with a discrete latent state (lpdf,...
Definition hmm_check.hpp:30
auto hmm_marginal_val(const Eigen::Matrix< T_omega, Eigen::Dynamic, Eigen::Dynamic > &omegas, const T_Gamma &Gamma_val, const T_rho &rho_val, Eigen::Matrix< T_alpha, Eigen::Dynamic, Eigen::Dynamic > &alphas, Eigen::Matrix< T_alpha, Eigen::Dynamic, 1 > &alpha_log_norms, T_alpha &norm_norm)
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
typename ref_type_if<!is_constant< T >::value, T >::type ref_type_if_not_constant_t
Definition ref_type.hpp:62
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...