1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_indexing_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_indexing_HPP
28template <
typename T_mat,
typename T_row_index,
typename T_col_index>
31 typename std::remove_reference_t<T_mat>::Scalar,
32 T_mat, T_row_index, T_col_index> {
33 static_assert(std::is_integral<value_type_t<T_row_index>>::value,
34 "indexing: Row index scalar type must be an integer!");
35 static_assert(std::is_integral<value_type_t<T_col_index>>::value,
36 "indexing: Column index scalar type must be an integer!");
39 using Scalar =
typename std::remove_reference_t<T_mat>::Scalar;
41 Scalar, T_mat, T_row_index, T_col_index>;
43 using base::operator=;
54 const char* function =
"indexing";
72 auto&& mat_copy = this->
template get_arg<0>().deep_copy();
73 auto&& row_index_copy = this->
template get_arg<1>().deep_copy();
74 auto&& col_index_copy = this->
template get_arg<2>().deep_copy();
75 return indexing_<std::remove_reference_t<
decltype(mat_copy)>,
76 std::remove_reference_t<
decltype(row_index_copy)>,
77 std::remove_reference_t<
decltype(col_index_copy)>>{
78 std::move(mat_copy), std::move(row_index_copy),
79 std::move(col_index_copy)};
95 std::unordered_map<const void*, const char*>& generated,
96 std::unordered_map<const void*, const char*>& generated_all,
98 const std::string& col_index_name,
bool view_handled)
const {
100 if (generated.count(
this) == 0) {
101 generated[
this] =
"";
103 const auto& mat = this->
template get_arg<0>();
104 const auto&
row_index = this->
template get_arg<1>();
105 const auto&
col_index = this->
template get_arg<2>();
108 generated, generated_all, name_gen, row_index_name, col_index_name,
111 generated, generated_all, name_gen, row_index_name, col_index_name,
113 std::unordered_map<const void*, const char*> generated2;
118 res = parts_row_idx + parts_col_idx + parts_mat;
137 std::unordered_map<const void*, const char*>& generated,
138 std::unordered_map<const void*, const char*>& generated_all,
140 const std::string& col_index_name)
const {
141 if (generated.count(
this) == 0) {
142 generated[
this] =
"";
144 const auto& mat = this->
template get_arg<0>();
145 const auto&
row_index = this->
template get_arg<1>();
146 const auto&
col_index = this->
template get_arg<2>();
150 row_index_name, col_index_name,
false);
153 row_index_name, col_index_name,
false);
154 std::unordered_map<const void*, const char*> generated2;
156 = mat.get_kernel_parts_lhs(generated2, generated_all, name_gen,
159 kernel_parts res = parts_row_idx + parts_col_idx + parts_mat;
175 std::unordered_map<const void*, const char*>& generated,
176 std::unordered_map<const void*, const char*>& generated_all,
177 cl::Kernel& kernel,
int& arg_num)
const {
178 if (generated.count(
this) == 0) {
179 generated[
this] =
"";
180 this->
template get_arg<1>().set_args(generated, generated_all, kernel,
182 this->
template get_arg<2>().set_args(generated, generated_all, kernel,
184 std::unordered_map<const void*, const char*> generated2;
185 this->
template get_arg<0>().set_args(generated2, generated_all, kernel,
196 return std::max(this->
template get_arg<1>().
rows(),
197 this->
template get_arg<2>().
rows());
206 return std::max(this->
template get_arg<1>().
cols(),
207 this->
template get_arg<2>().
cols());
224 inline void set_view(
int bottom_diagonal,
int top_diagonal,
225 int bottom_zero_diagonal,
int top_zero_diagonal)
const {
226 this->
template get_arg<0>().set_view(
227 std::numeric_limits<int>::min(), std::numeric_limits<int>::max(),
228 std::numeric_limits<int>::min(), std::numeric_limits<int>::max());
236 return {std::numeric_limits<int>::min(), std::numeric_limits<int>::max()};
248 "indexing", this->
rows(),
"rows of ",
"expression", rows);
250 "indexing", this->
cols(),
"columns of ",
"expression",
259 this->
template get_arg<1>().add_read_event(
e);
260 this->
template get_arg<2>().add_read_event(
e);
261 this->
template get_arg<0>().add_write_event(
e);
270 std::vector<cl::Event>& events)
const {
271 this->
template get_arg<0>().get_clear_read_write_events(events);
272 this->
template get_arg<1>().get_write_events(events);
273 this->
template get_arg<2>().get_write_events(events);
301template <
typename T_mat,
typename T_row_index,
typename T_col_index,
303 T_col_index>* =
nullptr>
306 auto&& mat_operation =
as_operation_cl(std::forward<T_mat>(mat)).deep_copy();
307 return indexing_<std::remove_reference_t<
decltype(mat_operation)>,
310 std::move(mat_operation),
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Represents operation that determines column index.
void get_clear_read_write_events(std::vector< cl::Event > &events) const
Adds all read and write events on any matrices used by nested expressions to a list and clears them f...
typename std::remove_reference_t< T_mat >::Scalar Scalar
auto deep_copy() const
Creates a deep copy of this expression.
void add_write_event(cl::Event &e) const
Adds write event to indexed matrix and read event to indices.
indexing_(T_mat &&mat, T_row_index &&row_index, T_col_index &&col_index)
Constructor.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
void check_assign_dimensions(int rows, int cols) const
Checks if desired dimensions match dimensions of the indexing.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
kernel_parts get_kernel_parts_lhs(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name) const
Generates kernel code for this expression if it appears on the left hand side of an assignment.
void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const
Sets the view of the underlying matrix depending on which of its parts are written to.
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this expression.
Represents indexing of a matrix with two matrices of indices.
Unique name generator for variables used in generated kernels.
Base for all kernel generator operations that can be used on left hand side of an expression.
static constexpr int dynamic
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Represents operation that determines row index.
require_all_t< is_kernel_expression< Types >... > require_all_kernel_expressions_t
Enables a template if all given types are are a valid kernel generator expressions.
auto indexing(T_mat &&mat, T_row_index &&row_index, T_col_index &&col_index)
Index a kernel generator expression using two expressions for indices.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
std::conditional_t< std::is_lvalue_reference< T >::value, decltype(as_operation_cl< AssignOp >(std::declval< T >())), std::remove_reference_t< decltype(as_operation_cl< AssignOp >(std::declval< T >()))> > as_operation_cl_t
Type that results when converting any valid kernel generator expression into operation.
static constexpr double e()
Return the base of the natural logarithm.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.