1#ifndef STAN_MATH_OPENCL_KERNELS_CATEGORICAL_LOGIT_GLM_LPMF_HPP
2#define STAN_MATH_OPENCL_KERNELS_CATEGORICAL_LOGIT_GLM_LPMF_HPP
10namespace opencl_kernels {
13static constexpr const char* categorical_logit_glm_kernel_code =
STRINGIFY(
43 __global
double* logp_global, __global
double* exp_lin_global,
44 __global
double* inv_sum_exp_lin_global,
45 __global
double* neg_softmax_lin_global,
46 __global
double* alpha_derivative,
const __global
int* y_global,
47 const __global
double* x_beta_global,
48 const __global
double* alpha_global,
const int N_instances,
49 const int N_attributes,
const int N_classes,
const int is_y_vector,
50 const int need_alpha_derivative,
51 const int need_neg_softmax_lin_global) {
52 const int gid = get_global_id(0);
53 const int lid = get_local_id(0);
54 const int lsize = get_local_size(0);
55 const int wg_id = get_group_id(0);
56 const int ngroups = get_num_groups(0);
58 __local
double local_storage[LOCAL_SIZE_];
61 double inv_sum_exp_lin;
65 if (gid < N_instances) {
66 double lin_max = -INFINITY;
67 for (
int i = 0; i < N_classes; i++) {
68 double lin = x_beta_global[i * N_instances + gid] + alpha_global[i];
73 double alpha = alpha_global[gid];
74 double sum_exp_lin = 0;
75 for (
int i = 0; i < N_classes; i++) {
76 double lin = x_beta_global[i * N_instances + gid] + alpha_global[i];
77 double exp_lin =
exp(lin - lin_max);
78 sum_exp_lin += exp_lin;
79 exp_lin_global[i * N_instances + gid] = exp_lin;
81 inv_sum_exp_lin = 1 / sum_exp_lin;
82 inv_sum_exp_lin_global[gid] = inv_sum_exp_lin;
84 class_idx = y_global[gid * is_y_vector] - 1;
85 if (class_idx < 0 || class_idx > N_classes) {
88 logp =
log(inv_sum_exp_lin) - lin_max
89 + x_beta_global[class_idx * N_instances + gid]
90 + alpha_global[class_idx];
93 barrier(CLK_GLOBAL_MEM_FENCE);
94 double neg_softmax_lin_sum = 0;
95 if (need_alpha_derivative || need_neg_softmax_lin_global) {
96 for (
int i = 0; i < N_classes; i++) {
97 double neg_softmax_lin = 0;
98 if (gid < N_instances) {
99 int idx = i * N_instances + gid;
100 neg_softmax_lin = -exp_lin_global[idx] * inv_sum_exp_lin;
101 if (need_neg_softmax_lin_global) {
102 neg_softmax_lin_global[idx] = neg_softmax_lin;
105 if (need_alpha_derivative) {
106 local_storage[lid] = neg_softmax_lin + (class_idx == i);
107 barrier(CLK_LOCAL_MEM_FENCE);
108 for (
int step = lsize / REDUCTION_STEP_SIZE;
step > 0;
109 step /= REDUCTION_STEP_SIZE) {
111 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
112 local_storage[lid] += local_storage[lid +
step * i];
115 barrier(CLK_LOCAL_MEM_FENCE);
118 alpha_derivative[i + wg_id * N_classes] = local_storage[0];
120 barrier(CLK_LOCAL_MEM_FENCE);
127 local_storage[lid] = logp;
128 barrier(CLK_LOCAL_MEM_FENCE);
129 for (
int step = lsize / REDUCTION_STEP_SIZE;
step > 0;
130 step /= REDUCTION_STEP_SIZE) {
132 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
133 local_storage[lid] += local_storage[lid +
step * i];
136 barrier(CLK_LOCAL_MEM_FENCE);
139 logp_global[wg_id] = local_storage[0];
150const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, out_buffer,
151 in_buffer, in_buffer, in_buffer, int, int, int, int, int,
int>
153 {categorical_logit_glm_kernel_code},
154 {{
"REDUCTION_STEP_SIZE", 4}, {
"LOCAL_SIZE_", 64}});
157static const std::string categorical_logit_glm_beta_derivative_kernel_code
175 __global
double* beta_derivative, __global
double* temp,
176 const __global
int* y,
const __global
double* x,
177 const int N_instances,
const int N_attributes,
const int N_classes,
178 const int is_y_vector) {
179 const int gid = get_global_id(0);
180 const int lid = get_local_id(0);
181 const int lsize = get_local_size(0);
182 const int wg_id = get_group_id(0);
184 for (
int i = 0; i < N_classes; i++) {
185 temp[gid * N_classes + i] = 0;
187 for (
int i = lid; i < N_instances; i += lsize) {
188 int pos = y[i * is_y_vector] - 1;
189 temp[gid * N_classes + pos] += x[wg_id * N_instances + i];
191 barrier(CLK_GLOBAL_MEM_FENCE);
192 for (
int i = lid; i < N_classes; i += lsize) {
194 for (
int j = 0; j < lsize; j++) {
195 res += temp[(wg_id * lsize + j) * N_classes + i];
197 beta_derivative[i * N_attributes + wg_id] += res;
208const kernel_cl<in_out_buffer, in_out_buffer, in_buffer, in_buffer, int, int,
211 "categorical_logit_glm_beta_derivative",
212 {categorical_logit_glm_beta_derivative_kernel_code});
const kernel_cl< out_buffer, out_buffer, out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int, int, int > categorical_logit_glm("categorical_logit_glm", {categorical_logit_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
See the docs for categorical_logit_glm() .
const kernel_cl< in_out_buffer, in_out_buffer, in_buffer, in_buffer, int, int, int, int > categorical_logit_glm_beta_derivative("categorical_logit_glm_beta_derivative", {categorical_logit_glm_beta_derivative_kernel_code})
See the docs for categorical_logit_glm_beta_derivative() .
T step(const T &y)
The step, or Heaviside, function.
fvar< T > log(const fvar< T > &x)
fvar< T > exp(const fvar< T > &x)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...