Automatic Differentiation
 
Loading...
Searching...
No Matches
operation_cl.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_OPERATION_CL_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_OPERATION_CL_HPP
3#ifdef STAN_OPENCL
4
13#include <CL/opencl.hpp>
14#include <algorithm>
15#include <string>
16#include <utility>
17#include <tuple>
18#include <map>
19#include <array>
20#include <numeric>
21#include <vector>
22
23namespace stan {
24namespace math {
25
33 std::string includes; // any function definitions - as if they were included
34 // at the start of kernel source
35 std::string declarations; // declarations of any local variables
36 std::string initialization; // the code for initializations done by all
37 // threads, even if they have no work
38 std::string
39 body_prefix; // the code that should be run at the start of the kernel
40 // body (before the code for arguments of an operation)
41 std::string body; // the body of the kernel - code executing operations
42 std::string body_suffix; // the code that should be run at the end of the
43 // kernel body
44 std::string
45 reduction_1d; // the code for reductions within work group by all
46 // threads, even if they have no work. Run once per column.
47 std::string
48 reduction_2d; // the code for reductions within work group by all
49 // threads, even if they have no work. Run only once.
50 std::string args; // kernel arguments
51
53 return {includes + other.includes,
57 body + other.body,
61 args + other.args};
62 }
63
65 includes += other.includes;
69 body += other.body;
70 body_suffix += other.body_suffix;
73 args += other.args;
74 return *this;
75 }
76};
77
78inline std::ostream& operator<<(std::ostream& os, kernel_parts& parts) {
79 os << "args:" << std::endl;
80 os << parts.args.substr(0, parts.args.size() - 2) << std::endl;
81 os << "Decl:" << std::endl;
82 os << parts.declarations << std::endl;
83 os << "Init:" << std::endl;
84 os << parts.initialization << std::endl;
85 os << "body:" << std::endl;
86 os << parts.body << std::endl;
87 os << "body_suffix:" << std::endl;
88 os << parts.body_suffix << std::endl;
89 os << "reduction_1d:" << std::endl;
90 os << parts.reduction_1d << std::endl;
91 os << "reduction_2d:" << std::endl;
92 os << parts.reduction_2d << std::endl;
93 return os;
94}
95
102template <typename Derived, typename Scalar, typename... Args>
104 static_assert(
105 conjunction<std::is_base_of<operation_cl_base,
106 std::remove_reference_t<Args>>...>::value,
107 "operation_cl: all arguments to operation must be operations!");
108
109 protected:
110 std::tuple<Args...> arguments_;
111 mutable std::string var_name_; // name of the variable that holds result of
112 // this operation in the kernel
113
114 public:
119 inline Derived& derived() { return *static_cast<Derived*>(this); }
120
125 inline const Derived& derived() const {
126 return *static_cast<const Derived*>(this);
127 }
128
129 using Deriv = Derived;
130 using ArgsTuple = std::tuple<Args...>;
132 // number of arguments this operation has
133 static constexpr int N = sizeof...(Args);
134 using view_transitivity = std::tuple<std::is_same<Args, void>...>;
135 // value representing a not yet determined size
136 static constexpr int dynamic = -1;
137
142 template <size_t N>
143 const auto& get_arg() const {
144 return std::get<N>(arguments_);
145 }
146
152 explicit operation_cl(Args&&... arguments)
153 : arguments_(std::forward<Args>(arguments)...) {}
154
160 int rows = derived().rows();
161 int cols = derived().cols();
162 const char* function = "operation_cl.eval()";
163 if (rows < 0) {
164 invalid_argument(function, "Number of rows of expression", rows,
165 " must be nonnegative, but is ",
166 " (broadcasted expressions can not be evaluated)");
167 }
168 if (cols < 0) {
169 invalid_argument(function, "Number of columns of expression", cols,
170 " must be nonnegative, but is ",
171 " (broadcasted expressions can not be evaluated)");
172 }
174 if (res.size() > 0) {
175 this->evaluate_into(res);
176 }
177 return res;
178 }
179
187 template <typename T_lhs>
188 inline void evaluate_into(T_lhs& lhs) const;
189
197 template <typename T_lhs>
198 inline std::string get_kernel_source_for_evaluating_into(
199 const T_lhs& lhs) const;
200
213 template <typename T_result>
215 std::unordered_map<const void*, const char*>& generated,
216 std::unordered_map<const void*, const char*>& generated_all,
217 name_generator& ng, const std::string& row_index_name,
218 const std::string& col_index_name, const T_result& result) const {
219 kernel_parts parts = derived().get_kernel_parts(
220 generated, generated_all, ng, row_index_name, col_index_name, false);
221 kernel_parts out_parts = result.get_kernel_parts_lhs(
222 generated, generated_all, ng, row_index_name, col_index_name);
223 out_parts.body += assignment_op<T_result>() + derived().var_name_ + ";\n";
224 parts += out_parts;
225 return parts;
226 }
227
241 std::unordered_map<const void*, const char*>& generated,
242 std::unordered_map<const void*, const char*>& generated_all,
243 name_generator& name_gen, const std::string& row_index_name,
244 const std::string& col_index_name, bool view_handled) const {
245 kernel_parts res{};
246 if (generated.count(this) == 0) {
247 this->var_name_ = name_gen.generate();
248 generated[this] = "";
249 std::string row_index_name_arg = row_index_name;
250 std::string col_index_name_arg = col_index_name;
251 derived().modify_argument_indices(row_index_name_arg, col_index_name_arg);
252 std::array<kernel_parts, N> args_parts = index_apply<N>([&](auto... Is) {
253 std::unordered_map<const void*, const char*> generated2;
254 return std::array<kernel_parts, N>{this->get_arg<Is>().get_kernel_parts(
255 &Derived::modify_argument_indices
257 ? generated
258 : generated2,
259 generated_all, name_gen, row_index_name_arg, col_index_name_arg,
260 view_handled
261 && std::tuple_element_t<
262 Is, typename Deriv::view_transitivity>::value)...};
263 });
264 res = std::accumulate(args_parts.begin(), args_parts.end(),
265 kernel_parts{});
266 kernel_parts my_part = index_apply<N>([&](auto... Is) {
267 return this->derived().generate(row_index_name, col_index_name,
268 view_handled,
269 this->get_arg<Is>().var_name_...);
270 });
271 res += my_part;
272 res.body = res.body_prefix + res.body;
273 res.body_prefix = "";
274 }
275 return res;
276 }
277
286 inline kernel_parts generate(const std::string& row_index_name,
287 const std::string& col_index_name,
288 const bool view_handled,
289 const std::string& var_name_arg) const {
290 var_name_ = var_name_arg;
291 return {};
292 }
293
302 inline void modify_argument_indices(std::string& row_index_name,
303 std::string& col_index_name) const {}
304
315 inline void set_args(
316 std::unordered_map<const void*, const char*>& generated,
317 std::unordered_map<const void*, const char*>& generated_all,
318 cl::Kernel& kernel, int& arg_num) const {
319 if (generated.count(this) == 0) {
320 generated[this] = "";
321 // parameter pack expansion returns a comma-separated list of values,
322 // which can not be used as an expression. We work around that by using
323 // comma operator to get a list of ints, which we use to construct an
324 // initializer_list from. Cast to voids avoids warnings about unused
325 // expression.
326 index_apply<N>([&](auto... Is) {
327 std::unordered_map<const void*, const char*> generated2;
328 static_cast<void>(std::initializer_list<int>{
329 (this->get_arg<Is>().set_args(
330 &Derived::modify_argument_indices
332 ? generated
333 : generated2,
334 generated_all, kernel, arg_num),
335 0)...});
336 });
337 }
338 }
339
344 inline void add_read_event(cl::Event& e) const {
345 index_apply<N>([&](auto... Is) {
346 static_cast<void>(std::initializer_list<int>{
347 (this->get_arg<Is>().add_read_event(e), 0)...});
348 });
349 }
350
355 inline void get_write_events(std::vector<cl::Event>& events) const {
356 index_apply<N>([&](auto... Is) {
357 static_cast<void>(std::initializer_list<int>{
358 (this->template get_arg<Is>().get_write_events(events), 0)...});
359 });
360 }
361
367 inline int rows() const {
368 static_assert(
369 N > 0, "default rows does not work on expressions with no arguments!");
370 return index_apply<N>([&](auto... Is) {
371 // assuming all non-dynamic sizes match
372 return std::max({this->get_arg<Is>().rows()...});
373 });
374 }
375
381 inline int cols() const {
382 static_assert(
383 N > 0, "default cols does not work on expressions with no arguments!");
384 return index_apply<N>([&](auto... Is) {
385 // assuming all non-dynamic sizes match
386 return std::max({this->get_arg<Is>().cols()...});
387 });
388 }
389
395 inline int size() const { return derived().rows() * derived().cols(); }
396
402 inline int thread_rows() const { return derived().rows(); }
403
409 inline int thread_cols() const { return derived().cols(); }
410
416 inline std::pair<int, int> extreme_diagonals() const {
417 static_assert(N > 0,
418 "default extreme_diagonals does not work on expressions with "
419 "no arguments!");
420 return index_apply<N>([&](auto... Is) {
421 auto arg_diags
422 = std::make_tuple(this->get_arg<Is>().extreme_diagonals()...);
423 int bottom = std::min(
424 std::initializer_list<int>({std::get<Is>(arg_diags).first...}));
425 int top = std::max(
426 std::initializer_list<int>({std::get<Is>(arg_diags).second...}));
427 return std::make_pair(bottom, top);
428 });
429 }
430
435 inline matrix_cl_view view() const {
436 std::pair<int, int> diagonals = derived().extreme_diagonals();
438 if (diagonals.first < 0) {
440 } else {
442 }
443 if (diagonals.second > 0) {
445 }
446 return view;
447 }
448
457 std::vector<int>& uids, std::unordered_map<const void*, int>& id_map,
458 int& next_id) const {
459 index_apply<N>([&](auto... Is) {
460 static_cast<void>(std::initializer_list<int>{(
461 this->get_arg<Is>().get_unique_matrix_accesses(uids, id_map, next_id),
462 0)...});
463 });
464 }
465};
466
467template <typename Derived, typename Scalar, typename... Args>
468const bool operation_cl<Derived, Scalar, Args...>::require_specific_local_size
469 = std::max({false,
470 std::decay_t<Args>::Deriv::require_specific_local_size...});
472} // namespace math
473} // namespace stan
474
475#endif
476#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
const auto & get_arg() const
Returns an argument to this operation.
matrix_cl_view view() const
View of a matrix that would be the result of evaluating this expression.
static constexpr int dynamic
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Derived & derived()
Casts the instance into its derived type.
static constexpr int N
std::tuple< Args... > ArgsTuple
kernel_parts get_whole_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &ng, const std::string &row_index_name, const std::string &col_index_name, const T_result &result) const
Generates kernel code for assigning this expression into result expression.
operation_cl(Args &&... arguments)
Constructor.
void get_write_events(std::vector< cl::Event > &events) const
Adds all write events on any matrices used by nested expressions to a list.
matrix_cl< Scalar > eval() const
Evaluates the expression.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this expression.
void add_read_event(cl::Event &e) const
Adds read event to any matrices used by nested expressions.
std::tuple< Args... > arguments_
int thread_rows() const
Number of rows threads need to be launched for.
int size() const
Size of a matrix that would be the result of evaluating this expression.
void get_unique_matrix_accesses(std::vector< int > &uids, std::unordered_map< const void *, int > &id_map, int &next_id) const
Collects data that is needed beside types to uniqly identify a kernel generator expression.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for nested expressions.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Does nothing.
const Derived & derived() const
Casts the instance into its derived type.
int thread_cols() const
Number of columns threads need to be launched for.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
std::tuple< std::is_same< Args, void >... > view_transitivity
Base for all kernel generator operations.
Represents a two dimensional reduction in kernel generator expressions.
Non-templated base of operation_cl is needed for easy checking if something is a subclass of operatio...
std::string get_kernel_source_for_evaluating_into(const T_lhs &lhs) const
Generates kernel source for evaluating this expression into given left-hand-side expression.
void evaluate_into(T_lhs &lhs) const
Evaluates this expression into given left-hand-side expression.
static const bool require_specific_local_size
std::ostream & operator<<(std::ostream &os, kernel_parts &parts)
const matrix_cl_view either(const matrix_cl_view left_view, const matrix_cl_view right_view)
Determines which parts are nonzero in any of the input views.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
void invalid_argument(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw an invalid_argument exception with a consistently formatted message.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
kernel_parts & operator+=(const kernel_parts &other)
kernel_parts operator+(const kernel_parts &other)
Parts of an OpenCL kernel, generated by an expression.