Automatic Differentiation
 
Loading...
Searching...
No Matches
neg_binomial_2_log_glm_lpmf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNELS_NEG_BINOMIAL_2_LOG_GLM_LPMF_HPP
2#define STAN_MATH_OPENCL_KERNELS_NEG_BINOMIAL_2_LOG_GLM_LPMF_HPP
3#ifdef STAN_OPENCL
4
8
9namespace stan {
10namespace math {
11namespace opencl_kernels {
12
13// \cond
14static constexpr const char* neg_binomial_2_log_glm_kernel_code = STRINGIFY(
15 // \endcond
16
59 __global double* logp_global, __global double* theta_derivative_global,
60 __global double* theta_derivative_sum,
61 __global double* phi_derivative_global, const __global int* y_global,
62 const __global double* x, const __global double* alpha,
63 const __global double* beta, const __global double* phi_global,
64 const int N, const int M, const int is_y_vector,
65 const int is_alpha_vector, const int is_phi_vector,
66 const int need_theta_derivative, const int need_theta_derivative_sum,
67 const int need_phi_derivative, const int need_phi_derivative_sum,
68 const int need_logp1, const int need_logp2, const int need_logp3,
69 const int need_logp4) {
70 const int gid = get_global_id(0);
71 const int lid = get_local_id(0);
72 const int lsize = get_local_size(0);
73 const int wgid = get_group_id(0);
74
75 __local double res_loc[LOCAL_SIZE_];
76 double logp = 0;
77 double phi_derivative = 0;
78 double theta_derivative = 0;
79
80 // Most calculations only happen for relevant data within next if.
81 // Exceptions are reductions between threads that need barriers.
82 if (gid < N) {
83 double theta = 0;
84 for (int i = 0, j = 0; i < M; i++, j += N) {
85 theta += x[j + gid] * beta[i];
86 }
87 double phi = phi_global[gid * is_phi_vector];
88 double y = y_global[gid * is_y_vector];
89 if (!isfinite(theta) || y < 0 || !isfinite(phi)) {
90 logp = NAN;
91 }
92 theta += alpha[gid * is_alpha_vector];
93 double log_phi = log(phi);
94 double logsumexp_theta_logphi;
95 if (theta > log_phi) {
96 logsumexp_theta_logphi = theta + log1p_exp(log_phi - theta);
97 } else {
98 logsumexp_theta_logphi = log_phi + log1p_exp(theta - log_phi);
99 }
100 double y_plus_phi = y + phi;
101 if (need_logp1) {
102 logp -= lgamma(y + 1);
103 }
104 if (need_logp2) {
105 logp -= lgamma(phi);
106 if (phi != 0) {
107 logp += phi * log(phi);
108 }
109 }
110 logp -= y_plus_phi * logsumexp_theta_logphi;
111 if (need_logp3) {
112 logp += y * theta;
113 }
114 if (need_logp4) {
115 logp += lgamma(y_plus_phi);
116 }
117 double theta_exp = exp(theta);
118 theta_derivative = y - theta_exp * y_plus_phi / (theta_exp + phi);
119 if (need_theta_derivative) {
120 theta_derivative_global[gid] = theta_derivative;
121 }
122 if (need_phi_derivative) {
123 phi_derivative = 1 - y_plus_phi / (theta_exp + phi) + log_phi
124 - logsumexp_theta_logphi + digamma(y_plus_phi)
125 - digamma(phi);
126 if (!need_phi_derivative_sum) {
127 phi_derivative_global[gid] = phi_derivative;
128 }
129 }
130 }
131
132 // Sum logp, calculated by different threads.
133 // Since we can't sum between different work groups, we emit one number
134 // per work group. These must be summed on CPU for final result.
135 res_loc[lid] = logp;
136 barrier(CLK_LOCAL_MEM_FENCE);
137 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
138 step /= REDUCTION_STEP_SIZE) {
139 if (lid < step) {
140 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
141 res_loc[lid] += res_loc[lid + step * i];
142 }
143 }
144 barrier(CLK_LOCAL_MEM_FENCE);
145 }
146 if (lid == 0) {
147 logp_global[wgid] = res_loc[0];
148 }
149
150 if (need_theta_derivative_sum) {
151 // Sum theta_derivative, calculated by different threads.
152 barrier(CLK_LOCAL_MEM_FENCE);
153 res_loc[lid] = theta_derivative;
154 barrier(CLK_LOCAL_MEM_FENCE);
155 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
156 step /= REDUCTION_STEP_SIZE) {
157 if (lid < step) {
158 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
159 res_loc[lid] += res_loc[lid + step * i];
160 }
161 }
162 barrier(CLK_LOCAL_MEM_FENCE);
163 }
164 if (lid == 0) {
165 theta_derivative_sum[wgid] = res_loc[0];
166 }
167 }
168
169 if (need_phi_derivative_sum) {
170 // Sum phi_derivative, calculated by different threads.
171 barrier(CLK_LOCAL_MEM_FENCE);
172 res_loc[lid] = phi_derivative;
173 barrier(CLK_LOCAL_MEM_FENCE);
174 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
175 step /= REDUCTION_STEP_SIZE) {
176 if (lid < step) {
177 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
178 res_loc[lid] += res_loc[lid + step * i];
179 }
180 }
181 barrier(CLK_LOCAL_MEM_FENCE);
182 }
183 if (lid == 0) {
184 phi_derivative_global[wgid] = res_loc[0];
185 }
186 }
187 }
188 // \cond
189);
190// \endcond
191
196const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
197 in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int,
198 int, int, int, int, int, int, int, int, int>
199 neg_binomial_2_log_glm("neg_binomial_2_log_glm",
200 {digamma_device_function, log1p_exp_device_function,
201 neg_binomial_2_log_glm_kernel_code},
202 {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});
203
204} // namespace opencl_kernels
205} // namespace math
206} // namespace stan
207
208#endif
209#endif
isfinite_< as_operation_cl_t< T > > isfinite(T &&a)
const kernel_cl< out_buffer, out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int, int, int, int, int, int, int, int, int, int > neg_binomial_2_log_glm("neg_binomial_2_log_glm", {digamma_device_function, log1p_exp_device_function, neg_binomial_2_log_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
See the docs for neg_binomial_2_log_glm_lpmf() .
double log1p_exp(double a)
Calculates the log of 1 plus the exponential of the specified value without overflow.
Definition log1p_exp.hpp:28
double digamma(double x)
Calculates the digamma function - derivative of logarithm of gamma.
Definition digamma.hpp:25
double beta(double a, double b)
Return the beta function applied to the specified arguments.
Definition beta.hpp:25
T step(const T &y)
The step, or Heaviside, function.
Definition step.hpp:31
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
Definition lgamma.hpp:21
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9