Automatic Differentiation
 
Loading...
Searching...
No Matches
transpose.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_TRANSPOSE_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_TRANSPOSE_HPP
3#ifdef STAN_OPENCL
4
11#include <algorithm>
12#include <string>
13#include <set>
14#include <tuple>
15#include <type_traits>
16#include <utility>
17
18namespace stan {
19namespace math {
20
31template <typename Arg>
33 : public operation_cl_lhs<
34 transpose_<Arg>, typename std::remove_reference_t<Arg>::Scalar, Arg> {
35 public:
36 using Scalar = typename std::remove_reference_t<Arg>::Scalar;
38 using base::var_name_;
39 using view_transitivity = std::tuple<std::true_type>;
40 using base::operator=;
41
46 explicit transpose_(Arg&& a) : base(std::forward<Arg>(a)) {}
47
52 inline auto deep_copy() const {
53 auto&& arg_copy = this->template get_arg<0>().deep_copy();
54 return transpose_<std::remove_reference_t<decltype(arg_copy)>>{
55 std::move(arg_copy)};
56 }
57
64 inline void modify_argument_indices(std::string& row_index_name,
65 std::string& col_index_name) const {
66 std::swap(row_index_name, col_index_name);
67 }
68
74 inline int rows() const { return this->template get_arg<0>().cols(); }
75
81 inline int cols() const { return this->template get_arg<0>().rows(); }
82
87 inline std::pair<int, int> extreme_diagonals() const {
88 std::pair<int, int> arg_diags
89 = this->template get_arg<0>().extreme_diagonals();
90 return {-arg_diags.second, -arg_diags.first};
91 }
92
107 void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal,
108 int top_zero_diagonal) const {
109 this->template get_arg<0>().set_view(
110 top_diagonal, bottom_diagonal, top_zero_diagonal, bottom_zero_diagonal);
111 }
112
121 void check_assign_dimensions(int rows, int cols) const {
122 this->template get_arg<0>().check_assign_dimensions(cols, rows);
123 }
124};
125
137template <typename Arg,
139inline auto transpose(Arg&& a) {
140 auto&& a_operation = as_operation_cl(std::forward<Arg>(a)).deep_copy();
141 return transpose_<std::remove_reference_t<decltype(a_operation)>>{
142 std::move(a_operation)};
143}
145} // namespace math
146} // namespace stan
147
148#endif
149#endif
Base for all kernel generator operations that can be used on left hand side of an expression.
auto deep_copy() const
Creates a deep copy of this expression.
Definition transpose.hpp:52
transpose_(Arg &&a)
Constructor.
Definition transpose.hpp:46
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Definition transpose.hpp:87
typename std::remove_reference_t< Arg >::Scalar Scalar
Definition transpose.hpp:36
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Definition transpose.hpp:74
std::tuple< std::true_type > view_transitivity
Definition transpose.hpp:39
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Swaps indices row_index_name and col_index_name for the argument expression.
Definition transpose.hpp:64
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Definition transpose.hpp:81
void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const
Sets the view of the underlying matrix depending on which of its parts are written to.
void check_assign_dimensions(int rows, int cols) const
Sets the dimensions of the underlying expressions if possible.
Represents a transpose in kernel generator expressions.
Definition transpose.hpp:34
auto transpose(Arg &&a)
Transposes a kernel generator expression.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.