Automatic Differentiation
 
Loading...
Searching...
No Matches
as_column_vector_or_scalar.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_COLUMN_VECTOR_OR_SCALAR_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_COLUMN_VECTOR_OR_SCALAR_HPP
3#ifdef STAN_OPENCL
4
13#include <algorithm>
14#include <map>
15#include <string>
16#include <tuple>
17#include <type_traits>
18#include <utility>
19
20namespace stan {
21namespace math {
22
33template <typename T>
35 : public operation_cl_lhs<as_column_vector_or_scalar_<T>,
36 typename std::remove_reference_t<T>::Scalar, T> {
37 public:
38 using Scalar = typename std::remove_reference_t<T>::Scalar;
40 using base::var_name_;
41 using base::operator=;
42
47 explicit as_column_vector_or_scalar_(T&& a) : base(std::forward<T>(a)) {
48 check_vector("as_column_vector_or_scalar", "a", a);
49 }
50
55 inline auto deep_copy() const {
56 auto&& arg_copy = this->template get_arg<0>().deep_copy();
58 std::remove_reference_t<decltype(arg_copy)>>{std::move(arg_copy)};
59 }
60
74 std::unordered_map<const void*, const char*>& generated,
75 std::unordered_map<const void*, const char*>& generated_all,
76 name_generator& name_gen, const std::string& row_index_name,
77 const std::string& col_index_name, bool view_handled) const {
78 kernel_parts res{};
79 if (generated.count(this) == 0) {
80 this->var_name_ = name_gen.generate();
81 generated[this] = "";
82 std::string row_index_name_arg = row_index_name;
83 std::string col_index_name_arg = col_index_name;
84 modify_argument_indices(row_index_name_arg, col_index_name_arg);
85 std::unordered_map<const void*, const char*> generated2;
86 res = this->template get_arg<0>().get_kernel_parts(
87 generated2, generated_all, name_gen, row_index_name_arg,
88 col_index_name_arg, view_handled);
90 = this->generate(row_index_name, col_index_name, view_handled,
91 this->template get_arg<0>().var_name_);
92 if (generated_all.count(this) == 0) {
93 generated_all[this] = "";
94 } else {
95 my.args = "";
96 }
97 res += my;
98 res.body = res.body_prefix + res.body;
99 res.body_prefix = "";
100 }
101 return res;
102 }
103
113 inline kernel_parts generate(const std::string& row_index_name,
114 const std::string& col_index_name,
115 const bool view_handled,
116 const std::string& var_name_arg) const {
117 kernel_parts res;
118 res.args = "int " + var_name_ + "_transpose, ";
119 res.body
120 = type_str<Scalar>() + " " + var_name_ + " = " + var_name_arg + ";\n";
121 return res;
122 }
123
137 std::unordered_map<const void*, const char*>& generated,
138 std::unordered_map<const void*, const char*>& generated_all,
139 name_generator& name_gen, const std::string& row_index_name,
140 const std::string& col_index_name) const {
141 if (generated.count(this) == 0) {
142 generated[this] = "";
143 this->var_name_ = name_gen.generate();
144 }
145 std::string row_index_name_arg = row_index_name;
146 std::string col_index_name_arg = col_index_name;
147 modify_argument_indices(row_index_name_arg, col_index_name_arg);
148 std::unordered_map<const void*, const char*> generated2;
149 kernel_parts res = this->template get_arg<0>().get_kernel_parts_lhs(
150 generated2, generated_all, name_gen, row_index_name_arg,
151 col_index_name_arg);
152 res += this->derived().generate_lhs(row_index_name, col_index_name,
153 this->template get_arg<0>().var_name_);
154 if (generated_all.count(this) == 0) {
155 generated_all[this] = "";
156 } else {
157 res.args = "";
158 }
159 return res;
160 }
161
171 inline kernel_parts generate_lhs(const std::string& i, const std::string& j,
172 const std::string& var_name_arg) const {
173 kernel_parts res;
174 res.args = "int " + var_name_ + "_transpose, ";
175 return res;
176 }
177
188 inline void set_args(
189 std::unordered_map<const void*, const char*>& generated,
190 std::unordered_map<const void*, const char*>& generated_all,
191 cl::Kernel& kernel, int& arg_num) const {
192 if (generated.count(this) == 0) {
193 generated[this] = "";
194 std::unordered_map<const void*, const char*> generated2;
195 this->template get_arg<0>().set_args(generated2, generated_all, kernel,
196 arg_num);
197 if (generated_all.count(this) == 0) {
198 generated_all[this] = "";
199 kernel.setArg(arg_num++,
200 static_cast<int>(this->template get_arg<0>().rows()
201 < this->template get_arg<0>().cols()));
202 }
203 }
204 }
205
212 inline void modify_argument_indices(std::string& row_index_name,
213 std::string& col_index_name) const {
214 std::string row_index_name2 = "(" + var_name_ + "_transpose ? "
215 + col_index_name + " : " + row_index_name
216 + ")";
217 col_index_name = "(" + var_name_ + "_transpose ? " + row_index_name + " : "
218 + col_index_name + ")";
219 row_index_name = std::move(row_index_name2);
220 }
221
227 inline int rows() const {
228 return std::max(this->template get_arg<0>().rows(),
229 this->template get_arg<0>().cols());
230 }
231
237 inline int cols() const {
238 return std::min(this->template get_arg<0>().rows(),
239 this->template get_arg<0>().cols());
240 }
241
256 inline void set_view(int bottom_diagonal, int top_diagonal,
257 int bottom_zero_diagonal, int top_zero_diagonal) const {
258 auto& arg = this->template get_arg<0>();
259 if (arg.rows() >= arg.cols()) {
260 arg.set_view(bottom_diagonal, top_diagonal, bottom_zero_diagonal,
261 top_zero_diagonal);
262 } else {
263 arg.set_view(top_diagonal, bottom_diagonal, top_zero_diagonal,
264 bottom_zero_diagonal);
265 }
266 }
267
272 inline std::pair<int, int> extreme_diagonals() const {
273 auto& arg = this->template get_arg<0>();
274 std::pair<int, int> arg_diags = arg.extreme_diagonals();
275 if (arg.rows() >= arg.cols()) {
276 return arg_diags;
277 } else {
278 return {-arg_diags.second, -arg_diags.first};
279 }
280 }
281
289 inline void check_assign_dimensions(int rows, int cols) const {
290 // use a dummy expression with same number of rows and cols to simplify the
291 // check
292 check_vector("as_column_vector_or_scalar_::check_assign_dimensions()",
293 "expression assigned to as_column_vector_or_scalar",
294 constant(0, rows, cols));
295 auto& arg = this->template get_arg<0>();
296 int arg_rows = arg.rows();
297 int arg_cols = arg.cols();
298 if (arg_rows >= arg_cols) {
299 check_size_match("block_.check_assign_dimensions", "Rows of ",
300 "check_assign_dimensions argument", arg_rows, "rows of ",
301 "expression", rows);
302 check_size_match("block_.check_assign_dimensions", "Columns of ",
303 "check_assign_dimensions argument", arg_cols,
304 "columns of ", "expression", cols);
305 } else {
306 check_size_match("block_.check_assign_dimensions", "Columns of ",
307 "check_assign_dimensions argument", arg_cols, "rows of ",
308 "expression", rows);
309 check_size_match("block_.check_assign_dimensions", "Rows of ",
310 "check_assign_dimensions argument", arg_rows,
311 "columns of ", "expression", cols);
312 }
313 }
314};
315
323template <typename T,
325inline auto as_column_vector_or_scalar(T&& a) {
326 auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
328 std::remove_reference_t<decltype(a_operation)>>(std::move(a_operation));
329}
331} // namespace math
332} // namespace stan
333
334#endif
335#endif
void check_assign_dimensions(int rows, int cols) const
Checks if desired dimensions match dimensions of the argument vector.
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
kernel_parts generate_lhs(const std::string &i, const std::string &j, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions if this expression appears on the left hand sid...
kernel_parts get_kernel_parts_lhs(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name) const
Generates kernel code for this expression if it appears on the left hand side of an assignment.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
typename std::remove_reference_t< T >::Scalar Scalar
auto deep_copy() const
Creates a deep copy of this expression.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this expression.
void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const
Sets the view of the underlying matrix depending on which of its parts are written to.
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Swaps indices row_index_name and col_index_name for the argument expression if necessary.
Represents as_column_vector_or_scalar of a row or column vector in kernel generator expressions.
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
Base for all kernel generator operations that can be used on left hand side of an expression.
Derived & derived()
Casts the instance into its derived type.
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
auto constant(const T a, int rows, int cols)
Matrix of repeated values in kernel generator expressions.
Definition constant.hpp:130
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
Definition arg.hpp:19
void check_vector(const char *function, const char *name, const Mat &x)
Check the input is either a row vector or column vector or a matrix with a single row or column.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
Parts of an OpenCL kernel, generated by an expression.