1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ELT_FUNCTION_CL_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ELT_FUNCTION_CL_HPP
53template <
typename Derived,
typename Scal,
typename... T>
77 const std::string& row_index_name,
const std::string& col_index_name,
78 const bool view_handled,
79 std::conditional_t<false, T, const std::string&>... var_names_arg)
const {
85 std::array<std::string,
sizeof...(T)> var_names_arg_arr
86 = {(var_names_arg +
", ")...};
87 std::string var_names_list = std::accumulate(
88 var_names_arg_arr.begin(), var_names_arg_arr.end(), std::string());
89 res.body = type_str<Scalar>() +
" " +
var_name_ +
" = " +
fun_ +
"((double)"
90 + var_names_list.substr(0, var_names_list.size() - 2) +
");\n";
104#define ADD_BINARY_FUNCTION_WITH_INCLUDES(fun, ...) \
105 template <typename T1, typename T2> \
106 class fun##_ : public elt_function_cl<fun##_<T1, T2>, double, T1, T2> { \
107 using base = elt_function_cl<fun##_<T1, T2>, double, T1, T2>; \
108 using base::arguments_; \
113 static const std::vector<const char*> includes; \
114 explicit fun##_(T1&& a, T2&& b) \
115 : base(#fun, std::forward<T1>(a), std::forward<T2>(b)) { \
116 if (a.rows() != base::dynamic && b.rows() != base::dynamic) { \
117 check_size_match(#fun, "Rows of ", "a", a.rows(), "rows of ", "b", \
120 if (a.cols() != base::dynamic && b.cols() != base::dynamic) { \
121 check_size_match(#fun, "Columns of ", "a", a.cols(), "columns of ", \
125 inline auto deep_copy() const { \
126 auto&& arg1_copy = this->template get_arg<0>().deep_copy(); \
127 auto&& arg2_copy = this->template get_arg<1>().deep_copy(); \
128 return fun##_<std::remove_reference_t<decltype(arg1_copy)>, \
129 std::remove_reference_t<decltype(arg2_copy)>>{ \
130 std::move(arg1_copy), std::move(arg2_copy)}; \
132 inline std::pair<int, int> extreme_diagonals() const { \
133 return {-rows() + 1, cols() - 1}; \
137 template <typename T1, typename T2, \
138 require_all_kernel_expressions_t<T1, T2>* = nullptr, \
139 require_any_not_stan_scalar_t<T1, T2>* = nullptr> \
140 inline fun##_<as_operation_cl_t<T1>, as_operation_cl_t<T2>> fun(T1&& a, \
142 return fun##_<as_operation_cl_t<T1>, as_operation_cl_t<T2>>( \
143 as_operation_cl(std::forward<T1>(a)), \
144 as_operation_cl(std::forward<T2>(b))); \
146 template <typename T1, typename T2> \
147 const std::vector<const char*> fun##_<T1, T2>::includes{__VA_ARGS__};
155#define ADD_UNARY_FUNCTION_WITH_INCLUDES(fun, ...) \
156 template <typename T> \
157 class fun##_ : public elt_function_cl<fun##_<T>, double, T> { \
158 using base = elt_function_cl<fun##_<T>, double, T>; \
159 using base::arguments_; \
164 static const std::vector<const char*> includes; \
165 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
166 inline auto deep_copy() const { \
167 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
168 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
169 std::move(arg_copy)}; \
171 inline std::pair<int, int> extreme_diagonals() const { \
172 return {-rows() + 1, cols() - 1}; \
176 template <typename T, typename Cond \
177 = require_all_kernel_expressions_and_none_scalar_t<T>> \
178 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
179 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
181 template <typename T> \
182 const std::vector<const char*> fun##_<T>::includes{__VA_ARGS__};
189#define ADD_UNARY_FUNCTION(fun) ADD_UNARY_FUNCTION_WITH_INCLUDES(fun)
197#define ADD_UNARY_FUNCTION_PASS_ZERO(fun) \
198 template <typename T> \
199 class fun##_ : public elt_function_cl<fun##_<T>, double, T> { \
200 using base = elt_function_cl<fun##_<T>, double, T>; \
201 using base::arguments_; \
206 static constexpr auto view_transitivness = std::make_tuple(true); \
207 static const std::vector<const char*> includes; \
208 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
209 inline auto deep_copy() const { \
210 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
211 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
212 std::move(arg_copy)}; \
216 template <typename T, typename Cond \
217 = require_all_kernel_expressions_and_none_scalar_t<T>> \
218 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
219 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
221 template <typename T> \
222 const std::vector<const char*> fun##_<T>::includes{};
230#define ADD_CLASSIFICATION_FUNCTION(fun, ...) \
231 template <typename T> \
232 class fun##_ : public elt_function_cl<fun##_<T>, bool, T> { \
233 using base = elt_function_cl<fun##_<T>, bool, T>; \
234 using base::arguments_; \
239 static constexpr auto view_transitivness = std::make_tuple(true); \
240 static const std::vector<const char*> includes; \
241 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
242 inline auto deep_copy() const { \
243 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
244 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
245 std::move(arg_copy)}; \
247 inline std::pair<int, int> extreme_diagonals() const { \
248 return __VA_ARGS__; \
252 template <typename T, typename Cond \
253 = require_all_kernel_expressions_and_none_scalar_t<T>> \
254 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
255 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
257 template <typename T> \
258 const std::vector<const char*> fun##_<T>::includes{};
298 opencl_kernels::digamma_device_function)
301 opencl_kernels::log1p_exp_device_function,
302 opencl_kernels::log_inv_logit_device_function)
304 opencl_kernels::log1m_exp_device_function)
306 opencl_kernels::log1p_exp_device_function)
308 opencl_kernels::inv_square_device_function)
310 opencl_kernels::inv_logit_device_function)
312 opencl_kernels::logit_device_function)
315 opencl_kernels::inv_logit_device_function,
316 opencl_kernels::phi_approx_device_function)
318 opencl_kernels::phi_device_function,
319 opencl_kernels::inv_phi_device_function)
322 opencl_kernels::log1m_inv_logit_device_function)
324 opencl_kernels::trigamma_device_function)
327 "\n
#ifndef STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_SQUARE\n"
328 "#define STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_SQUARE\n"
329 "double square(double x){return x*x;}\n"
332ADD_CLASSIFICATION_FUNCTION(isfinite, {-rows() + 1, cols() - 1})
333ADD_CLASSIFICATION_FUNCTION(isinf,
334 this->template get_arg<0>().extreme_diagonals())
335ADD_CLASSIFICATION_FUNCTION(isnan,
336 this->template get_arg<0>().extreme_diagonals())
338ADD_BINARY_FUNCTION_WITH_INCLUDES(fdim)
339ADD_BINARY_FUNCTION_WITH_INCLUDES(fmax)
340ADD_BINARY_FUNCTION_WITH_INCLUDES(fmin)
341ADD_BINARY_FUNCTION_WITH_INCLUDES(fmod)
342ADD_BINARY_FUNCTION_WITH_INCLUDES(hypot)
343ADD_BINARY_FUNCTION_WITH_INCLUDES(ldexp)
344ADD_BINARY_FUNCTION_WITH_INCLUDES(pow)
345ADD_BINARY_FUNCTION_WITH_INCLUDES(copysign)
347ADD_BINARY_FUNCTION_WITH_INCLUDES(
348 beta, stan::math::opencl_kernels::beta_device_function)
349ADD_BINARY_FUNCTION_WITH_INCLUDES(
350 binomial_coefficient_log,
351 stan::math::opencl_kernels::lgamma_stirling_device_function,
352 stan::math::opencl_kernels::lgamma_stirling_diff_device_function,
353 stan::math::opencl_kernels::lbeta_device_function,
354 stan::math::opencl_kernels::binomial_coefficient_log_device_function)
355ADD_BINARY_FUNCTION_WITH_INCLUDES(
356 lbeta, stan::math::opencl_kernels::lgamma_stirling_device_function,
357 stan::math::opencl_kernels::lgamma_stirling_diff_device_function,
358 stan::math::opencl_kernels::lbeta_device_function)
359ADD_BINARY_FUNCTION_WITH_INCLUDES(
360 log_inv_logit_diff, opencl_kernels::log1p_exp_device_function,
361 opencl_kernels::log1m_exp_device_function,
362 opencl_kernels::log_inv_logit_diff_device_function)
363ADD_BINARY_FUNCTION_WITH_INCLUDES(log_diff_exp,
364 opencl_kernels::log1m_exp_device_function,
365 opencl_kernels::log_diff_exp_device_function)
366ADD_BINARY_FUNCTION_WITH_INCLUDES(
367 multiply_log, stan::math::opencl_kernels::multiply_log_device_function)
368ADD_BINARY_FUNCTION_WITH_INCLUDES(
369 lmultiply, stan::math::opencl_kernels::lmultiply_device_function)
371#undef ADD_BINARY_FUNCTION_WITH_INCLUDES
372#undef ADD_UNARY_FUNCTION_WITH_INCLUDES
373#undef ADD_UNARY_FUNCTION
374#undef ADD_UNARY_FUNCTION_PASS_ZERO
375#undef ADD_CLASSIFICATION_FUNCTION
elt_function_cl(const std::string &fun, T &&... args)
Constructor.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, std::conditional_t< false, T, const std::string & >... var_names_arg) const
Generates kernel code for this expression.
Represents an element-wise function in kernel generator expressions.
Derived & derived()
Casts the instance into its derived type.
Base for all kernel generator operations.
rsqrt_< as_operation_cl_t< T > > rsqrt(T &&a)
#define ADD_UNARY_FUNCTION(fun)
Generates a class and function for a general unary function that is defined by OpenCL.
#define ADD_UNARY_FUNCTION_PASS_ZERO(fun)
Generates a class and function for an unary function, defined by OpenCL with special property that it...
#define ADD_UNARY_FUNCTION_WITH_INCLUDES(fun,...)
Generates a class and function for a general unary function that is defined by OpenCL or in the inclu...
fvar< T > acos(const fvar< T > &x)
fvar< T > sin(const fvar< T > &x)
fvar< T > acosh(const fvar< T > &x)
fvar< T > logit(const fvar< T > &x)
fvar< T > expm1(const fvar< T > &x)
fvar< T > atanh(const fvar< T > &x)
Return inverse hyperbolic tangent of specified value.
fvar< T > inv_square(const fvar< T > &x)
fvar< T > exp2(const fvar< T > &x)
constexpr double log2()
Return natural logarithm of two.
fvar< T > log1m_exp(const fvar< T > &x)
Return the natural logarithm of one minus the exponentiation of the specified argument.
fvar< T > asinh(const fvar< T > &x)
fvar< T > cosh(const fvar< T > &x)
fvar< T > log(const fvar< T > &x)
fvar< T > erf(const fvar< T > &x)
fvar< T > Phi_approx(const fvar< T > &x)
Return an approximation of the unit normal cumulative distribution function (CDF).
fvar< T > log_inv_logit(const fvar< T > &x)
fvar< T > cbrt(const fvar< T > &x)
Return cube root of specified argument.
fvar< T > sinh(const fvar< T > &x)
fvar< T > log1p_exp(const fvar< T > &x)
fvar< T > sqrt(const fvar< T > &x)
fvar< T > atan(const fvar< T > &x)
fvar< T > trigamma(const fvar< T > &u)
Return the value of the trigamma function at the specified argument (i.e., the second derivative of t...
fvar< T > tan(const fvar< T > &x)
fvar< T > Phi(const fvar< T > &x)
fvar< T > erfc(const fvar< T > &x)
fvar< T > log1p(const fvar< T > &x)
fvar< T > inv_Phi(const fvar< T > &p)
fvar< T > floor(const fvar< T > &x)
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
static constexpr double log10()
Returns the natural logarithm of ten.
fvar< T > tanh(const fvar< T > &x)
fvar< T > cos(const fvar< T > &x)
fvar< T > round(const fvar< T > &x)
Return the closest integer to the specified argument, with halfway cases rounded away from zero.
fvar< T > tgamma(const fvar< T > &x)
Return the result of applying the gamma function to the specified argument.
fvar< T > asin(const fvar< T > &x)
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
fvar< T > ceil(const fvar< T > &x)
fvar< T > log1m(const fvar< T > &x)
fvar< T > log1m_inv_logit(const fvar< T > &x)
Return the natural logarithm of one minus the inverse logit of the specified argument.
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
fvar< T > square(const fvar< T > &x)
fvar< T > trunc(const fvar< T > &x)
Return the nearest integral value that is not larger in magnitude than the specified argument.
fvar< T > exp(const fvar< T > &x)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.