Automatic Differentiation
 
Loading...
Searching...
No Matches
ordered_logistic_glm_lpmf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNELS_ORDERED_LOGISTIC_GLM_LPMF_HPP
2#define STAN_MATH_OPENCL_KERNELS_ORDERED_LOGISTIC_GLM_LPMF_HPP
3#ifdef STAN_OPENCL
4
9
10namespace stan {
11namespace math {
12namespace opencl_kernels {
13
14// \cond
15static constexpr const char* ordered_logistic_glm_kernel_code = STRINGIFY(
16 // \endcond
43 __kernel void ordered_logistic_glm(
44 __global double* location_sum, __global double* logp_global,
45 __global double* location_derivative, __global double* cuts_derivative,
46 const __global int* y_global, const __global double* x,
47 const __global double* beta, const __global double* cuts,
48 const int N_instances, const int N_attributes, const int N_classes,
49 const int is_y_vector, const int need_location_derivative,
50 const int need_cuts_derivative) {
51 const int gid = get_global_id(0);
52 const int lid = get_local_id(0);
53 const int lsize = get_local_size(0);
54 const int wg_id = get_group_id(0);
55 const int ngroups = get_num_groups(0);
56
57 __local double local_storage[LOCAL_SIZE_];
58
59 double logp = 0;
60 double d1 = 0;
61 double d2 = 0;
62 double location = 0;
63 int y;
64 // Most calculations only happen for relevant data within next if.
65 // Exceptions are reductions between threads that need barriers.
66 if (gid < N_instances) {
67 for (int i = 0, j = 0; i < N_attributes; i++, j += N_instances) {
68 location += x[j + gid] * beta[i];
69 }
70 y = y_global[gid * is_y_vector];
71 if (y < 1 || y > N_classes) {
72 location = NAN;
73 } else {
74 const double cut_y1 = y == N_classes ? INFINITY : cuts[y - 1];
75 const double cut_y2 = y == 1 ? -INFINITY : cuts[y - 2];
76 const double cut1 = location - cut_y1;
77 const double cut2 = location - cut_y2;
78
79 if (y != N_classes) {
80 logp -= log1p_exp(cut1);
81 }
82 if (y != 1) {
83 logp -= log1p_exp(-cut2);
84 }
85 if (y != 1 && y != N_classes) {
86 logp += log1m_exp(cut1 - cut2);
87 }
88
89 if (need_location_derivative || need_cuts_derivative) {
90 double exp_cuts_diff = exp(cut_y2 - cut_y1);
91 d1 = inv_logit(-cut2);
92 d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
93 d2 = 1 / (1 - exp_cuts_diff);
94 d2 -= inv_logit(-cut1);
95
96 if (need_location_derivative) {
97 location_derivative[gid] = d1 - d2;
98 }
99 }
100 }
101 }
102 if (need_cuts_derivative) {
103 for (int i = 0; i < N_classes - 1; i++) {
104 local_storage[lid] = 0;
105 if (gid < N_instances) {
106 if (y - 1 == i) {
107 local_storage[lid] = d2;
108 } else if (y - 2 == i) {
109 local_storage[lid] = -d1;
110 }
111 }
112 // Sum cuts_derivative, calculated by different threads.
113 // Since we can't sum between different work groups, we emit one
114 // number per work group. These must be summed on CPU for final
115 // result.
116 barrier(CLK_LOCAL_MEM_FENCE);
117 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
118 step /= REDUCTION_STEP_SIZE) {
119 if (lid < step) {
120 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
121 local_storage[lid] += local_storage[lid + step * i];
122 }
123 }
124 barrier(CLK_LOCAL_MEM_FENCE);
125 }
126 if (lid == 0) {
127 cuts_derivative[(N_classes - 1) * wg_id + i] = local_storage[0];
128 }
129 barrier(CLK_LOCAL_MEM_FENCE);
130 }
131 }
132 local_storage[lid] = logp;
133 barrier(CLK_LOCAL_MEM_FENCE);
134 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
135 step /= REDUCTION_STEP_SIZE) {
136 if (lid < step) {
137 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
138 local_storage[lid] += local_storage[lid + step * i];
139 }
140 }
141 barrier(CLK_LOCAL_MEM_FENCE);
142 }
143 if (lid == 0) {
144 logp_global[wg_id] = local_storage[0];
145 }
146
147 barrier(CLK_LOCAL_MEM_FENCE);
148 local_storage[lid] = location;
149 barrier(CLK_LOCAL_MEM_FENCE);
150 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
151 step /= REDUCTION_STEP_SIZE) {
152 if (lid < step) {
153 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
154 local_storage[lid] += local_storage[lid + step * i];
155 }
156 }
157 barrier(CLK_LOCAL_MEM_FENCE);
158 }
159 if (lid == 0) {
160 location_sum[wg_id] = local_storage[0];
161 }
162 }
163 // \cond
164);
165// \endcond
166
171const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, in_buffer,
172 in_buffer, in_buffer, in_buffer, int, int, int, int, int, int>
173 ordered_logistic_glm("ordered_logistic_glm",
174 {log1p_exp_device_function, log1m_exp_device_function,
175 inv_logit_device_function,
176 ordered_logistic_glm_kernel_code},
177 {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});
178
179} // namespace opencl_kernels
180} // namespace math
181} // namespace stan
182
183#endif
184#endif
const kernel_cl< out_buffer, out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int, int, int > ordered_logistic_glm("ordered_logistic_glm", {log1p_exp_device_function, log1m_exp_device_function, inv_logit_device_function, ordered_logistic_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
See the docs for ordered_logistic_glm() .
double log1m_exp(double a)
Calculates the natural logarithm of one minus the exponential of the specified value without overflow...
Definition log1m_exp.hpp:31
double log1p_exp(double a)
Calculates the log of 1 plus the exponential of the specified value without overflow.
Definition log1p_exp.hpp:28
double beta(double a, double b)
Return the beta function applied to the specified arguments.
Definition beta.hpp:25
double inv_logit(double x)
Returns the inverse logit function applied to the kernel generator expression.
Definition inv_logit.hpp:57
T step(const T &y)
The step, or Heaviside, function.
Definition step.hpp:31
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9