Automatic Differentiation
 
Loading...
Searching...
No Matches
multiply.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_MULTIPLY_HPP
2#define STAN_MATH_OPENCL_PRIM_MULTIPLY_HPP
3#ifdef STAN_OPENCL
4
13#include <algorithm>
14
15namespace stan {
16namespace math {
17
36template <typename T1, typename T2,
37 typename = require_all_kernel_expressions_and_none_scalar_t<T1, T2>>
38inline matrix_cl<return_type_t<T1, T2>> multiply(const T1& A, const T2& B) {
39 check_size_match("multiply ((OpenCL))", "A.cols()", A.cols(), "B.rows()",
40 B.rows());
41 if (A.size() == 0 || B.size() == 0) {
42 return constant(0.0, A.rows(), B.cols());
43 }
44 matrix_cl<return_type_t<T1, T2>> temp(A.rows(), B.cols(),
45 either(A.view(), B.view()));
46 if (A.rows() == 1) {
47 const int local_size
48 = opencl_kernels::row_vector_matrix_multiply.get_option("LOCAL_SIZE_");
49 try {
51 cl::NDRange(temp.cols() * local_size), cl::NDRange(local_size),
52 A.eval(), B.eval(), temp, B.rows(), B.cols(), A.view(), B.view());
53 } catch (cl::Error& e) {
54 check_opencl_error("row_vector - matrix multiply", e);
55 }
56 return temp;
57 }
58 if (B.cols() == 1) {
59 temp = matrix_vector_multiply(A, B);
60 return temp;
61 }
62 int local = opencl_kernels::matrix_multiply.get_option("THREAD_BLOCK_SIZE");
63 const int Mpad = ((A.rows() + local - 1) / local) * local;
64 const int Npad = ((B.cols() + local - 1) / local) * local;
65 const int wpt = opencl_kernels::matrix_multiply.get_option("WORK_PER_THREAD");
66 const int wgs = Mpad / local * Npad / local;
67 const int split = std::min(
68 A.cols() / local,
70 * static_cast<int>(opencl_context.device()[0]
71 .getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>())
72 + wgs - 1)
73 / wgs);
74 try {
75 if (split <= 1) {
76 opencl_kernels::matrix_multiply(cl::NDRange(Mpad, Npad / wpt),
77 cl::NDRange(local, local / wpt), A.eval(),
78 B.eval(), temp, A.rows(), B.cols(),
79 B.rows(), A.view(), B.view());
80 } else {
81 matrix_cl<return_type_t<T1, T2>> tempSplit(A.rows(), B.cols() * split);
82 opencl_kernels::matrix_multiply(cl::NDRange(Mpad, Npad / wpt, split),
83 cl::NDRange(local, local / wpt, 1),
84 A.eval(), B.eval(), tempSplit, A.rows(),
85 B.cols(), B.rows(), A.view(), B.view());
86 opencl_kernels::add_batch(cl::NDRange(A.rows(), B.cols()), temp,
87 tempSplit, A.rows(), B.cols(), split);
88 }
89 } catch (cl::Error& e) {
90 check_opencl_error("multiply", e);
91 }
92 return temp;
93}
94
104template <typename T_a, typename T_b,
107 const T_b& b) {
108 // no need for perfect forwarding as operations are evaluated
110}
111
120template <typename T_a, typename T_b, require_stan_scalar_t<T_a>* = nullptr,
121 require_all_kernel_expressions_and_none_scalar_t<T_b>* = nullptr,
122 require_all_not_var_t<T_a, T_b>* = nullptr>
123inline matrix_cl<return_type_t<T_a, T_b>> multiply(const T_a& a, const T_b& b) {
124 return a * b;
125}
126
135template <typename T_a, typename T_b, require_stan_scalar_t<T_b>* = nullptr,
136 require_all_kernel_expressions_and_none_scalar_t<T_a>* = nullptr,
137 require_all_not_var_t<T_a, T_b>* = nullptr>
138inline matrix_cl<return_type_t<T_a, T_b>> multiply(const T_a& a, const T_b& b) {
139 return a * b;
140}
141
142} // namespace math
143} // namespace stan
144#endif
145#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
The API to access the methods and values in opencl_context_base.
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
std::vector< cl::Device > & device() noexcept
Returns a vector containing the OpenCL device used to create the context.
opencl_context_base::tuning_struct & tuning_opts() noexcept
Returns the thread block size for the Cholesky Decompositions L_11.
auto constant(const T a, int rows, int cols)
Matrix of repeated values in kernel generator expressions.
Definition constant.hpp:130
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, matrix_cl_view, matrix_cl_view > row_vector_matrix_multiply("row_vector_matrix_multiply", {view_kernel_helpers, row_vector_matrix_multiply_kernel_code}, {{"LOCAL_SIZE_", 64}, {"REDUCTION_STEP_SIZE", 4}})
See the docs for row_vector_matrix_multiply() .
const kernel_cl< out_buffer, in_buffer, int, int, int > add_batch("add_batch", {indexing_helpers, add_batch_kernel_code})
See the docs for add_batch() .
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, int, matrix_cl_view, matrix_cl_view > matrix_multiply("matrix_multiply", {thread_block_helpers, view_kernel_helpers, matrix_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for matrix_multiply() .
const matrix_cl_view either(const matrix_cl_view left_view, const matrix_cl_view right_view)
Determines which parts are nonzero in any of the input views.
fvar< T > operator*(const fvar< T > &x, const fvar< T > &y)
Return the product of the two arguments.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
auto multiply(const Mat1 &m1, const Mat2 &m2)
Return the product of the specified matrices.
Definition multiply.hpp:18
auto matrix_vector_multiply(T_matrix &&matrix, T_vector &&vector)
Multiplies a matrix and a vector on an OpenCL device.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9