Automatic Differentiation
 
Loading...
Searching...
No Matches
block_zero_based.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_BLOCK_ZERO_BASED_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_BLOCK_ZERO_BASED_HPP
3#ifdef STAN_OPENCL
4
12#include <map>
13#include <string>
14#include <tuple>
15#include <type_traits>
16#include <utility>
17
18namespace stan {
19namespace math {
20
30template <typename T>
31class block_
32 : public operation_cl_lhs<block_<T>,
33 typename std::remove_reference_t<T>::Scalar, T> {
34 public:
35 using Scalar = typename std::remove_reference_t<T>::Scalar;
37 using base::var_name_;
38 using view_transitivity = std::tuple<std::true_type>;
39 using base::operator=;
40
41 protected:
43
44 public:
53 block_(T&& a, int start_row, int start_col, int rows, int cols)
54 : base(std::forward<T>(a)),
55 start_row_(start_row),
56 start_col_(start_col),
57 rows_(rows),
58 cols_(cols) {
59 if (start_col < 0) {
60 invalid_argument("block", "start_col", start_col,
61 " should be non-negative, but is ");
62 }
63 if (start_row < 0) {
64 invalid_argument("block", "start_row", start_row,
65 " should be non-negative, but is ");
66 }
67 if (rows < 0) {
68 invalid_argument("block", "rows", rows,
69 " should be non-negative, but is ");
70 }
71 if (cols < 0) {
72 invalid_argument("block", "cols", cols,
73 " should be non-negative, but is ");
74 }
75 if ((a.rows() != base::dynamic && (start_row + rows) > a.rows())
76 || (a.cols() != base::dynamic && (start_col + cols) > a.cols())) {
77 invalid_argument("block", "block of \"a\"", " is out of bounds", "");
78 }
79 }
80
85 inline auto deep_copy() const {
86 auto&& arg_copy = this->template get_arg<0>().deep_copy();
87 return block_<std::remove_reference_t<decltype(arg_copy)>>{
88 std::move(arg_copy), start_row_, start_col_, rows_, cols_};
89 }
90
104 std::unordered_map<const void*, const char*>& generated,
105 std::unordered_map<const void*, const char*>& generated_all,
106 name_generator& name_gen, const std::string& row_index_name,
107 const std::string& col_index_name, bool view_handled) const {
108 kernel_parts res{};
109 if (generated.count(this) == 0) {
110 this->var_name_ = name_gen.generate();
111 generated[this] = "";
112 std::string row_index_name_arg = row_index_name;
113 std::string col_index_name_arg = col_index_name;
114 modify_argument_indices(row_index_name_arg, col_index_name_arg);
115 std::unordered_map<const void*, const char*> generated2;
116 res = this->template get_arg<0>().get_kernel_parts(
117 generated2, generated_all, name_gen, row_index_name_arg,
118 col_index_name_arg, view_handled);
119 kernel_parts my
120 = this->generate(row_index_name, col_index_name, view_handled,
121 this->template get_arg<0>().var_name_);
122 if (generated_all.count(this) == 0) {
123 generated_all[this] = "";
124 } else {
125 my.args = "";
126 }
127 res += my;
128 res.body = res.body_prefix + res.body;
129 res.body_prefix = "";
130 }
131 return res;
132 }
133
143 inline kernel_parts generate(const std::string& row_index_name,
144 const std::string& col_index_name,
145 const bool view_handled,
146 const std::string& var_name_arg) const {
147 kernel_parts res;
148 res.body
149 = type_str<Scalar>() + " " + var_name_ + " = " + var_name_arg + ";\n";
150 res.args = "int " + var_name_ + "_i, int " + var_name_ + "_j, ";
151 return res;
152 }
153
159 inline void modify_argument_indices(std::string& row_index_name,
160 std::string& col_index_name) const {
161 row_index_name = "(" + row_index_name + " + " + var_name_ + "_i)";
162 col_index_name = "(" + col_index_name + " + " + var_name_ + "_j)";
163 }
164
178 std::unordered_map<const void*, const char*>& generated,
179 std::unordered_map<const void*, const char*>& generated_all,
180 name_generator& name_gen, const std::string& row_index_name,
181 const std::string& col_index_name) const {
182 if (generated.count(this) == 0) {
183 generated[this] = "";
184 this->var_name_ = name_gen.generate();
185 }
186 std::string row_index_name_arg = row_index_name;
187 std::string col_index_name_arg = col_index_name;
188 modify_argument_indices(row_index_name_arg, col_index_name_arg);
189 std::unordered_map<const void*, const char*> generated2;
190 kernel_parts res = this->template get_arg<0>().get_kernel_parts_lhs(
191 generated2, generated_all, name_gen, row_index_name_arg,
192 col_index_name_arg);
193 res += this->derived().generate_lhs(row_index_name, col_index_name,
194 this->template get_arg<0>().var_name_);
195 if (generated_all.count(this) == 0) {
196 generated_all[this] = "";
197 } else {
198 res.args = "";
199 }
200 return res;
201 }
202
212 inline kernel_parts generate_lhs(const std::string& i, const std::string& j,
213 const std::string& var_name_arg) const {
214 kernel_parts res;
215 res.args = "int " + var_name_ + "_i, int " + var_name_ + "_j, ";
216 return res;
217 }
218
229 inline void set_args(
230 std::unordered_map<const void*, const char*>& generated,
231 std::unordered_map<const void*, const char*>& generated_all,
232 cl::Kernel& kernel, int& arg_num) const {
233 if (generated.count(this) == 0) {
234 generated[this] = "";
235 std::unordered_map<const void*, const char*> generated2;
236 this->template get_arg<0>().set_args(generated2, generated_all, kernel,
237 arg_num);
238 if (generated_all.count(this) == 0) {
239 generated_all[this] = "";
240 kernel.setArg(arg_num++, start_row_);
241 kernel.setArg(arg_num++, start_col_);
242 }
243 }
244 }
245
251 inline int rows() const { return rows_; }
252
258 inline int cols() const { return cols_; }
259
273 inline void set_view(int bottom_diagonal, int top_diagonal,
274 int bottom_zero_diagonal, int top_zero_diagonal) const {
275 int change = start_col_ - start_row_;
276 auto& a = this->template get_arg<0>();
277 a.set_view(
278 bottom_diagonal + change, top_diagonal + change,
279 (start_col_ == 0 && start_row_ <= 1 && start_row_ + rows_ == a.rows()
280 && start_col_ + cols_ >= std::min(a.rows() - 1, a.cols())
281 ? bottom_zero_diagonal
282 : bottom_diagonal)
283 + change,
284 (start_row_ == 0 && start_col_ <= 1 && start_col_ + cols_ == a.cols()
285 && start_row_ + rows_ >= std::min(a.rows(), a.cols() - 1)
286 ? top_zero_diagonal
287 : top_diagonal)
288 + change);
289 }
290
295 inline std::pair<int, int> extreme_diagonals() const {
296 std::pair<int, int> arg_diags
297 = this->template get_arg<0>().extreme_diagonals();
298 return {arg_diags.first - start_col_ + start_row_,
299 arg_diags.second - start_col_ + start_row_};
300 }
301
309 inline void check_assign_dimensions(int rows, int cols) const {
310 check_size_match("block_.check_assign_dimensions", "Rows of ", "block",
311 rows_, "rows of ", "expression", rows);
312 check_size_match("block_.check_assign_dimensions", "Columns of ", "block",
313 cols_, "columns of ", "expression", cols);
314 }
315};
316
338template <typename T,
340inline auto block_zero_based(T&& a, int start_row, int start_col, int rows,
341 int cols) {
342 auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
343 return block_<std::remove_reference_t<decltype(a_operation)>>(
344 std::move(a_operation), start_row, start_col, rows, cols);
345}
347} // namespace math
348} // namespace stan
349
350#endif
351#endif
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this expression.
typename std::remove_reference_t< T >::Scalar Scalar
kernel_parts get_kernel_parts_lhs(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name) const
Generates kernel code for this expression if it appears on the left hand side of an assignment.
kernel_parts generate_lhs(const std::string &i, const std::string &j, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions if this expression appears on the left hand sid...
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Sets offset of block to indices of the argument expression.
auto deep_copy() const
Creates a deep copy of this expression.
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
void check_assign_dimensions(int rows, int cols) const
Checks if desired dimensions match dimensions of the block.
void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const
Sets view of the underlying matrix depending on which part is written.
block_(T &&a, int start_row, int start_col, int rows, int cols)
Constructor.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
std::tuple< std::true_type > view_transitivity
Represents submatrix block in kernel generator expressions.
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
Base for all kernel generator operations that can be used on left hand side of an expression.
Derived & derived()
Casts the instance into its derived type.
auto block_zero_based(T &&a, int start_row, int start_col, int rows, int cols)
Block of a kernel generator expression.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
int rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:21
int cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
Definition cols.hpp:20
void invalid_argument(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw an invalid_argument exception with a consistently formatted message.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Parts of an OpenCL kernel, generated by an expression.