Automatic Differentiation
 
Loading...
Searching...
No Matches
binary_operation.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_BINARY_OPERATOR_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_BINARY_OPERATOR_HPP
3#ifdef STAN_OPENCL
4
14#include <algorithm>
15#include <string>
16#include <tuple>
17#include <type_traits>
18#include <utility>
19
20namespace stan {
21namespace math {
22
34template <typename Derived, typename T_res, typename T_a, typename T_b>
35class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
36 public:
37 using Scalar = T_res;
39 using base::var_name_;
40
41 protected:
42 std::string op_;
43 using base::arguments_;
44
45 public:
52 binary_operation(T_a&& a, T_b&& b, const std::string& op) // NOLINT
53 : base(std::forward<T_a>(a), std::forward<T_b>(b)), op_(op) {
54 const std::string function = "binary_operator" + op;
55 if (a.rows() != base::dynamic && b.rows() != base::dynamic) {
56 check_size_match(function.c_str(), "Rows of ", "a", a.rows(), "rows of ",
57 "b", b.rows());
58 }
59 if (a.cols() != base::dynamic && b.cols() != base::dynamic) {
60 check_size_match(function.c_str(), "Columns of ", "a", a.cols(),
61 "columns of ", "b", b.cols());
62 }
63 }
64
74 inline kernel_parts generate(const std::string& row_index_name,
75 const std::string& col_index_name,
76 const bool view_handled,
77 const std::string& var_name_a,
78 const std::string& var_name_b) const {
79 kernel_parts res{};
80 res.body = type_str<Scalar>() + " " + var_name_ + " = " + var_name_a + " "
81 + op_ + " " + var_name_b + ";\n";
82 return res;
83 }
84};
85
90#define COMMA ,
91
104#define ADD_BINARY_OPERATION(class_name, function_name, scalar_type_expr, \
105 operation) \
106 template <typename T_a, typename T_b> \
107 class class_name : public binary_operation<class_name<T_a, T_b>, \
108 scalar_type_expr, T_a, T_b> { \
109 using base \
110 = binary_operation<class_name<T_a, T_b>, scalar_type_expr, T_a, T_b>; \
111 using base::arguments_; \
112 \
113 public: \
114 using base::rows; \
115 using base::cols; \
116 class_name(T_a&& a, T_b&& b) /* NOLINT */ \
117 : base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
118 inline auto deep_copy() const { \
119 auto&& a_copy = this->template get_arg<0>().deep_copy(); \
120 auto&& b_copy = this->template get_arg<1>().deep_copy(); \
121 return class_name<std::remove_reference_t<decltype(a_copy)>, \
122 std::remove_reference_t<decltype(b_copy)>>( \
123 std::move(a_copy), std::move(b_copy)); \
124 } \
125 }; \
126 \
127 template <typename T_a, typename T_b, \
128 require_all_kernel_expressions_t<T_a, T_b>* = nullptr, \
129 require_any_not_arithmetic_t<T_a, T_b>* = nullptr> \
130 inline class_name<as_operation_cl_t<T_a>, as_operation_cl_t<T_b>> \
131 function_name(T_a&& a, T_b&& b) { /* NOLINT */ \
132 return {as_operation_cl(std::forward<T_a>(a)), \
133 as_operation_cl(std::forward<T_b>(b))}; \
134 }
135
154#define ADD_BINARY_OPERATION_WITH_CUSTOM_CODE( \
155 class_name, function_name, scalar_type_expr, operation, ...) \
156 template <typename T_a, typename T_b> \
157 class class_name : public binary_operation<class_name<T_a, T_b>, \
158 scalar_type_expr, T_a, T_b> { \
159 using base \
160 = binary_operation<class_name<T_a, T_b>, scalar_type_expr, T_a, T_b>; \
161 using base::arguments_; \
162 \
163 public: \
164 using base::rows; \
165 using base::cols; \
166 class_name(T_a&& a, T_b&& b) /* NOLINT */ \
167 : base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
168 inline auto deep_copy() const { \
169 auto&& a_copy = this->template get_arg<0>().deep_copy(); \
170 auto&& b_copy = this->template get_arg<1>().deep_copy(); \
171 return class_name<std::remove_reference_t<decltype(a_copy)>, \
172 std::remove_reference_t<decltype(b_copy)>>( \
173 std::move(a_copy), std::move(b_copy)); \
174 } \
175 __VA_ARGS__ \
176 }; \
177 \
178 template <typename T_a, typename T_b, \
179 require_all_kernel_expressions_t<T_a, T_b>* = nullptr, \
180 require_any_not_arithmetic_t<T_a, T_b>* = nullptr> \
181 inline class_name<as_operation_cl_t<T_a>, as_operation_cl_t<T_b>> \
182 function_name(T_a&& a, T_b&& b) { /* NOLINT */ \
183 return {as_operation_cl(std::forward<T_a>(a)), \
184 as_operation_cl(std::forward<T_b>(b))}; \
185 }
186
187ADD_BINARY_OPERATION(addition_operator_, operator+,
193 "-");
196 using view_transitivity = std::tuple<std::true_type, std::true_type>;
197 inline std::pair<int, int> extreme_diagonals() const {
198 std::pair<int, int> diags0
199 = this->template get_arg<0>().extreme_diagonals();
200 std::pair<int, int> diags1
201 = this->template get_arg<1>().extreme_diagonals();
202 return {std::max(diags0.first, diags1.first),
203 std::min(diags0.second, diags1.second)};
204 });
207 inline std::pair<int, int> extreme_diagonals() const {
208 return {-rows() + 1, cols() - 1};
209 });
212 static_assert(
213 std::is_integral<scalar_type_t<T_a>>::value&&
214 std::is_integral<scalar_type_t<T_b>>::value,
215 "both operands to operator% must have integral scalar types!");
216 inline std::pair<int, int> extreme_diagonals() const {
217 return {-rows() + 1, cols() - 1};
218 });
219
220ADD_BINARY_OPERATION(less_than_, operator<, bool, "<");
222 less_than_or_equal_, operator<=, bool,
223 "<=", inline std::pair<int, int> extreme_diagonals() const {
224 return {-rows() + 1, cols() - 1};
225 });
226ADD_BINARY_OPERATION(greater_than_, operator>, bool, ">");
228 greater_than_or_equal_, operator>=, bool,
229 ">=", inline std::pair<int, int> extreme_diagonals() const {
230 return {-rows() + 1, cols() - 1};
231 });
233 equals_, operator==, bool,
234 "==", inline std::pair<int, int> extreme_diagonals() const {
235 return {-rows() + 1, cols() - 1};
236 });
237ADD_BINARY_OPERATION(not_equals_, operator!=, bool, "!=");
238
239ADD_BINARY_OPERATION(logical_or_, operator||, bool, "||");
241 logical_and_, operator&&, bool, "&&",
242 using view_transitivity = std::tuple<std::true_type, std::true_type>;
243 inline std::pair<int, int> extreme_diagonals() const {
244 std::pair<int, int> diags0
245 = this->template get_arg<0>().extreme_diagonals();
246 std::pair<int, int> diags1
247 = this->template get_arg<1>().extreme_diagonals();
248 return {std::max(diags0.first, diags1.first),
249 std::min(diags0.second, diags1.second)};
250 });
251
260template <typename T_a, typename T_b, typename = require_arithmetic_t<T_a>,
261 typename = require_all_kernel_expressions_t<T_b>>
263 T_a a, T_b&& b) { // NOLINT
264 return {as_operation_cl(a), as_operation_cl(std::forward<T_b>(b))};
265}
266
275template <typename T_a, typename T_b,
277 typename = require_arithmetic_t<T_b>>
279 T_a&& a, const T_b b) { // NOLINT
280 return {as_operation_cl(std::forward<T_a>(a)), as_operation_cl(b)};
281}
282
283#undef COMMA
284#undef ADD_BINARY_OPERATION
285#undef ADD_BINARY_OPERATION_WITH_CUSTOM_CODE
287} // namespace math
288} // namespace stan
289#endif
290#endif
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_a, const std::string &var_name_b) const
Generates kernel code for this expression.
binary_operation(T_a &&a, T_b &&b, const std::string &op)
Constructor.
Represents a binary operation in kernel generator expressions.
std::pair< int, int > extreme_diagonals() const
static constexpr int dynamic
std::tuple< Args... > arguments_
std::tuple< std::is_same< Args, void >... > view_transitivity
Base for all kernel generator operations.
require_t< std::is_arithmetic< std::decay_t< T > > > require_arithmetic_t
Require type satisfies std::is_arithmetic.
elt_multiply_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_multiply(T_a &&a, T_b &&b)
subtraction_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > subtract(T_a &&a, T_b &&b)
addition_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > add(T_a &&a, T_b &&b)
elt_divide_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_divide(T_a &&a, T_b &&b)
#define ADD_BINARY_OPERATION_WITH_CUSTOM_CODE( class_name, function_name, scalar_type_expr, operation,...)
Defines a new binary operation in kernel generator that needs to implement custom function that deter...
require_all_t< is_kernel_expression< Types >... > require_all_kernel_expressions_t
Enables a template if all given types are are a valid kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
std::conditional_t< std::is_lvalue_reference< T >::value, decltype(as_operation_cl< AssignOp >(std::declval< T >())), std::remove_reference_t< decltype(as_operation_cl< AssignOp >(std::declval< T >()))> > as_operation_cl_t
Type that results when converting any valid kernel generator expression into operation.
#define ADD_BINARY_OPERATION(class_name, function_name, scalar_type_expr, operation)
Defines a new binary operation in kernel generator.
int rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:21
int cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
Definition cols.hpp:20
fvar< T > operator*(const fvar< T > &x, const fvar< T > &y)
Return the product of the two arguments.
typename std::common_type_t< typename std::remove_reference_t< Types >::Scalar... > common_scalar_t
Wrapper for std::common_type_t
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
typename scalar_type< T >::type scalar_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Parts of an OpenCL kernel, generated by an expression.