Automatic Differentiation
 
Loading...
Searching...
No Matches
elt_function_cl.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ELT_FUNCTION_CL_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ELT_FUNCTION_CL_HPP
3#ifdef STAN_OPENCL
4
35#include <array>
36#include <string>
37#include <type_traits>
38#include <set>
39#include <utility>
40
41namespace stan {
42namespace math {
43
54template <typename Derived, typename Scal, typename... T>
55class elt_function_cl : public operation_cl<Derived, Scal, T...> {
56 public:
57 using Scalar = Scal;
58 using base = operation_cl<Derived, Scalar, T...>;
59 using base::var_name_;
60
66 elt_function_cl(const std::string& fun, T&&... args) // NOLINT
67 : base(std::forward<T>(args)...), fun_(fun) {}
68
78 const std::string& row_index_name, const std::string& col_index_name,
79 const bool view_handled,
80 std::conditional_t<false, T, const std::string&>... var_names_arg) const {
81 kernel_parts res{};
82
83 for (const char* incl : base::derived().includes) {
84 res.includes += incl;
85 }
86 std::array<std::string, sizeof...(T)> var_names_arg_arr
87 = {(var_names_arg + ", ")...};
88 std::string var_names_list = std::accumulate(
89 var_names_arg_arr.begin(), var_names_arg_arr.end(), std::string());
90 res.body = type_str<Scalar>() + " " + var_name_ + " = " + fun_ + "((double)"
91 + var_names_list.substr(0, var_names_list.size() - 2) + ");\n";
92 return res;
93 }
94
95 protected:
96 std::string fun_;
97};
98
105#define ADD_BINARY_FUNCTION_WITH_INCLUDES(fun, ...) \
106 template <typename T1, typename T2> \
107 class fun##_ : public elt_function_cl<fun##_<T1, T2>, double, T1, T2> { \
108 using base = elt_function_cl<fun##_<T1, T2>, double, T1, T2>; \
109 using base::arguments_; \
110 \
111 public: \
112 using base::rows; \
113 using base::cols; \
114 static const std::vector<const char*> includes; \
115 explicit fun##_(T1&& a, T2&& b) \
116 : base(#fun, std::forward<T1>(a), std::forward<T2>(b)) { \
117 if (a.rows() != base::dynamic && b.rows() != base::dynamic) { \
118 check_size_match(#fun, "Rows of ", "a", a.rows(), "rows of ", "b", \
119 b.rows()); \
120 } \
121 if (a.cols() != base::dynamic && b.cols() != base::dynamic) { \
122 check_size_match(#fun, "Columns of ", "a", a.cols(), "columns of ", \
123 "b", b.cols()); \
124 } \
125 } \
126 inline auto deep_copy() const { \
127 auto&& arg1_copy = this->template get_arg<0>().deep_copy(); \
128 auto&& arg2_copy = this->template get_arg<1>().deep_copy(); \
129 return fun##_<std::remove_reference_t<decltype(arg1_copy)>, \
130 std::remove_reference_t<decltype(arg2_copy)>>{ \
131 std::move(arg1_copy), std::move(arg2_copy)}; \
132 } \
133 inline std::pair<int, int> extreme_diagonals() const { \
134 return {-rows() + 1, cols() - 1}; \
135 } \
136 }; \
137 \
138 template <typename T1, typename T2, \
139 require_all_kernel_expressions_t<T1, T2>* = nullptr, \
140 require_any_not_stan_scalar_t<T1, T2>* = nullptr> \
141 inline fun##_<as_operation_cl_t<T1>, as_operation_cl_t<T2>> fun(T1&& a, \
142 T2&& b) { \
143 return fun##_<as_operation_cl_t<T1>, as_operation_cl_t<T2>>( \
144 as_operation_cl(std::forward<T1>(a)), \
145 as_operation_cl(std::forward<T2>(b))); \
146 } \
147 template <typename T1, typename T2> \
148 const std::vector<const char*> fun##_<T1, T2>::includes{__VA_ARGS__};
149
156#define ADD_UNARY_FUNCTION_WITH_INCLUDES(fun, ...) \
157 template <typename T> \
158 class fun##_ : public elt_function_cl<fun##_<T>, double, T> { \
159 using base = elt_function_cl<fun##_<T>, double, T>; \
160 using base::arguments_; \
161 \
162 public: \
163 using base::rows; \
164 using base::cols; \
165 static const std::vector<const char*> includes; \
166 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
167 inline auto deep_copy() const { \
168 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
169 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
170 std::move(arg_copy)}; \
171 } \
172 inline std::pair<int, int> extreme_diagonals() const { \
173 return {-rows() + 1, cols() - 1}; \
174 } \
175 }; \
176 \
177 template <typename T, typename Cond \
178 = require_all_kernel_expressions_and_none_scalar_t<T>> \
179 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
180 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
181 } \
182 template <typename T> \
183 const std::vector<const char*> fun##_<T>::includes{__VA_ARGS__};
184
190#define ADD_UNARY_FUNCTION(fun) ADD_UNARY_FUNCTION_WITH_INCLUDES(fun)
191
198#define ADD_UNARY_FUNCTION_PASS_ZERO(fun) \
199 template <typename T> \
200 class fun##_ : public elt_function_cl<fun##_<T>, double, T> { \
201 using base = elt_function_cl<fun##_<T>, double, T>; \
202 using base::arguments_; \
203 \
204 public: \
205 using base::rows; \
206 using base::cols; \
207 static constexpr auto view_transitivness = std::make_tuple(true); \
208 static const std::vector<const char*> includes; \
209 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
210 inline auto deep_copy() const { \
211 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
212 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
213 std::move(arg_copy)}; \
214 } \
215 }; \
216 \
217 template <typename T, typename Cond \
218 = require_all_kernel_expressions_and_none_scalar_t<T>> \
219 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
220 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
221 } \
222 template <typename T> \
223 const std::vector<const char*> fun##_<T>::includes{};
224
231#define ADD_CLASSIFICATION_FUNCTION(fun, ...) \
232 template <typename T> \
233 class fun##_ : public elt_function_cl<fun##_<T>, bool, T> { \
234 using base = elt_function_cl<fun##_<T>, bool, T>; \
235 using base::arguments_; \
236 \
237 public: \
238 using base::rows; \
239 using base::cols; \
240 static constexpr auto view_transitivness = std::make_tuple(true); \
241 static const std::vector<const char*> includes; \
242 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
243 inline auto deep_copy() const { \
244 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
245 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
246 std::move(arg_copy)}; \
247 } \
248 inline std::pair<int, int> extreme_diagonals() const { \
249 return __VA_ARGS__; \
250 } \
251 }; \
252 \
253 template <typename T, typename Cond \
254 = require_all_kernel_expressions_and_none_scalar_t<T>> \
255 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
256 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
257 } \
258 template <typename T> \
259 const std::vector<const char*> fun##_<T>::includes{};
260
264
268
273
286
291
297
299 opencl_kernels::digamma_device_function)
300ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m, opencl_kernels::log1m_device_function)
302 opencl_kernels::log1p_exp_device_function,
303 opencl_kernels::log_inv_logit_device_function)
305 opencl_kernels::log1m_exp_device_function)
307 opencl_kernels::log1p_exp_device_function)
309 opencl_kernels::inv_square_device_function)
311 opencl_kernels::inv_logit_device_function)
312ADD_UNARY_FUNCTION_WITH_INCLUDES(logit, opencl_kernels::log1m_device_function,
313 opencl_kernels::logit_device_function)
314ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi, opencl_kernels::phi_device_function)
316 opencl_kernels::inv_logit_device_function,
317 opencl_kernels::phi_approx_device_function)
320 opencl_kernels::std_normal_lcdf_device_function)
323 opencl_kernels::std_normal_lcdf_device_function)
324ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_Phi, opencl_kernels::log1m_device_function,
325 opencl_kernels::phi_device_function,
326 opencl_kernels::inv_phi_device_function)
328 log1m_inv_logit, opencl_kernels::log1p_exp_device_function,
329 opencl_kernels::log1m_inv_logit_device_function)
331 opencl_kernels::trigamma_device_function)
333 square,
334 "\n#ifndef STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_SQUARE\n"
335 "#define STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_SQUARE\n"
336 "double square(double x){return x*x;}\n"
337 "#endif\n")
338
339ADD_CLASSIFICATION_FUNCTION(isfinite, {-rows() + 1, cols() - 1})
340ADD_CLASSIFICATION_FUNCTION(isinf,
341 this->template get_arg<0>().extreme_diagonals())
342ADD_CLASSIFICATION_FUNCTION(isnan,
343 this->template get_arg<0>().extreme_diagonals())
344
345ADD_BINARY_FUNCTION_WITH_INCLUDES(fdim)
346ADD_BINARY_FUNCTION_WITH_INCLUDES(fmax)
347ADD_BINARY_FUNCTION_WITH_INCLUDES(fmin)
348ADD_BINARY_FUNCTION_WITH_INCLUDES(fmod)
349ADD_BINARY_FUNCTION_WITH_INCLUDES(hypot)
350ADD_BINARY_FUNCTION_WITH_INCLUDES(ldexp)
351ADD_BINARY_FUNCTION_WITH_INCLUDES(pow)
352ADD_BINARY_FUNCTION_WITH_INCLUDES(copysign)
353
354ADD_BINARY_FUNCTION_WITH_INCLUDES(
355 beta, stan::math::opencl_kernels::beta_device_function)
356ADD_BINARY_FUNCTION_WITH_INCLUDES(
357 binomial_coefficient_log,
358 stan::math::opencl_kernels::lgamma_stirling_device_function,
359 stan::math::opencl_kernels::lgamma_stirling_diff_device_function,
360 stan::math::opencl_kernels::lbeta_device_function,
361 stan::math::opencl_kernels::binomial_coefficient_log_device_function)
362template <typename T1, typename T2>
363class lbeta_ : public elt_function_cl<lbeta_<T1, T2>, double, T1, T2> {
364 using base = elt_function_cl<lbeta_<T1, T2>, double, T1, T2>;
365 using base::arguments_;
366
367 public:
368 using base::cols;
369 using base::rows;
370 static const std::vector<const char*> includes;
371 explicit lbeta_(T1&& a, T2&& b)
372 : base("stan_lbeta", std::forward<T1>(a), std::forward<T2>(b)) {
373 if (a.rows() != base::dynamic && b.rows() != base::dynamic) {
374 check_size_match("lbeta", "Rows of ", "a", a.rows(), "rows of ", "b",
375 b.rows());
376 }
377 if (a.cols() != base::dynamic && b.cols() != base::dynamic) {
378 check_size_match("lbeta", "Columns of ", "a", a.cols(), "columns of ",
379 "b", b.cols());
380 }
381 }
382 inline auto deep_copy() const {
383 auto&& arg1_copy = this->template get_arg<0>().deep_copy();
384 auto&& arg2_copy = this->template get_arg<1>().deep_copy();
385 return lbeta_<std::remove_reference_t<decltype(arg1_copy)>,
386 std::remove_reference_t<decltype(arg2_copy)>>{
387 std::move(arg1_copy), std::move(arg2_copy)};
388 }
389 inline std::pair<int, int> extreme_diagonals() const {
390 return {-rows() + 1, cols() - 1};
391 }
392};
393
394template <typename T1, typename T2,
395 require_all_kernel_expressions_t<T1, T2>* = nullptr,
396 require_any_not_stan_scalar_t<T1, T2>* = nullptr>
397inline lbeta_<as_operation_cl_t<T1>, as_operation_cl_t<T2>> lbeta(T1&& a,
398 T2&& b) {
399 return lbeta_<as_operation_cl_t<T1>, as_operation_cl_t<T2>>(
400 as_operation_cl(std::forward<T1>(a)),
401 as_operation_cl(std::forward<T2>(b)));
402}
403
404template <typename T1, typename T2>
405const std::vector<const char*> lbeta_<T1, T2>::includes{
406 stan::math::opencl_kernels::lgamma_stirling_device_function,
407 stan::math::opencl_kernels::lgamma_stirling_diff_device_function,
408 stan::math::opencl_kernels::lbeta_device_function};
409ADD_BINARY_FUNCTION_WITH_INCLUDES(
410 log_inv_logit_diff, opencl_kernels::log1p_exp_device_function,
411 opencl_kernels::log1m_exp_device_function,
412 opencl_kernels::log_inv_logit_diff_device_function)
413ADD_BINARY_FUNCTION_WITH_INCLUDES(log_diff_exp,
414 opencl_kernels::log1m_exp_device_function,
415 opencl_kernels::log_diff_exp_device_function)
416ADD_BINARY_FUNCTION_WITH_INCLUDES(
417 multiply_log, stan::math::opencl_kernels::multiply_log_device_function)
418ADD_BINARY_FUNCTION_WITH_INCLUDES(
419 lmultiply, stan::math::opencl_kernels::lmultiply_device_function)
420
421#undef ADD_BINARY_FUNCTION_WITH_INCLUDES
422#undef ADD_UNARY_FUNCTION_WITH_INCLUDES
423#undef ADD_UNARY_FUNCTION
424#undef ADD_UNARY_FUNCTION_PASS_ZERO
425#undef ADD_CLASSIFICATION_FUNCTION
426
428} // namespace math
429} // namespace stan
430#endif
431#endif
elt_function_cl(const std::string &fun, T &&... args)
Constructor.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, std::conditional_t< false, T, const std::string & >... var_names_arg) const
Generates kernel code for this expression.
Represents an element-wise function in kernel generator expressions.
Derived & derived()
Casts the instance into its derived type.
Base for all kernel generator operations.
rsqrt_< as_operation_cl_t< T > > rsqrt(T &&a)
#define ADD_UNARY_FUNCTION(fun)
Generates a class and function for a general unary function that is defined by OpenCL.
#define ADD_UNARY_FUNCTION_PASS_ZERO(fun)
Generates a class and function for an unary function, defined by OpenCL with special property that it...
#define ADD_UNARY_FUNCTION_WITH_INCLUDES(fun,...)
Generates a class and function for a general unary function that is defined by OpenCL or in the inclu...
std_normal_lcdf_dscaled_impl_< as_operation_cl_t< T > > std_normal_lcdf_dscaled_impl(T &&a)
std_normal_lcdf_scaled_impl_< as_operation_cl_t< T > > std_normal_lcdf_scaled_impl(T &&a)
fvar< T > acos(const fvar< T > &x)
Definition acos.hpp:16
fvar< T > sin(const fvar< T > &x)
Definition sin.hpp:16
fvar< T > acosh(const fvar< T > &x)
Definition acosh.hpp:16
fvar< T > logit(const fvar< T > &x)
Definition logit.hpp:14
fvar< T > expm1(const fvar< T > &x)
Definition expm1.hpp:14
fvar< T > atanh(const fvar< T > &x)
Return inverse hyperbolic tangent of specified value.
Definition atanh.hpp:26
fvar< T > inv_square(const fvar< T > &x)
fvar< T > exp2(const fvar< T > &x)
Definition exp2.hpp:14
constexpr double log2()
Return natural logarithm of two.
Definition log2.hpp:17
fvar< T > log1m_exp(const fvar< T > &x)
Return the natural logarithm of one minus the exponentiation of the specified argument.
Definition log1m_exp.hpp:22
fvar< T > asinh(const fvar< T > &x)
Definition asinh.hpp:16
fvar< T > cosh(const fvar< T > &x)
Definition cosh.hpp:16
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
fvar< T > erf(const fvar< T > &x)
Definition erf.hpp:16
auto inv_logit(T &&x)
Returns the inverse logit function applied to the argument.
Definition inv_logit.hpp:20
fvar< T > Phi_approx(const fvar< T > &x)
Return an approximation of the unit normal cumulative distribution function (CDF).
fvar< T > log_inv_logit(const fvar< T > &x)
fvar< T > cbrt(const fvar< T > &x)
Return cube root of specified argument.
Definition cbrt.hpp:20
fvar< T > sinh(const fvar< T > &x)
Definition sinh.hpp:15
fvar< T > log1p_exp(const fvar< T > &x)
Definition log1p_exp.hpp:14
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:18
fvar< T > atan(const fvar< T > &x)
Definition atan.hpp:16
fvar< T > trigamma(const fvar< T > &u)
Return the value of the trigamma function at the specified argument (i.e., the second derivative of t...
Definition trigamma.hpp:25
fvar< T > tan(const fvar< T > &x)
Definition tan.hpp:16
fvar< T > Phi(const fvar< T > &x)
Definition Phi.hpp:16
fvar< T > erfc(const fvar< T > &x)
Definition erfc.hpp:16
fvar< T > log1p(const fvar< T > &x)
Definition log1p.hpp:12
fvar< T > inv_Phi(const fvar< T > &p)
Definition inv_Phi.hpp:16
fvar< T > floor(const fvar< T > &x)
Definition floor.hpp:13
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
Definition lgamma.hpp:21
static constexpr double log10()
Returns the natural logarithm of ten.
fvar< T > tanh(const fvar< T > &x)
Definition tanh.hpp:15
fvar< T > cos(const fvar< T > &x)
Definition cos.hpp:16
fvar< T > round(const fvar< T > &x)
Return the closest integer to the specified argument, with halfway cases rounded away from zero.
Definition round.hpp:24
fvar< T > tgamma(const fvar< T > &x)
Return the result of applying the gamma function to the specified argument.
Definition tgamma.hpp:21
fvar< T > asin(const fvar< T > &x)
Definition asin.hpp:16
fvar< T > ceil(const fvar< T > &x)
Definition ceil.hpp:13
fvar< T > log1m(const fvar< T > &x)
Definition log1m.hpp:12
fvar< T > log1m_inv_logit(const fvar< T > &x)
Return the natural logarithm of one minus the inverse logit of the specified argument.
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
Definition digamma.hpp:23
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
fvar< T > trunc(const fvar< T > &x)
Return the nearest integral value that is not larger in magnitude than the specified argument.
Definition trunc.hpp:20
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
Parts of an OpenCL kernel, generated by an expression.