Automatic Differentiation
 
Loading...
Searching...
No Matches
opencl_code.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_OPENCL_CODE_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_OPENCL_CODE_HPP
3#ifdef STAN_OPENCL
4
13#include <algorithm>
14#include <set>
15#include <string>
16#include <type_traits>
17#include <utility>
18
19namespace stan {
20namespace math {
21
31template <typename T_code, typename T_scalar>
33 : public operation_cl<opencl_code_output<T_code, T_scalar>, T_scalar,
34 T_code> {
35 public:
36 using Scalar = T_scalar;
37 using base
39 using base::var_name_;
40 const char* custom_var_name_;
41
47 opencl_code_output(T_code code, const char* custom_var_name)
48 : base(std::move(code)), custom_var_name_(custom_var_name) {}
49
54 inline auto deep_copy() const {
55 auto&& code_copy = this->template get_arg<0>().deep_copy();
56 return opencl_code_output<std::remove_reference_t<decltype(code_copy)>,
57 T_scalar>(std::move(code_copy), custom_var_name_);
58 }
59
68 inline kernel_parts generate(const std::string& row_index_name,
69 const std::string& col_index_name,
70 const bool view_handled,
71 const std::string& dummy_var_name_code) const {
72 kernel_parts res{};
73 res.body = type_str<Scalar>() + " " + var_name_ + " = " + custom_var_name_
74 + ";\n";
75 return res;
76 }
77
82 inline std::pair<int, int> extreme_diagonals() const {
83 return {-this->rows() + 1, this->cols() - 1};
84 }
85};
86
87namespace internal {
88template <const char* Code, typename... T_arguments>
90 : public operation_cl<opencl_code_impl<Code, T_arguments...>, double,
91 T_arguments...> {
92 public:
93 // using Scalar = double;
94 using base = operation_cl<opencl_code_impl<Code, T_arguments...>, double,
95 T_arguments...>;
97 = std::tuple<typename std::pair<const char*, T_arguments>::first_type...>;
98 using base::var_name_;
100
107 explicit opencl_code_impl(names_tuple names, T_arguments&&... arguments)
108 : base(std::forward<T_arguments>(arguments)...), names_(names) {}
109
120 const std::string& row_index_name, const std::string& col_index_name,
121 const bool view_handled,
122 std::tuple_element_t<
123 0, std::pair<const std::string&, T_arguments>>... var_names) const {
124 return index_apply<sizeof...(T_arguments)>([this](auto... Is) {
125 kernel_parts res{};
126 std::array<std::string, sizeof...(T_arguments)> input_renames{
127 (type_str<scalar_type_t<decltype(this->template get_arg<Is>())>>()
128 + " " + std::get<Is>(names_) + " = "
129 + this->template get_arg<Is>().var_name_ + ";\n")...};
130 res.body = std::accumulate(input_renames.begin(), input_renames.end(),
131 std::string())
132 + Code;
133 return res;
134 });
135 }
136};
137} // namespace internal
138
144template <const char* Code, typename... T_arguments>
146 public:
147 std::shared_ptr<internal::opencl_code_impl<Code, T_arguments...>> impl_;
148 std::string& var_name_;
150 = std::tuple<typename std::pair<const char*, T_arguments>::first_type...>;
151 using Deriv = internal::opencl_code_impl<Code, T_arguments...>;
152
159 explicit opencl_code_(const names_tuple& names, T_arguments&&... arguments)
160 : impl_(
161 std::make_shared<internal::opencl_code_impl<Code, T_arguments...>>(
162 names, std::forward<T_arguments>(arguments)...)),
164
170 : impl_(other.impl_), var_name_(impl_->var_name_) {}
171
185 std::unordered_map<const void*, const char*>& generated,
186 std::unordered_map<const void*, const char*>& generated_all,
187 name_generator& name_gen, const std::string& row_index_name,
188 const std::string& col_index_name, bool view_handled) const {
189 return impl_->get_kernel_parts(generated, generated_all, name_gen,
190 row_index_name, col_index_name,
191 view_handled);
192 }
193
204 auto set_args(std::unordered_map<const void*, const char*>& generated,
205 std::unordered_map<const void*, const char*>& generated_all,
206 cl::Kernel& kernel, int& arg_num) const {
207 return impl_->set_args(generated, generated_all, kernel, arg_num);
208 }
209
214 auto add_read_event(cl::Event& e) const { return impl_->add_read_event(e); }
215
220 auto get_write_events(std::vector<cl::Event>& events) const {
221 return impl_->get_write_events(events);
222 }
223
229 template <int N = sizeof...(T_arguments),
230 std::enable_if_t<(N > 0)>* = nullptr>
231 auto rows() const {
232 return impl_->rows();
233 }
234 template <int N = sizeof...(T_arguments),
235 std::enable_if_t<(N == 0)>* = nullptr>
236 auto rows() const {
237 return -1;
238 }
239
245 template <int N = sizeof...(T_arguments),
246 std::enable_if_t<(N > 0)>* = nullptr>
247 auto cols() const {
248 return impl_->cols();
249 }
250 template <int N = sizeof...(T_arguments),
251 std::enable_if_t<(N == 0)>* = nullptr>
252 auto cols() const {
253 return -1;
254 }
255
261 template <int N = sizeof...(T_arguments),
262 std::enable_if_t<(N > 0)>* = nullptr>
263 auto thread_rows() const {
264 return impl_->thread_rows();
265 }
266 template <int N = sizeof...(T_arguments),
267 std::enable_if_t<(N == 0)>* = nullptr>
268 auto thread_rows() const {
269 return -1;
270 }
271
277 template <int N = sizeof...(T_arguments),
278 std::enable_if_t<(N > 0)>* = nullptr>
279 auto thread_cols() const {
280 return impl_->thread_cols();
281 }
282 template <int N = sizeof...(T_arguments),
283 std::enable_if_t<(N == 0)>* = nullptr>
284 auto thread_cols() const {
285 return -1;
286 }
287
292 auto extreme_diagonals() const { return impl_->extreme_diagonals(); }
293
301 auto get_unique_matrix_accesses(std::vector<int>& uids,
302 std::unordered_map<const void*, int>& id_map,
303 int& next_id) const {
304 return impl_->get_unique_matrix_accesses(uids, id_map, next_id);
305 }
306
311 inline auto deep_copy() const {
312 return index_apply<sizeof...(T_arguments)>([this](auto... Is) {
313 auto args_copy
314 = std::make_tuple(this->impl_->template get_arg<Is>().deep_copy()...);
315 return opencl_code_<
316 Code, std::remove_reference_t<decltype(std::get<Is>(args_copy))>...>(
317 this->impl_->names_, std::move(std::get<Is>(args_copy))...);
318 });
319 }
320
325 template <typename T_scalar>
326 inline auto output(const char* var_name) const {
327 return opencl_code_output<opencl_code_<Code, T_arguments...>, T_scalar>(
328 *this, var_name);
329 }
330};
331
340template <const char* Code, typename... T_arguments,
341 require_all_kernel_expressions_t<T_arguments...>* = nullptr>
342inline auto opencl_code(
343 std::tuple<typename std::pair<const char*, T_arguments>::first_type...>
344 names,
345 T_arguments&&... arguments) {
347 names, as_operation_cl(std::forward<T_arguments>(arguments))...);
348}
349
351} // namespace math
352} // namespace stan
353#endif
354#endif
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, std::tuple_element_t< 0, std::pair< const std::string &, T_arguments > >... var_names) const
Generates kernel code for this (select) operation.
std::tuple< typename std::pair< const char *, T_arguments >::first_type... > names_tuple
opencl_code_impl(names_tuple names, T_arguments &&... arguments)
Constructor.
Unique name generator for variables used in generated kernels.
std::shared_ptr< internal::opencl_code_impl< Code, T_arguments... > > impl_
auto thread_rows() const
Number of rows threads need to be launched for.
opencl_code_(const names_tuple &names, T_arguments &&... arguments)
Constructor.
auto get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
opencl_code_(const opencl_code_< Code, T_arguments... > &other)
Copy constructor.
auto get_write_events(std::vector< cl::Event > &events) const
Adds all write events on any matrices used by nested expressions to a list.
auto get_unique_matrix_accesses(std::vector< int > &uids, std::unordered_map< const void *, int > &id_map, int &next_id) const
Collects data that is needed beside types to uniqly identify a kernel generator expression.
auto thread_cols() const
Number of columns threads need to be launched for.
auto output(const char *var_name) const
Get object representing output variable of ccustom code.
auto add_read_event(cl::Event &e) const
Adds read event to any matrices used by nested expressions.
auto deep_copy() const
Creates a deep copy of this expression.
auto extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
auto set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for nested expressions.
auto rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
std::tuple< typename std::pair< const char *, T_arguments >::first_type... > names_tuple
auto cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Represents custom code in kernel generator expressions.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &dummy_var_name_code) const
Generates kernel code for this operation.
opencl_code_output(T_code code, const char *custom_var_name)
Constructor.
auto deep_copy() const
Creates a deep copy of this expression.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Represents output variable of custom code in kernel generator expressions.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Base for all kernel generator operations.
Non-templated base of operation_cl is needed for easy checking if something is a subclass of operatio...
auto opencl_code(std::tuple< typename std::pair< const char *, T_arguments >::first_type... > names, T_arguments &&... arguments)
Custom code in kernel generator expressions.
require_all_t< is_kernel_expression< Types >... > require_all_kernel_expressions_t
Enables a template if all given types are are a valid kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
std::string type_str()
Determines a string name of a type.
Definition type_str.hpp:14
constexpr auto index_apply(F &&f)
Calls given callable with an index sequence.
typename scalar_type< T >::type scalar_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
Parts of an OpenCL kernel, generated by an expression.