1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_MULTI_RESULT_KERNEL_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_MULTI_RESULT_KERNEL_HPP
33template <
int N,
typename... T_results>
35 template <
typename... T_expressions>
39 N - 1, T_results...>::template
inner<T_expressions...>;
41 std::tuple_element_t<N, std::tuple<T_results...>>>;
43 std::tuple_element_t<N, std::tuple<T_expressions...>>>;
51 std::vector<cl::Event>& events,
52 const std::tuple<std::pair<T_results, T_expressions>...>&
54 next::get_clear_events(events, assignment_pairs);
55 std::get<N>(assignment_pairs).second.get_write_events(events);
56 std::get<N>(assignment_pairs).first.get_clear_read_write_events(events);
69 int n_rows,
int n_cols,
70 const std::tuple<std::pair<T_results, T_expressions>...>&
72 next::check_assign_dimensions(n_rows, n_cols, assignment_pairs);
73 const auto& expression = std::get<N>(assignment_pairs).second;
74 const auto& result = std::get<N>(assignment_pairs).first;
75 constexpr const char* function =
"results.operator=";
77 int expression_rows = expression.rows();
78 int expression_cols = expression.cols();
80 && expression_cols == -1) {
81 expression_cols = n_cols;
82 expression_rows = expression.thread_rows();
84 expression_rows < 0 ? n_rows : expression_rows, expression_cols);
86 && expression_cols == -1) {
88 if (expression_rows == 0) {
91 expression_cols = (n_cols + expression_rows - 1) / expression_rows;
94 if (expression.thread_rows() != -1) {
96 expression.thread_rows(),
"rows of ",
97 "first expression", n_rows);
99 expression_rows = n_rows;
101 if (expression.thread_cols() != -1) {
103 expression.thread_cols(),
"columns of ",
104 "first expression", n_cols);
106 expression_cols = n_cols;
109 result.check_assign_dimensions(expression_rows, expression_cols);
110 int bottom_written = 1 - expression.rows();
111 int top_written = expression.cols() - 1;
112 std::pair<int, int> extreme_diagonals = expression.extreme_diagonals();
113 result.set_view(std::max(extreme_diagonals.first, bottom_written),
114 std::min(extreme_diagonals.second, top_written),
115 bottom_written, top_written);
131 std::unordered_map<const void*, const char*>& generated,
132 std::unordered_map<const void*, const char*>& generated_all,
134 const std::string& col_index_name,
135 const std::tuple<std::pair<T_results, T_expressions>...>&
138 = next::generate(generated, generated_all, ng, row_index_name,
139 col_index_name, assignment_pairs);
142 = std::get<N>(assignment_pairs)
143 .second.get_whole_kernel_parts(
144 generated, generated_all, ng, row_index_name,
145 col_index_name, std::get<N>(assignment_pairs).first);
161 std::unordered_map<const void*, const char*>& generated,
162 std::unordered_map<const void*, const char*>& generated_all,
163 cl::Kernel& kernel,
int& arg_num,
164 const std::tuple<std::pair<T_results, T_expressions>...>&
166 next::set_args(generated, generated_all, kernel, arg_num,
169 std::get<N>(assignment_pairs)
170 .second.set_args(generated, generated_all, kernel, arg_num);
171 std::get<N>(assignment_pairs)
172 .first.set_args(generated, generated_all, kernel, arg_num);
181 cl::Event
e,
const std::tuple<std::pair<T_results, T_expressions>...>&
183 next::add_event(
e, assignment_pairs);
184 std::get<N>(assignment_pairs).second.add_read_event(
e);
185 std::get<N>(assignment_pairs).first.add_write_event(
e);
196 std::vector<int>& uids, std::unordered_map<const void*, int>& id_map,
198 const std::tuple<std::pair<T_results, T_expressions>...>&
200 std::get<N>(assignment_pairs)
201 .second.get_unique_matrix_accesses(uids, id_map, next_id);
202 std::get<N>(assignment_pairs)
203 .first.get_unique_matrix_accesses(uids, id_map, next_id);
204 next::get_unique_matrix_accesses(uids, id_map, next_id, assignment_pairs);
210template <
typename... T_results>
212 template <
typename... T_expressions>
215 std::vector<cl::Event>& events,
216 const std::tuple<std::pair<T_results, T_expressions>...>&
220 int n_rows,
int n_cols,
221 const std::tuple<std::pair<T_results, T_expressions>...>&
225 std::unordered_map<const void*, const char*>& generated,
226 std::unordered_map<const void*, const char*>& generated_all,
228 const std::string& col_index_name,
229 const std::tuple<std::pair<T_results, T_expressions>...>&
235 std::unordered_map<const void*, const char*>& generated,
236 std::unordered_map<const void*, const char*>& generated_all,
237 cl::Kernel& kernel,
int& arg_num,
238 const std::tuple<std::pair<T_results, T_expressions>...>&
242 cl::Event
e,
const std::tuple<std::pair<T_results, T_expressions>...>&
246 std::vector<int>& uids, std::unordered_map<const void*, int>& id_map,
248 const std::tuple<std::pair<T_results, T_expressions>...>&
253template <
int N,
typename... T_results>
254template <
typename... T_expressions>
255std::map<std::vector<int>, cl::Kernel> multi_result_kernel_internal<
256 N, T_results...>::inner<T_expressions...>::kernel_cache_;
264template <
typename... T_expressions>
277 template <
typename...>
279 template <
typename...>
288template <
typename... T_expressions>
298template <
typename... T_results>
318 template <
typename... T_expressions,
319 typename = std::enable_if_t<
sizeof...(T_results)
320 ==
sizeof...(T_expressions)>>
322 index_apply<
sizeof...(T_expressions)>([
this, &exprs](
auto... Is) {
338 typename... T_expressions,
339 typename = std::enable_if_t<
sizeof...(T_results)
340 ==
sizeof...(T_expressions)>>
342 index_apply<
sizeof...(T_expressions)>([
this, &exprs](
auto... Is) {
343 auto tmp = std::tuple_cat(make_assignment_pair<AssignOp>(
345 index_apply<std::tuple_size<
decltype(tmp)>::value>(
346 [
this, &tmp](
auto... Is2) {
348 std::get<Is2>(tmp).first, std::get<Is2>(tmp).second)...));
360 template <
typename... T_expressions,
361 typename = std::enable_if_t<
sizeof...(T_results)
362 ==
sizeof...(T_expressions)>>
364 compound_assignment_impl<assign_op_cl::plus_equals>(exprs);
374 template <
typename... T_expressions,
375 typename = std::enable_if_t<
sizeof...(T_results)
376 ==
sizeof...(T_expressions)>>
378 compound_assignment_impl<assign_op_cl::minus_equals>(exprs);
388 template <
typename... T_expressions,
389 typename = std::enable_if_t<
sizeof...(T_results)
390 ==
sizeof...(T_expressions)>>
392 compound_assignment_impl<assign_op_cl::divide_equals>(exprs);
402 template <
typename... T_expressions,
403 typename = std::enable_if_t<
sizeof...(T_results)
404 ==
sizeof...(T_expressions)>>
406 compound_assignment_impl<assign_op_cl::multiply_equals>(exprs);
416 template <
typename... T_expressions,
417 typename = std::enable_if_t<
sizeof...(T_results)
418 ==
sizeof...(T_expressions)>>
421 return index_apply<
sizeof...(T_expressions)>([
this, &exprs](
auto... Is) {
434 template <
typename... T_res,
typename... T_expressions>
436 const std::tuple<std::pair<T_res, T_expressions>...>& assignment_pairs) {
438 std::tuple_size<std::tuple<T_expressions...>>::value - 1,
439 T_res...>::template inner<T_expressions...>;
440 static constexpr bool require_specific_local_size
442 T_expressions>::Deriv::require_specific_local_size>...>::value;
445 std::unordered_map<const void*, const char*> generated;
446 std::unordered_map<const void*, const char*> generated_all;
447 kernel_parts parts = impl::generate(generated, generated_all, ng,
"i",
"j",
450 if (require_specific_local_size) {
453 "kernel void calculate(" + parts.
args +
454 "const int rows, const int cols){\n"
455 "const int gid_i = get_global_id(0);\n"
456 "const int gid_j = get_global_id(1);\n"
457 "const int lid_i = get_local_id(0);\n"
458 "const int lsize_i = get_local_size(0);\n"
459 "const int gsize_i = get_global_size(0);\n"
460 "const int gsize_j = get_global_size(1);\n"
461 "const int wg_id_i = get_group_id(0);\n"
462 "const int wg_id_j = get_group_id(1);\n"
463 "const int n_groups_i = get_num_groups(0);\n"
465 "for(int j = gid_j; j < cols; j+=gsize_j){\n"
467 "for(int i = gid_i; i < rows; i+=gsize_i){\n"
478 "kernel void calculate(" +
479 parts.
args.substr(0, parts.
args.size() - 2) +
481 "int i = get_global_id(0);\n"
482 "int j = get_global_id(1);\n"
500 template <
typename... T_res,
typename... T_expressions>
502 const std::tuple<std::pair<T_res, T_expressions>...>& assignment_pairs) {
504 std::tuple_size<std::tuple<T_expressions...>>::value - 1,
505 T_res...>::template inner<T_expressions...>;
507 static constexpr bool any_output
514 static constexpr bool require_specific_local_size
516 T_expressions>::Deriv::require_specific_local_size>...>::value;
518 int n_rows = std::get<0>(assignment_pairs).second.thread_rows();
519 int n_cols = std::get<0>(assignment_pairs).second.thread_cols();
520 const char* function =
"results_cl.assignment";
521 impl::check_assign_dimensions(n_rows, n_cols, assignment_pairs);
522 if (n_rows * n_cols == 0) {
527 " must be nonnegative, but is ",
528 " (broadcasted expressions can not be evaluated)");
532 " must be nonnegative, but is ",
533 " (broadcasted expressions can not be evaluated)");
536 std::vector<int> uids;
537 std::unordered_map<const void*, int> id_map;
539 impl::get_unique_matrix_accesses(uids, id_map, next_id, assignment_pairs);
542 if (impl::kernel_cache_[uids]() == NULL) {
545 opts[
"LOCAL_SIZE_"] = std::min(64, opts.at(
"LOCAL_SIZE_"));
547 "calculate", {view_kernel_helpers, src}, opts);
550 cl::Kernel& kernel = impl::kernel_cache_[uids];
553 std::unordered_map<const void*, const char*> generated;
554 std::unordered_map<const void*, const char*> generated_all;
555 impl::set_args(generated, generated_all, kernel, arg_num,
558 std::vector<cl::Event> events;
559 impl::get_clear_events(events, assignment_pairs);
561 if (require_specific_local_size) {
562 kernel.setArg(arg_num++, n_rows);
563 kernel.setArg(arg_num++, n_cols);
568 int wgs_cols = (n_cols + wgs_rows - 1) / wgs_rows;
571 kernel, cl::NullRange, cl::NDRange(local * wgs_rows, wgs_cols),
572 cl::NDRange(local, 1), &events, &
e);
575 cl::NDRange(n_rows, n_cols),
576 cl::NullRange, &events, &
e);
578 impl::add_event(
e, assignment_pairs);
579 }
catch (
const cl::Error&
e) {
603 typename T_expression,
606 std::is_arithmetic<std::decay_t<
607 T_expression>>>>* =
nullptr>
609 T_expression&& expression) {
612 as_operation_cl<AssignOp>(std::forward<T_result>(result)),
629 typename T_expression,
632 T_expression&& expression) {
633 return std::make_tuple();
654 s << result.function_ <<
": " << result.err_variable_ <<
" = "
655 << result.arg_.a_ <<
", but it must be " << result.must_be_ <<
"!";
656 throw std::domain_error(s.str());
658 return std::make_tuple();
667template <
typename... T_results>
std::tuple< T_expressions... > expressions_
friend class adjoint_results_cl
expressions_cl(T_expressions &&... expressions)
Constructor.
Represents multiple expressions that will be calculated in same kernel.
Unique name generator for variables used in generated kernels.
void register_kernel_cache(cl::Kernel *cache)
Registers a cached kernel.
The API to access the methods and values in opencl_context_base.
void operator+=(const expressions_cl< T_expressions... > &exprs)
Incrementing results_ object by expressions_cl object executes the kernel that evaluates expressions ...
void operator*=(const expressions_cl< T_expressions... > &exprs)
Elementwise multiply results_ object by expressions_cl object executes the kernel that evaluates expr...
void operator=(const expressions_cl< T_expressions... > &exprs)
Assigning expressions_cl object to results_ object generates and executes the kernel that evaluates e...
results_cl(T_results &&... results)
Constructor.
void operator/=(const expressions_cl< T_expressions... > &exprs)
Elementwise divide results_ object by expressions_cl object executes the kernel that evaluates expres...
static std::tuple make_assignment_pair(T_check &&result, T_pass &&pass)
Checks on scalars are done separately in this overload instead of in kernel.
static void assignment_impl(const std::tuple< std::pair< T_res, T_expressions >... > &assignment_pairs)
Implementation of assignments of expressions to results.
std::string get_kernel_source_for_evaluating(const expressions_cl< T_expressions... > &exprs)
Generates kernel source for evaluating given expressions into results held by this.
static std::string get_kernel_source_impl(const std::tuple< std::pair< T_res, T_expressions >... > &assignment_pairs)
Implementation of kernel source generation.
std::tuple< T_results... > results_
void compound_assignment_impl(const expressions_cl< T_expressions... > &exprs)
Incrementing results_ object by expressions_cl object executes the kernel that evaluates expressions ...
void operator-=(const expressions_cl< T_expressions... > &exprs)
Decrement results_ object by expressions_cl object executes the kernel that evaluates expressions and...
static auto make_assignment_pair(T_result &&result, T_expression &&expression)
Makes a std::pair of one result and one expression and wraps it into a tuple.
static void assignment_impl(const std::tuple<> &)
Implementation of assignments of no expressions to no results.
Represents results that will be calculated in same kernel.
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
require_t< std::is_integral< std::decay_t< T > > > require_integral_t
Require type satisfies std::is_integral.
auto compile_kernel(const char *name, const std::vector< std::string > &sources, const std::unordered_map< std::string, int > &options)
Compile an OpenCL kernel.
opencl_context_base::map_base_opts & base_opts() noexcept
Returns a copy of the map of kernel defines.
cl::CommandQueue & queue() noexcept
Returns the reference to the active OpenCL command queue for the device.
results_cl< T_results... > results(T_results &&... results)
Deduces types for constructing results_cl object.
void operator=(bool condition)
Assignment of a scalar bool triggers the scalar check.
expressions_cl< T_expressions... > expressions(T_expressions &&... expressions)
Deduces types for constructing expressions_cl object.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
std::conditional_t< std::is_lvalue_reference< T >::value, decltype(as_operation_cl< AssignOp >(std::declval< T >())), std::remove_reference_t< decltype(as_operation_cl< AssignOp >(std::declval< T >()))> > as_operation_cl_t
Type that results when converting any valid kernel generator expression into operation.
std::integral_constant< bool, B > bool_constant
Alias for structs used for wraps a static constant of bool.
int colwise_reduction_wgs_rows(int n_rows, int n_cols)
Determine number of work groups in rows direction that will be run fro colwise reduction of given siz...
static constexpr double e()
Return the base of the natural logarithm.
void invalid_argument(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw an invalid_argument exception with a consistently formatted message.
constexpr auto index_apply(F &&f)
Calls given callable with an index sequence.
assign_op_cl
Ops that decide the type of assignment for LHS operations.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
std::enable_if_t<!math::disjunction< Checks... >::value > require_all_not_t
If all conditions are false, template is enabled.
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
Extends std::false_type when instantiated with zero or more template parameters, all of which extend ...
std::remove_reference_t< std::tuple_element_t< N, std::tuple< T_results... > > > T_current_result
typename multi_result_kernel_internal< N - 1, T_results... >::template inner< T_expressions... > next
std::remove_reference_t< std::tuple_element_t< N, std::tuple< T_expressions... > > > T_current_expression
static void get_clear_events(std::vector< cl::Event > &events, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
Generates list of all events kernel assigning expressions to results must wait on.
static void check_assign_dimensions(int n_rows, int n_cols, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
Assigns the dimensions of expressions to matching results if possible.
static void add_event(cl::Event e, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
Adds event to matrices used in kernel.
static std::map< std::vector< int >, cl::Kernel > kernel_cache_
static void get_unique_matrix_accesses(std::vector< int > &uids, std::unordered_map< const void *, int > &id_map, int &next_id, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
Collects data that is needed beside types to uniqly identify a kernel.
static kernel_parts generate(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &ng, const std::string &row_index_name, const std::string &col_index_name, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
Generates kernel source for assignment of expressions to results.
static void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
Sets kernel arguments.
static kernel_parts generate(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &ng, const std::string &row_index_name, const std::string &col_index_name, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
static void get_clear_events(std::vector< cl::Event > &events, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
static void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
static void add_event(cl::Event e, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
static void get_unique_matrix_accesses(std::vector< int > &uids, std::unordered_map< const void *, int > &id_map, int &next_id, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
static void check_assign_dimensions(int n_rows, int n_cols, const std::tuple< std::pair< T_results, T_expressions >... > &assignment_pairs)
std::string initialization
Parts of an OpenCL kernel, generated by an expression.