Automatic Differentiation
 
Loading...
Searching...
No Matches
matrix_power.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_REV_MATRIX_POWER_HPP
2#define STAN_MATH_OPENCL_REV_MATRIX_POWER_HPP
3#ifdef STAN_OPENCL
11
12namespace stan {
13namespace math {
14
27template <typename T,
28 require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
30 const int n) {
31 check_square("matrix_power", "M", M);
32 check_nonnegative("matrix_power", "n", n);
33
34 if (M.size() == 0)
35 return M;
36
37 auto N = M.rows();
38
39 if (n == 0) {
40 return diag_matrix(constant(1.0, N, 1));
41 }
42 if (n == 1) {
43 return M;
44 }
45
46 arena_t<std::vector<matrix_cl<double>>> arena_powers(n + 1);
47 arena_powers[0] = diag_matrix(constant(1.0, N, 1));
48 arena_powers[1] = M.val();
49 for (size_t i = 2; i <= n; ++i) {
50 arena_powers[i] = arena_powers[1] * arena_powers[i - 1];
51 }
52
53 return make_callback_var(
54 arena_powers.back(),
55 [M, n, arena_powers](vari_value<matrix_cl<double>> res) mutable {
56 const auto& M_val = arena_powers[1];
57 matrix_cl<double> adj_C = res.adj();
58 matrix_cl<double> adj_M = constant(0.0, M_val.rows(), M_val.cols());
59 for (size_t i = n; i > 1; --i) {
60 adj_M += adj_C * transpose(arena_powers[i - 1]);
61 adj_C = transpose(M_val) * adj_C;
62 }
63 M.adj() += adj_M + adj_C;
64 });
65}
66} // namespace math
67} // namespace stan
68
69#endif
70#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
auto constant(const T a, int rows, int cols)
Matrix of repeated values in kernel generator expressions.
Definition constant.hpp:130
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
plain_type_t< T_m > matrix_power(T_m &&M, const int n)
Returns the nth power of the specific matrix.
auto diag_matrix(T_x &&x)
Return a square diagonal matrix with the specified vector of coefficients as the diagonal values.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...