Automatic Differentiation
 
Loading...
Searching...
No Matches
laplace_marginal_density.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_MIX_FUNCTOR_LAPLACE_MARGINAL_DENSITY_HPP
2#define STAN_MATH_MIX_FUNCTOR_LAPLACE_MARGINAL_DENSITY_HPP
11#include <unsupported/Eigen/MatrixFunctions>
12#include <cmath>
13
21namespace stan {
22namespace math {
23
51template <
52 typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs,
53 bool InitTheta,
54 require_t<is_all_arithmetic_scalar<CovarArgs, LLTupleArgs>>* = nullptr>
55inline auto laplace_marginal_density(LLFun&& ll_fun, LLTupleArgs&& ll_args,
56 CovarFun&& covariance_function,
57 CovarArgs&& covar_args,
58 const laplace_options<InitTheta>& options,
59 std::ostream* msgs) {
60 Eigen::MatrixXd covariance = stan::math::apply(
61 [msgs, &covariance_function](auto&&... args) {
62 return covariance_function(std::forward<decltype(args)>(args)..., msgs);
63 },
64 std::forward<CovarArgs>(covar_args));
66 std::forward<LLFun>(ll_fun), std::forward<LLTupleArgs>(ll_args),
67 std::move(covariance), options, msgs)
68 .lmd;
69}
70
71namespace internal {
72
73template <bool ZeroInput = false, typename Output, typename Input,
74 require_tuple_t<Output>* = nullptr, require_tuple_t<Input>* = nullptr,
75 require_t<std::bool_constant<
76 std::tuple_size_v<std::decay_t<Output>> == 0>>* = nullptr,
77 require_t<std::bool_constant<
78 std::tuple_size_v<std::decay_t<Input>> == 0>>* = nullptr>
79inline constexpr void copy_compute_s2(Output&& output, Input&& input) {}
80
89template <bool ZeroInput = false, typename Output, typename Input,
92inline constexpr void copy_compute_s2(Output&& output, Input&& input) {
93 if constexpr (is_tuple_v<Output> && is_tuple_v<Input>) {
94 static_assert(
95 std::tuple_size<std::decay_t<Output>>::value
96 == std::tuple_size<std::decay_t<Input>>::value,
97 "INTERNAL ERROR:(laplace_marginal_lpdf) copy_compute_s2 called on "
98 "tuples of different sizes. This is an internal error, please report "
99 "it: "
100 "https://github.com/stan-dev/math/issues");
101 }
102 return iter_tuple_nested(
103 [](auto&& output_i, auto&& input_i) {
104 using output_i_t = std::decay_t<decltype(output_i)>;
105 if constexpr (is_std_vector_v<output_i_t>) {
106 using dbl_map_t = Eigen::Map<Eigen::Matrix<double, -1, 1>>;
107 using var_map_t = Eigen::Map<Eigen::Matrix<var, -1, 1>>;
108 var_map_t input_map(input_i.data(), input_i.size());
109 dbl_map_t(output_i.data(), output_i.size()).array()
110 += 0.5 * input_map.adj().array();
111 if constexpr (ZeroInput) {
112 input_map.adj().setZero();
113 }
114 } else if constexpr (is_eigen_v<output_i_t>) {
115 output_i.array() += 0.5 * input_i.adj().array();
116 if constexpr (ZeroInput) {
117 input_i.adj().setZero();
118 }
119 } else if constexpr (is_stan_scalar_v<output_i_t>) {
120 output_i += (0.5 * input_i.adj());
121 if constexpr (ZeroInput) {
122 input_i.adj() = 0;
123 }
124 } else {
125 static_assert(
126 sizeof(std::decay_t<output_i_t>*) == 0,
127 "INTERNAL ERROR:(laplace_marginal_lpdf) copy_compute_s2 was "
128 "not able to deduce the actions needed for the given type. "
129 "This is an internal error, please report it: "
130 "https://github.com/stan-dev/math/issues");
131 }
132 },
133 std::forward<Output>(output), std::forward<Input>(input));
134}
135
136} // namespace internal
137
165template <typename LLFun, typename LLTupleArgs, typename CovarFun,
166 typename CovarArgs, bool InitTheta,
168inline auto laplace_marginal_density(LLFun&& ll_fun, LLTupleArgs&& ll_args,
169 CovarFun&& covariance_function,
170 CovarArgs&& covar_args,
171 const laplace_options<InitTheta>& options,
172 std::ostream* msgs) {
173 auto covar_args_refs = to_ref(std::forward<CovarArgs>(covar_args));
174 auto ll_args_refs = to_ref(std::forward<LLTupleArgs>(ll_args));
175 // Solver 1, 2, 3
176 constexpr bool ll_args_contain_var = is_any_var_scalar<LLTupleArgs>::value;
177 auto partial_parm = internal::make_zeroed_arena(ll_args_refs);
178 auto covar_args_adj = internal::make_zeroed_arena(covar_args_refs);
179 double lmd = 0.0;
180 {
181 nested_rev_autodiff nested;
182 // Make one hard copy here
183 auto ll_args_copy = internal::deep_copy_vargs<var>(ll_args_refs);
184 auto covar_args_copy = internal::deep_copy_vargs<var>(covar_args_refs);
185 auto covariance = stan::math::apply(
186 [&covariance_function, &msgs](auto&&... args) {
187 if constexpr (is_any_var_scalar_v<decltype(args)...>) {
188 return to_var_value(covariance_function(args..., msgs));
189 } else {
190 return covariance_function(args..., msgs);
191 }
192 },
193 covar_args_copy);
194 decltype(auto) covariance_val = value_of(covariance);
195 decltype(auto) ll_args_vals = value_of(ll_args_copy);
197 ll_fun, ll_args_vals, covariance_val, options, msgs);
198 auto ll_args_filter = internal::filter_var_scalar_types(ll_args_copy);
199 // tuple of references to var types
200 // Solver 1, 2
201 const bool solver_1_or_2
202 = md_est.solver_used == 1 || md_est.solver_used == 2;
203 arena_t<Eigen::MatrixXd> R(md_est.theta.size() * solver_1_or_2,
204 md_est.theta.size() * solver_1_or_2);
205 // Solver 3
206 arena_t<Eigen::MatrixXd> LU_solve_covariance(
207 covariance.rows() * (md_est.solver_used == 3),
208 covariance.cols() * (md_est.solver_used == 3));
209 // Solver 1, 2, 3
210 arena_t<Eigen::VectorXd> s2(md_est.theta.size());
212 if (md_est.solver_used == 1) {
213 if (options.hessian_block_size == 1) {
214 arena_t<Eigen::MatrixXd> tmp = md_est.W_r.toDense();
215 md_est.L.template triangularView<Eigen::Lower>().solveInPlace(tmp);
216 R.noalias() = tmp.transpose() * tmp;
217 arena_t<Eigen::MatrixXd> C
218 = md_est.L.template triangularView<Eigen::Lower>().solve(
219 md_est.W_r * covariance_val);
220 if constexpr (ll_args_contain_var) {
221 arena_t<Eigen::MatrixXd> A = covariance_val - C.transpose() * C;
222 auto s2_tmp = laplace_likelihood::compute_s2(
223 ll_fun, md_est.theta, A, options.hessian_block_size, ll_args_copy,
224 msgs);
225 s2.deep_copy(s2_tmp);
226 internal::copy_compute_s2<ZeroOut>(partial_parm, ll_args_filter);
227 } else {
228 s2.deep_copy(
229 (0.5
230 * (covariance_val.diagonal() - (C.transpose() * C).diagonal())
231 .cwiseProduct(laplace_likelihood::third_diff(
232 ll_fun, md_est.theta, ll_args_vals, msgs))));
233 }
234
235 } else {
236 arena_t<Eigen::MatrixXd> tmp = md_est.W_r.toDense();
237 md_est.L.template triangularView<Eigen::Lower>().solveInPlace(tmp);
238 R.noalias() = tmp.transpose() * tmp;
239 arena_t<Eigen::MatrixXd> C
240 = md_est.L.template triangularView<Eigen::Lower>().solve(
241 md_est.W_r * covariance_val);
242 arena_t<Eigen::MatrixXd> A = covariance_val - C.transpose() * C;
243 auto s2_tmp = laplace_likelihood::compute_s2(ll_fun, md_est.theta, A,
244 options.hessian_block_size,
245 ll_args_copy, msgs);
246 s2.deep_copy(s2_tmp);
247 internal::copy_compute_s2<ZeroOut>(partial_parm, ll_args_filter);
248 }
249 } else if (md_est.solver_used == 2) {
250 R = md_est.W_r
251 - md_est.W_r * md_est.K_root
252 * md_est.L.transpose()
253 .template triangularView<Eigen::Upper>()
254 .solve(
255 md_est.L.template triangularView<Eigen::Lower>()
256 .solve(md_est.K_root.transpose() * md_est.W_r));
257
258 arena_t<Eigen::MatrixXd> C
259 = md_est.L.template triangularView<Eigen::Lower>().solve(
260 md_est.K_root.transpose());
261 auto s2_tmp = laplace_likelihood::compute_s2(
262 ll_fun, md_est.theta, (C.transpose() * C).eval(),
263 options.hessian_block_size, ll_args_copy, msgs);
264 s2.deep_copy(s2_tmp);
265 internal::copy_compute_s2<ZeroOut>(partial_parm, ll_args_filter);
266 } else { // options.solver with LU decomposition
267 LU_solve_covariance = md_est.LU.solve(covariance_val);
268 auto I_minus_BinvKW
269 = Eigen::MatrixXd::Identity(md_est.W_r.rows(), md_est.W_r.cols())
270 - LU_solve_covariance * md_est.W_r;
271 R = md_est.W_r * I_minus_BinvKW; // == W - W B^{-1} K W
272 arena_t<Eigen::MatrixXd> A
273 = covariance_val - covariance_val * md_est.W_r * LU_solve_covariance;
274 auto s2_tmp = laplace_likelihood::compute_s2(ll_fun, md_est.theta, A,
275 options.hessian_block_size,
276 ll_args_copy, msgs);
277 s2.deep_copy(s2_tmp);
278 internal::copy_compute_s2<ZeroOut>(partial_parm, ll_args_filter);
279 }
280 if constexpr (is_any_var_scalar_v<scalar_type_t<CovarArgs>>) {
281 arena_t<Eigen::MatrixXd> K_adj_arena
282 = 0.5 * md_est.a * md_est.a.transpose() - 0.5 * R
283 + s2 * md_est.theta_grad.transpose()
284 - (R * (covariance.val() * s2)) * md_est.theta_grad.transpose();
286 0.0, [covariance, K_adj_arena](auto&& vi) mutable {
287 covariance.adj().array() += vi.adj() * K_adj_arena.array();
288 });
289 grad(Z.vi_);
290 auto covar_args_filter
291 = internal::filter_var_scalar_types(covar_args_copy);
292 internal::collect_adjoints(covar_args_adj, covar_args_filter);
293 }
294 if constexpr (ll_args_contain_var) {
295 laplace_likelihood::ll_arg_grad(ll_fun, md_est.theta, ll_args_copy, msgs);
296 internal::collect_adjoints<ZeroOut>(partial_parm, ll_args_filter);
297 arena_t<Eigen::VectorXd> v;
298 if (md_est.solver_used == 1 || md_est.solver_used == 2) {
299 v = covariance_val * s2 - covariance_val * R * covariance_val * s2;
300 } else {
301 v = LU_solve_covariance * s2;
302 }
303 laplace_likelihood::diff_eta_implicit(ll_fun, v, md_est.theta,
304 ll_args_copy, msgs);
305 internal::collect_adjoints<ZeroOut>(partial_parm, ll_args_filter);
306 }
307 lmd = md_est.lmd;
308 }
309 var ret(lmd);
310 if constexpr (is_any_var_scalar_v<CovarArgs>) {
311 auto covar_args_filter = internal::filter_var_scalar_types(covar_args_refs);
312 internal::reverse_pass_collect_adjoints(ret, covar_args_filter,
313 std::move(covar_args_adj));
314 }
315 if constexpr (ll_args_contain_var) {
316 auto ll_args_filter = internal::filter_var_scalar_types(ll_args_refs);
317 internal::reverse_pass_collect_adjoints(ret, ll_args_filter,
318 std::move(partial_parm));
319 }
320 return ret;
321}
322
323} // namespace math
324} // namespace stan
325
326#endif
Reference for calculations of marginal and its gradients: Margossian et al (2020),...
constexpr decltype(auto) filter_var_scalar_types(T &&t)
Filter a tuple and return a tuple with references to the types with a var scalar type.
void reverse_pass_collect_adjoints(var ret, Output &&output, Input &&input)
Collects adjoints from a tuple or std::vector of tuples.
constexpr auto make_zeroed_arena(Input &&input)
Creates an arena type that is the same type as the input and initialized with zeros.
auto laplace_marginal_density_est(LLFun &&ll_fun, LLTupleArgs &&ll_args, CovarMat &&covariance, const laplace_options< InitTheta > &options, std::ostream *msgs)
For a latent Gaussian model with hyperparameters phi and latent variables theta, and observations y,...
void collect_adjoints(Output &output, Input &&input)
Collect the adjoints from the input and add them to the output.
constexpr void copy_compute_s2(Output &&output, Input &&input)
auto compute_s2(F &&f, Theta &&theta, AMat &&A, int hessian_block_size, TupleArgs &&ll_args, Stream *msgs)
A wrapper that accepts a tuple as arguments.
auto diff_eta_implicit(F &&f, V_t &&v, Theta &&theta, TupleArgs &&ll_args, Stream *msgs)
A wrapper that accepts a tuple as arguments.
auto ll_arg_grad(F &&f, Theta &&theta, TupleArgs &&ll_tup, Stream *msgs=nullptr)
Eigen::VectorXd third_diff(F &&f, Theta &&theta, TupleArgs &&ll_args, Stream *msgs)
A wrapper that accepts a tuple as arguments.
void iter_tuple_nested(F &&f, Types &&... args)
Iterate and nest into a tuple or std::vector to apply f to each matrix or scalar type.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
auto laplace_marginal_density(LLFun &&ll_fun, LLTupleArgs &&ll_args, CovarFun &&covariance_function, CovarArgs &&covar_args, const laplace_options< InitTheta > &options, std::ostream *msgs)
For a latent Gaussian model with global parameters phi, latent variables theta, and observations y,...
var_value< Eigen::Matrix< double, T::RowsAtCompileTime, T::ColsAtCompileTime > > to_var_value(const T &a)
Converts an Eigen matrix (or vector or row_vector) or expression of vars into var_value.
var_value< double > var
Definition var.hpp:1187
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
static void grad()
Compute the gradient for all variables starting from the end of the AD tape.
Definition grad.hpp:26
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
Definition apply.hpp:51
constexpr bool is_any_var_scalar_v
Definition is_var.hpp:50
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...