1#ifndef STAN_MATH_OPENCL_KERNELS_MERGESORT_HPP
2#define STAN_MATH_OPENCL_KERNELS_MERGESORT_HPP
12namespace opencl_kernels {
15static constexpr const char* mergesort_kernel_code =
STRINGIFY(
25 void merge(__global SCAL* A, __global SCAL* B, __global SCAL* res,
26 int A_size,
int B_size) {
29 if (B_i < B_size && A_i < A_size) {
31 if (A[A_i] CMP B[B_i]) {
32 res[A_i + B_i] = A[A_i];
38 res[A_i + B_i] = B[B_i];
46 for (; A_i < A_size; A_i++) {
47 res[A_i + B_i] = A[A_i];
49 for (; B_i < B_size; B_i++) {
50 res[A_i + B_i] = B[B_i];
62 int binary_search(__global SCAL* input,
int start,
int end, SCAL value) {
64 int mid = start + (end - start) / 2;
65 if (value CMP input[mid]) {
83 __kernel
void merge_step(__global SCAL* output, __global SCAL* input,
84 int run_len,
int size,
int tasks) {
85 int gid = get_global_id(0);
86 int n_threads = get_global_size(0);
87 if (tasks >= n_threads) {
89 for (
int task_id = gid; task_id < tasks; task_id += n_threads) {
90 int start_a = (2 * task_id) * run_len;
91 int start_b =
min(
size, (2 * task_id + 1) * run_len);
92 int end =
min(
size, (2 * task_id + 2) * run_len);
93 merge(input + start_a, input + start_b, output + start_a,
94 start_b - start_a, end - start_b);
98 int task_id = gid % tasks;
99 int threads_with_task_id = n_threads / tasks;
101 += n_threads - threads_with_task_id * tasks > task_id;
102 int id_in_task = gid / tasks;
104 int start_a = (2 * task_id) * run_len;
105 int start_b =
min(
size, (2 * task_id + 1) * run_len);
106 int end =
min(
size, (2 * task_id + 2) * run_len);
109 int my_start_a = start_a
110 + (start_b - start_a) * (
long)id_in_task
111 / threads_with_task_id;
112 int my_end_a = start_a
113 + (start_b - start_a) * (
long)(id_in_task + 1)
114 / threads_with_task_id;
115 int my_start_b = id_in_task == 0 ? start_b
119 = id_in_task == threads_with_task_id - 1
124 = start_a + my_start_a - start_a + my_start_b - start_b;
126 merge(input + my_start_a, input + my_start_b, output + output_start,
127 my_end_a - my_start_a, my_end_b - my_start_b);
137template <
typename Scalar,
typename =
void>
147 {
"#define SCAL double\n",
"#define CMP <\n",
148 mergesort_kernel_code});
157 {
"#define SCAL int\n",
"#define CMP <\n",
158 mergesort_kernel_code});
163template <
typename Scalar,
typename =
void>
174 mergesort_kernel_code});
183 {
"#define SCAL int\n",
"#define CMP >\n",
184 mergesort_kernel_code});
__kernel void merge_step(__global SCAL *output, __global SCAL *input, int run_len, int size, int tasks)
Merges sorted runs into longer sorted runs.
int64_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
int binary_search(__global SCAL *input, int start, int end, SCAL value)
Searches for the index of the element that is larger than or equal to given value in given range.
void merge(__global SCAL *A, __global SCAL *B, __global SCAL *res, int A_size, int B_size)
Merges two sorted runs into a single sorted run of combined length.
auto min(T1 x, T2 y)
Returns the minimum coefficient of the two specified scalar arguments.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Creates functor for kernels.
static const kernel_cl< out_buffer, in_buffer, int, int, int > merge_step
static const kernel_cl< out_buffer, in_buffer, int, int, int > merge_step
struct containing sort_asc kernels, grouped by scalar type.
static const kernel_cl< out_buffer, in_buffer, int, int, int > merge_step
static const kernel_cl< out_buffer, in_buffer, int, int, int > merge_step
struct containing sort_desc kernels, grouped by scalar type.