1#ifndef STAN_MATH_REV_FUN_HMM_MARGINAL_LPDF_HPP
2#define STAN_MATH_REV_FUN_HMM_MARGINAL_LPDF_HPP
19template <
typename T_omega,
typename T_Gamma,
typename T_rho,
typename T_alpha>
21 const Eigen::Matrix<T_omega, Eigen::Dynamic, Eigen::Dynamic>& omegas,
22 const T_Gamma& Gamma_val,
const T_rho& rho_val,
23 Eigen::Matrix<T_alpha, Eigen::Dynamic, Eigen::Dynamic>& alphas,
24 Eigen::Matrix<T_alpha, Eigen::Dynamic, 1>& alpha_log_norms,
26 const int n_transitions = omegas.cols() - 1;
27 alphas.col(0) = omegas.col(0).cwiseProduct(rho_val);
29 const auto norm = alphas.col(0).maxCoeff();
30 alphas.col(0) /=
norm;
33 auto Gamma_val_transpose = Gamma_val.transpose().eval();
34 for (
int n = 1; n <= n_transitions; ++n) {
36 = omegas.col(n).cwiseProduct(Gamma_val_transpose * alphas.col(n - 1));
37 const auto col_norm = alphas.col(n).maxCoeff();
38 alphas.col(n) /= col_norm;
39 alpha_log_norms(n) =
log(col_norm) + alpha_log_norms(n - 1);
41 norm_norm = alpha_log_norms(n_transitions);
42 return log(alphas.col(n_transitions).sum()) + norm_norm;
73template <
typename T_omega,
typename T_Gamma,
typename T_rho,
76inline auto hmm_marginal(
const T_omega& log_omegas,
const T_Gamma& Gamma,
79 using eig_matrix_partial
80 = Eigen::Matrix<T_partial_type, Eigen::Dynamic, Eigen::Dynamic>;
81 using eig_vector_partial = Eigen::Matrix<T_partial_type, Eigen::Dynamic, 1>;
85 int n_states = log_omegas.rows();
86 int n_transitions = log_omegas.cols() - 1;
88 T_omega_ref log_omegas_ref = log_omegas;
89 T_Gamma_ref Gamma_ref = Gamma;
90 T_rho_ref rho_ref = rho;
94 hmm_check(log_omegas, Gamma_val, rho_val,
"hmm_marginal");
99 eig_matrix_partial alphas(n_states, n_transitions + 1);
100 eig_vector_partial alpha_log_norms(n_transitions + 1);
102 eig_matrix_partial omegas =
value_of(log_omegas_ref).array().exp();
103 T_partial_type norm_norm;
105 omegas, Gamma_val, rho_val, alphas, alpha_log_norms, norm_norm);
108 auto unnormed_marginal = alphas.col(n_transitions).sum();
110 std::vector<eig_vector_partial> kappa(n_transitions);
111 eig_vector_partial kappa_log_norms(n_transitions);
112 std::vector<T_partial_type> grad_corr(n_transitions, 0);
114 if (n_transitions > 0) {
115 kappa[n_transitions - 1] = Eigen::VectorXd::Ones(n_states);
116 kappa_log_norms(n_transitions - 1) = 0;
117 grad_corr[n_transitions - 1]
118 =
exp(alpha_log_norms(n_transitions - 1) - norm_norm);
121 for (
int n = n_transitions - 1; n-- > 0;) {
122 kappa[n] = Gamma_val * (omegas.col(n + 2).cwiseProduct(kappa[n + 1]));
124 auto norm = kappa[n].maxCoeff();
126 kappa_log_norms[n] =
log(
norm) + kappa_log_norms[n + 1];
127 grad_corr[n] =
exp(alpha_log_norms[n] + kappa_log_norms[n] - norm_norm);
131 for (
int n = n_transitions - 1; n >= 0; --n) {
132 edge<1>(ops_partials).partials_
133 += grad_corr[n] * alphas.col(n)
134 * kappa[n].cwiseProduct(omegas.col(n + 1)).transpose()
141 if (n_transitions == 0) {
143 edge<0>(ops_partials).partials_
144 = omegas.cwiseProduct(rho_val) /
exp(log_marginal_density);
148 edge<2>(ops_partials).partials_
149 = omegas.col(0) /
exp(log_marginal_density);
151 return ops_partials.build(log_marginal_density);
153 auto grad_corr_boundary =
exp(kappa_log_norms(0) - norm_norm);
154 eig_vector_partial C = Gamma_val * omegas.col(1).cwiseProduct(kappa[0]);
157 eig_matrix_partial log_omega_jacad
158 = Eigen::MatrixXd::Zero(n_states, n_transitions + 1);
160 for (
int n = n_transitions - 1; n >= 0; --n) {
161 log_omega_jacad.col(n + 1)
163 * kappa[n].cwiseProduct(Gamma_val.transpose() * alphas.col(n));
166 log_omega_jacad.col(0) = grad_corr_boundary * C.cwiseProduct(rho_val);
167 edge<0>(ops_partials).partials_
168 = log_omega_jacad.cwiseProduct(omegas / unnormed_marginal);
172 partials<2>(ops_partials) = grad_corr_boundary
173 * C.cwiseProduct(omegas.col(0))
179 return ops_partials.build(log_marginal_density);
require_t< is_eigen_col_vector< std::decay_t< T > > > require_eigen_col_vector_t
Require type satisfies is_eigen_col_vector.
require_all_t< is_eigen< std::decay_t< Types > >... > require_all_eigen_t
Require all of the types satisfy is_eigen.
fvar< T > norm(const std::complex< fvar< T > > &z)
Return the squared magnitude of the complex argument.
auto hmm_marginal(const T_omega &log_omegas, const T_Gamma &Gamma, const T_rho &rho)
For a Hidden Markov Model with observation y, hidden state x, and parameters theta,...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > log(const fvar< T > &x)
void hmm_check(const T_omega &log_omegas, const T_Gamma &Gamma, const T_rho &rho, const char *function)
Check arguments for hidden Markov model functions with a discrete latent state (lpdf,...
auto hmm_marginal_val(const Eigen::Matrix< T_omega, Eigen::Dynamic, Eigen::Dynamic > &omegas, const T_Gamma &Gamma_val, const T_rho &rho_val, Eigen::Matrix< T_alpha, Eigen::Dynamic, Eigen::Dynamic > &alphas, Eigen::Matrix< T_alpha, Eigen::Dynamic, 1 > &alpha_log_norms, T_alpha &norm_norm)
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
fvar< T > exp(const fvar< T > &x)
typename ref_type_if<!is_constant< T >::value, T >::type ref_type_if_not_constant_t
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...