Automatic Differentiation
 
Loading...
Searching...
No Matches
unary_operation_cl.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_UNARY_OPERATION_CL_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_UNARY_OPERATION_CL_HPP
3#ifdef STAN_OPENCL
4
9#include <string>
10#include <type_traits>
11#include <set>
12#include <utility>
13
14namespace stan {
15namespace math {
16
22template <typename Derived, typename T, typename Scal>
24 : public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
25 T> {
26 public:
27 using Scalar = Scal;
29 using base::var_name_;
30
36 unary_operation_cl(T&& a, const std::string& op)
37 : base(std::forward<T>(a)), op_(op) {}
38
47 inline kernel_parts generate(const std::string& row_index_name,
48 const std::string& col_index_name,
49 const bool view_handled,
50 const std::string& var_name_arg) const {
51 kernel_parts res{};
52 res.body = type_str<Scalar>() + " " + var_name_ + " = " + op_ + var_name_arg
53 + ";\n";
54 return res;
55 }
56
57 protected:
58 std::string op_;
59};
60
66template <typename T>
68 : public unary_operation_cl<logical_negation_<T>, T, bool> {
69 static_assert(
70 std::is_integral<typename std::remove_reference_t<T>::Scalar>::value,
71 "logical_negation: argument must be expression with integral "
72 "or boolean return type!");
74 using base::arguments_;
75
76 public:
81 explicit logical_negation_(T&& a) : base(std::forward<T>(a), "!") {}
82
87 inline auto deep_copy() const {
88 auto&& arg_copy = this->template get_arg<0>().deep_copy();
89 return logical_negation_<std::remove_reference_t<decltype(arg_copy)>>{
90 std::move(arg_copy)};
91 }
92
97 inline matrix_cl_view view() const { return matrix_cl_view::Entire; }
98};
99
107template <typename T,
111 as_operation_cl(std::forward<T>(a)));
112}
113
119template <typename T>
121 : public unary_operation_cl<unary_minus_<T>, T,
122 typename std::remove_reference_t<T>::Scalar> {
124 typename std::remove_reference_t<T>::Scalar>;
125 using base::arguments_;
126
127 public:
132 explicit unary_minus_(T&& a) : base(std::forward<T>(a), "-") {}
133
138 inline auto deep_copy() const {
139 auto&& arg_copy = this->template get_arg<0>().deep_copy();
140 return unary_minus_<std::remove_reference_t<decltype(arg_copy)>>{
141 std::move(arg_copy)};
142 }
143
148 inline matrix_cl_view view() const {
149 return this->template get_arg<0>().view();
150 }
151};
152
160template <typename T,
164 as_operation_cl(std::forward<T>(a)));
165}
166
167} // namespace math
168} // namespace stan
169
170#endif
171#endif
auto deep_copy() const
Creates a deep copy of this expression.
matrix_cl_view view() const
View of a matrix that would be the result of evaluating this expression.
Represents a logical negation in kernel generator expressions.
std::tuple< Args... > arguments_
Base for all kernel generator operations.
auto deep_copy() const
Creates a deep copy of this expression.
matrix_cl_view view() const
View of a matrix that would be the result of evaluating this expression.
Represents an unary minus operation in kernel generator expressions.
unary_operation_cl(T &&a, const std::string &op)
Constructor.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this expression.
Represents a unary operation in kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
fvar< T > operator-(const fvar< T > &x1, const fvar< T > &x2)
Return the difference of the specified arguments.
bool operator!(const fvar< T > &x)
Return the negation of the value of the argument as defined by !.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Parts of an OpenCL kernel, generated by an expression.