Automatic Differentiation
 
Loading...
Searching...
No Matches
diagonal.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_DIAGONAL_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_DIAGONAL_HPP
3#ifdef STAN_OPENCL
4
11#include <algorithm>
12#include <map>
13#include <string>
14#include <tuple>
15#include <type_traits>
16#include <utility>
17
18namespace stan {
19namespace math {
20
31template <typename T>
33 : public operation_cl_lhs<diagonal_<T>,
34 typename std::remove_reference_t<T>::Scalar, T> {
35 public:
36 using Scalar = typename std::remove_reference_t<T>::Scalar;
38 using base::var_name_;
39 using base::operator=;
40
45 explicit diagonal_(T&& a) : base(std::forward<T>(a)) {}
46
51 inline auto deep_copy() const {
52 auto&& arg_copy = this->template get_arg<0>().deep_copy();
53 return diagonal_<std::remove_reference_t<decltype(arg_copy)>>{
54 std::move(arg_copy)};
55 }
56
63 inline void modify_argument_indices(std::string& row_index_name,
64 std::string& col_index_name) const {
65 col_index_name = row_index_name;
66 }
67
73 inline int rows() const {
74 return std::min(this->template get_arg<0>().rows(),
75 this->template get_arg<0>().cols());
76 }
77
83 inline int cols() const { return 1; }
84
100 inline void set_view(int bottom_diagonal, int top_diagonal,
101 int bottom_zero_diagonal, int top_zero_diagonal) const {}
102
107 inline std::pair<int, int> extreme_diagonals() const {
108 return {1 - rows(), 1};
109 }
110
118 inline void check_assign_dimensions(int rows, int cols) const {
119 check_size_match("diagonal_.check_assign_dimensions", "Rows of ",
120 "diagonal", this->rows(), "rows of ", "expression", rows);
121 check_size_match("diagonal_.check_assign_dimensions", "Columns of ",
122 "diagonal", 1, "columns of ", "expression", cols);
123 }
124};
125
134template <typename T,
136inline auto diagonal(T&& a) {
137 auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
138 return diagonal_<std::remove_reference_t<decltype(a_operation)>>(
139 std::move(a_operation));
140}
142} // namespace math
143} // namespace stan
144
145#endif
146#endif
auto deep_copy() const
Creates a deep copy of this expression.
Definition diagonal.hpp:51
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Definition diagonal.hpp:107
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Definition diagonal.hpp:83
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Definition diagonal.hpp:73
typename std::remove_reference_t< T >::Scalar Scalar
Definition diagonal.hpp:36
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Sets col_index_name to value of row_index_name.
Definition diagonal.hpp:63
void check_assign_dimensions(int rows, int cols) const
Checks if desired dimensions match dimensions of the block.
Definition diagonal.hpp:118
void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const
Sets the view of the underlying matrix depending on which of its parts are written to.
Definition diagonal.hpp:100
diagonal_(T &&a)
Constructor.
Definition diagonal.hpp:45
Represents diagonal of a matrix (as column vector) in kernel generator expressions.
Definition diagonal.hpp:34
Base for all kernel generator operations that can be used on left hand side of an expression.
auto diagonal(T &&a)
Diagonal of a kernel generator expression.
Definition diagonal.hpp:136
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.