Automatic Differentiation
 
Loading...
Searching...
No Matches
rep_matrix.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_REV_REP_MATRIX_HPP
2#define STAN_MATH_OPENCL_REV_REP_MATRIX_HPP
3#ifdef STAN_OPENCL
4
11
12namespace stan {
13namespace math {
14
29template <typename T_ret, require_var_vt<is_matrix_cl, T_ret>* = nullptr>
30inline var_value<matrix_cl<double>> rep_matrix(const var& A, int n, int m) {
31 return make_callback_var(rep_matrix<matrix_cl<double>>(A.val(), n, m),
32 [A](vari_value<matrix_cl<double>>& res) mutable {
33 A.adj() += sum(res.adj());
34 });
35}
36
53template <typename T,
56 return make_callback_var(
57 rep_matrix(A.val(), m), [A](vari_value<matrix_cl<double>>& res) mutable {
58 if (A.adj().size() != 0) {
59 matrix_cl<double> A_adj = std::move(A.adj());
60 try {
61 opencl_kernels::rep_matrix_rev(
62 cl::NDRange(A_adj.rows(), A_adj.cols()), A_adj, res.adj(),
63 res.adj().rows(), res.adj().cols(), res.adj().view());
64 } catch (const cl::Error& e) {
65 check_opencl_error("rep_matrix(rev OpenCL)", e);
66 }
67 A.adj() = std::move(A_adj);
68 }
69 });
70}
71
72} // namespace math
73} // namespace stan
74
75#endif
76#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
auto rep_matrix(const value_type_t< T > &x, int n, int m)
Creates a matrix_cl by replicating the given value of arithmetic type.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...