1#ifndef STAN_MATH_OPENCL_KERNELS_NEG_BINOMIAL_2_LOG_GLM_LPMF_HPP 
    2#define STAN_MATH_OPENCL_KERNELS_NEG_BINOMIAL_2_LOG_GLM_LPMF_HPP 
   11namespace opencl_kernels {
 
   14static constexpr const char* neg_binomial_2_log_glm_kernel_code = 
STRINGIFY(
 
   59        __global 
double* logp_global, __global 
double* theta_derivative_global,
 
   60        __global 
double* theta_derivative_sum,
 
   61        __global 
double* phi_derivative_global, 
const __global 
int* y_global,
 
   62        const __global 
double* x, 
const __global 
double* alpha,
 
   63        const __global 
double* 
beta, 
const __global 
double* phi_global,
 
   64        const int N, 
const int M, 
const int is_y_vector,
 
   65        const int is_alpha_vector, 
const int is_phi_vector,
 
   66        const int need_theta_derivative, 
const int need_theta_derivative_sum,
 
   67        const int need_phi_derivative, 
const int need_phi_derivative_sum,
 
   68        const int need_logp1, 
const int need_logp2, 
const int need_logp3,
 
   69        const int need_logp4) {
 
   70      const int gid = get_global_id(0);
 
   71      const int lid = get_local_id(0);
 
   72      const int lsize = get_local_size(0);
 
   73      const int wgid = get_group_id(0);
 
   75      __local 
double res_loc[LOCAL_SIZE_];
 
   77      double phi_derivative = 0;
 
   78      double theta_derivative = 0;
 
   84        for (
int i = 0, j = 0; i < M; i++, j += N) {
 
   85          theta += x[j + gid] * 
beta[i];
 
   87        double phi = phi_global[gid * is_phi_vector];
 
   88        double y = y_global[gid * is_y_vector];
 
   92        theta += alpha[gid * is_alpha_vector];
 
   93        double log_phi = 
log(phi);
 
   94        double logsumexp_theta_logphi;
 
   95        if (theta > log_phi) {
 
   96          logsumexp_theta_logphi = theta + 
log1p_exp(log_phi - theta);
 
   98          logsumexp_theta_logphi = log_phi + 
log1p_exp(theta - log_phi);
 
  100        double y_plus_phi = y + phi;
 
  107            logp += phi * 
log(phi);
 
  110        logp -= y_plus_phi * logsumexp_theta_logphi;
 
  115          logp += 
lgamma(y_plus_phi);
 
  117        double theta_exp = 
exp(theta);
 
  118        theta_derivative = y - theta_exp * y_plus_phi / (theta_exp + phi);
 
  119        if (need_theta_derivative) {
 
  120          theta_derivative_global[gid] = theta_derivative;
 
  122        if (need_phi_derivative) {
 
  123          phi_derivative = 1 - y_plus_phi / (theta_exp + phi) + log_phi
 
  124                           - logsumexp_theta_logphi + 
digamma(y_plus_phi)
 
  126          if (!need_phi_derivative_sum) {
 
  127            phi_derivative_global[gid] = phi_derivative;
 
  136      barrier(CLK_LOCAL_MEM_FENCE);
 
  137      for (
int step = lsize / REDUCTION_STEP_SIZE; 
step > 0;
 
  138           step /= REDUCTION_STEP_SIZE) {
 
  140          for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
 
  141            res_loc[lid] += res_loc[lid + 
step * i];
 
  144        barrier(CLK_LOCAL_MEM_FENCE);
 
  147        logp_global[wgid] = res_loc[0];
 
  150      if (need_theta_derivative_sum) {
 
  152        barrier(CLK_LOCAL_MEM_FENCE);
 
  153        res_loc[lid] = theta_derivative;
 
  154        barrier(CLK_LOCAL_MEM_FENCE);
 
  155        for (
int step = lsize / REDUCTION_STEP_SIZE; 
step > 0;
 
  156             step /= REDUCTION_STEP_SIZE) {
 
  158            for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
 
  159              res_loc[lid] += res_loc[lid + 
step * i];
 
  162          barrier(CLK_LOCAL_MEM_FENCE);
 
  165          theta_derivative_sum[wgid] = res_loc[0];
 
  169      if (need_phi_derivative_sum) {
 
  171        barrier(CLK_LOCAL_MEM_FENCE);
 
  172        res_loc[lid] = phi_derivative;
 
  173        barrier(CLK_LOCAL_MEM_FENCE);
 
  174        for (
int step = lsize / REDUCTION_STEP_SIZE; 
step > 0;
 
  175             step /= REDUCTION_STEP_SIZE) {
 
  177            for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
 
  178              res_loc[lid] += res_loc[lid + 
step * i];
 
  181          barrier(CLK_LOCAL_MEM_FENCE);
 
  184          phi_derivative_global[wgid] = res_loc[0];
 
  196const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
 
  197                in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int,
 
  198                int, int, int, int, int, int, int, int, 
int>
 
  200                           {digamma_device_function, log1p_exp_device_function,
 
  201                            neg_binomial_2_log_glm_kernel_code},
 
  202                           {{
"REDUCTION_STEP_SIZE", 4}, {
"LOCAL_SIZE_", 64}});
 
isfinite_< as_operation_cl_t< T > > isfinite(T &&a)
 
const kernel_cl< out_buffer, out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int, int, int, int, int, int, int, int, int, int > neg_binomial_2_log_glm("neg_binomial_2_log_glm", {digamma_device_function, log1p_exp_device_function, neg_binomial_2_log_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
See the docs for neg_binomial_2_log_glm_lpmf() .
 
double log1p_exp(double a)
Calculates the log of 1 plus the exponential of the specified value without overflow.
 
double digamma(double x)
Calculates the digamma function - derivative of logarithm of gamma.
 
double beta(double a, double b)
Return the beta function applied to the specified arguments.
 
T step(const T &y)
The step, or Heaviside, function.
 
fvar< T > log(const fvar< T > &x)
 
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
 
fvar< T > exp(const fvar< T > &x)
 
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...