Automatic Differentiation
 
Loading...
Searching...
No Matches
ordered_logistic_lpmf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNELS_ORDERED_LOGISTIC_LPMF_HPP
2#define STAN_MATH_OPENCL_KERNELS_ORDERED_LOGISTIC_LPMF_HPP
3#ifdef STAN_OPENCL
4
9
10namespace stan {
11namespace math {
12namespace opencl_kernels {
13
14// \cond
15static constexpr const char* ordered_logistic_kernel_code = STRINGIFY(
16 // \endcond
41 __kernel void ordered_logistic(
42 __global double* logp_global, __global double* lambda_derivative,
43 __global double* cuts_derivative, const __global int* y_global,
44 const __global double* lambda_global, const __global double* cuts,
45 const int N_instances, const int N_classes, const int is_y_vector,
46 const int is_cuts_matrix, const int need_lambda_derivative,
47 const int need_cuts_derivative) {
48 const int gid = get_global_id(0);
49 const int lid = get_local_id(0);
50 const int lsize = get_local_size(0);
51 const int wg_id = get_group_id(0);
52 const int ngroups = get_num_groups(0);
53
54 __local double local_storage[LOCAL_SIZE_];
55
56 double logp = 0;
57 double d1 = 0;
58 double d2 = 0;
59 int y;
60 int cuts_start = (N_classes - 1) * gid * is_cuts_matrix;
61 // Most calculations only happen for relevant data within next if.
62 // Exceptions are reductions between threads that need barriers.
63 if (gid < N_instances) {
64 double lambda = lambda_global[gid];
65 y = y_global[gid * is_y_vector];
66 if (y < 1 || y > N_classes || !isfinite(lambda)) {
67 logp = NAN;
68 } else {
69 const double cut_y1
70 = y == N_classes ? INFINITY : cuts[cuts_start + y - 1];
71 const double cut_y2 = y == 1 ? -INFINITY : cuts[cuts_start + y - 2];
72 const double cut1 = lambda - cut_y1;
73 const double cut2 = lambda - cut_y2;
74
75 if (y != N_classes) {
76 logp -= log1p_exp(cut1);
77 }
78 if (y != 1) {
79 logp -= log1p_exp(-cut2);
80 }
81 if (y != 1 && y != N_classes) {
82 logp += log1m_exp(cut1 - cut2);
83 }
84
85 if (need_lambda_derivative || need_cuts_derivative) {
86 double exp_cuts_diff = exp(cut_y2 - cut_y1);
87 d1 = inv_logit(-cut2);
88 d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
89 d2 = 1 / (1 - exp_cuts_diff);
90 d2 -= inv_logit(-cut1);
91
92 if (need_lambda_derivative) {
93 lambda_derivative[gid] = d1 - d2;
94 }
95 }
96 }
97 }
98 if (need_cuts_derivative) {
99 if (is_cuts_matrix) {
100 if (gid < N_instances) {
101 for (int i = 0; i < N_classes - 1; i++) {
102 if (y - 1 == i) {
103 cuts_derivative[cuts_start + i] = d2;
104 } else if (y - 2 == i) {
105 cuts_derivative[cuts_start + i] = -d1;
106 } else {
107 cuts_derivative[cuts_start + i] = 0.0;
108 }
109 }
110 }
111 } else {
112 for (int i = 0; i < N_classes - 1; i++) {
113 local_storage[lid] = 0;
114 if (gid < N_instances) {
115 if (y - 1 == i) {
116 local_storage[lid] = d2;
117 } else if (y - 2 == i) {
118 local_storage[lid] = -d1;
119 }
120 }
121 // Sum cuts_derivative, calculated by different threads.
122 // Since we can't sum between different work groups, we emit one
123 // number per work group. These must be summed on CPU for final
124 // result.
125 barrier(CLK_LOCAL_MEM_FENCE);
126 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
127 step /= REDUCTION_STEP_SIZE) {
128 if (lid < step) {
129 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
130 local_storage[lid] += local_storage[lid + step * i];
131 }
132 }
133 barrier(CLK_LOCAL_MEM_FENCE);
134 }
135 if (lid == 0) {
136 cuts_derivative[(N_classes - 1) * wg_id + i] = local_storage[0];
137 }
138 barrier(CLK_LOCAL_MEM_FENCE);
139 }
140 }
141 }
142 local_storage[lid] = logp;
143 barrier(CLK_LOCAL_MEM_FENCE);
144 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
145 step /= REDUCTION_STEP_SIZE) {
146 if (lid < step) {
147 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
148 local_storage[lid] += local_storage[lid + step * i];
149 }
150 }
151 barrier(CLK_LOCAL_MEM_FENCE);
152 }
153 if (lid == 0) {
154 logp_global[wg_id] = local_storage[0];
155 }
156 }
157 // \cond
158);
159// \endcond
160
165const kernel_cl<out_buffer, out_buffer, out_buffer, in_buffer, in_buffer,
166 in_buffer, int, int, int, int, int, int>
167 ordered_logistic("ordered_logistic",
168 {log1p_exp_device_function, log1m_exp_device_function,
169 inv_logit_device_function, ordered_logistic_kernel_code},
170 {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});
171
172} // namespace opencl_kernels
173} // namespace math
174} // namespace stan
175
176#endif
177#endif
isfinite_< as_operation_cl_t< T > > isfinite(T &&a)
const kernel_cl< out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int, int, int > ordered_logistic("ordered_logistic", {log1p_exp_device_function, log1m_exp_device_function, inv_logit_device_function, ordered_logistic_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
See the docs for ordered_logistic() .
double log1m_exp(double a)
Calculates the natural logarithm of one minus the exponential of the specified value without overflow...
Definition log1m_exp.hpp:31
double log1p_exp(double a)
Calculates the log of 1 plus the exponential of the specified value without overflow.
Definition log1p_exp.hpp:28
double inv_logit(double x)
Returns the inverse logit function applied to the kernel generator expression.
Definition inv_logit.hpp:57
T step(const T &y)
The step, or Heaviside, function.
Definition step.hpp:31
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9