Automatic Differentiation
 
Loading...
Searching...
No Matches
optional_broadcast.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_OPTIONALBROADCAST_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_OPTIONALBROADCAST_HPP
3#ifdef STAN_OPENCL
4
11#include <limits>
12#include <string>
13#include <type_traits>
14#include <map>
15#include <utility>
16
17namespace stan {
18namespace math {
19
27template <typename T, bool Colwise, bool Rowwise>
29 : public operation_cl<optional_broadcast_<T, Colwise, Rowwise>,
30 typename std::remove_reference_t<T>::Scalar, T> {
31 public:
32 using Scalar = typename std::remove_reference_t<T>::Scalar;
33 using base
35 using base::var_name_;
36
41 explicit optional_broadcast_(T&& a) : base(std::forward<T>(a)) {}
42
47 inline auto deep_copy() const {
48 auto&& arg_copy = this->template get_arg<0>().deep_copy();
49 return optional_broadcast_<std::remove_reference_t<decltype(arg_copy)>,
50 Colwise, Rowwise>{std::move(arg_copy)};
51 }
52
62 inline kernel_parts generate(const std::string& row_idx_name,
63 const std::string& col_idx_name,
64 const bool view_handled,
65 const std::string& var_name_arg) const {
66 kernel_parts res;
67 res.body
68 += type_str<Scalar>() + " " + var_name_ + " = " + var_name_arg + ";\n";
69 if (Colwise) {
70 res.args += "int " + var_name_ + "is_multirow, ";
71 }
72 if (Rowwise) {
73 res.args += "int " + var_name_ + "is_multicol, ";
74 }
75 return res;
76 }
77
83 inline void modify_argument_indices(std::string& row_idx_name,
84 std::string& col_idx_name) const {
85 if (Colwise) {
86 row_idx_name = "(" + row_idx_name + " * " + var_name_ + "is_multirow)";
87 }
88 if (Rowwise) {
89 col_idx_name = "(" + col_idx_name + " * " + var_name_ + "is_multicol)";
90 }
91 }
92
103 inline void set_args(
104 std::unordered_map<const void*, const char*>& generated,
105 std::unordered_map<const void*, const char*>& generated_all,
106 cl::Kernel& kernel, int& arg_num) const {
107 if (generated.count(this) == 0) {
108 generated[this] = "";
109 std::unordered_map<const void*, const char*> generated2;
110 this->template get_arg<0>().set_args(generated2, generated_all, kernel,
111 arg_num);
112 if (Colwise) {
113 kernel.setArg(arg_num++, static_cast<int>(
114 this->template get_arg<0>().rows() != 1));
115 }
116 if (Rowwise) {
117 kernel.setArg(arg_num++, static_cast<int>(
118 this->template get_arg<0>().cols() != 1));
119 }
120 }
121 }
122
128 inline int rows() const {
129 return Colwise && this->template get_arg<0>().rows() == 1
131 : this->template get_arg<0>().rows();
132 }
133
139 inline int cols() const {
140 return Rowwise && this->template get_arg<0>().cols() == 1
142 : this->template get_arg<0>().cols();
143 }
144
149 inline matrix_cl_view view() const {
150 matrix_cl_view view = this->template get_arg<0>().view();
151 if (Colwise && this->template get_arg<0>().rows() == 1) {
153 }
154 if (Rowwise && this->template get_arg<0>().cols() == 1) {
156 }
157 return view;
158 }
159
164 inline int bottom_diagonal() const {
165 if (Colwise && this->template get_arg<0>().rows() == 1) {
166 return std::numeric_limits<int>::min();
167 } else {
168 return this->template get_arg<0>().bottom_diagonal();
169 }
170 }
171
176 inline int top_diagonal() const {
177 if (Rowwise && this->template get_arg<0>().cols() == 1) {
178 return std::numeric_limits<int>::max();
179 } else {
180 return this->template get_arg<0>().top_diagonal();
181 }
182 }
183};
184
200template <bool Colwise, bool Rowwise, typename T,
202inline optional_broadcast_<as_operation_cl_t<T>, Colwise, Rowwise>
204 auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
205 return optional_broadcast_<as_operation_cl_t<T>, Colwise, Rowwise>(
206 std::move(a_operation));
207}
208
221template <typename T,
223inline auto rowwise_optional_broadcast(T&& a) {
224 return optional_broadcast<false, true>(std::forward<T>(a));
225}
226
239template <typename T,
241inline auto colwise_optional_broadcast(T&& a) {
242 return optional_broadcast<true, false>(std::forward<T>(a));
243}
244
245} // namespace math
246} // namespace stan
247
248#endif
249#endif
static constexpr int dynamic
Base for all kernel generator operations.
typename std::remove_reference_t< T >::Scalar Scalar
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
int bottom_diagonal() const
Determine index of bottom diagonal written.
matrix_cl_view view() const
View of a matrix that would be the result of evaluating this expression.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
kernel_parts generate(const std::string &row_idx_name, const std::string &col_idx_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions.
void modify_argument_indices(std::string &row_idx_name, std::string &col_idx_name) const
Sets index/indices along broadcasted dimmension(s) to 0.
int top_diagonal() const
Determine index of top diagonal written.
auto deep_copy() const
Creates a deep copy of this expression.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
Represents an optional broadcasting operation in kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
const matrix_cl_view either(const matrix_cl_view left_view, const matrix_cl_view right_view)
Determines which parts are nonzero in any of the input views.
optional_broadcast_< as_operation_cl_t< T >, Colwise, Rowwise > optional_broadcast(T &&a)
Broadcast an expression in specified dimension(s) if the size along that dimension equals 1.
auto colwise_optional_broadcast(T &&a)
Broadcast an expression in colwise dimmension if the number of rows equals to 1.
auto rowwise_optional_broadcast(T &&a)
Broadcast an expression in rowwise dimmension if the number of columns equals to 1.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Parts of an OpenCL kernel, generated by an expression.