Automatic Differentiation
 
Loading...
Searching...
No Matches
to_matrix.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_TO_MATRIX_HPP
2#define STAN_MATH_OPENCL_PRIM_TO_MATRIX_HPP
3#ifdef STAN_OPENCL
4
7
8namespace stan {
9namespace math {
10
19template <typename T_x,
20 require_nonscalar_prim_or_rev_kernel_expression_t<T_x>* = nullptr>
21inline T_x to_matrix(T_x&& x) {
22 return std::forward<T_x>(x);
23}
24
38template <typename T_x,
40inline matrix_cl<return_type_t<T_x>> to_matrix(const T_x& x, int m, int n) {
41 using res_scal = return_type_t<T_x>;
42 check_size_match("to_matrix", "rows * columns", "", m * n, "input size", "",
43 x.size());
44 matrix_cl<res_scal> res(m, n);
45 matrix_cl<res_scal> tmp(res.buffer(), x.rows(), x.cols());
46 tmp = x;
47 for (cl::Event e : tmp.write_events()) {
48 res.add_write_event(e);
49 }
50 return res;
51}
52
68template <typename T_x,
70inline auto to_matrix(const T_x& x, int m, int n, bool col_major)
71 -> decltype(to_matrix(x, m, n)) {
72 if (col_major) {
73 return to_matrix(x, m, n);
74 } else {
75 return transpose(to_matrix(transpose(x), n, m));
76 }
77}
78
79} // namespace math
80} // namespace stan
81#endif
82#endif
const cl::Buffer & buffer() const
void add_write_event(cl::Event new_event) const
Add an event to the write event stack.
const tbb::concurrent_vector< cl::Event > & write_events() const
Get the events from the event stacks.
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
require_t< is_nonscalar_prim_or_rev_kernel_expression< std::decay_t< T > > > require_nonscalar_prim_or_rev_kernel_expression_t
Require type satisfies is_nonscalar_prim_or_rev_kernel_expression.
auto transpose(Arg &&a)
Transposes a kernel generator expression.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
T_x to_matrix(T_x &&x)
Returns input matrix.
Definition to_matrix.hpp:21
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9