1#ifndef STAN_MATH_OPENCL_KERNELS_NEG_BINOMIAL_2_LOG_GLM_LPMF_HPP
2#define STAN_MATH_OPENCL_KERNELS_NEG_BINOMIAL_2_LOG_GLM_LPMF_HPP
11namespace opencl_kernels {
14static constexpr const char* neg_binomial_2_log_glm_kernel_code =
STRINGIFY(
59 __global
double* logp_global, __global
double* theta_derivative_global,
60 __global
double* theta_derivative_sum,
61 __global
double* phi_derivative_global,
const __global
int* y_global,
62 const __global
double* x,
const __global
double* alpha,
63 const __global
double*
beta,
const __global
double* phi_global,
64 const int N,
const int M,
const int is_y_vector,
65 const int is_alpha_vector,
const int is_phi_vector,
66 const int need_theta_derivative,
const int need_theta_derivative_sum,
67 const int need_phi_derivative,
const int need_phi_derivative_sum,
68 const int need_logp1,
const int need_logp2,
const int need_logp3,
69 const int need_logp4) {
70 const int gid = get_global_id(0);
71 const int lid = get_local_id(0);
72 const int lsize = get_local_size(0);
73 const int wgid = get_group_id(0);
75 __local
double res_loc[LOCAL_SIZE_];
77 double phi_derivative = 0;
78 double theta_derivative = 0;
84 for (
int i = 0, j = 0; i < M; i++, j += N) {
85 theta += x[j + gid] *
beta[i];
87 double phi = phi_global[gid * is_phi_vector];
88 double y = y_global[gid * is_y_vector];
92 theta += alpha[gid * is_alpha_vector];
93 double log_phi =
log(phi);
94 double logsumexp_theta_logphi;
95 if (theta > log_phi) {
96 logsumexp_theta_logphi = theta +
log1p_exp(log_phi - theta);
98 logsumexp_theta_logphi = log_phi +
log1p_exp(theta - log_phi);
100 double y_plus_phi = y + phi;
107 logp += phi *
log(phi);
110 logp -= y_plus_phi * logsumexp_theta_logphi;
115 logp +=
lgamma(y_plus_phi);
117 double theta_exp =
exp(theta);
118 theta_derivative = y - theta_exp * y_plus_phi / (theta_exp + phi);
119 if (need_theta_derivative) {
120 theta_derivative_global[gid] = theta_derivative;
122 if (need_phi_derivative) {
123 phi_derivative = 1 - y_plus_phi / (theta_exp + phi) + log_phi
124 - logsumexp_theta_logphi +
digamma(y_plus_phi)
126 if (!need_phi_derivative_sum) {
127 phi_derivative_global[gid] = phi_derivative;
136 barrier(CLK_LOCAL_MEM_FENCE);
137 for (
int step = lsize / REDUCTION_STEP_SIZE;
step > 0;
138 step /= REDUCTION_STEP_SIZE) {
140 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
141 res_loc[lid] += res_loc[lid +
step * i];
144 barrier(CLK_LOCAL_MEM_FENCE);
147 logp_global[wgid] = res_loc[0];
150 if (need_theta_derivative_sum) {
152 barrier(CLK_LOCAL_MEM_FENCE);
153 res_loc[lid] = theta_derivative;
154 barrier(CLK_LOCAL_MEM_FENCE);
155 for (
int step = lsize / REDUCTION_STEP_SIZE;
step > 0;
156 step /= REDUCTION_STEP_SIZE) {
158 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
159 res_loc[lid] += res_loc[lid +
step * i];
162 barrier(CLK_LOCAL_MEM_FENCE);
165 theta_derivative_sum[wgid] = res_loc[0];
169 if (need_phi_derivative_sum) {
171 barrier(CLK_LOCAL_MEM_FENCE);
172 res_loc[lid] = phi_derivative;
173 barrier(CLK_LOCAL_MEM_FENCE);
174 for (
int step = lsize / REDUCTION_STEP_SIZE;
step > 0;
175 step /= REDUCTION_STEP_SIZE) {
177 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
178 res_loc[lid] += res_loc[lid +
step * i];
181 barrier(CLK_LOCAL_MEM_FENCE);
184 phi_derivative_global[wgid] = res_loc[0];
196const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
197 in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int,
198 int, int, int, int, int, int, int, int,
int>
200 {digamma_device_function, log1p_exp_device_function,
201 neg_binomial_2_log_glm_kernel_code},
202 {{
"REDUCTION_STEP_SIZE", 4}, {
"LOCAL_SIZE_", 64}});
isfinite_< as_operation_cl_t< T > > isfinite(T &&a)
const kernel_cl< out_buffer, out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int, int, int, int, int, int, int, int, int, int > neg_binomial_2_log_glm("neg_binomial_2_log_glm", {digamma_device_function, log1p_exp_device_function, neg_binomial_2_log_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
See the docs for neg_binomial_2_log_glm_lpmf() .
double log1p_exp(double a)
Calculates the log of 1 plus the exponential of the specified value without overflow.
double digamma(double x)
Calculates the digamma function - derivative of logarithm of gamma.
double beta(double a, double b)
Return the beta function applied to the specified arguments.
T step(const T &y)
The step, or Heaviside, function.
fvar< T > log(const fvar< T > &x)
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
fvar< T > exp(const fvar< T > &x)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...