Automatic Differentiation
 
Loading...
Searching...
No Matches
elt_function_cl.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ELT_FUNCTION_CL_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ELT_FUNCTION_CL_HPP
3#ifdef STAN_OPENCL
4
34#include <array>
35#include <string>
36#include <type_traits>
37#include <set>
38#include <utility>
39
40namespace stan {
41namespace math {
42
53template <typename Derived, typename Scal, typename... T>
54class elt_function_cl : public operation_cl<Derived, Scal, T...> {
55 public:
56 using Scalar = Scal;
57 using base = operation_cl<Derived, Scalar, T...>;
58 using base::var_name_;
59
65 elt_function_cl(const std::string& fun, T&&... args) // NOLINT
66 : base(std::forward<T>(args)...), fun_(fun) {}
67
77 const std::string& row_index_name, const std::string& col_index_name,
78 const bool view_handled,
79 std::conditional_t<false, T, const std::string&>... var_names_arg) const {
80 kernel_parts res{};
81
82 for (const char* incl : base::derived().includes) {
83 res.includes += incl;
84 }
85 std::array<std::string, sizeof...(T)> var_names_arg_arr
86 = {(var_names_arg + ", ")...};
87 std::string var_names_list = std::accumulate(
88 var_names_arg_arr.begin(), var_names_arg_arr.end(), std::string());
89 res.body = type_str<Scalar>() + " " + var_name_ + " = " + fun_ + "((double)"
90 + var_names_list.substr(0, var_names_list.size() - 2) + ");\n";
91 return res;
92 }
93
94 protected:
95 std::string fun_;
96};
97
104#define ADD_BINARY_FUNCTION_WITH_INCLUDES(fun, ...) \
105 template <typename T1, typename T2> \
106 class fun##_ : public elt_function_cl<fun##_<T1, T2>, double, T1, T2> { \
107 using base = elt_function_cl<fun##_<T1, T2>, double, T1, T2>; \
108 using base::arguments_; \
109 \
110 public: \
111 using base::rows; \
112 using base::cols; \
113 static const std::vector<const char*> includes; \
114 explicit fun##_(T1&& a, T2&& b) \
115 : base(#fun, std::forward<T1>(a), std::forward<T2>(b)) { \
116 if (a.rows() != base::dynamic && b.rows() != base::dynamic) { \
117 check_size_match(#fun, "Rows of ", "a", a.rows(), "rows of ", "b", \
118 b.rows()); \
119 } \
120 if (a.cols() != base::dynamic && b.cols() != base::dynamic) { \
121 check_size_match(#fun, "Columns of ", "a", a.cols(), "columns of ", \
122 "b", b.cols()); \
123 } \
124 } \
125 inline auto deep_copy() const { \
126 auto&& arg1_copy = this->template get_arg<0>().deep_copy(); \
127 auto&& arg2_copy = this->template get_arg<1>().deep_copy(); \
128 return fun##_<std::remove_reference_t<decltype(arg1_copy)>, \
129 std::remove_reference_t<decltype(arg2_copy)>>{ \
130 std::move(arg1_copy), std::move(arg2_copy)}; \
131 } \
132 inline std::pair<int, int> extreme_diagonals() const { \
133 return {-rows() + 1, cols() - 1}; \
134 } \
135 }; \
136 \
137 template <typename T1, typename T2, \
138 require_all_kernel_expressions_t<T1, T2>* = nullptr, \
139 require_any_not_stan_scalar_t<T1, T2>* = nullptr> \
140 inline fun##_<as_operation_cl_t<T1>, as_operation_cl_t<T2>> fun(T1&& a, \
141 T2&& b) { \
142 return fun##_<as_operation_cl_t<T1>, as_operation_cl_t<T2>>( \
143 as_operation_cl(std::forward<T1>(a)), \
144 as_operation_cl(std::forward<T2>(b))); \
145 } \
146 template <typename T1, typename T2> \
147 const std::vector<const char*> fun##_<T1, T2>::includes{__VA_ARGS__};
148
155#define ADD_UNARY_FUNCTION_WITH_INCLUDES(fun, ...) \
156 template <typename T> \
157 class fun##_ : public elt_function_cl<fun##_<T>, double, T> { \
158 using base = elt_function_cl<fun##_<T>, double, T>; \
159 using base::arguments_; \
160 \
161 public: \
162 using base::rows; \
163 using base::cols; \
164 static const std::vector<const char*> includes; \
165 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
166 inline auto deep_copy() const { \
167 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
168 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
169 std::move(arg_copy)}; \
170 } \
171 inline std::pair<int, int> extreme_diagonals() const { \
172 return {-rows() + 1, cols() - 1}; \
173 } \
174 }; \
175 \
176 template <typename T, typename Cond \
177 = require_all_kernel_expressions_and_none_scalar_t<T>> \
178 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
179 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
180 } \
181 template <typename T> \
182 const std::vector<const char*> fun##_<T>::includes{__VA_ARGS__};
183
189#define ADD_UNARY_FUNCTION(fun) ADD_UNARY_FUNCTION_WITH_INCLUDES(fun)
190
197#define ADD_UNARY_FUNCTION_PASS_ZERO(fun) \
198 template <typename T> \
199 class fun##_ : public elt_function_cl<fun##_<T>, double, T> { \
200 using base = elt_function_cl<fun##_<T>, double, T>; \
201 using base::arguments_; \
202 \
203 public: \
204 using base::rows; \
205 using base::cols; \
206 static constexpr auto view_transitivness = std::make_tuple(true); \
207 static const std::vector<const char*> includes; \
208 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
209 inline auto deep_copy() const { \
210 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
211 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
212 std::move(arg_copy)}; \
213 } \
214 }; \
215 \
216 template <typename T, typename Cond \
217 = require_all_kernel_expressions_and_none_scalar_t<T>> \
218 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
219 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
220 } \
221 template <typename T> \
222 const std::vector<const char*> fun##_<T>::includes{};
223
230#define ADD_CLASSIFICATION_FUNCTION(fun, ...) \
231 template <typename T> \
232 class fun##_ : public elt_function_cl<fun##_<T>, bool, T> { \
233 using base = elt_function_cl<fun##_<T>, bool, T>; \
234 using base::arguments_; \
235 \
236 public: \
237 using base::rows; \
238 using base::cols; \
239 static constexpr auto view_transitivness = std::make_tuple(true); \
240 static const std::vector<const char*> includes; \
241 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
242 inline auto deep_copy() const { \
243 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
244 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
245 std::move(arg_copy)}; \
246 } \
247 inline std::pair<int, int> extreme_diagonals() const { \
248 return __VA_ARGS__; \
249 } \
250 }; \
251 \
252 template <typename T, typename Cond \
253 = require_all_kernel_expressions_and_none_scalar_t<T>> \
254 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
255 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
256 } \
257 template <typename T> \
258 const std::vector<const char*> fun##_<T>::includes{};
259
263
267
272
285
290
296
298 opencl_kernels::digamma_device_function)
299ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m, opencl_kernels::log1m_device_function)
301 opencl_kernels::log1p_exp_device_function,
302 opencl_kernels::log_inv_logit_device_function)
304 opencl_kernels::log1m_exp_device_function)
306 opencl_kernels::log1p_exp_device_function)
308 opencl_kernels::inv_square_device_function)
310 opencl_kernels::inv_logit_device_function)
311ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::log1m_device_function,
312 opencl_kernels::logit_device_function)
313ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function)
315 opencl_kernels::inv_logit_device_function,
316 opencl_kernels::phi_approx_device_function)
317ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_Phi, opencl_kernels::log1m_device_function,
318 opencl_kernels::phi_device_function,
319 opencl_kernels::inv_phi_device_function)
321 log1m_inv_logit, opencl_kernels::log1p_exp_device_function,
322 opencl_kernels::log1m_inv_logit_device_function)
324 opencl_kernels::trigamma_device_function)
326 square,
327 "\n#ifndef STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_SQUARE\n"
328 "#define STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_SQUARE\n"
329 "double square(double x){return x*x;}\n"
330 "#endif\n")
331
332ADD_CLASSIFICATION_FUNCTION(isfinite, {-rows() + 1, cols() - 1})
333ADD_CLASSIFICATION_FUNCTION(isinf,
334 this->template get_arg<0>().extreme_diagonals())
335ADD_CLASSIFICATION_FUNCTION(isnan,
336 this->template get_arg<0>().extreme_diagonals())
337
338ADD_BINARY_FUNCTION_WITH_INCLUDES(fdim)
339ADD_BINARY_FUNCTION_WITH_INCLUDES(fmax)
340ADD_BINARY_FUNCTION_WITH_INCLUDES(fmin)
341ADD_BINARY_FUNCTION_WITH_INCLUDES(fmod)
342ADD_BINARY_FUNCTION_WITH_INCLUDES(hypot)
343ADD_BINARY_FUNCTION_WITH_INCLUDES(ldexp)
344ADD_BINARY_FUNCTION_WITH_INCLUDES(pow)
345ADD_BINARY_FUNCTION_WITH_INCLUDES(copysign)
346
347ADD_BINARY_FUNCTION_WITH_INCLUDES(
348 beta, stan::math::opencl_kernels::beta_device_function)
349ADD_BINARY_FUNCTION_WITH_INCLUDES(
350 binomial_coefficient_log,
351 stan::math::opencl_kernels::lgamma_stirling_device_function,
352 stan::math::opencl_kernels::lgamma_stirling_diff_device_function,
353 stan::math::opencl_kernels::lbeta_device_function,
354 stan::math::opencl_kernels::binomial_coefficient_log_device_function)
355ADD_BINARY_FUNCTION_WITH_INCLUDES(
356 lbeta, stan::math::opencl_kernels::lgamma_stirling_device_function,
357 stan::math::opencl_kernels::lgamma_stirling_diff_device_function,
358 stan::math::opencl_kernels::lbeta_device_function)
359ADD_BINARY_FUNCTION_WITH_INCLUDES(
360 log_inv_logit_diff, opencl_kernels::log1p_exp_device_function,
361 opencl_kernels::log1m_exp_device_function,
362 opencl_kernels::log_inv_logit_diff_device_function)
363ADD_BINARY_FUNCTION_WITH_INCLUDES(log_diff_exp,
364 opencl_kernels::log1m_exp_device_function,
365 opencl_kernels::log_diff_exp_device_function)
366ADD_BINARY_FUNCTION_WITH_INCLUDES(
367 multiply_log, stan::math::opencl_kernels::multiply_log_device_function)
368ADD_BINARY_FUNCTION_WITH_INCLUDES(
369 lmultiply, stan::math::opencl_kernels::lmultiply_device_function)
370
371#undef ADD_BINARY_FUNCTION_WITH_INCLUDES
372#undef ADD_UNARY_FUNCTION_WITH_INCLUDES
373#undef ADD_UNARY_FUNCTION
374#undef ADD_UNARY_FUNCTION_PASS_ZERO
375#undef ADD_CLASSIFICATION_FUNCTION
376
378} // namespace math
379} // namespace stan
380#endif
381#endif
elt_function_cl(const std::string &fun, T &&... args)
Constructor.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, std::conditional_t< false, T, const std::string & >... var_names_arg) const
Generates kernel code for this expression.
Represents an element-wise function in kernel generator expressions.
Derived & derived()
Casts the instance into its derived type.
Base for all kernel generator operations.
rsqrt_< as_operation_cl_t< T > > rsqrt(T &&a)
#define ADD_UNARY_FUNCTION(fun)
Generates a class and function for a general unary function that is defined by OpenCL.
#define ADD_UNARY_FUNCTION_PASS_ZERO(fun)
Generates a class and function for an unary function, defined by OpenCL with special property that it...
#define ADD_UNARY_FUNCTION_WITH_INCLUDES(fun,...)
Generates a class and function for a general unary function that is defined by OpenCL or in the inclu...
fvar< T > acos(const fvar< T > &x)
Definition acos.hpp:15
fvar< T > sin(const fvar< T > &x)
Definition sin.hpp:14
fvar< T > acosh(const fvar< T > &x)
Definition acosh.hpp:16
fvar< T > logit(const fvar< T > &x)
Definition logit.hpp:14
fvar< T > expm1(const fvar< T > &x)
Definition expm1.hpp:13
fvar< T > atanh(const fvar< T > &x)
Return inverse hyperbolic tangent of specified value.
Definition atanh.hpp:24
fvar< T > inv_square(const fvar< T > &x)
fvar< T > exp2(const fvar< T > &x)
Definition exp2.hpp:14
constexpr double log2()
Return natural logarithm of two.
Definition log2.hpp:17
fvar< T > log1m_exp(const fvar< T > &x)
Return the natural logarithm of one minus the exponentiation of the specified argument.
Definition log1m_exp.hpp:23
fvar< T > asinh(const fvar< T > &x)
Definition asinh.hpp:15
fvar< T > cosh(const fvar< T > &x)
Definition cosh.hpp:15
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
fvar< T > erf(const fvar< T > &x)
Definition erf.hpp:15
fvar< T > Phi_approx(const fvar< T > &x)
Return an approximation of the unit normal cumulative distribution function (CDF).
fvar< T > log_inv_logit(const fvar< T > &x)
fvar< T > cbrt(const fvar< T > &x)
Return cube root of specified argument.
Definition cbrt.hpp:20
fvar< T > sinh(const fvar< T > &x)
Definition sinh.hpp:13
fvar< T > log1p_exp(const fvar< T > &x)
Definition log1p_exp.hpp:13
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:17
fvar< T > atan(const fvar< T > &x)
Definition atan.hpp:15
fvar< T > trigamma(const fvar< T > &u)
Return the value of the trigamma function at the specified argument (i.e., the second derivative of t...
Definition trigamma.hpp:21
fvar< T > tan(const fvar< T > &x)
Definition tan.hpp:14
fvar< T > Phi(const fvar< T > &x)
Definition Phi.hpp:14
fvar< T > erfc(const fvar< T > &x)
Definition erfc.hpp:15
fvar< T > log1p(const fvar< T > &x)
Definition log1p.hpp:12
fvar< T > inv_Phi(const fvar< T > &p)
Definition inv_Phi.hpp:15
fvar< T > floor(const fvar< T > &x)
Definition floor.hpp:12
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
Definition lgamma.hpp:21
static constexpr double log10()
Returns the natural logarithm of ten.
fvar< T > tanh(const fvar< T > &x)
Definition tanh.hpp:15
fvar< T > cos(const fvar< T > &x)
Definition cos.hpp:14
fvar< T > round(const fvar< T > &x)
Return the closest integer to the specified argument, with halfway cases rounded away from zero.
Definition round.hpp:24
fvar< T > tgamma(const fvar< T > &x)
Return the result of applying the gamma function to the specified argument.
Definition tgamma.hpp:21
fvar< T > asin(const fvar< T > &x)
Definition asin.hpp:15
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
Definition inv_logit.hpp:20
fvar< T > ceil(const fvar< T > &x)
Definition ceil.hpp:12
fvar< T > log1m(const fvar< T > &x)
Definition log1m.hpp:12
fvar< T > log1m_inv_logit(const fvar< T > &x)
Return the natural logarithm of one minus the inverse logit of the specified argument.
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
Definition digamma.hpp:23
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
fvar< T > trunc(const fvar< T > &x)
Return the nearest integral value that is not larger in magnitude than the specified argument.
Definition trunc.hpp:20
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Parts of an OpenCL kernel, generated by an expression.