Automatic Differentiation
 
Loading...
Searching...
No Matches
mergesort.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNELS_MERGESORT_HPP
2#define STAN_MATH_OPENCL_KERNELS_MERGESORT_HPP
3#ifdef STAN_OPENCL
4
8#include <string>
9
10namespace stan {
11namespace math {
12namespace opencl_kernels {
13
14// \cond
15static constexpr const char* mergesort_kernel_code = STRINGIFY(
16 // \endcond
25 void merge(__global SCAL* A, __global SCAL* B, __global SCAL* res,
26 int A_size, int B_size) {
27 int A_i = 0;
28 int B_i = 0;
29 if (B_i < B_size && A_i < A_size) {
30 while (1) {
31 if (A[A_i] CMP B[B_i]) {
32 res[A_i + B_i] = A[A_i];
33 A_i++;
34 if (A_i >= A_size) {
35 break;
36 }
37 } else {
38 res[A_i + B_i] = B[B_i];
39 B_i++;
40 if (B_i >= B_size) {
41 break;
42 }
43 }
44 }
45 }
46 for (; A_i < A_size; A_i++) {
47 res[A_i + B_i] = A[A_i];
48 }
49 for (; B_i < B_size; B_i++) {
50 res[A_i + B_i] = B[B_i];
51 }
52 }
53
62 int binary_search(__global SCAL* input, int start, int end, SCAL value) {
63 while (end > start) {
64 int mid = start + (end - start) / 2;
65 if (value CMP input[mid]) {
66 end = mid;
67 } else {
68 start = mid + 1;
69 }
70 }
71 return start;
72 }
73
83 __kernel void merge_step(__global SCAL* output, __global SCAL* input,
84 int run_len, int size, int tasks) {
85 int gid = get_global_id(0);
86 int n_threads = get_global_size(0);
87 if (tasks >= n_threads) {
88 // divide tasks between threads
89 for (int task_id = gid; task_id < tasks; task_id += n_threads) {
90 int start_a = (2 * task_id) * run_len;
91 int start_b = min(size, (2 * task_id + 1) * run_len);
92 int end = min(size, (2 * task_id + 2) * run_len);
93 merge(input + start_a, input + start_b, output + start_a,
94 start_b - start_a, end - start_b);
95 }
96 } else {
97 // divide threads between tasks
98 int task_id = gid % tasks;
99 int threads_with_task_id = n_threads / tasks;
100 threads_with_task_id
101 += n_threads - threads_with_task_id * tasks > task_id;
102 int id_in_task = gid / tasks;
103
104 int start_a = (2 * task_id) * run_len;
105 int start_b = min(size, (2 * task_id + 1) * run_len);
106 int end = min(size, (2 * task_id + 2) * run_len);
107
108 // divide a task between threads working on it
109 int my_start_a = start_a
110 + (start_b - start_a) * (long)id_in_task // NOLINT
111 / threads_with_task_id;
112 int my_end_a = start_a
113 + (start_b - start_a) * (long)(id_in_task + 1) // NOLINT
114 / threads_with_task_id;
115 int my_start_b = id_in_task == 0 ? start_b
116 : binary_search(input, start_b, end,
117 input[my_start_a]);
118 int my_end_b
119 = id_in_task == threads_with_task_id - 1
120 ? end
121 : binary_search(input, start_b, end, input[my_end_a]);
122
123 int output_start
124 = start_a + my_start_a - start_a + my_start_b - start_b;
125
126 merge(input + my_start_a, input + my_start_b, output + output_start,
127 my_end_a - my_start_a, my_end_b - my_start_b);
128 }
129 }
130 // \cond
131);
132// \endcond
133
137template <typename Scalar, typename = void>
138struct sort_asc {};
139
140template <typename T>
141struct sort_asc<double, T> {
143};
144template <typename T>
147 {"#define SCAL double\n", "#define CMP <\n",
148 mergesort_kernel_code});
149
150template <typename T>
151struct sort_asc<int, T> {
153};
154template <typename T>
156 sort_asc<int, T>::merge_step("merge_step",
157 {"#define SCAL int\n", "#define CMP <\n",
158 mergesort_kernel_code});
159
163template <typename Scalar, typename = void>
164struct sort_desc {};
165
166template <typename T>
167struct sort_desc<double, T> {
169};
170template <typename T>
172 sort_desc<double, T>::merge_step("merge_step", {"#define SCAL double\n",
173 "#define CMP >\n",
174 mergesort_kernel_code});
175
176template <typename T>
177struct sort_desc<int, T> {
179};
180template <typename T>
183 {"#define SCAL int\n", "#define CMP >\n",
184 mergesort_kernel_code});
185
186} // namespace opencl_kernels
187} // namespace math
188} // namespace stan
189
190#endif
191#endif
__kernel void merge_step(__global SCAL *output, __global SCAL *input, int run_len, int size, int tasks)
Merges sorted runs into longer sorted runs.
Definition mergesort.hpp:83
int64_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:19
int binary_search(__global SCAL *input, int start, int end, SCAL value)
Searches for the index of the element that is larger than or equal to given value in given range.
Definition mergesort.hpp:62
void merge(__global SCAL *A, __global SCAL *B, __global SCAL *res, int A_size, int B_size)
Merges two sorted runs into a single sorted run of combined length.
Definition mergesort.hpp:25
auto min(T1 x, T2 y)
Returns the minimum coefficient of the two specified scalar arguments.
Definition min.hpp:24
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9
Creates functor for kernels.
static const kernel_cl< out_buffer, in_buffer, int, int, int > merge_step
static const kernel_cl< out_buffer, in_buffer, int, int, int > merge_step
struct containing sort_asc kernels, grouped by scalar type.
static const kernel_cl< out_buffer, in_buffer, int, int, int > merge_step
static const kernel_cl< out_buffer, in_buffer, int, int, int > merge_step
struct containing sort_desc kernels, grouped by scalar type.