1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_APPEND_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_APPEND_HPP
33template <
typename T_a,
typename T_b>
35 common_scalar_t<T_a, T_b>, T_a, T_b> {
51 :
base(
std::forward<T_a>(a),
std::forward<T_b>(b)) {
54 "columns of ",
"b", b.cols());
58 "should be nonnegative!");
62 "should be nonnegative!");
71 auto&& a_copy = this->
template get_arg<0>().deep_copy();
72 auto&& b_copy = this->
template get_arg<1>().deep_copy();
73 return append_row_<std::remove_reference_t<
decltype(a_copy)>,
74 std::remove_reference_t<
decltype(b_copy)>>{
75 std::move(a_copy), std::move(b_copy)};
91 std::unordered_map<const void*, const char*>& generated,
92 std::unordered_map<const void*, const char*>& generated_all,
94 const std::string& col_index_name,
bool view_handled)
const {
96 if (generated.count(
this) == 0) {
99 kernel_parts parts_a = this->
template get_arg<0>().get_kernel_parts(
100 generated, generated_all, name_gen, row_index_name, col_index_name,
102 std::string row_index_name_b
103 =
"(" + row_index_name +
" - " +
var_name_ +
"_first_rows)";
104 std::unordered_map<const void*, const char*> generated_b;
105 kernel_parts parts_b = this->
template get_arg<1>().get_kernel_parts(
106 generated_b, generated_all, name_gen, row_index_name_b,
107 col_index_name,
true);
108 res = parts_a + parts_b;
110 "if("+ row_index_name +
" < " +
var_name_ +
"_first_rows){\n"
112 var_name_ +
" = " + this->
template get_arg<0>().var_name_ +
";\n"
115 var_name_ +
" = " + this->
template get_arg<1>().var_name_ +
";\n"
117 res.args +=
"int " +
var_name_ +
"_first_rows, ";
133 std::unordered_map<const void*, const char*>& generated,
134 std::unordered_map<const void*, const char*>& generated_all,
135 cl::Kernel& kernel,
int& arg_num)
const {
136 if (generated.count(
this) == 0) {
137 generated[
this] =
"";
138 this->
template get_arg<0>().set_args(generated, generated_all, kernel,
140 std::unordered_map<const void*, const char*> generated_b;
141 this->
template get_arg<1>().set_args(generated_b, generated_all, kernel,
143 kernel.setArg(arg_num++, this->
template get_arg<0>().
rows());
153 return this->
template get_arg<0>().rows()
154 + this->
template get_arg<1>().rows();
162 std::pair<int, int> a_diags
163 = this->
template get_arg<0>().extreme_diagonals();
164 std::pair<int, int> b_diags
165 = this->
template get_arg<1>().extreme_diagonals();
166 int my_rows = this->
template get_arg<0>().rows();
167 return {std::min(a_diags.first, b_diags.first - my_rows),
168 std::max(a_diags.second, b_diags.second - my_rows)};
181template <
typename Ta,
typename Tb,
186 return append_row_<std::remove_reference_t<
decltype(a_operation)>,
187 std::remove_reference_t<
decltype(b_operation)>>(
188 std::move(a_operation), std::move(b_operation));
196template <
typename T_a,
typename T_b>
198 common_scalar_t<T_a, T_b>, T_a, T_b> {
214 :
base(
std::forward<T_a>(a),
std::forward<T_b>(b)) {
221 "should be nonnegative!");
225 "should be nonnegative!");
234 auto&& a_copy = this->
template get_arg<0>().deep_copy();
235 auto&& b_copy = this->
template get_arg<1>().deep_copy();
236 return append_col_<std::remove_reference_t<
decltype(a_copy)>,
237 std::remove_reference_t<
decltype(b_copy)>>{
238 std::move(a_copy), std::move(b_copy)};
254 std::unordered_map<const void*, const char*>& generated,
255 std::unordered_map<const void*, const char*>& generated_all,
257 const std::string& col_index_name,
bool view_handled)
const {
259 if (generated.count(
this) == 0) {
261 generated[
this] =
"";
262 kernel_parts parts_a = this->
template get_arg<0>().get_kernel_parts(
263 generated, generated_all, name_gen, row_index_name, col_index_name,
265 std::string col_index_name_b
266 =
"(" + col_index_name +
" - " +
var_name_ +
"_first_cols)";
267 std::unordered_map<const void*, const char*> generated_b;
268 kernel_parts parts_b = this->
template get_arg<1>().get_kernel_parts(
269 generated_b, generated_all, name_gen, row_index_name,
270 col_index_name_b,
true);
271 res = parts_a + parts_b;
273 "if("+ col_index_name +
" < " +
var_name_ +
"_first_cols){\n"
275 var_name_ +
" = " + this->
template get_arg<0>().var_name_ +
";\n"
278 var_name_ +
" = " + this->
template get_arg<1>().var_name_ +
";\n"
280 res.args +=
"int " +
var_name_ +
"_first_cols, ";
296 std::unordered_map<const void*, const char*>& generated,
297 std::unordered_map<const void*, const char*>& generated_all,
298 cl::Kernel& kernel,
int& arg_num)
const {
299 if (generated.count(
this) == 0) {
300 generated[
this] =
"";
301 this->
template get_arg<0>().set_args(generated, generated_all, kernel,
303 std::unordered_map<const void*, const char*> generated_b;
304 this->
template get_arg<1>().set_args(generated_b, generated_all, kernel,
306 kernel.setArg(arg_num++, this->
template get_arg<0>().
cols());
316 return this->
template get_arg<0>().cols()
317 + this->
template get_arg<1>().cols();
325 std::pair<int, int> a_diags
326 = this->
template get_arg<0>().extreme_diagonals();
327 std::pair<int, int> b_diags
328 = this->
template get_arg<1>().extreme_diagonals();
329 int my_cols = this->
template get_arg<0>().cols();
330 return {std::min(a_diags.first, b_diags.first + my_cols),
331 std::max(a_diags.second, b_diags.second + my_cols)};
344template <
typename Ta,
typename Tb,
349 return append_col_<std::remove_reference_t<
decltype(a_operation)>,
350 std::remove_reference_t<
decltype(b_operation)>>(
351 std::move(a_operation), std::move(b_operation));
Represents appending of cols in kernel generator expressions.
Represents appending of rows in kernel generator expressions.
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
static constexpr int dynamic
std::tuple< Args... > arguments_
Base for all kernel generator operations.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
auto deep_copy() const
Creates a deep copy of this expression.
auto deep_copy() const
Creates a deep copy of this expression.
auto append_col(Ta &&a, Tb &&b)
Stack the cols of the arguments.
common_scalar_t< T_a, T_b > Scalar
append_col_(T_a &&a, T_b &&b)
Constructor.
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
append_row_(T_a &&a, T_b &&b)
Constructor.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
int cols() const
Number of rows of a matrix that would be the result of evaluating this expression.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
common_scalar_t< T_a, T_b > Scalar
auto append_row(Ta &&a, Tb &&b)
Stack the rows of the first argument on top of the second argument.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
typename std::common_type_t< typename std::remove_reference_t< Types >::Scalar... > common_scalar_t
Wrapper for std::common_type_t
void invalid_argument(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw an invalid_argument exception with a consistently formatted message.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.