1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ELT_FUNCTION_CL_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ELT_FUNCTION_CL_HPP
54template <
typename Derived,
typename Scal,
typename... T>
78 const std::string& row_index_name,
const std::string& col_index_name,
79 const bool view_handled,
80 std::conditional_t<false, T, const std::string&>... var_names_arg)
const {
86 std::array<std::string,
sizeof...(T)> var_names_arg_arr
87 = {(var_names_arg +
", ")...};
88 std::string var_names_list = std::accumulate(
89 var_names_arg_arr.begin(), var_names_arg_arr.end(), std::string());
90 res.body = type_str<Scalar>() +
" " +
var_name_ +
" = " +
fun_ +
"((double)"
91 + var_names_list.substr(0, var_names_list.size() - 2) +
");\n";
105#define ADD_BINARY_FUNCTION_WITH_INCLUDES(fun, ...) \
106 template <typename T1, typename T2> \
107 class fun##_ : public elt_function_cl<fun##_<T1, T2>, double, T1, T2> { \
108 using base = elt_function_cl<fun##_<T1, T2>, double, T1, T2>; \
109 using base::arguments_; \
114 static const std::vector<const char*> includes; \
115 explicit fun##_(T1&& a, T2&& b) \
116 : base(#fun, std::forward<T1>(a), std::forward<T2>(b)) { \
117 if (a.rows() != base::dynamic && b.rows() != base::dynamic) { \
118 check_size_match(#fun, "Rows of ", "a", a.rows(), "rows of ", "b", \
121 if (a.cols() != base::dynamic && b.cols() != base::dynamic) { \
122 check_size_match(#fun, "Columns of ", "a", a.cols(), "columns of ", \
126 inline auto deep_copy() const { \
127 auto&& arg1_copy = this->template get_arg<0>().deep_copy(); \
128 auto&& arg2_copy = this->template get_arg<1>().deep_copy(); \
129 return fun##_<std::remove_reference_t<decltype(arg1_copy)>, \
130 std::remove_reference_t<decltype(arg2_copy)>>{ \
131 std::move(arg1_copy), std::move(arg2_copy)}; \
133 inline std::pair<int, int> extreme_diagonals() const { \
134 return {-rows() + 1, cols() - 1}; \
138 template <typename T1, typename T2, \
139 require_all_kernel_expressions_t<T1, T2>* = nullptr, \
140 require_any_not_stan_scalar_t<T1, T2>* = nullptr> \
141 inline fun##_<as_operation_cl_t<T1>, as_operation_cl_t<T2>> fun(T1&& a, \
143 return fun##_<as_operation_cl_t<T1>, as_operation_cl_t<T2>>( \
144 as_operation_cl(std::forward<T1>(a)), \
145 as_operation_cl(std::forward<T2>(b))); \
147 template <typename T1, typename T2> \
148 const std::vector<const char*> fun##_<T1, T2>::includes{__VA_ARGS__};
156#define ADD_UNARY_FUNCTION_WITH_INCLUDES(fun, ...) \
157 template <typename T> \
158 class fun##_ : public elt_function_cl<fun##_<T>, double, T> { \
159 using base = elt_function_cl<fun##_<T>, double, T>; \
160 using base::arguments_; \
165 static const std::vector<const char*> includes; \
166 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
167 inline auto deep_copy() const { \
168 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
169 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
170 std::move(arg_copy)}; \
172 inline std::pair<int, int> extreme_diagonals() const { \
173 return {-rows() + 1, cols() - 1}; \
177 template <typename T, typename Cond \
178 = require_all_kernel_expressions_and_none_scalar_t<T>> \
179 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
180 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
182 template <typename T> \
183 const std::vector<const char*> fun##_<T>::includes{__VA_ARGS__};
190#define ADD_UNARY_FUNCTION(fun) ADD_UNARY_FUNCTION_WITH_INCLUDES(fun)
198#define ADD_UNARY_FUNCTION_PASS_ZERO(fun) \
199 template <typename T> \
200 class fun##_ : public elt_function_cl<fun##_<T>, double, T> { \
201 using base = elt_function_cl<fun##_<T>, double, T>; \
202 using base::arguments_; \
207 static constexpr auto view_transitivness = std::make_tuple(true); \
208 static const std::vector<const char*> includes; \
209 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
210 inline auto deep_copy() const { \
211 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
212 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
213 std::move(arg_copy)}; \
217 template <typename T, typename Cond \
218 = require_all_kernel_expressions_and_none_scalar_t<T>> \
219 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
220 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
222 template <typename T> \
223 const std::vector<const char*> fun##_<T>::includes{};
231#define ADD_CLASSIFICATION_FUNCTION(fun, ...) \
232 template <typename T> \
233 class fun##_ : public elt_function_cl<fun##_<T>, bool, T> { \
234 using base = elt_function_cl<fun##_<T>, bool, T>; \
235 using base::arguments_; \
240 static constexpr auto view_transitivness = std::make_tuple(true); \
241 static const std::vector<const char*> includes; \
242 explicit fun##_(T&& a) : base(#fun, std::forward<T>(a)) {} \
243 inline auto deep_copy() const { \
244 auto&& arg_copy = this->template get_arg<0>().deep_copy(); \
245 return fun##_<std::remove_reference_t<decltype(arg_copy)>>{ \
246 std::move(arg_copy)}; \
248 inline std::pair<int, int> extreme_diagonals() const { \
249 return __VA_ARGS__; \
253 template <typename T, typename Cond \
254 = require_all_kernel_expressions_and_none_scalar_t<T>> \
255 inline fun##_<as_operation_cl_t<T>> fun(T&& a) { \
256 return fun##_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a))); \
258 template <typename T> \
259 const std::vector<const char*> fun##_<T>::includes{};
299 opencl_kernels::digamma_device_function)
302 opencl_kernels::log1p_exp_device_function,
303 opencl_kernels::log_inv_logit_device_function)
305 opencl_kernels::log1m_exp_device_function)
307 opencl_kernels::log1p_exp_device_function)
309 opencl_kernels::inv_square_device_function)
311 opencl_kernels::inv_logit_device_function)
313 opencl_kernels::logit_device_function)
316 opencl_kernels::inv_logit_device_function,
317 opencl_kernels::phi_approx_device_function)
320 opencl_kernels::std_normal_lcdf_device_function)
323 opencl_kernels::std_normal_lcdf_device_function)
325 opencl_kernels::phi_device_function,
326 opencl_kernels::inv_phi_device_function)
329 opencl_kernels::log1m_inv_logit_device_function)
331 opencl_kernels::trigamma_device_function)
334 "\n
#ifndef STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_SQUARE\n"
335 "#define STAN_MATH_OPENCL_KERNELS_DEVICE_FUNCTIONS_SQUARE\n"
336 "double square(double x){return x*x;}\n"
339ADD_CLASSIFICATION_FUNCTION(isfinite, {-rows() + 1, cols() - 1})
340ADD_CLASSIFICATION_FUNCTION(isinf,
341 this->template get_arg<0>().extreme_diagonals())
342ADD_CLASSIFICATION_FUNCTION(isnan,
343 this->template get_arg<0>().extreme_diagonals())
345ADD_BINARY_FUNCTION_WITH_INCLUDES(fdim)
346ADD_BINARY_FUNCTION_WITH_INCLUDES(fmax)
347ADD_BINARY_FUNCTION_WITH_INCLUDES(fmin)
348ADD_BINARY_FUNCTION_WITH_INCLUDES(fmod)
349ADD_BINARY_FUNCTION_WITH_INCLUDES(hypot)
350ADD_BINARY_FUNCTION_WITH_INCLUDES(ldexp)
351ADD_BINARY_FUNCTION_WITH_INCLUDES(pow)
352ADD_BINARY_FUNCTION_WITH_INCLUDES(copysign)
354ADD_BINARY_FUNCTION_WITH_INCLUDES(
355 beta, stan::math::opencl_kernels::beta_device_function)
356ADD_BINARY_FUNCTION_WITH_INCLUDES(
357 binomial_coefficient_log,
358 stan::math::opencl_kernels::lgamma_stirling_device_function,
359 stan::math::opencl_kernels::lgamma_stirling_diff_device_function,
360 stan::math::opencl_kernels::lbeta_device_function,
361 stan::math::opencl_kernels::binomial_coefficient_log_device_function)
362template <typename T1, typename T2>
363class lbeta_ : public elt_function_cl<lbeta_<T1, T2>, double, T1, T2> {
364 using base = elt_function_cl<lbeta_<T1, T2>, double, T1, T2>;
365 using base::arguments_;
370 static const std::vector<const char*> includes;
371 explicit lbeta_(T1&& a, T2&& b)
372 : base("stan_lbeta", std::forward<T1>(a), std::forward<T2>(b)) {
373 if (a.rows() != base::dynamic && b.rows() != base::dynamic) {
374 check_size_match("lbeta", "Rows of ", "a", a.rows(), "rows of ", "b",
377 if (a.cols() != base::dynamic && b.cols() != base::dynamic) {
378 check_size_match("lbeta", "Columns of ", "a", a.cols(), "columns of ",
382 inline auto deep_copy() const {
383 auto&& arg1_copy = this->template get_arg<0>().deep_copy();
384 auto&& arg2_copy = this->template get_arg<1>().deep_copy();
385 return lbeta_<std::remove_reference_t<decltype(arg1_copy)>,
386 std::remove_reference_t<decltype(arg2_copy)>>{
387 std::move(arg1_copy), std::move(arg2_copy)};
389 inline std::pair<int, int> extreme_diagonals() const {
390 return {-rows() + 1, cols() - 1};
394template <typename T1, typename T2,
395 require_all_kernel_expressions_t<T1, T2>* = nullptr,
396 require_any_not_stan_scalar_t<T1, T2>* = nullptr>
397inline lbeta_<as_operation_cl_t<T1>, as_operation_cl_t<T2>> lbeta(T1&& a,
399 return lbeta_<as_operation_cl_t<T1>, as_operation_cl_t<T2>>(
400 as_operation_cl(std::forward<T1>(a)),
401 as_operation_cl(std::forward<T2>(b)));
404template <typename T1, typename T2>
405const std::vector<const char*> lbeta_<T1, T2>::includes{
406 stan::math::opencl_kernels::lgamma_stirling_device_function,
407 stan::math::opencl_kernels::lgamma_stirling_diff_device_function,
408 stan::math::opencl_kernels::lbeta_device_function};
409ADD_BINARY_FUNCTION_WITH_INCLUDES(
410 log_inv_logit_diff, opencl_kernels::log1p_exp_device_function,
411 opencl_kernels::log1m_exp_device_function,
412 opencl_kernels::log_inv_logit_diff_device_function)
413ADD_BINARY_FUNCTION_WITH_INCLUDES(log_diff_exp,
414 opencl_kernels::log1m_exp_device_function,
415 opencl_kernels::log_diff_exp_device_function)
416ADD_BINARY_FUNCTION_WITH_INCLUDES(
417 multiply_log, stan::math::opencl_kernels::multiply_log_device_function)
418ADD_BINARY_FUNCTION_WITH_INCLUDES(
419 lmultiply, stan::math::opencl_kernels::lmultiply_device_function)
421#undef ADD_BINARY_FUNCTION_WITH_INCLUDES
422#undef ADD_UNARY_FUNCTION_WITH_INCLUDES
423#undef ADD_UNARY_FUNCTION
424#undef ADD_UNARY_FUNCTION_PASS_ZERO
425#undef ADD_CLASSIFICATION_FUNCTION
elt_function_cl(const std::string &fun, T &&... args)
Constructor.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, std::conditional_t< false, T, const std::string & >... var_names_arg) const
Generates kernel code for this expression.
Represents an element-wise function in kernel generator expressions.
Derived & derived()
Casts the instance into its derived type.
Base for all kernel generator operations.
rsqrt_< as_operation_cl_t< T > > rsqrt(T &&a)
#define ADD_UNARY_FUNCTION(fun)
Generates a class and function for a general unary function that is defined by OpenCL.
#define ADD_UNARY_FUNCTION_PASS_ZERO(fun)
Generates a class and function for an unary function, defined by OpenCL with special property that it...
#define ADD_UNARY_FUNCTION_WITH_INCLUDES(fun,...)
Generates a class and function for a general unary function that is defined by OpenCL or in the inclu...
std_normal_lcdf_dscaled_impl_< as_operation_cl_t< T > > std_normal_lcdf_dscaled_impl(T &&a)
std_normal_lcdf_scaled_impl_< as_operation_cl_t< T > > std_normal_lcdf_scaled_impl(T &&a)
fvar< T > acos(const fvar< T > &x)
fvar< T > sin(const fvar< T > &x)
fvar< T > acosh(const fvar< T > &x)
fvar< T > logit(const fvar< T > &x)
fvar< T > expm1(const fvar< T > &x)
fvar< T > atanh(const fvar< T > &x)
Return inverse hyperbolic tangent of specified value.
fvar< T > inv_square(const fvar< T > &x)
fvar< T > exp2(const fvar< T > &x)
constexpr double log2()
Return natural logarithm of two.
fvar< T > log1m_exp(const fvar< T > &x)
Return the natural logarithm of one minus the exponentiation of the specified argument.
fvar< T > asinh(const fvar< T > &x)
fvar< T > cosh(const fvar< T > &x)
fvar< T > log(const fvar< T > &x)
fvar< T > erf(const fvar< T > &x)
auto inv_logit(T &&x)
Returns the inverse logit function applied to the argument.
fvar< T > Phi_approx(const fvar< T > &x)
Return an approximation of the unit normal cumulative distribution function (CDF).
fvar< T > log_inv_logit(const fvar< T > &x)
fvar< T > cbrt(const fvar< T > &x)
Return cube root of specified argument.
fvar< T > sinh(const fvar< T > &x)
fvar< T > log1p_exp(const fvar< T > &x)
fvar< T > sqrt(const fvar< T > &x)
fvar< T > atan(const fvar< T > &x)
fvar< T > trigamma(const fvar< T > &u)
Return the value of the trigamma function at the specified argument (i.e., the second derivative of t...
fvar< T > tan(const fvar< T > &x)
fvar< T > Phi(const fvar< T > &x)
fvar< T > erfc(const fvar< T > &x)
fvar< T > log1p(const fvar< T > &x)
fvar< T > inv_Phi(const fvar< T > &p)
fvar< T > floor(const fvar< T > &x)
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
static constexpr double log10()
Returns the natural logarithm of ten.
fvar< T > tanh(const fvar< T > &x)
fvar< T > cos(const fvar< T > &x)
fvar< T > round(const fvar< T > &x)
Return the closest integer to the specified argument, with halfway cases rounded away from zero.
fvar< T > tgamma(const fvar< T > &x)
Return the result of applying the gamma function to the specified argument.
fvar< T > asin(const fvar< T > &x)
fvar< T > ceil(const fvar< T > &x)
fvar< T > log1m(const fvar< T > &x)
fvar< T > log1m_inv_logit(const fvar< T > &x)
Return the natural logarithm of one minus the inverse logit of the specified argument.
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
fvar< T > square(const fvar< T > &x)
fvar< T > trunc(const fvar< T > &x)
Return the nearest integral value that is not larger in magnitude than the specified argument.
fvar< T > exp(const fvar< T > &x)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.