Automatic Differentiation
 
Loading...
Searching...
No Matches
vari.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_REV_VARI_HPP
2#define STAN_MATH_OPENCL_REV_VARI_HPP
3#ifdef STAN_OPENCL
4
9
10namespace stan {
11namespace math {
12
13template <typename T>
14class vari_cl_base : public vari_base {
15 public:
20
26
37 template <typename S, typename U,
41 : val_(std::forward<S>(val)), adj_(std::forward<U>(adj)) {}
42
48 inline const auto& val() const { return val_; }
49 inline auto& val_op() { return val_; }
50
59 inline auto& adj() { return adj_; }
60 inline auto& adj() const { return adj_; }
61 inline auto& adj_op() { return adj_; }
62
71 auto block(int row, int col, int rows, int cols) {
72 auto&& val_block = stan::math::block_zero_based(val_, row, col, rows, cols);
73 auto&& adj_block = stan::math::block_zero_based(adj_, row, col, rows, cols);
74 return vari_view<std::decay_t<decltype(val_block)>>(std::move(val_block),
75 std::move(adj_block));
76 }
77
82 auto transpose() {
83 auto&& val_t = stan::math::transpose(val_);
84 auto&& adj_t = stan::math::transpose(adj_);
85 return vari_view<std::decay_t<decltype(val_t)>>(std::move(val_t),
86 std::move(adj_t));
87 }
88
96 return vari_view<std::decay_t<decltype(val_t)>>(std::move(val_t),
97 std::move(adj_t));
98 }
99
104 auto reverse() {
105 auto&& val_t = stan::math::reverse(val_);
106 auto&& adj_t = stan::math::reverse(adj_);
107 return vari_view<std::decay_t<decltype(val_t)>>(std::move(val_t),
108 std::move(adj_t));
109 }
110
120 template <typename RowIndex, typename ColIndex>
121 auto index(const RowIndex& row_index, const ColIndex& col_index) {
122 RowIndex r1 = row_index;
123 RowIndex r2 = row_index;
124 ColIndex c1 = col_index;
125 ColIndex c2 = col_index;
126 auto&& val_t = stan::math::indexing(val_, std::move(r1), std::move(c1));
127 auto&& adj_t = stan::math::indexing(adj_, std::move(r2), std::move(c2));
128 return vari_view<std::decay_t<decltype(val_t)>>(std::move(val_t),
129 std::move(adj_t));
130 }
131
135 const Eigen::Index rows() const { return val_.rows(); }
139 const Eigen::Index cols() const { return val_.cols(); }
143 const Eigen::Index size() const { return rows() * cols(); }
144
145 virtual void chain() {}
146};
147
148template <typename T>
150 : public vari_cl_base<T> {
151 public:
155 static constexpr int RowsAtCompileTime{-1};
159 static constexpr int ColsAtCompileTime{-1};
160
161 using value_type = T;
162 using vari_cl_base<T>::vari_cl_base;
163 inline void set_zero_adjoint() final {}
164};
165
177template <typename T>
179 public vari_cl_base<T> {
180 public:
181 using value_type = T;
182
186 static constexpr int RowsAtCompileTime{-1};
190 static constexpr int ColsAtCompileTime{-1};
191
206 template <typename S, require_convertible_t<S&, T>* = nullptr>
207 explicit vari_value(S&& x)
208 : chainable_alloc(),
209 vari_cl_base<T>(std::forward<S>(x), constant(0, x.rows(), x.cols())) {
210 ChainableStack::instance_->var_stack_.push_back(this);
211 }
212
226 template <typename S, require_eigen_t<S>* = nullptr,
227 require_vt_same<T, S>* = nullptr>
228 explicit vari_value(const S& x)
229 : chainable_alloc(), vari_cl_base<T>(x, constant(0, x.rows(), x.cols())) {
230 ChainableStack::instance_->var_stack_.push_back(this);
231 }
232
251 template <typename S, require_convertible_t<S&, T>* = nullptr>
252 vari_value(S&& x, bool stacked)
253 : chainable_alloc(),
254 vari_cl_base<T>(std::forward<S>(x), constant(0, x.rows(), x.cols())) {
255 if (stacked) {
256 ChainableStack::instance_->var_stack_.push_back(this);
257 } else {
259 }
260 }
261
267 inline void set_zero_adjoint() final {
268 this->adj_ = constant(0, this->rows(), this->cols());
269 }
270
271 private:
272 template <typename, typename>
273 friend class var_value;
274};
275
276} // namespace math
277} // namespace stan
278
279#endif
280#endif
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
Represents operation that determines column index.
Definition index.hpp:80
Represents operation that determines row index.
Definition index.hpp:17
Abstract base class that all vari_value and it's derived classes inherit.
Definition vari.hpp:28
auto transpose()
Returns a transposed view into the matrix.
Definition vari.hpp:82
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends.
Definition vari.hpp:145
T val_
The value of this variable.
Definition vari.hpp:19
auto & adj() const
Definition vari.hpp:60
vari_cl_base(S &&val, U &&adj)
Construct a matrix_cl variable implementation from a value and adjoint.
Definition vari.hpp:40
const Eigen::Index rows() const
Return the number of rows for this class's val_ member.
Definition vari.hpp:135
auto reverse()
Returns reverse view into the row or column vector.
Definition vari.hpp:104
auto & adj()
Return a reference to the derivative of the root expression with respect to this expression.
Definition vari.hpp:59
auto index(const RowIndex &row_index, const ColIndex &col_index)
Return indexed view into a matrix.
Definition vari.hpp:121
const Eigen::Index cols() const
Return the number of columns for this class's val_ member.
Definition vari.hpp:139
const Eigen::Index size() const
Return the size of this class's val_ member.
Definition vari.hpp:143
T adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
Definition vari.hpp:25
auto as_column_vector_or_scalar()
Returns column vector view into the row or column vector.
Definition vari.hpp:93
auto block(int row, int col, int rows, int cols)
Returns a view into a block of matrix.
Definition vari.hpp:71
const auto & val() const
Return a constant reference to the value of this vari.
Definition vari.hpp:48
vari_value(S &&x, bool stacked)
Construct an matrix_cl variable implementation from a value.
Definition vari.hpp:252
vari_value(S &&x)
Construct a matrix_cl variable implementation from a value.
Definition vari.hpp:207
vari_value(const S &x)
Construct a matrix_cl variable implementation from an Eigen value.
Definition vari.hpp:228
void set_zero_adjoint() final
Set the adjoint value of this variable to 0.
Definition vari.hpp:267
A vari_view is used to read from a slice of a vari_value with an inner eigen type.
Definition vari.hpp:206
require_t< is_kernel_expression_lhs< std::decay_t< T > > > require_kernel_expression_lhs_t
Require type satisfies is_kernel_expression_lhs.
require_t< is_matrix_cl< std::decay_t< T > > > require_matrix_cl_t
Require type satisfies is_matrix_cl.
auto block_zero_based(T &&a, int start_row, int start_col, int rows, int cols)
Block of a kernel generator expression.
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
auto transpose(Arg &&a)
Transposes a kernel generator expression.
auto constant(const T a, int rows, int cols)
Matrix of repeated values in kernel generator expressions.
Definition constant.hpp:130
auto indexing(T_mat &&mat, T_row_index &&row_index, T_col_index &&col_index)
Index a kernel generator expression using two expressions for indices.
Definition indexing.hpp:304
auto col(T_x &&x, size_t j)
Return the specified column of the specified kernel generator expression using start-at-1 indexing.
Definition col.hpp:23
int rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:21
auto row(T_x &&x, size_t j)
Return the specified row of the specified kernel generator expression using start-at-1 indexing.
Definition row.hpp:23
int cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
Definition cols.hpp:20
auto reverse(T_x &&x)
Return reversed view into the specified vector or row vector.
Definition reverse.hpp:20
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
static thread_local AutodiffStackStorage * instance_