1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_COLUMN_VECTOR_OR_SCALAR_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_COLUMN_VECTOR_OR_SCALAR_HPP
36 typename std::remove_reference_t<T>::Scalar, T> {
38 using Scalar =
typename std::remove_reference_t<T>::Scalar;
41 using base::operator=;
56 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
58 std::remove_reference_t<
decltype(arg_copy)>>{std::move(arg_copy)};
74 std::unordered_map<const void*, const char*>& generated,
75 std::unordered_map<const void*, const char*>& generated_all,
77 const std::string& col_index_name,
bool view_handled)
const {
79 if (generated.count(
this) == 0) {
82 std::string row_index_name_arg = row_index_name;
83 std::string col_index_name_arg = col_index_name;
85 std::unordered_map<const void*, const char*> generated2;
86 res = this->
template get_arg<0>().get_kernel_parts(
87 generated2, generated_all, name_gen, row_index_name_arg,
88 col_index_name_arg, view_handled);
90 = this->
generate(row_index_name, col_index_name, view_handled,
92 if (generated_all.count(
this) == 0) {
93 generated_all[
this] =
"";
98 res.
body = res.body_prefix + res.body;
114 const std::string& col_index_name,
115 const bool view_handled,
116 const std::string& var_name_arg)
const {
120 = type_str<Scalar>() +
" " +
var_name_ +
" = " + var_name_arg +
";\n";
137 std::unordered_map<const void*, const char*>& generated,
138 std::unordered_map<const void*, const char*>& generated_all,
140 const std::string& col_index_name)
const {
141 if (generated.count(
this) == 0) {
142 generated[
this] =
"";
145 std::string row_index_name_arg = row_index_name;
146 std::string col_index_name_arg = col_index_name;
148 std::unordered_map<const void*, const char*> generated2;
149 kernel_parts res = this->
template get_arg<0>().get_kernel_parts_lhs(
150 generated2, generated_all, name_gen, row_index_name_arg,
152 res += this->
derived().generate_lhs(row_index_name, col_index_name,
154 if (generated_all.count(
this) == 0) {
155 generated_all[
this] =
"";
172 const std::string& var_name_arg)
const {
189 std::unordered_map<const void*, const char*>& generated,
190 std::unordered_map<const void*, const char*>& generated_all,
191 cl::Kernel& kernel,
int& arg_num)
const {
192 if (generated.count(
this) == 0) {
193 generated[
this] =
"";
194 std::unordered_map<const void*, const char*> generated2;
195 this->
template get_arg<0>().set_args(generated2, generated_all, kernel,
197 if (generated_all.count(
this) == 0) {
198 generated_all[
this] =
"";
199 kernel.setArg(arg_num++,
200 static_cast<int>(this->
template get_arg<0>().
rows()
201 < this->
template get_arg<0>().
cols()));
213 std::string& col_index_name)
const {
214 std::string row_index_name2 =
"(" +
var_name_ +
"_transpose ? "
215 + col_index_name +
" : " + row_index_name
217 col_index_name =
"(" +
var_name_ +
"_transpose ? " + row_index_name +
" : "
218 + col_index_name +
")";
219 row_index_name = std::move(row_index_name2);
228 return std::max(this->
template get_arg<0>().
rows(),
229 this->
template get_arg<0>().
cols());
238 return std::min(this->
template get_arg<0>().
rows(),
239 this->
template get_arg<0>().
cols());
256 inline void set_view(
int bottom_diagonal,
int top_diagonal,
257 int bottom_zero_diagonal,
int top_zero_diagonal)
const {
258 auto&
arg = this->
template get_arg<0>();
259 if (
arg.rows() >=
arg.cols()) {
260 arg.set_view(bottom_diagonal, top_diagonal, bottom_zero_diagonal,
263 arg.set_view(top_diagonal, bottom_diagonal, top_zero_diagonal,
264 bottom_zero_diagonal);
273 auto&
arg = this->
template get_arg<0>();
274 std::pair<int, int> arg_diags =
arg.extreme_diagonals();
275 if (
arg.rows() >=
arg.cols()) {
278 return {-arg_diags.second, -arg_diags.first};
292 check_vector(
"as_column_vector_or_scalar_::check_assign_dimensions()",
293 "expression assigned to as_column_vector_or_scalar",
295 auto&
arg = this->
template get_arg<0>();
296 int arg_rows =
arg.rows();
297 int arg_cols =
arg.cols();
298 if (arg_rows >= arg_cols) {
300 "check_assign_dimensions argument", arg_rows,
"rows of ",
303 "check_assign_dimensions argument", arg_cols,
304 "columns of ",
"expression",
cols);
307 "check_assign_dimensions argument", arg_cols,
"rows of ",
310 "check_assign_dimensions argument", arg_rows,
311 "columns of ",
"expression",
cols);
328 std::remove_reference_t<
decltype(a_operation)>>(std::move(a_operation));
void check_assign_dimensions(int rows, int cols) const
Checks if desired dimensions match dimensions of the argument vector.
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
kernel_parts generate_lhs(const std::string &i, const std::string &j, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions if this expression appears on the left hand sid...
as_column_vector_or_scalar_(T &&a)
Constructor.
kernel_parts get_kernel_parts_lhs(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name) const
Generates kernel code for this expression if it appears on the left hand side of an assignment.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
typename std::remove_reference_t< T >::Scalar Scalar
auto deep_copy() const
Creates a deep copy of this expression.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this expression.
void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const
Sets the view of the underlying matrix depending on which of its parts are written to.
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Swaps indices row_index_name and col_index_name for the argument expression if necessary.
Represents as_column_vector_or_scalar of a row or column vector in kernel generator expressions.
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
Base for all kernel generator operations that can be used on left hand side of an expression.
Derived & derived()
Casts the instance into its derived type.
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
auto constant(const T a, int rows, int cols)
Matrix of repeated values in kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
void check_vector(const char *function, const char *name, const Mat &x)
Check the input is either a row vector or column vector or a matrix with a single row or column.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.