Automatic Differentiation
 
Loading...
Searching...
No Matches
tri_inverse.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_TRI_INVERSE_HPP
2#define STAN_MATH_OPENCL_TRI_INVERSE_HPP
3
4#ifdef STAN_OPENCL
5
17#include <cmath>
18#include <string>
19#include <vector>
20
21namespace stan {
22namespace math {
38template <matrix_cl_view matrix_view = matrix_cl_view::Entire, typename T,
39 require_matrix_cl_st<std::is_floating_point, T>* = nullptr>
40inline plain_type_t<T> tri_inverse(const T& A) {
41 check_square("tri_inverse (OpenCL)", "A", A);
42 // if the triangular view is not specified use the triangularity of
43 // the input matrix
44 matrix_cl_view tri_view = matrix_view;
45 if (matrix_view == matrix_cl_view::Entire) {
46 if (A.view() != matrix_cl_view::Diagonal) {
47 check_triangular("tri_inverse (OpenCL)", "A", A);
48 }
49 tri_view = A.view();
50 }
51 if (tri_view == matrix_cl_view::Diagonal) {
52 plain_type_t<T> inv_mat(A.rows(), A.cols());
53 diagonal(inv_mat) = elt_divide(1.0, diagonal(A));
54 return inv_mat;
55 }
56
57 int thread_block_2D_dim = 32;
58 int max_1D_thread_block_size = opencl_context.max_thread_block_size();
59 // we split the input matrix to 32 blocks
60 int thread_block_size_1D
61 = (((A.rows() / 32) + thread_block_2D_dim - 1) / thread_block_2D_dim)
62 * thread_block_2D_dim;
63 if (max_1D_thread_block_size < thread_block_size_1D) {
64 thread_block_size_1D = max_1D_thread_block_size;
65 }
66 int max_2D_thread_block_dim = std::sqrt(max_1D_thread_block_size);
67 if (max_2D_thread_block_dim < thread_block_2D_dim) {
68 thread_block_2D_dim = max_2D_thread_block_dim;
69 }
70 // for small size split in max 2 parts
71 if (thread_block_size_1D < 64) {
72 thread_block_size_1D = 32;
73 }
74 if (A.rows() < thread_block_size_1D) {
75 thread_block_size_1D = A.rows();
76 }
77
78 // pad the input matrix
79 int A_rows_padded
80 = ((A.rows() + thread_block_size_1D - 1) / thread_block_size_1D)
81 * thread_block_size_1D;
82
83 plain_type_t<T> temp(A_rows_padded, A_rows_padded);
84 plain_type_t<T> inv_padded = constant(0.0, A_rows_padded, A_rows_padded);
85 plain_type_t<T> inv_mat(A);
86 plain_type_t<T> zero_mat
87 = constant(0.0, A_rows_padded - A.rows(), A_rows_padded);
88 if (tri_view == matrix_cl_view::Upper) {
89 inv_mat = transpose(inv_mat).eval();
90 }
91 int work_per_thread
92 = opencl_kernels::inv_lower_tri_multiply.get_option("WORK_PER_THREAD");
93 // the number of blocks in the first step
94 // each block is inverted with using the regular forward substitution
95 int parts = inv_padded.rows() / thread_block_size_1D;
96 block_zero_based(inv_padded, 0, 0, inv_mat.rows(), inv_mat.rows()) = inv_mat;
97 try {
98 // create a batch of identity matrices to be used in the first step
100 cl::NDRange(parts, thread_block_size_1D, thread_block_size_1D), temp,
101 thread_block_size_1D, temp.size());
102 // spawn parts thread blocks, each responsible for one block
103 opencl_kernels::diag_inv(cl::NDRange(parts * thread_block_size_1D),
104 cl::NDRange(thread_block_size_1D), inv_padded,
105 temp, inv_padded.rows());
106 } catch (cl::Error& e) {
107 check_opencl_error("inverse step1", e);
108 }
109 // set the padded part of the matrix and the upper triangular to zeros
110 block_zero_based(inv_padded, inv_mat.rows(), 0, zero_mat.rows(),
111 zero_mat.cols())
112 = zero_mat;
113 inv_padded.template zeros_strict_tri<stan::math::matrix_cl_view::Upper>();
114 if (parts == 1) {
115 inv_mat
116 = block_zero_based(inv_padded, 0, 0, inv_mat.rows(), inv_mat.rows());
117 if (tri_view == matrix_cl_view::Upper) {
118 inv_mat = transpose(inv_mat).eval();
119 }
120 return inv_mat;
121 }
122 using std::ceil;
123 parts = ceil(parts / 2.0);
124
125 auto result_matrix_dim = thread_block_size_1D;
126 auto thread_block_work2d_dim = thread_block_2D_dim / work_per_thread;
127 auto ndrange_2d
128 = cl::NDRange(thread_block_2D_dim, thread_block_work2d_dim, 1);
129 while (parts > 0) {
130 int result_matrix_dim_x = result_matrix_dim;
131 // when calculating the last submatrix
132 // we can reduce the size to the actual size (not the next power of 2)
133 if (parts == 1 && (inv_padded.rows() - result_matrix_dim * 2) < 0) {
134 result_matrix_dim_x = inv_padded.rows() - result_matrix_dim;
135 }
136 auto result_work_dim = result_matrix_dim / work_per_thread;
137 auto result_ndrange
138 = cl::NDRange(result_matrix_dim_x, result_work_dim, parts);
139 opencl_kernels::inv_lower_tri_multiply(result_ndrange, ndrange_2d,
140 inv_padded, temp, inv_padded.rows(),
141 result_matrix_dim);
143 result_ndrange, ndrange_2d, inv_padded, temp, inv_padded.rows(),
144 result_matrix_dim);
145 // if this is the last submatrix, end
146 if (parts == 1) {
147 parts = 0;
148 } else {
149 parts = ceil(parts / 2.0);
150 }
151 result_matrix_dim *= 2;
152 // set the padded part and upper diagonal to zeros
153 block_zero_based(inv_padded, inv_mat.rows(), 0, zero_mat.rows(),
154 zero_mat.cols())
155 = zero_mat;
156 inv_padded.template zeros_strict_tri<stan::math::matrix_cl_view::Upper>();
157 }
158 // un-pad and return
159 inv_mat = block_zero_based(inv_padded, 0, 0, inv_mat.rows(), inv_mat.rows());
160 if (tri_view == matrix_cl_view::Upper) {
161 inv_mat = transpose(inv_mat).eval();
162 }
163 inv_mat.view(tri_view);
164 return inv_mat;
165}
166} // namespace math
167} // namespace stan
168
169#endif
170#endif
The API to access the methods and values in opencl_context_base.
void check_triangular(const char *function, const char *name, const T &A)
Check if the matrix_cl is either upper triangular or lower triangular.
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
int max_thread_block_size() noexcept
Returns the maximum thread block size defined by CL_DEVICE_MAX_WORK_GROUP_SIZE for the device in the ...
auto block_zero_based(T &&a, int start_row, int start_col, int rows, int cols)
Block of a kernel generator expression.
auto transpose(Arg &&a)
Transposes a kernel generator expression.
elt_divide_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_divide(T_a &&a, T_b &&b)
auto constant(const T a, int rows, int cols)
Matrix of repeated values in kernel generator expressions.
Definition constant.hpp:130
auto diagonal(T &&a)
Diagonal of a kernel generator expression.
Definition diagonal.hpp:136
const kernel_cl< in_out_buffer, in_out_buffer, int > diag_inv("diag_inv", {indexing_helpers, diag_inv_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}})
See the docs for add() .
const kernel_cl< out_buffer, int, int > batch_identity("batch_identity", {indexing_helpers, batch_identity_kernel_code})
See the docs for batch_identity() .
const kernel_cl< in_buffer, out_buffer, int, int > inv_lower_tri_multiply("inv_lower_tri_multiply", {thread_block_helpers, inv_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for add() .
const kernel_cl< in_out_buffer, in_buffer, int, int > neg_rect_lower_tri_multiply("neg_rect_lower_tri_multiply", {thread_block_helpers, neg_rect_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for neg_rect_lower_tri_multiply() .
plain_type_t< T > tri_inverse(const T &A)
Computes the inverse of a triangular matrix.
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
fvar< T > ceil(const fvar< T > &x)
Definition ceil.hpp:12
typename plain_type< T >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...