1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_OPERATION_CL_LHS_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_OPERATION_CL_LHS_HPP
26template <
typename Derived,
typename Scalar,
typename... Args>
31 static constexpr int N =
sizeof...(Args);
50 std::unordered_map<const void*, const char*>& generated,
51 std::unordered_map<const void*, const char*>& generated_all,
53 const std::string& col_index_name)
const {
54 if (generated.count(
this) == 0) {
58 std::string row_index_name_arg = row_index_name;
59 std::string col_index_name_arg = col_index_name;
60 derived().modify_argument_indices(row_index_name_arg, col_index_name_arg);
61 std::array<kernel_parts, N> args_parts = index_apply<N>([&](
auto... Is) {
62 std::unordered_map<const void*, const char*> generated2;
63 return std::array<kernel_parts, N>{
64 this->
template get_arg<Is>().get_kernel_parts_lhs(
65 &Derived::modify_argument_indices
70 generated_all, name_gen, row_index_name_arg,
71 col_index_name_arg)...};
74 = std::accumulate(args_parts.begin(), args_parts.end(),
kernel_parts{});
76 return this->
derived().generate_lhs(
77 row_index_name, col_index_name,
78 this->
template get_arg<Is>().
var_name_...);
81 if (generated_all.count(
this) == 0) {
82 generated_all[
this] =
"";
99 const std::string& col_index_name,
100 const std::string& var_name_arg)
const {
109 template <
typename T_expression,
115 int this_rows =
derived().rows();
116 int this_cols =
derived().cols();
117 if (this_rows == expression.rows() && this_cols == expression.cols()
118 && this_rows * this_cols == 0) {
121 expression.evaluate_into(
derived());
146 inline void set_view(
int bottom_diagonal,
int top_diagonal,
147 int bottom_zero_diagonal,
int top_zero_diagonal)
const {
148 index_apply<N>([&](
auto... Is) {
149 static_cast<void>(std::initializer_list<int>{
150 (this->
template get_arg<Is>().set_view(bottom_diagonal, top_diagonal,
151 bottom_zero_diagonal,
166 index_apply<N>([&](
auto... Is) {
167 static_cast<void>(std::initializer_list<int>{
168 (this->
template get_arg<Is>().check_assign_dimensions(
rows,
cols),
178 index_apply<N>([&](
auto... Is) {
179 static_cast<void>(std::initializer_list<int>{
180 (this->
template get_arg<Is>().add_write_event(
e), 0)...});
190 std::vector<cl::Event>& events)
const {
191 index_apply<N>([&](
auto... Is) {
192 static_cast<void>(std::initializer_list<int>{
193 (this->
template get_arg<Is>().get_clear_read_write_events(events),
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const
Sets the view of the underlying matrix depending on which of its parts are written to.
void add_write_event(cl::Event &e) const
Adds write event to any matrices used by nested expressions.
void get_clear_read_write_events(std::vector< cl::Event > &events) const
Adds all read and write events on any matrices used by nested expressions to a list and clears them f...
Derived & operator=(T_expression &&rhs)
Evaluates an expression and assigns it to this.
void check_assign_dimensions(int rows, int cols) const
Sets the dimensions of the underlying expressions if possible.
kernel_parts generate_lhs(const std::string &row_index_name, const std::string &col_index_name, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions if this expression appears on the left hand sid...
kernel_parts get_kernel_parts_lhs(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name) const
Generates kernel code for this expression if it appears on the left hand side of an assignment.
const operation_cl_lhs< Derived, Scalar, Args... > & operator=(const operation_cl_lhs< Derived, Scalar, Args... > &rhs) const
Base for all kernel generator operations that can be used on left hand side of an expression.
Derived & derived()
Casts the instance into its derived type.
operation_cl(Args &&... arguments)
Constructor.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Does nothing.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Base for all kernel generator operations.
Non-templated base of operation_cl_lhs is needed for easy checking if something is a subclass of oper...
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
static constexpr double e()
Return the base of the natural logarithm.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.