1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_BINARY_OPERATOR_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_BINARY_OPERATOR_HPP
34template <
typename Derived,
typename T_res,
typename T_a,
typename T_b>
54 const std::string function =
"binary_operator" + op;
61 "columns of ",
"b", b.cols());
75 const std::string& col_index_name,
76 const bool view_handled,
77 const std::string& var_name_a,
78 const std::string& var_name_b)
const {
80 res.
body = type_str<Scalar>() +
" " +
var_name_ +
" = " + var_name_a +
" "
81 +
op_ +
" " + var_name_b +
";\n";
104#define ADD_BINARY_OPERATION(class_name, function_name, scalar_type_expr, \
106 template <typename T_a, typename T_b> \
107 class class_name : public binary_operation<class_name<T_a, T_b>, \
108 scalar_type_expr, T_a, T_b> { \
110 = binary_operation<class_name<T_a, T_b>, scalar_type_expr, T_a, T_b>; \
111 using base::arguments_; \
116 class_name(T_a&& a, T_b&& b) \
117 : base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
118 inline auto deep_copy() const { \
119 auto&& a_copy = this->template get_arg<0>().deep_copy(); \
120 auto&& b_copy = this->template get_arg<1>().deep_copy(); \
121 return class_name<std::remove_reference_t<decltype(a_copy)>, \
122 std::remove_reference_t<decltype(b_copy)>>( \
123 std::move(a_copy), std::move(b_copy)); \
127 template <typename T_a, typename T_b, \
128 require_all_kernel_expressions_t<T_a, T_b>* = nullptr, \
129 require_any_not_arithmetic_t<T_a, T_b>* = nullptr> \
130 inline class_name<as_operation_cl_t<T_a>, as_operation_cl_t<T_b>> \
131 function_name(T_a&& a, T_b&& b) { \
132 return {as_operation_cl(std::forward<T_a>(a)), \
133 as_operation_cl(std::forward<T_b>(b))}; \
154#define ADD_BINARY_OPERATION_WITH_CUSTOM_CODE( \
155 class_name, function_name, scalar_type_expr, operation, ...) \
156 template <typename T_a, typename T_b> \
157 class class_name : public binary_operation<class_name<T_a, T_b>, \
158 scalar_type_expr, T_a, T_b> { \
160 = binary_operation<class_name<T_a, T_b>, scalar_type_expr, T_a, T_b>; \
161 using base::arguments_; \
166 class_name(T_a&& a, T_b&& b) \
167 : base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
168 inline auto deep_copy() const { \
169 auto&& a_copy = this->template get_arg<0>().deep_copy(); \
170 auto&& b_copy = this->template get_arg<1>().deep_copy(); \
171 return class_name<std::remove_reference_t<decltype(a_copy)>, \
172 std::remove_reference_t<decltype(b_copy)>>( \
173 std::move(a_copy), std::move(b_copy)); \
178 template <typename T_a, typename T_b, \
179 require_all_kernel_expressions_t<T_a, T_b>* = nullptr, \
180 require_any_not_arithmetic_t<T_a, T_b>* = nullptr> \
181 inline class_name<as_operation_cl_t<T_a>, as_operation_cl_t<T_b>> \
182 function_name(T_a&& a, T_b&& b) { \
183 return {as_operation_cl(std::forward<T_a>(a)), \
184 as_operation_cl(std::forward<T_b>(b))}; \
197 inline std::pair<int, int> extreme_diagonals()
const {
198 std::pair<int, int> diags0
199 = this->
template get_arg<0>().extreme_diagonals();
200 std::pair<int, int> diags1
201 = this->
template get_arg<1>().extreme_diagonals();
202 return {std::max(diags0.first, diags1.first),
203 std::min(diags0.second, diags1.second)};
207 inline std::pair<int, int> extreme_diagonals()
const {
215 "both operands to operator% must have integral scalar types!");
216 inline std::pair<int, int> extreme_diagonals()
const {
223 "<=",
inline std::pair<int, int> extreme_diagonals()
const {
229 ">=",
inline std::pair<int, int> extreme_diagonals()
const {
234 "==",
inline std::pair<int, int> extreme_diagonals()
const {
243 inline std::pair<int, int> extreme_diagonals()
const {
244 std::pair<int, int> diags0
246 std::pair<int, int> diags1
247 = this->
template get_arg<1>().extreme_diagonals();
248 return {std::max(diags0.first, diags1.first),
249 std::min(diags0.second, diags1.second)};
260template <
typename T_a,
typename T_b,
typename = require_arithmetic_t<T_a>,
261 typename = require_all_kernel_expressions_t<T_b>>
275template <
typename T_a,
typename T_b,
279 T_a&& a,
const T_b b) {
284#undef ADD_BINARY_OPERATION
285#undef ADD_BINARY_OPERATION_WITH_CUSTOM_CODE
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_a, const std::string &var_name_b) const
Generates kernel code for this expression.
binary_operation(T_a &&a, T_b &&b, const std::string &op)
Constructor.
Represents a binary operation in kernel generator expressions.
std::pair< int, int > extreme_diagonals() const
static constexpr int dynamic
std::tuple< Args... > arguments_
std::tuple< std::is_same< Args, void >... > view_transitivity
Base for all kernel generator operations.
require_t< std::is_arithmetic< std::decay_t< T > > > require_arithmetic_t
Require type satisfies std::is_arithmetic.
elt_multiply_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_multiply(T_a &&a, T_b &&b)
subtraction_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > subtract(T_a &&a, T_b &&b)
addition_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > add(T_a &&a, T_b &&b)
elt_divide_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_divide(T_a &&a, T_b &&b)
#define ADD_BINARY_OPERATION_WITH_CUSTOM_CODE( class_name, function_name, scalar_type_expr, operation,...)
Defines a new binary operation in kernel generator that needs to implement custom function that deter...
require_all_t< is_kernel_expression< Types >... > require_all_kernel_expressions_t
Enables a template if all given types are are a valid kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
std::conditional_t< std::is_lvalue_reference< T >::value, decltype(as_operation_cl< AssignOp >(std::declval< T >())), std::remove_reference_t< decltype(as_operation_cl< AssignOp >(std::declval< T >()))> > as_operation_cl_t
Type that results when converting any valid kernel generator expression into operation.
#define ADD_BINARY_OPERATION(class_name, function_name, scalar_type_expr, operation)
Defines a new binary operation in kernel generator.
int64_t cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
fvar< T > operator*(const fvar< T > &x, const fvar< T > &y)
Return the product of the two arguments.
typename std::common_type_t< typename std::remove_reference_t< Types >::Scalar... > common_scalar_t
Wrapper for std::common_type_t
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
typename scalar_type< T >::type scalar_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.