Automatic Differentiation
 
Loading...
Searching...
No Matches
arena_matrix_cl.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_REV_ARENA_MATRIX_CL_HPP
2#define STAN_MATH_OPENCL_REV_ARENA_MATRIX_CL_HPP
3#ifdef STAN_OPENCL
4
9#include <utility>
10
11namespace stan {
12namespace math {
13namespace internal {
14template <typename T>
16 public:
17 using Scalar = typename matrix_cl<T>::Scalar;
18 using type = typename matrix_cl<T>::type;
19
20 template <typename... Args>
21 explicit arena_matrix_cl_impl(Args&&... args)
22 : chainable_alloc(), matrix_cl<T>(std::forward<Args>(args)...) {}
23
29 using matrix_cl<T>::operator=;
30};
31
32} // namespace internal
33
38template <typename T>
40 private:
42 template <typename>
43 friend class matrix_cl;
44
45 public:
46 using Scalar = typename matrix_cl<T>::Scalar;
47 using type = typename matrix_cl<T>::type;
48
54 template <typename... Args>
55 explicit arena_matrix_cl(Args&&... args)
56 : impl_(
57 new internal::arena_matrix_cl_impl<T>(std::forward<Args>(args)...)) {}
58
64
70 // we need this as a separate overload, because the general constructor is
71 // explicit
72 template <typename Expr,
74 arena_matrix_cl(Expr&& expression) // NOLINT(runtime/explicit)
75 : impl_(new internal::arena_matrix_cl_impl<T>(
76 std::forward<Expr>(expression))) {}
77
83 operator const matrix_cl<T> &() const { // NOLINT(runtime/explicit)
84 return *static_cast<const matrix_cl<T>*>(impl_);
85 }
86 operator matrix_cl<T> &() { // NOLINT(runtime/explicit)
87 return *static_cast<matrix_cl<T>*>(impl_);
88 }
89
94 matrix_cl<T> eval() const& { return *impl_; }
95 matrix_cl<T> eval() && { return std::move(*impl_); }
96
97 // Wrapers to functions with explicit template parameters are implemented
98 // without macros.
99 template <matrix_cl_view matrix_view = matrix_cl_view::Entire>
100 inline void zeros_strict_tri() {
101 impl_->template zeros_strict_tri<matrix_view>();
102 }
103
108#define ARENA_MATRIX_CL_FUNCTION_WRAPPER(function_name) \
109 template <typename... Args> \
110 inline decltype(auto) function_name(Args&&... args) { \
111 return impl_->function_name(std::forward<Args>(args)...); \
112 }
113
118#define ARENA_MATRIX_CL_CONST_FUNCTION_WRAPPER(function_name) \
119 template <typename... Args> \
120 inline decltype(auto) function_name(Args&&... args) const { \
121 return impl_->function_name(std::forward<Args>(args)...); \
122 }
123
143
144#undef ARENA_MATRIX_CL_FUNCTION_WRAPPER
145#undef ARENA_MATRIX_CL_CONST_FUNCTION_WRAPPER
146};
147template <typename T>
149 // works like a move constructor, except it does not modify `a`
150 : buffer_cl_(A.impl_->buffer_cl_),
151 rows_(A.impl_->rows_),
152 cols_(A.impl_->cols_),
153 view_(A.impl_->view_),
154 write_events_(A.impl_->write_events_),
155 read_events_(A.impl_->read_events_) {}
156
157template <typename T>
159 // works like a move assignment operator, except it does not modify `a`
160 view_ = a.impl_->view();
161 rows_ = a.impl_->rows();
162 cols_ = a.impl_->cols();
163 this->wait_for_read_write_events();
164 buffer_cl_ = a.impl_->buffer_cl_;
165 write_events_ = a.impl_->write_events_;
166 read_events_ = a.impl_->read_events_;
167 return *this;
168}
169
170} // namespace math
171} // namespace stan
172
173#endif
174#endif
#define ARENA_MATRIX_CL_FUNCTION_WRAPPER(function_name)
Implements a wrapper for a non-const function in matrix_cl.
#define ARENA_MATRIX_CL_CONST_FUNCTION_WRAPPER(function_name)
Implements a wrapper for a const function in matrix_cl.
typename matrix_cl< T >::type type
matrix_cl< T > eval() const &
Evaluates this.
decltype(auto) read_write_events(Args &&... args) const
decltype(auto) read_events(Args &&... args) const
arena_matrix_cl(arena_matrix_cl< T > &&)=default
decltype(auto) view(Args &&... args) const
decltype(auto) clear_read_write_events(Args &&... args) const
decltype(auto) cols(Args &&... args) const
arena_matrix_cl(const arena_matrix_cl< T > &)=default
decltype(auto) add_read_event(Args &&... args) const
decltype(auto) clear_read_events(Args &&... args) const
arena_matrix_cl(Args &&... args)
General constructor forwards arguments to various matrix_cl constructors.
decltype(auto) wait_for_write_events(Args &&... args) const
decltype(auto) add_write_event(Args &&... args) const
decltype(auto) buffer(Args &&... args) const
decltype(auto) wait_for_read_events(Args &&... args) const
arena_matrix_cl(arena_matrix_cl< T > &)=default
decltype(auto) write_events(Args &&... args) const
internal::arena_matrix_cl_impl< T > * impl_
decltype(auto) wait_for_read_write_events(Args &&... args) const
arena_matrix_cl(Expr &&expression)
Constructor from a kernel generator expression.
decltype(auto) clear_write_events(Args &&... args) const
decltype(auto) rows(Args &&... args) const
arena_matrix_cl< T > & operator=(arena_matrix_cl< T > &&)=default
typename matrix_cl< T >::Scalar Scalar
decltype(auto) size(Args &&... args) const
decltype(auto) add_read_write_event(Args &&... args) const
arena_matrix_cl< T > & operator=(const arena_matrix_cl< T > &)=default
A variant of matrix_cl that schedules its destructor to be called, so it can be used on the AD stack.
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
arena_matrix_cl_impl(arena_matrix_cl_impl< T > &)=default
arena_matrix_cl_impl(const arena_matrix_cl_impl< T > &)=default
arena_matrix_cl_impl(arena_matrix_cl_impl< T > &&)=default
arena_matrix_cl_impl< T > & operator=(arena_matrix_cl_impl< T > &&)=default
arena_matrix_cl_impl< T > & operator=(const arena_matrix_cl_impl< T > &)=default
typename matrix_cl< T >::Scalar Scalar
Non-templated base class for matrix_cl simplifies checking if something is matrix_cl.
matrix_cl< T > & operator=(matrix_cl< T > &&a)
Move assignment operator.
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.