Automatic Differentiation
 
Loading...
Searching...
No Matches
calc_if.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_CALC_IF_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_CALC_IF_HPP
3#ifdef STAN_OPENCL
4
11#include <string>
12#include <type_traits>
13#include <map>
14#include <utility>
15
16namespace stan {
17namespace math {
18
28template <bool Do_Calculate, typename T>
30 : public operation_cl<calc_if_<Do_Calculate, T>,
31 typename std::remove_reference_t<T>::Scalar, T> {
32 public:
33 using Scalar = typename std::remove_reference_t<T>::Scalar;
35 using base::var_name_;
36
41 explicit calc_if_(T&& a) : base(std::forward<T>(a)) {}
42
43 inline kernel_parts generate(const std::string& row_index_name,
44 const std::string& col_index_name,
45 const bool view_handled,
46 const std::string& var_name_arg) const {
47 if (Do_Calculate) {
48 var_name_ = var_name_arg;
49 }
50 return {};
51 }
52
67 template <typename T_result>
69 std::unordered_map<const void*, const char*>& generated,
70 std::unordered_map<const void*, const char*>& generated_all,
71 name_generator& ng, const std::string& row_index_name,
72 const std::string& col_index_name, const T_result& result) const {
73 if (Do_Calculate) {
74 return this->template get_arg<0>().get_whole_kernel_parts(
75 generated, generated_all, ng, row_index_name, col_index_name, result);
76 } else {
77 return {};
78 }
79 }
80
91 inline void set_args(
92 std::unordered_map<const void*, const char*>& generated,
93 std::unordered_map<const void*, const char*>& generated_all,
94 cl::Kernel& kernel, int& arg_num) const {
95 if (Do_Calculate) {
96 this->template get_arg<0>().set_args(generated, generated_all, kernel,
97 arg_num);
98 }
99 }
100
105 inline int thread_rows() const {
106 return this->template get_arg<0>().thread_rows();
107 }
108
113 inline int thread_cols() const {
114 return this->template get_arg<0>().thread_cols();
115 }
116};
117
118template <bool Do_Calculate, typename T,
120 std::enable_if_t<Do_Calculate>* = nullptr>
123 as_operation_cl(std::forward<T>(a)));
124}
125
126template <bool Do_Calculate, typename T,
127 std::enable_if_t<!Do_Calculate>* = nullptr>
130}
131
132namespace internal {
133template <typename T>
134struct is_without_output_impl : std::false_type {};
135
136template <typename T>
137struct is_without_output_impl<calc_if_<false, T>> : std::true_type {};
138} // namespace internal
139
140template <typename T>
143} // namespace math
144} // namespace stan
145
146#endif
147#endif // STAN_MATH_OPENCL_KERNEL_GENERATOR_calc_if_HPP
calc_if_(T &&a)
Constructor.
Definition calc_if.hpp:41
int thread_rows() const
Number of rows threads need to be launched for.
Definition calc_if.hpp:105
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Definition calc_if.hpp:43
kernel_parts get_whole_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &ng, const std::string &row_index_name, const std::string &col_index_name, const T_result &result) const
Generates kernel code for assigning this expression into result expression.
Definition calc_if.hpp:68
typename std::remove_reference_t< T >::Scalar Scalar
Definition calc_if.hpp:33
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for nested expressions.
Definition calc_if.hpp:91
int thread_cols() const
Number of columns threads need to be launched for.
Definition calc_if.hpp:113
Represents a calc_if in kernel generator expressions.
Definition calc_if.hpp:31
Unique name generator for variables used in generated kernels.
Base for all kernel generator operations.
Represents a scalar in kernel generator expressions.
Definition scalar.hpp:29
calc_if_< true, as_operation_cl_t< T > > calc_if(T &&a)
Definition calc_if.hpp:121
require_all_t< is_kernel_expression< Types >... > require_all_kernel_expressions_t
Enables a template if all given types are are a valid kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Parts of an OpenCL kernel, generated by an expression.