Automatic Differentiation
 
Loading...
Searching...
No Matches
indexing.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_indexing_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_indexing_HPP
3#ifdef STAN_OPENCL
4
10#include <map>
11#include <string>
12#include <utility>
13
14namespace stan {
15namespace math {
16
28template <typename T_mat, typename T_row_index, typename T_col_index>
30 : public operation_cl_lhs<indexing_<T_mat, T_row_index, T_col_index>,
31 typename std::remove_reference_t<T_mat>::Scalar,
32 T_mat, T_row_index, T_col_index> {
33 static_assert(std::is_integral<value_type_t<T_row_index>>::value,
34 "indexing: Row index scalar type must be an integer!");
35 static_assert(std::is_integral<value_type_t<T_col_index>>::value,
36 "indexing: Column index scalar type must be an integer!");
37
38 public:
39 using Scalar = typename std::remove_reference_t<T_mat>::Scalar;
41 Scalar, T_mat, T_row_index, T_col_index>;
42 using base::var_name_;
43 using base::operator=;
44
51 indexing_(T_mat&& mat, T_row_index&& row_index, T_col_index&& col_index)
52 : base(std::forward<T_mat>(mat), std::forward<T_row_index>(row_index),
53 std::forward<T_col_index>(col_index)) {
54 const char* function = "indexing";
57 check_size_match(function, "Rows of ", "col_index", col_index.rows(),
58 "rows of ", "row_index", row_index.rows());
59 }
62 check_size_match(function, "Columns of ", "col_index", col_index.cols(),
63 "columns of ", "row_index", row_index.cols());
64 }
65 }
66
71 inline auto deep_copy() const {
72 auto&& mat_copy = this->template get_arg<0>().deep_copy();
73 auto&& row_index_copy = this->template get_arg<1>().deep_copy();
74 auto&& col_index_copy = this->template get_arg<2>().deep_copy();
75 return indexing_<std::remove_reference_t<decltype(mat_copy)>,
76 std::remove_reference_t<decltype(row_index_copy)>,
77 std::remove_reference_t<decltype(col_index_copy)>>{
78 std::move(mat_copy), std::move(row_index_copy),
79 std::move(col_index_copy)};
80 }
81
95 std::unordered_map<const void*, const char*>& generated,
96 std::unordered_map<const void*, const char*>& generated_all,
97 name_generator& name_gen, const std::string& row_index_name,
98 const std::string& col_index_name, bool view_handled) const {
99 kernel_parts res{};
100 if (generated.count(this) == 0) {
101 generated[this] = "";
102
103 const auto& mat = this->template get_arg<0>();
104 const auto& row_index = this->template get_arg<1>();
105 const auto& col_index = this->template get_arg<2>();
106
108 generated, generated_all, name_gen, row_index_name, col_index_name,
109 view_handled);
111 generated, generated_all, name_gen, row_index_name, col_index_name,
112 view_handled);
113 std::unordered_map<const void*, const char*> generated2;
114 kernel_parts parts_mat = mat.get_kernel_parts(
115 generated2, generated_all, name_gen, row_index.var_name_,
116 col_index.var_name_, false);
117
118 res = parts_row_idx + parts_col_idx + parts_mat;
119 var_name_ = mat.var_name_;
120 }
121 return res;
122 }
123
137 std::unordered_map<const void*, const char*>& generated,
138 std::unordered_map<const void*, const char*>& generated_all,
139 name_generator& name_gen, const std::string& row_index_name,
140 const std::string& col_index_name) const {
141 if (generated.count(this) == 0) {
142 generated[this] = "";
143 }
144 const auto& mat = this->template get_arg<0>();
145 const auto& row_index = this->template get_arg<1>();
146 const auto& col_index = this->template get_arg<2>();
147
148 kernel_parts parts_row_idx
149 = row_index.get_kernel_parts(generated, generated_all, name_gen,
150 row_index_name, col_index_name, false);
151 kernel_parts parts_col_idx
152 = col_index.get_kernel_parts(generated, generated_all, name_gen,
153 row_index_name, col_index_name, false);
154 std::unordered_map<const void*, const char*> generated2;
155 kernel_parts parts_mat
156 = mat.get_kernel_parts_lhs(generated2, generated_all, name_gen,
158
159 kernel_parts res = parts_row_idx + parts_col_idx + parts_mat;
160 var_name_ = mat.var_name_;
161 return res;
162 }
163
174 inline void set_args(
175 std::unordered_map<const void*, const char*>& generated,
176 std::unordered_map<const void*, const char*>& generated_all,
177 cl::Kernel& kernel, int& arg_num) const {
178 if (generated.count(this) == 0) {
179 generated[this] = "";
180 this->template get_arg<1>().set_args(generated, generated_all, kernel,
181 arg_num);
182 this->template get_arg<2>().set_args(generated, generated_all, kernel,
183 arg_num);
184 std::unordered_map<const void*, const char*> generated2;
185 this->template get_arg<0>().set_args(generated2, generated_all, kernel,
186 arg_num);
187 }
188 }
189
195 inline int rows() const {
196 return std::max(this->template get_arg<1>().rows(),
197 this->template get_arg<2>().rows());
198 }
199
205 inline int cols() const {
206 return std::max(this->template get_arg<1>().cols(),
207 this->template get_arg<2>().cols());
208 }
209
224 inline void set_view(int bottom_diagonal, int top_diagonal,
225 int bottom_zero_diagonal, int top_zero_diagonal) const {
226 this->template get_arg<0>().set_view(
227 std::numeric_limits<int>::min(), std::numeric_limits<int>::max(),
228 std::numeric_limits<int>::min(), std::numeric_limits<int>::max());
229 }
230
235 inline std::pair<int, int> extreme_diagonals() const {
236 return {std::numeric_limits<int>::min(), std::numeric_limits<int>::max()};
237 }
238
246 inline void check_assign_dimensions(int rows, int cols) const {
247 check_size_match("indexing_.check_assign_dimensions", "Rows of ",
248 "indexing", this->rows(), "rows of ", "expression", rows);
249 check_size_match("indexing_.check_assign_dimensions", "Columns of ",
250 "indexing", this->cols(), "columns of ", "expression",
251 cols);
252 }
253
258 inline void add_write_event(cl::Event& e) const {
259 this->template get_arg<1>().add_read_event(e);
260 this->template get_arg<2>().add_read_event(e);
261 this->template get_arg<0>().add_write_event(e);
262 }
263
270 std::vector<cl::Event>& events) const {
271 this->template get_arg<0>().get_clear_read_write_events(events);
272 this->template get_arg<1>().get_write_events(events);
273 this->template get_arg<2>().get_write_events(events);
274 }
275};
276
301template <typename T_mat, typename T_row_index, typename T_col_index,
302 require_all_kernel_expressions_t<T_mat, T_row_index,
303 T_col_index>* = nullptr>
304inline auto indexing(T_mat&& mat, T_row_index&& row_index,
305 T_col_index&& col_index) {
306 auto&& mat_operation = as_operation_cl(std::forward<T_mat>(mat)).deep_copy();
307 return indexing_<std::remove_reference_t<decltype(mat_operation)>,
310 std::move(mat_operation),
311 as_operation_cl(std::forward<T_row_index>(row_index)),
312 as_operation_cl(std::forward<T_col_index>(col_index)));
313}
314
316} // namespace math
317} // namespace stan
318
319#endif
320#endif
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Definition index.hpp:129
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Definition index.hpp:122
Represents operation that determines column index.
Definition index.hpp:80
void get_clear_read_write_events(std::vector< cl::Event > &events) const
Adds all read and write events on any matrices used by nested expressions to a list and clears them f...
Definition indexing.hpp:269
typename std::remove_reference_t< T_mat >::Scalar Scalar
Definition indexing.hpp:39
auto deep_copy() const
Creates a deep copy of this expression.
Definition indexing.hpp:71
void add_write_event(cl::Event &e) const
Adds write event to indexed matrix and read event to indices.
Definition indexing.hpp:258
indexing_(T_mat &&mat, T_row_index &&row_index, T_col_index &&col_index)
Constructor.
Definition indexing.hpp:51
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Definition indexing.hpp:205
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Definition indexing.hpp:235
void check_assign_dimensions(int rows, int cols) const
Checks if desired dimensions match dimensions of the indexing.
Definition indexing.hpp:246
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Definition indexing.hpp:195
kernel_parts get_kernel_parts_lhs(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name) const
Generates kernel code for this expression if it appears on the left hand side of an assignment.
Definition indexing.hpp:136
void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const
Sets the view of the underlying matrix depending on which of its parts are written to.
Definition indexing.hpp:224
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
Definition indexing.hpp:94
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this expression.
Definition indexing.hpp:174
Represents indexing of a matrix with two matrices of indices.
Definition indexing.hpp:32
Unique name generator for variables used in generated kernels.
Base for all kernel generator operations that can be used on left hand side of an expression.
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Definition index.hpp:59
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Definition index.hpp:66
Represents operation that determines row index.
Definition index.hpp:17
require_all_t< is_kernel_expression< Types >... > require_all_kernel_expressions_t
Enables a template if all given types are are a valid kernel generator expressions.
auto indexing(T_mat &&mat, T_row_index &&row_index, T_col_index &&col_index)
Index a kernel generator expression using two expressions for indices.
Definition indexing.hpp:304
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
std::conditional_t< std::is_lvalue_reference< T >::value, decltype(as_operation_cl< AssignOp >(std::declval< T >())), std::remove_reference_t< decltype(as_operation_cl< AssignOp >(std::declval< T >()))> > as_operation_cl_t
Type that results when converting any valid kernel generator expression into operation.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
Parts of an OpenCL kernel, generated by an expression.