Automatic Differentiation
 
Loading...
Searching...
No Matches
rep_matrix.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_REP_MATRIX_HPP
2#define STAN_MATH_REV_FUN_REP_MATRIX_HPP
3
9
10namespace stan {
11namespace math {
12
23template <typename Ret, typename T, require_var_matrix_t<Ret>* = nullptr,
24 require_var_t<T>* = nullptr>
25inline auto rep_matrix(const T& x, int m, int n) {
26 check_nonnegative("rep_matrix", "rows", m);
27 check_nonnegative("rep_matrix", "cols", n);
28 return make_callback_var(
29 value_type_t<Ret>::Constant(m, n, x.val()),
30 [x](auto& rep) mutable { x.adj() += rep.adj().sum(); });
31}
32
43template <typename Ret, typename Vec, require_var_matrix_t<Ret>* = nullptr,
44 require_var_matrix_t<Vec>* = nullptr>
45inline auto rep_matrix(const Vec& x, int n) {
47 check_nonnegative("rep_matrix", "rows", n);
48 return make_callback_var(x.val().replicate(n, 1), [x](auto& rep) mutable {
49 x.adj() += rep.adj().colwise().sum();
50 });
51 } else {
52 check_nonnegative("rep_matrix", "cols", n);
53 return make_callback_var(x.val().replicate(1, n), [x](auto& rep) mutable {
54 x.adj() += rep.adj().rowwise().sum();
55 });
56 }
57}
58
59} // namespace math
60} // namespace stan
61
62#endif
auto rep_matrix(const value_type_t< T > &x, int n, int m)
Creates a matrix_cl by replicating the given value of arithmetic type.
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
If the input type T has a static comple time constant type RowsAtCompileTime equal to 1 this has a st...