Automatic Differentiation
 
Loading...
Searching...
No Matches
matrix_power.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_MATRIX_POWER_HPP
2#define STAN_MATH_REV_FUN_MATRIX_POWER_HPP
3
11#include <vector>
12
13namespace stan {
14namespace math {
15
28template <typename T, require_rev_matrix_t<T>* = nullptr>
29inline plain_type_t<T> matrix_power(const T& M, const int n) {
30 check_square("matrix_power", "M", M);
31 check_nonnegative("matrix_power", "n", n);
32
33 if (M.size() == 0)
34 return M;
35
36 const auto& M_ref = to_ref(M);
37 check_finite("matrix_power", "M", M_ref);
38
39 size_t N = M.rows();
40
41 if (n == 0)
42 return Eigen::MatrixXd::Identity(N, N);
43
44 if (n == 1)
45 return M_ref;
46
47 arena_t<std::vector<Eigen::MatrixXd>> arena_powers(n + 1);
48 arena_t<plain_type_t<T>> arena_M = M_ref;
49
50 arena_powers[0] = Eigen::MatrixXd::Identity(N, N);
51 arena_powers[1] = M_ref.val();
52 for (size_t i = 2; i <= n; ++i) {
53 arena_powers[i] = arena_powers[1] * arena_powers[i - 1];
54 }
55 using ret_type = return_var_matrix_t<T>;
56 arena_t<ret_type> res = arena_powers[arena_powers.size() - 1];
57
58 reverse_pass_callback([arena_M, n, res, arena_powers]() mutable {
59 const auto& M_val = arena_powers[1];
60 Eigen::MatrixXd adj_C = res.adj();
61 Eigen::MatrixXd adj_M = Eigen::MatrixXd::Zero(M_val.rows(), M_val.cols());
62 for (size_t i = n; i > 1; --i) {
63 adj_M += adj_C * arena_powers[i - 1].transpose();
64 adj_C = M_val.transpose() * adj_C;
65 }
66 arena_M.adj() += adj_M + adj_C;
67 });
68
69 return ret_type(res);
70}
71
72} // namespace math
73} // namespace stan
74#endif
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
plain_type_t< T_m > matrix_power(T_m &&M, const int n)
Returns the nth power of the specific matrix.
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...