1#ifndef STAN_MATH_OPENCL_KERNELS_CATEGORICAL_LOGIT_GLM_LPMF_HPP 
    2#define STAN_MATH_OPENCL_KERNELS_CATEGORICAL_LOGIT_GLM_LPMF_HPP 
   10namespace opencl_kernels {
 
   13static constexpr const char* categorical_logit_glm_kernel_code = 
STRINGIFY(
 
   43        __global 
double* logp_global, __global 
double* exp_lin_global,
 
   44        __global 
double* inv_sum_exp_lin_global,
 
   45        __global 
double* neg_softmax_lin_global,
 
   46        __global 
double* alpha_derivative, 
const __global 
int* y_global,
 
   47        const __global 
double* x_beta_global,
 
   48        const __global 
double* alpha_global, 
const int N_instances,
 
   49        const int N_attributes, 
const int N_classes, 
const int is_y_vector,
 
   50        const int need_alpha_derivative,
 
   51        const int need_neg_softmax_lin_global) {
 
   52      const int gid = get_global_id(0);
 
   53      const int lid = get_local_id(0);
 
   54      const int lsize = get_local_size(0);
 
   55      const int wg_id = get_group_id(0);
 
   56      const int ngroups = get_num_groups(0);
 
   58      __local 
double local_storage[LOCAL_SIZE_];
 
   61      double inv_sum_exp_lin;
 
   65      if (gid < N_instances) {
 
   66        double lin_max = -INFINITY;
 
   67        for (
int i = 0; i < N_classes; i++) {
 
   68          double lin = x_beta_global[i * N_instances + gid] + alpha_global[i];
 
   73        double alpha = alpha_global[gid];
 
   74        double sum_exp_lin = 0;
 
   75        for (
int i = 0; i < N_classes; i++) {
 
   76          double lin = x_beta_global[i * N_instances + gid] + alpha_global[i];
 
   77          double exp_lin = 
exp(lin - lin_max);
 
   78          sum_exp_lin += exp_lin;
 
   79          exp_lin_global[i * N_instances + gid] = exp_lin;
 
   81        inv_sum_exp_lin = 1 / sum_exp_lin;
 
   82        inv_sum_exp_lin_global[gid] = inv_sum_exp_lin;
 
   84        class_idx = y_global[gid * is_y_vector] - 1;
 
   85        if (class_idx < 0 || class_idx > N_classes) {
 
   88          logp = 
log(inv_sum_exp_lin) - lin_max
 
   89                 + x_beta_global[class_idx * N_instances + gid]
 
   90                 + alpha_global[class_idx];
 
   93      barrier(CLK_GLOBAL_MEM_FENCE);
 
   94      double neg_softmax_lin_sum = 0;
 
   95      if (need_alpha_derivative || need_neg_softmax_lin_global) {
 
   96        for (
int i = 0; i < N_classes; i++) {
 
   97          double neg_softmax_lin = 0;
 
   98          if (gid < N_instances) {
 
   99            int idx = i * N_instances + gid;
 
  100            neg_softmax_lin = -exp_lin_global[idx] * inv_sum_exp_lin;
 
  101            if (need_neg_softmax_lin_global) {
 
  102              neg_softmax_lin_global[idx] = neg_softmax_lin;
 
  105          if (need_alpha_derivative) {
 
  106            local_storage[lid] = neg_softmax_lin + (class_idx == i);
 
  107            barrier(CLK_LOCAL_MEM_FENCE);
 
  108            for (
int step = lsize / REDUCTION_STEP_SIZE; 
step > 0;
 
  109                 step /= REDUCTION_STEP_SIZE) {
 
  111                for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
 
  112                  local_storage[lid] += local_storage[lid + 
step * i];
 
  115              barrier(CLK_LOCAL_MEM_FENCE);
 
  118              alpha_derivative[i + wg_id * N_classes] = local_storage[0];
 
  120            barrier(CLK_LOCAL_MEM_FENCE);
 
  127      local_storage[lid] = logp;
 
  128      barrier(CLK_LOCAL_MEM_FENCE);
 
  129      for (
int step = lsize / REDUCTION_STEP_SIZE; 
step > 0;
 
  130           step /= REDUCTION_STEP_SIZE) {
 
  132          for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
 
  133            local_storage[lid] += local_storage[lid + 
step * i];
 
  136        barrier(CLK_LOCAL_MEM_FENCE);
 
  139        logp_global[wg_id] = local_storage[0];
 
  150const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, out_buffer,
 
  151                in_buffer, in_buffer, in_buffer, int, int, int, int, int, 
int>
 
  153                          {categorical_logit_glm_kernel_code},
 
  154                          {{
"REDUCTION_STEP_SIZE", 4}, {
"LOCAL_SIZE_", 64}});
 
  157static const std::string categorical_logit_glm_beta_derivative_kernel_code
 
  175            __global 
double* beta_derivative, __global 
double* temp,
 
  176            const __global 
int* y, 
const __global 
double* x,
 
  177            const int N_instances, 
const int N_attributes, 
const int N_classes,
 
  178            const int is_y_vector) {
 
  179          const int gid = get_global_id(0);
 
  180          const int lid = get_local_id(0);
 
  181          const int lsize = get_local_size(0);
 
  182          const int wg_id = get_group_id(0);
 
  184          for (
int i = 0; i < N_classes; i++) {
 
  185            temp[gid * N_classes + i] = 0;
 
  187          for (
int i = lid; i < N_instances; i += lsize) {
 
  188            int pos = y[i * is_y_vector] - 1;
 
  189            temp[gid * N_classes + pos] += x[wg_id * N_instances + i];
 
  191          barrier(CLK_GLOBAL_MEM_FENCE);
 
  192          for (
int i = lid; i < N_classes; i += lsize) {
 
  194            for (
int j = 0; j < lsize; j++) {
 
  195              res += temp[(wg_id * lsize + j) * N_classes + i];
 
  197            beta_derivative[i * N_attributes + wg_id] += res;
 
  208const kernel_cl<in_out_buffer, in_out_buffer, in_buffer, in_buffer, int, int,
 
  211        "categorical_logit_glm_beta_derivative",
 
  212        {categorical_logit_glm_beta_derivative_kernel_code});
 
const kernel_cl< out_buffer, out_buffer, out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int, int, int > categorical_logit_glm("categorical_logit_glm", {categorical_logit_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
See the docs for categorical_logit_glm() .
 
const kernel_cl< in_out_buffer, in_out_buffer, in_buffer, in_buffer, int, int, int, int > categorical_logit_glm_beta_derivative("categorical_logit_glm_beta_derivative", {categorical_logit_glm_beta_derivative_kernel_code})
See the docs for categorical_logit_glm_beta_derivative() .
 
T step(const T &y)
The step, or Heaviside, function.
 
fvar< T > log(const fvar< T > &x)
 
fvar< T > exp(const fvar< T > &x)
 
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...