Automatic Differentiation
 
Loading...
Searching...
No Matches
operation_cl_lhs.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_OPERATION_CL_LHS_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_OPERATION_CL_LHS_HPP
3#ifdef STAN_OPENCL
4
7#include <string>
8#include <map>
9#include <array>
10#include <numeric>
11#include <vector>
12
13namespace stan {
14namespace math {
15
26template <typename Derived, typename Scalar, typename... Args>
27class operation_cl_lhs : public operation_cl<Derived, Scalar, Args...>,
29 protected:
30 using base = operation_cl<Derived, Scalar, Args...>;
31 static constexpr int N = sizeof...(Args);
32
33 public:
34 using base::derived;
36
50 std::unordered_map<const void*, const char*>& generated,
51 std::unordered_map<const void*, const char*>& generated_all,
52 name_generator& name_gen, const std::string& row_index_name,
53 const std::string& col_index_name) const {
54 if (generated.count(this) == 0) {
55 generated[this] = "";
56 this->var_name_ = name_gen.generate();
57 }
58 std::string row_index_name_arg = row_index_name;
59 std::string col_index_name_arg = col_index_name;
60 derived().modify_argument_indices(row_index_name_arg, col_index_name_arg);
61 std::array<kernel_parts, N> args_parts = index_apply<N>([&](auto... Is) {
62 std::unordered_map<const void*, const char*> generated2;
63 return std::array<kernel_parts, N>{
64 this->template get_arg<Is>().get_kernel_parts_lhs(
65 &Derived::modify_argument_indices
66 == &operation_cl<Derived, Scalar,
68 ? generated
69 : generated2,
70 generated_all, name_gen, row_index_name_arg,
71 col_index_name_arg)...};
72 });
73 kernel_parts res
74 = std::accumulate(args_parts.begin(), args_parts.end(), kernel_parts{});
75 kernel_parts my_part = index_apply<N>([&](auto... Is) {
76 return this->derived().generate_lhs(
77 row_index_name, col_index_name,
78 this->template get_arg<Is>().var_name_...);
79 });
80 res += my_part;
81 if (generated_all.count(this) == 0) {
82 generated_all[this] = "";
83 } else {
84 res.args = "";
85 }
86 return res;
87 }
88
98 inline kernel_parts generate_lhs(const std::string& row_index_name,
99 const std::string& col_index_name,
100 const std::string& var_name_arg) const {
101 return {};
102 }
103
109 template <typename T_expression,
110 typename
112 Derived& operator=(T_expression&& rhs) {
113 auto expression
114 = as_operation_cl(std::forward<T_expression>(rhs)).derived();
115 int this_rows = derived().rows();
116 int this_cols = derived().cols();
117 if (this_rows == expression.rows() && this_cols == expression.cols()
118 && this_rows * this_cols == 0) {
119 return derived();
120 }
121 expression.evaluate_into(derived());
122 return derived();
123 }
124 // Copy assignment delegates to general assignment operator. If we didn't
125 // implement this, we would get ambiguities in overload resolution with
126 // implicitly generated one
127 inline const operation_cl_lhs<Derived, Scalar, Args...>& operator=(
129 return operator=<const operation_cl_lhs<Derived, Scalar, Args...>&>(rhs);
130 }
131
146 inline void set_view(int bottom_diagonal, int top_diagonal,
147 int bottom_zero_diagonal, int top_zero_diagonal) const {
148 index_apply<N>([&](auto... Is) {
149 static_cast<void>(std::initializer_list<int>{
150 (this->template get_arg<Is>().set_view(bottom_diagonal, top_diagonal,
151 bottom_zero_diagonal,
152 top_zero_diagonal),
153 0)...});
154 });
155 }
156
165 inline void check_assign_dimensions(int rows, int cols) const {
166 index_apply<N>([&](auto... Is) {
167 static_cast<void>(std::initializer_list<int>{
168 (this->template get_arg<Is>().check_assign_dimensions(rows, cols),
169 0)...});
170 });
171 }
172
177 inline void add_write_event(cl::Event& e) const {
178 index_apply<N>([&](auto... Is) {
179 static_cast<void>(std::initializer_list<int>{
180 (this->template get_arg<Is>().add_write_event(e), 0)...});
181 });
182 }
183
190 std::vector<cl::Event>& events) const {
191 index_apply<N>([&](auto... Is) {
192 static_cast<void>(std::initializer_list<int>{
193 (this->template get_arg<Is>().get_clear_read_write_events(events),
194 0)...});
195 });
196 }
197};
199} // namespace math
200} // namespace stan
201
202#endif
203#endif
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const
Sets the view of the underlying matrix depending on which of its parts are written to.
void add_write_event(cl::Event &e) const
Adds write event to any matrices used by nested expressions.
void get_clear_read_write_events(std::vector< cl::Event > &events) const
Adds all read and write events on any matrices used by nested expressions to a list and clears them f...
Derived & operator=(T_expression &&rhs)
Evaluates an expression and assigns it to this.
void check_assign_dimensions(int rows, int cols) const
Sets the dimensions of the underlying expressions if possible.
kernel_parts generate_lhs(const std::string &row_index_name, const std::string &col_index_name, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions if this expression appears on the left hand sid...
kernel_parts get_kernel_parts_lhs(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name) const
Generates kernel code for this expression if it appears on the left hand side of an assignment.
const operation_cl_lhs< Derived, Scalar, Args... > & operator=(const operation_cl_lhs< Derived, Scalar, Args... > &rhs) const
Base for all kernel generator operations that can be used on left hand side of an expression.
Derived & derived()
Casts the instance into its derived type.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Does nothing.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Base for all kernel generator operations.
Non-templated base of operation_cl_lhs is needed for easy checking if something is a subclass of oper...
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Parts of an OpenCL kernel, generated by an expression.