1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_BLOCK_ZERO_BASED_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_BLOCK_ZERO_BASED_HPP
33 typename std::remove_reference_t<T>::Scalar, T> {
35 using Scalar =
typename std::remove_reference_t<T>::Scalar;
39 using base::operator=;
61 " should be non-negative, but is ");
65 " should be non-negative, but is ");
69 " should be non-negative, but is ");
73 " should be non-negative, but is ");
86 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
87 return block_<std::remove_reference_t<
decltype(arg_copy)>>{
104 std::unordered_map<const void*, const char*>& generated,
105 std::unordered_map<const void*, const char*>& generated_all,
107 const std::string& col_index_name,
bool view_handled)
const {
109 if (generated.count(
this) == 0) {
111 generated[
this] =
"";
112 std::string row_index_name_arg = row_index_name;
113 std::string col_index_name_arg = col_index_name;
115 std::unordered_map<const void*, const char*> generated2;
116 res = this->
template get_arg<0>().get_kernel_parts(
117 generated2, generated_all, name_gen, row_index_name_arg,
118 col_index_name_arg, view_handled);
120 = this->
generate(row_index_name, col_index_name, view_handled,
122 if (generated_all.count(
this) == 0) {
123 generated_all[
this] =
"";
128 res.
body = res.body_prefix + res.body;
129 res.body_prefix =
"";
144 const std::string& col_index_name,
145 const bool view_handled,
146 const std::string& var_name_arg)
const {
149 = type_str<Scalar>() +
" " +
var_name_ +
" = " + var_name_arg +
";\n";
160 std::string& col_index_name)
const {
161 row_index_name =
"(" + row_index_name +
" + " +
var_name_ +
"_i)";
162 col_index_name =
"(" + col_index_name +
" + " +
var_name_ +
"_j)";
178 std::unordered_map<const void*, const char*>& generated,
179 std::unordered_map<const void*, const char*>& generated_all,
181 const std::string& col_index_name)
const {
182 if (generated.count(
this) == 0) {
183 generated[
this] =
"";
186 std::string row_index_name_arg = row_index_name;
187 std::string col_index_name_arg = col_index_name;
189 std::unordered_map<const void*, const char*> generated2;
190 kernel_parts res = this->
template get_arg<0>().get_kernel_parts_lhs(
191 generated2, generated_all, name_gen, row_index_name_arg,
193 res += this->
derived().generate_lhs(row_index_name, col_index_name,
195 if (generated_all.count(
this) == 0) {
196 generated_all[
this] =
"";
213 const std::string& var_name_arg)
const {
230 std::unordered_map<const void*, const char*>& generated,
231 std::unordered_map<const void*, const char*>& generated_all,
232 cl::Kernel& kernel,
int& arg_num)
const {
233 if (generated.count(
this) == 0) {
234 generated[
this] =
"";
235 std::unordered_map<const void*, const char*> generated2;
236 this->
template get_arg<0>().set_args(generated2, generated_all, kernel,
238 if (generated_all.count(
this) == 0) {
239 generated_all[
this] =
"";
273 inline void set_view(
int bottom_diagonal,
int top_diagonal,
274 int bottom_zero_diagonal,
int top_zero_diagonal)
const {
276 auto& a = this->
template get_arg<0>();
278 bottom_diagonal + change, top_diagonal + change,
281 ? bottom_zero_diagonal
296 std::pair<int, int> arg_diags
297 = this->
template get_arg<0>().extreme_diagonals();
313 cols_,
"columns of ",
"expression",
cols);
343 return block_<std::remove_reference_t<
decltype(a_operation)>>(
344 std::move(a_operation), start_row, start_col,
rows,
cols);
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this expression.
typename std::remove_reference_t< T >::Scalar Scalar
kernel_parts get_kernel_parts_lhs(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name) const
Generates kernel code for this expression if it appears on the left hand side of an assignment.
kernel_parts generate_lhs(const std::string &i, const std::string &j, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions if this expression appears on the left hand sid...
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Sets offset of block to indices of the argument expression.
auto deep_copy() const
Creates a deep copy of this expression.
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
void check_assign_dimensions(int rows, int cols) const
Checks if desired dimensions match dimensions of the block.
void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const
Sets view of the underlying matrix depending on which part is written.
block_(T &&a, int start_row, int start_col, int rows, int cols)
Constructor.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
std::tuple< std::true_type > view_transitivity
Represents submatrix block in kernel generator expressions.
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
Base for all kernel generator operations that can be used on left hand side of an expression.
static constexpr int dynamic
Derived & derived()
Casts the instance into its derived type.
auto block_zero_based(T &&a, int start_row, int start_col, int rows, int cols)
Block of a kernel generator expression.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
int64_t cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
void invalid_argument(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw an invalid_argument exception with a consistently formatted message.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.