Automatic Differentiation
 
Loading...
Searching...
No Matches
matrix_power.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_MATRIX_POWER_HPP
2#define STAN_MATH_REV_FUN_MATRIX_POWER_HPP
3
12#include <vector>
13
14namespace stan {
15namespace math {
16
29template <typename T, require_rev_matrix_t<T>* = nullptr>
30inline plain_type_t<T> matrix_power(const T& M, const int n) {
31 check_square("matrix_power", "M", M);
32 check_nonnegative("matrix_power", "n", n);
33
34 if (M.size() == 0)
35 return M;
36
37 const auto& M_ref = to_ref(M);
38 check_finite("matrix_power", "M", M_ref);
39
40 size_t N = M.rows();
41
42 if (n == 0)
43 return Eigen::MatrixXd::Identity(N, N);
44
45 if (n == 1)
46 return M_ref;
47
48 arena_t<std::vector<Eigen::MatrixXd>> arena_powers(n + 1);
49 arena_t<plain_type_t<T>> arena_M = M_ref;
50
51 arena_powers[0] = Eigen::MatrixXd::Identity(N, N);
52 arena_powers[1] = M_ref.val();
53 for (size_t i = 2; i <= n; ++i) {
54 arena_powers[i] = arena_powers[1] * arena_powers[i - 1];
55 }
56 using ret_type = return_var_matrix_t<T>;
57 arena_t<ret_type> res = arena_powers[arena_powers.size() - 1];
58
59 reverse_pass_callback([arena_M, n, res, arena_powers]() mutable {
60 const auto& M_val = arena_powers[1];
61 Eigen::MatrixXd adj_C = res.adj();
62 Eigen::MatrixXd adj_M = Eigen::MatrixXd::Zero(M_val.rows(), M_val.cols());
63 for (size_t i = n; i > 1; --i) {
64 adj_M += adj_C * arena_powers[i - 1].transpose();
65 adj_C = M_val.transpose() * adj_C;
66 }
67 arena_M.adj() += adj_M + adj_C;
68 });
69
70 return ret_type(res);
71}
72
73} // namespace math
74} // namespace stan
75#endif
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
plain_type_t< T_m > matrix_power(T_m &&M, const int n)
Returns the nth power of the specific matrix.
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...