1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_OPERATION_CL_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_OPERATION_CL_HPP
13#include <CL/opencl.hpp>
79 os <<
"args:" << std::endl;
80 os << parts.
args.substr(0, parts.
args.size() - 2) << std::endl;
81 os <<
"Decl:" << std::endl;
83 os <<
"Init:" << std::endl;
85 os <<
"body:" << std::endl;
86 os << parts.
body << std::endl;
87 os <<
"body_suffix:" << std::endl;
89 os <<
"reduction_1d:" << std::endl;
91 os <<
"reduction_2d:" << std::endl;
102template <
typename Derived,
typename Scalar,
typename... Args>
106 std::remove_reference_t<Args>>...>::value,
107 "operation_cl: all arguments to operation must be operations!");
119 inline Derived&
derived() {
return *
static_cast<Derived*
>(
this); }
126 return *
static_cast<const Derived*
>(
this);
133 static constexpr int N =
sizeof...(Args);
162 const char* function =
"operation_cl.eval()";
165 " must be nonnegative, but is ",
166 " (broadcasted expressions can not be evaluated)");
170 " must be nonnegative, but is ",
171 " (broadcasted expressions can not be evaluated)");
174 if (res.
size() > 0) {
187 template <
typename T_lhs>
197 template <
typename T_lhs>
199 const T_lhs& lhs)
const;
213 template <
typename T_result>
215 std::unordered_map<const void*, const char*>& generated,
216 std::unordered_map<const void*, const char*>& generated_all,
218 const std::string& col_index_name,
const T_result& result)
const {
220 generated, generated_all, ng, row_index_name, col_index_name,
false);
222 generated, generated_all, ng, row_index_name, col_index_name);
223 out_parts.
body += assignment_op<T_result>() +
derived().var_name_ +
";\n";
241 std::unordered_map<const void*, const char*>& generated,
242 std::unordered_map<const void*, const char*>& generated_all,
244 const std::string& col_index_name,
bool view_handled)
const {
246 if (generated.count(
this) == 0) {
247 this->var_name_ = name_gen.
generate();
248 generated[
this] =
"";
249 std::string row_index_name_arg = row_index_name;
250 std::string col_index_name_arg = col_index_name;
251 derived().modify_argument_indices(row_index_name_arg, col_index_name_arg);
252 std::array<kernel_parts, N> args_parts = index_apply<N>([&](
auto... Is) {
253 std::unordered_map<const void*, const char*> generated2;
254 return std::array<kernel_parts, N>{this->get_arg<Is>().get_kernel_parts(
255 &Derived::modify_argument_indices
259 generated_all, name_gen, row_index_name_arg, col_index_name_arg,
261 && std::tuple_element_t<
262 Is,
typename Deriv::view_transitivity>::value)...};
264 res = std::accumulate(args_parts.begin(), args_parts.end(),
267 return this->
derived().generate(row_index_name, col_index_name,
272 res.
body = res.body_prefix + res.body;
273 res.body_prefix =
"";
287 const std::string& col_index_name,
288 const bool view_handled,
289 const std::string& var_name_arg)
const {
303 std::string& col_index_name)
const {}
316 std::unordered_map<const void*, const char*>& generated,
317 std::unordered_map<const void*, const char*>& generated_all,
318 cl::Kernel& kernel,
int& arg_num)
const {
319 if (generated.count(
this) == 0) {
320 generated[
this] =
"";
326 index_apply<N>([&](
auto... Is) {
327 std::unordered_map<const void*, const char*> generated2;
328 static_cast<void>(std::initializer_list<int>{
329 (this->get_arg<Is>().set_args(
330 &Derived::modify_argument_indices
334 generated_all, kernel, arg_num),
345 index_apply<N>([&](
auto... Is) {
346 static_cast<void>(std::initializer_list<int>{
347 (this->get_arg<Is>().add_read_event(
e), 0)...});
356 index_apply<N>([&](
auto... Is) {
357 static_cast<void>(std::initializer_list<int>{
358 (this->
template get_arg<Is>().get_write_events(events), 0)...});
369 N > 0,
"default rows does not work on expressions with no arguments!");
370 return index_apply<N>([&](
auto... Is) {
372 return std::max({this->get_arg<Is>().rows()...});
383 N > 0,
"default cols does not work on expressions with no arguments!");
384 return index_apply<N>([&](
auto... Is) {
386 return std::max({this->get_arg<Is>().cols()...});
418 "default extreme_diagonals does not work on expressions with "
420 return index_apply<N>([&](
auto... Is) {
423 int bottom = std::min(
424 std::initializer_list<int>({std::get<Is>(arg_diags).first...}));
426 std::initializer_list<int>({std::get<Is>(arg_diags).second...}));
427 return std::make_pair(bottom, top);
436 std::pair<int, int> diagonals =
derived().extreme_diagonals();
438 if (diagonals.first < 0) {
443 if (diagonals.second > 0) {
457 std::vector<int>& uids, std::unordered_map<const void*, int>& id_map,
458 int& next_id)
const {
459 index_apply<N>([&](
auto... Is) {
460 static_cast<void>(std::initializer_list<int>{(
461 this->get_arg<Is>().get_unique_matrix_accesses(uids, id_map, next_id),
467template <
typename Derived,
typename Scalar,
typename... Args>
468const bool operation_cl<Derived, Scalar, Args...>::require_specific_local_size
470 std::decay_t<Args>::Deriv::require_specific_local_size...});
Represents an arithmetic matrix on the OpenCL device.
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
const auto & get_arg() const
Returns an argument to this operation.
matrix_cl_view view() const
View of a matrix that would be the result of evaluating this expression.
static constexpr int dynamic
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Derived & derived()
Casts the instance into its derived type.
std::tuple< Args... > ArgsTuple
kernel_parts get_whole_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &ng, const std::string &row_index_name, const std::string &col_index_name, const T_result &result) const
Generates kernel code for assigning this expression into result expression.
operation_cl(Args &&... arguments)
Constructor.
void get_write_events(std::vector< cl::Event > &events) const
Adds all write events on any matrices used by nested expressions to a list.
matrix_cl< Scalar > eval() const
Evaluates the expression.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this expression.
void add_read_event(cl::Event &e) const
Adds read event to any matrices used by nested expressions.
std::tuple< Args... > arguments_
int thread_rows() const
Number of rows threads need to be launched for.
int size() const
Size of a matrix that would be the result of evaluating this expression.
void get_unique_matrix_accesses(std::vector< int > &uids, std::unordered_map< const void *, int > &id_map, int &next_id) const
Collects data that is needed beside types to uniqly identify a kernel generator expression.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for nested expressions.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Does nothing.
const Derived & derived() const
Casts the instance into its derived type.
int thread_cols() const
Number of columns threads need to be launched for.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
std::tuple< std::is_same< Args, void >... > view_transitivity
Base for all kernel generator operations.
Represents a two dimensional reduction in kernel generator expressions.
Non-templated base of operation_cl is needed for easy checking if something is a subclass of operatio...
std::string get_kernel_source_for_evaluating_into(const T_lhs &lhs) const
Generates kernel source for evaluating this expression into given left-hand-side expression.
void evaluate_into(T_lhs &lhs) const
Evaluates this expression into given left-hand-side expression.
static const bool require_specific_local_size
std::ostream & operator<<(std::ostream &os, kernel_parts &parts)
const matrix_cl_view either(const matrix_cl_view left_view, const matrix_cl_view right_view)
Determines which parts are nonzero in any of the input views.
static constexpr double e()
Return the base of the natural logarithm.
void invalid_argument(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw an invalid_argument exception with a consistently formatted message.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
kernel_parts & operator+=(const kernel_parts &other)
kernel_parts operator+(const kernel_parts &other)
std::string initialization
Parts of an OpenCL kernel, generated by an expression.