1#ifndef STAN_MATH_OPENCL_PRIM_MULTIPLY_HPP
2#define STAN_MATH_OPENCL_PRIM_MULTIPLY_HPP
36template <
typename T1,
typename T2,
37 typename = require_all_kernel_expressions_and_none_scalar_t<T1, T2>>
41 if (A.size() == 0 || B.size() == 0) {
42 return constant(0.0, A.rows(), B.cols());
45 either(A.view(), B.view()));
51 cl::NDRange(temp.
cols() * local_size), cl::NDRange(local_size),
52 A.eval(), B.eval(), temp, B.
rows(), B.cols(), A.view(), B.view());
53 }
catch (cl::Error&
e) {
63 const int Mpad = ((A.rows() + local - 1) / local) * local;
64 const int Npad = ((B.cols() + local - 1) / local) * local;
66 const int wgs = Mpad / local * Npad / local;
67 const int split = std::min(
71 .getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>())
77 cl::NDRange(local, local / wpt), A.eval(),
78 B.eval(), temp, A.
rows(), B.cols(),
79 B.rows(), A.view(), B.view());
83 cl::NDRange(local, local / wpt, 1),
84 A.eval(), B.eval(), tempSplit, A.
rows(),
85 B.cols(), B.rows(), A.view(), B.view());
87 tempSplit, A.
rows(), B.cols(), split);
89 }
catch (cl::Error&
e) {
104template <
typename T_a,
typename T_b,
120template <
typename T_a,
typename T_b, require_stan_scalar_t<T_a>* =
nullptr,
121 require_all_kernel_expressions_and_none_scalar_t<T_b>* =
nullptr,
122 require_all_not_var_t<T_a, T_b>* =
nullptr>
135template <
typename T_a,
typename T_b, require_stan_scalar_t<T_b>* =
nullptr,
136 require_all_kernel_expressions_and_none_scalar_t<T_a>* =
nullptr,
137 require_all_not_var_t<T_a, T_b>* =
nullptr>
138inline matrix_cl<return_type_t<T_a, T_b>>
multiply(
const T_a& a,
const T_b& b) {
Represents an arithmetic matrix on the OpenCL device.
The API to access the methods and values in opencl_context_base.
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
std::vector< cl::Device > & device() noexcept
Returns a vector containing the OpenCL device used to create the context.
opencl_context_base::tuning_struct & tuning_opts() noexcept
Returns the thread block size for the Cholesky Decompositions L_11.
auto constant(const T a, int rows, int cols)
Matrix of repeated values in kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, matrix_cl_view, matrix_cl_view > row_vector_matrix_multiply("row_vector_matrix_multiply", {view_kernel_helpers, row_vector_matrix_multiply_kernel_code}, {{"LOCAL_SIZE_", 64}, {"REDUCTION_STEP_SIZE", 4}})
See the docs for row_vector_matrix_multiply() .
const kernel_cl< out_buffer, in_buffer, int, int, int > add_batch("add_batch", {indexing_helpers, add_batch_kernel_code})
See the docs for add_batch() .
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, int, matrix_cl_view, matrix_cl_view > matrix_multiply("matrix_multiply", {thread_block_helpers, view_kernel_helpers, matrix_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for matrix_multiply() .
const matrix_cl_view either(const matrix_cl_view left_view, const matrix_cl_view right_view)
Determines which parts are nonzero in any of the input views.
fvar< T > operator*(const fvar< T > &x, const fvar< T > &y)
Return the product of the two arguments.
static constexpr double e()
Return the base of the natural logarithm.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
auto multiply(const Mat1 &m1, const Mat2 &m2)
Return the product of the specified matrices.
auto matrix_vector_multiply(T_matrix &&matrix, T_vector &&vector)
Multiplies a matrix and a vector on an OpenCL device.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
int multiply_wgs_per_compute_unit