Automatic Differentiation
 
Loading...
Searching...
No Matches
diag_inv.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNELS_DIAGONAL_INVERSE_LOWER_TRI_HPP
2#define STAN_MATH_OPENCL_KERNELS_DIAGONAL_INVERSE_LOWER_TRI_HPP
3#ifdef STAN_OPENCL
4
7#include <string>
8
9namespace stan {
10namespace math {
11namespace opencl_kernels {
12// \cond
13static constexpr const char* diag_inv_kernel_code = STRINGIFY(
14 // \endcond
43 __kernel void diag_inv(__global double* A, __global double* tmp_inv,
44 int rows) {
45 int index = get_local_id(0);
46 int group = get_group_id(0);
47 int block_size = get_local_size(0);
48 int A_offset = group * block_size;
49 // offset inside the matrix with batched identities
50 int tmp_offset = group * block_size * block_size + index * block_size;
51
52 // The following code is the sequential version of forward
53 // substitution with the identity matrix as RHS. Only the innermost loops
54 // are parallelized. The rows are processed sequentially. This loop
55 // process all the rows:
56 for (int k = 0; k < block_size; k++) {
57 double diag_ele = A(A_offset + k, A_offset + k);
58
59 // Each element under the diagonal of the RHS is divided by diag_ele.
60 // Each thread in a thread block does 1 division.
61 // Threads that are assigned elements above the diagonal
62 // skip this division.
63 if (index <= k) {
64 tmp_inv[tmp_offset + k] /= diag_ele;
65 }
66 barrier(CLK_LOCAL_MEM_FENCE);
67 // Each thread updates one column in the RHS matrix
68 // (ignores values above the diagonal).
69 for (int i = max(k + 1, index); i < block_size; i++) { // NOLINT
70 double factor = A(A_offset + i, A_offset + k);
71 tmp_inv[tmp_offset + i] -= tmp_inv[tmp_offset + k] * factor;
72 }
73 barrier(CLK_LOCAL_MEM_FENCE);
74 }
75 for (int j = 0; j < block_size; j++) {
76 // Each thread copies one column.
77 A(A_offset + j, A_offset + index) = tmp_inv[tmp_offset + j];
78 }
79 }
80 // \cond
81);
82// \endcond
83
89 "diag_inv", {indexing_helpers, diag_inv_kernel_code},
90 {{"THREAD_BLOCK_SIZE", 32}});
91
92} // namespace opencl_kernels
93} // namespace math
94} // namespace stan
95#endif
96#endif
const kernel_cl< in_out_buffer, in_out_buffer, int > diag_inv("diag_inv", {indexing_helpers, diag_inv_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}})
See the docs for add() .
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:22
static const std::string indexing_helpers
Defines helper macros for common matrix indexing operations.
Definition helpers.hpp:14
auto max(T1 x, T2 y)
Returns the maximum value of the two specified scalar arguments.
Definition max.hpp:25
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9
Creates functor for kernels.