Automatic Differentiation
 
Loading...
Searching...
No Matches
categorical_logit_glm_lpmf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNELS_CATEGORICAL_LOGIT_GLM_LPMF_HPP
2#define STAN_MATH_OPENCL_KERNELS_CATEGORICAL_LOGIT_GLM_LPMF_HPP
3#ifdef STAN_OPENCL
4
6#include <string>
7
8namespace stan {
9namespace math {
10namespace opencl_kernels {
11
12// \cond
13static constexpr const char* categorical_logit_glm_kernel_code = STRINGIFY(
14 // \endcond
42 __kernel void categorical_logit_glm(
43 __global double* logp_global, __global double* exp_lin_global,
44 __global double* inv_sum_exp_lin_global,
45 __global double* neg_softmax_lin_global,
46 __global double* alpha_derivative, const __global int* y_global,
47 const __global double* x_beta_global,
48 const __global double* alpha_global, const int N_instances,
49 const int N_attributes, const int N_classes, const int is_y_vector,
50 const int need_alpha_derivative,
51 const int need_neg_softmax_lin_global) {
52 const int gid = get_global_id(0);
53 const int lid = get_local_id(0);
54 const int lsize = get_local_size(0);
55 const int wg_id = get_group_id(0);
56 const int ngroups = get_num_groups(0);
57
58 __local double local_storage[LOCAL_SIZE_];
59
60 double logp = 0;
61 double inv_sum_exp_lin;
62 int class_idx = -1;
63 // Most calculations only happen for relevant data within next if.
64 // Exceptions are reductions between threads that need barriers.
65 if (gid < N_instances) {
66 double lin_max = -INFINITY;
67 for (int i = 0; i < N_classes; i++) {
68 double lin = x_beta_global[i * N_instances + gid] + alpha_global[i];
69 if (lin > lin_max) {
70 lin_max = lin;
71 }
72 }
73 double alpha = alpha_global[gid];
74 double sum_exp_lin = 0;
75 for (int i = 0; i < N_classes; i++) {
76 double lin = x_beta_global[i * N_instances + gid] + alpha_global[i];
77 double exp_lin = exp(lin - lin_max);
78 sum_exp_lin += exp_lin;
79 exp_lin_global[i * N_instances + gid] = exp_lin;
80 }
81 inv_sum_exp_lin = 1 / sum_exp_lin;
82 inv_sum_exp_lin_global[gid] = inv_sum_exp_lin;
83
84 class_idx = y_global[gid * is_y_vector] - 1;
85 if (class_idx < 0 || class_idx > N_classes) {
86 logp = NAN;
87 } else {
88 logp = log(inv_sum_exp_lin) - lin_max
89 + x_beta_global[class_idx * N_instances + gid]
90 + alpha_global[class_idx];
91 }
92 }
93 barrier(CLK_GLOBAL_MEM_FENCE);
94 double neg_softmax_lin_sum = 0;
95 if (need_alpha_derivative || need_neg_softmax_lin_global) {
96 for (int i = 0; i < N_classes; i++) {
97 double neg_softmax_lin = 0;
98 if (gid < N_instances) {
99 int idx = i * N_instances + gid;
100 neg_softmax_lin = -exp_lin_global[idx] * inv_sum_exp_lin;
101 if (need_neg_softmax_lin_global) {
102 neg_softmax_lin_global[idx] = neg_softmax_lin;
103 }
104 }
105 if (need_alpha_derivative) {
106 local_storage[lid] = neg_softmax_lin + (class_idx == i);
107 barrier(CLK_LOCAL_MEM_FENCE);
108 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
109 step /= REDUCTION_STEP_SIZE) {
110 if (lid < step) {
111 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
112 local_storage[lid] += local_storage[lid + step * i];
113 }
114 }
115 barrier(CLK_LOCAL_MEM_FENCE);
116 }
117 if (lid == 0) {
118 alpha_derivative[i + wg_id * N_classes] = local_storage[0];
119 }
120 barrier(CLK_LOCAL_MEM_FENCE);
121 }
122 }
123 }
124 // Sum logp, calculated by different threads.
125 // Since we can't sum between different work groups, we emit one number
126 // per work group. These must be summed on CPU for final result.
127 local_storage[lid] = logp;
128 barrier(CLK_LOCAL_MEM_FENCE);
129 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
130 step /= REDUCTION_STEP_SIZE) {
131 if (lid < step) {
132 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
133 local_storage[lid] += local_storage[lid + step * i];
134 }
135 }
136 barrier(CLK_LOCAL_MEM_FENCE);
137 }
138 if (lid == 0) {
139 logp_global[wg_id] = local_storage[0];
140 }
141 }
142 // \cond
143);
144// \endcond
145
150const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, out_buffer,
151 in_buffer, in_buffer, in_buffer, int, int, int, int, int, int>
152 categorical_logit_glm("categorical_logit_glm",
153 {categorical_logit_glm_kernel_code},
154 {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});
155
156// \cond
157static const std::string categorical_logit_glm_beta_derivative_kernel_code
158 = STRINGIFY(
159 // \endcond
175 __global double* beta_derivative, __global double* temp,
176 const __global int* y, const __global double* x,
177 const int N_instances, const int N_attributes, const int N_classes,
178 const int is_y_vector) {
179 const int gid = get_global_id(0);
180 const int lid = get_local_id(0);
181 const int lsize = get_local_size(0);
182 const int wg_id = get_group_id(0);
183
184 for (int i = 0; i < N_classes; i++) {
185 temp[gid * N_classes + i] = 0;
186 }
187 for (int i = lid; i < N_instances; i += lsize) {
188 int pos = y[i * is_y_vector] - 1;
189 temp[gid * N_classes + pos] += x[wg_id * N_instances + i];
190 }
191 barrier(CLK_GLOBAL_MEM_FENCE);
192 for (int i = lid; i < N_classes; i += lsize) {
193 double res = 0;
194 for (int j = 0; j < lsize; j++) {
195 res += temp[(wg_id * lsize + j) * N_classes + i];
196 }
197 beta_derivative[i * N_attributes + wg_id] += res;
198 }
199 }
200 // \cond
201 ); // NOLINT
202// \endcond
203
208const kernel_cl<in_out_buffer, in_out_buffer, in_buffer, in_buffer, int, int,
209 int, int>
211 "categorical_logit_glm_beta_derivative",
212 {categorical_logit_glm_beta_derivative_kernel_code});
213
214} // namespace opencl_kernels
215
216} // namespace math
217} // namespace stan
218#endif
219#endif
const kernel_cl< out_buffer, out_buffer, out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int, int, int > categorical_logit_glm("categorical_logit_glm", {categorical_logit_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
See the docs for categorical_logit_glm() .
const kernel_cl< in_out_buffer, in_out_buffer, in_buffer, in_buffer, int, int, int, int > categorical_logit_glm_beta_derivative("categorical_logit_glm_beta_derivative", {categorical_logit_glm_beta_derivative_kernel_code})
See the docs for categorical_logit_glm_beta_derivative() .
T step(const T &y)
The step, or Heaviside, function.
Definition step.hpp:31
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9