Automatic Differentiation
 
Loading...
Searching...
No Matches
append.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_APPEND_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_APPEND_HPP
3#ifdef STAN_OPENCL
4
14#include <algorithm>
15#include <string>
16#include <tuple>
17#include <type_traits>
18#include <map>
19#include <utility>
20
21namespace stan {
22namespace math {
23
33template <typename T_a, typename T_b>
34class append_row_ : public operation_cl<append_row_<T_a, T_b>,
35 common_scalar_t<T_a, T_b>, T_a, T_b> {
36 public:
39 using base::var_name_;
40
41 protected:
42 using base::arguments_;
43
44 public:
50 append_row_(T_a&& a, T_b&& b) // NOLINT
51 : base(std::forward<T_a>(a), std::forward<T_b>(b)) {
52 if (a.cols() != base::dynamic && b.cols() != base::dynamic) {
53 check_size_match("append_row", "Columns of ", "a", a.cols(),
54 "columns of ", "b", b.cols());
55 }
56 if (a.rows() < 0) {
57 invalid_argument("append_row", "Rows of a", a.rows(),
58 "should be nonnegative!");
59 }
60 if (b.rows() < 0) {
61 invalid_argument("append_row", "Rows of b", b.rows(),
62 "should be nonnegative!");
63 }
64 }
65
70 inline auto deep_copy() const {
71 auto&& a_copy = this->template get_arg<0>().deep_copy();
72 auto&& b_copy = this->template get_arg<1>().deep_copy();
73 return append_row_<std::remove_reference_t<decltype(a_copy)>,
74 std::remove_reference_t<decltype(b_copy)>>{
75 std::move(a_copy), std::move(b_copy)};
76 }
77
91 std::unordered_map<const void*, const char*>& generated,
92 std::unordered_map<const void*, const char*>& generated_all,
93 name_generator& name_gen, const std::string& row_index_name,
94 const std::string& col_index_name, bool view_handled) const {
95 kernel_parts res{};
96 if (generated.count(this) == 0) {
97 var_name_ = name_gen.generate();
98 generated[this] = "";
99 kernel_parts parts_a = this->template get_arg<0>().get_kernel_parts(
100 generated, generated_all, name_gen, row_index_name, col_index_name,
101 true);
102 std::string row_index_name_b
103 = "(" + row_index_name + " - " + var_name_ + "_first_rows)";
104 std::unordered_map<const void*, const char*> generated_b;
105 kernel_parts parts_b = this->template get_arg<1>().get_kernel_parts(
106 generated_b, generated_all, name_gen, row_index_name_b,
107 col_index_name, true);
108 res = parts_a + parts_b;
109 res.body = type_str<Scalar>() + " " + var_name_ + ";\n"
110 "if("+ row_index_name +" < " + var_name_ + "_first_rows){\n"
111 + parts_a.body +
112 var_name_ + " = " + this->template get_arg<0>().var_name_ + ";\n"
113 "} else{\n"
114 + parts_b.body +
115 var_name_ + " = " + this->template get_arg<1>().var_name_ + ";\n"
116 "}\n";
117 res.args += "int " + var_name_ + "_first_rows, ";
118 }
119 return res;
120 }
121
132 inline void set_args(
133 std::unordered_map<const void*, const char*>& generated,
134 std::unordered_map<const void*, const char*>& generated_all,
135 cl::Kernel& kernel, int& arg_num) const {
136 if (generated.count(this) == 0) {
137 generated[this] = "";
138 this->template get_arg<0>().set_args(generated, generated_all, kernel,
139 arg_num);
140 std::unordered_map<const void*, const char*> generated_b;
141 this->template get_arg<1>().set_args(generated_b, generated_all, kernel,
142 arg_num);
143 kernel.setArg(arg_num++, this->template get_arg<0>().rows());
144 }
145 }
146
152 inline int rows() const {
153 return this->template get_arg<0>().rows()
154 + this->template get_arg<1>().rows();
155 }
156
161 inline std::pair<int, int> extreme_diagonals() const {
162 std::pair<int, int> a_diags
163 = this->template get_arg<0>().extreme_diagonals();
164 std::pair<int, int> b_diags
165 = this->template get_arg<1>().extreme_diagonals();
166 int my_rows = this->template get_arg<0>().rows();
167 return {std::min(a_diags.first, b_diags.first - my_rows),
168 std::max(a_diags.second, b_diags.second - my_rows)};
169 }
170};
171
181template <typename Ta, typename Tb,
183inline auto append_row(Ta&& a, Tb&& b) {
184 auto&& a_operation = as_operation_cl(std::forward<Ta>(a)).deep_copy();
185 auto&& b_operation = as_operation_cl(std::forward<Tb>(b)).deep_copy();
186 return append_row_<std::remove_reference_t<decltype(a_operation)>,
187 std::remove_reference_t<decltype(b_operation)>>(
188 std::move(a_operation), std::move(b_operation));
189}
190
196template <typename T_a, typename T_b>
197class append_col_ : public operation_cl<append_col_<T_a, T_b>,
198 common_scalar_t<T_a, T_b>, T_a, T_b> {
199 public:
202 using base::var_name_;
203
204 protected:
205 using base::arguments_;
206
207 public:
213 append_col_(T_a&& a, T_b&& b) // NOLINT
214 : base(std::forward<T_a>(a), std::forward<T_b>(b)) {
215 if (a.rows() != base::dynamic && b.rows() != base::dynamic) {
216 check_size_match("append_col", "Rows of ", "a", a.rows(), "rows of ", "b",
217 b.rows());
218 }
219 if (a.cols() < 0) {
220 invalid_argument("append_col", "Columns of a", a.cols(),
221 "should be nonnegative!");
222 }
223 if (b.cols() < 0) {
224 invalid_argument("append_col", "Columns of b", b.cols(),
225 "should be nonnegative!");
226 }
227 }
228
233 inline auto deep_copy() const {
234 auto&& a_copy = this->template get_arg<0>().deep_copy();
235 auto&& b_copy = this->template get_arg<1>().deep_copy();
236 return append_col_<std::remove_reference_t<decltype(a_copy)>,
237 std::remove_reference_t<decltype(b_copy)>>{
238 std::move(a_copy), std::move(b_copy)};
239 }
240
254 std::unordered_map<const void*, const char*>& generated,
255 std::unordered_map<const void*, const char*>& generated_all,
256 name_generator& name_gen, const std::string& row_index_name,
257 const std::string& col_index_name, bool view_handled) const {
258 kernel_parts res{};
259 if (generated.count(this) == 0) {
260 var_name_ = name_gen.generate();
261 generated[this] = "";
262 kernel_parts parts_a = this->template get_arg<0>().get_kernel_parts(
263 generated, generated_all, name_gen, row_index_name, col_index_name,
264 true);
265 std::string col_index_name_b
266 = "(" + col_index_name + " - " + var_name_ + "_first_cols)";
267 std::unordered_map<const void*, const char*> generated_b;
268 kernel_parts parts_b = this->template get_arg<1>().get_kernel_parts(
269 generated_b, generated_all, name_gen, row_index_name,
270 col_index_name_b, true);
271 res = parts_a + parts_b;
272 res.body = type_str<Scalar>() + " " + var_name_ + ";\n"
273 "if("+ col_index_name +" < " + var_name_ + "_first_cols){\n"
274 + parts_a.body +
275 var_name_ + " = " + this->template get_arg<0>().var_name_ + ";\n"
276 "} else{\n"
277 + parts_b.body +
278 var_name_ + " = " + this->template get_arg<1>().var_name_ + ";\n"
279 "}\n";
280 res.args += "int " + var_name_ + "_first_cols, ";
281 }
282 return res;
283 }
284
295 inline void set_args(
296 std::unordered_map<const void*, const char*>& generated,
297 std::unordered_map<const void*, const char*>& generated_all,
298 cl::Kernel& kernel, int& arg_num) const {
299 if (generated.count(this) == 0) {
300 generated[this] = "";
301 this->template get_arg<0>().set_args(generated, generated_all, kernel,
302 arg_num);
303 std::unordered_map<const void*, const char*> generated_b;
304 this->template get_arg<1>().set_args(generated_b, generated_all, kernel,
305 arg_num);
306 kernel.setArg(arg_num++, this->template get_arg<0>().cols());
307 }
308 }
309
315 inline int cols() const {
316 return this->template get_arg<0>().cols()
317 + this->template get_arg<1>().cols();
318 }
319
324 inline std::pair<int, int> extreme_diagonals() const {
325 std::pair<int, int> a_diags
326 = this->template get_arg<0>().extreme_diagonals();
327 std::pair<int, int> b_diags
328 = this->template get_arg<1>().extreme_diagonals();
329 int my_cols = this->template get_arg<0>().cols();
330 return {std::min(a_diags.first, b_diags.first + my_cols),
331 std::max(a_diags.second, b_diags.second + my_cols)};
332 }
333};
334
344template <typename Ta, typename Tb,
346inline auto append_col(Ta&& a, Tb&& b) {
347 auto&& a_operation = as_operation_cl(std::forward<Ta>(a)).deep_copy();
348 auto&& b_operation = as_operation_cl(std::forward<Tb>(b)).deep_copy();
349 return append_col_<std::remove_reference_t<decltype(a_operation)>,
350 std::remove_reference_t<decltype(b_operation)>>(
351 std::move(a_operation), std::move(b_operation));
352}
353
354} // namespace math
355} // namespace stan
356
357#endif
358#endif
Represents appending of cols in kernel generator expressions.
Definition append.hpp:198
Represents appending of rows in kernel generator expressions.
Definition append.hpp:35
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
static constexpr int dynamic
std::tuple< Args... > arguments_
Base for all kernel generator operations.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
Definition append.hpp:132
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
Definition append.hpp:253
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Definition append.hpp:324
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Definition append.hpp:161
auto deep_copy() const
Creates a deep copy of this expression.
Definition append.hpp:233
auto deep_copy() const
Creates a deep copy of this expression.
Definition append.hpp:70
auto append_col(Ta &&a, Tb &&b)
Stack the cols of the arguments.
Definition append.hpp:346
common_scalar_t< T_a, T_b > Scalar
Definition append.hpp:37
append_col_(T_a &&a, T_b &&b)
Constructor.
Definition append.hpp:213
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
Definition append.hpp:90
append_row_(T_a &&a, T_b &&b)
Constructor.
Definition append.hpp:50
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
Definition append.hpp:295
int cols() const
Number of rows of a matrix that would be the result of evaluating this expression.
Definition append.hpp:315
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Definition append.hpp:152
common_scalar_t< T_a, T_b > Scalar
Definition append.hpp:200
auto append_row(Ta &&a, Tb &&b)
Stack the rows of the first argument on top of the second argument.
Definition append.hpp:183
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
typename std::common_type_t< typename std::remove_reference_t< Types >::Scalar... > common_scalar_t
Wrapper for std::common_type_t
void invalid_argument(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw an invalid_argument exception with a consistently formatted message.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Parts of an OpenCL kernel, generated by an expression.