Automatic Differentiation
 
Loading...
Searching...
No Matches
broadcast.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_BROADCAST_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_BROADCAST_HPP
3#ifdef STAN_OPENCL
4
11#include <limits>
12#include <string>
13#include <type_traits>
14#include <set>
15#include <utility>
16
17namespace stan {
18namespace math {
19
30template <typename T, bool Colwise, bool Rowwise>
32 : public operation_cl<broadcast_<T, Colwise, Rowwise>,
33 typename std::remove_reference_t<T>::Scalar, T> {
34 public:
35 using Scalar = typename std::remove_reference_t<T>::Scalar;
37 using base::var_name_;
38
43 explicit broadcast_(T&& a) : base(std::forward<T>(a)) {
44 const char* function = "broadcast";
45 if (Colwise) {
46 check_size_match(function, "Rows of ", "a", a.rows(), "", "", 1);
47 }
48 if (Rowwise) {
49 check_size_match(function, "Columns of ", "a", a.cols(), "", "", 1);
50 }
51 }
52
57 inline auto deep_copy() const {
58 auto&& arg_copy = this->template get_arg<0>().deep_copy();
59 return broadcast_<std::remove_reference_t<decltype(arg_copy)>, Colwise,
60 Rowwise>{std::move(arg_copy)};
61 }
62
68 inline void modify_argument_indices(std::string& row_index_name,
69 std::string& col_index_name) const {
70 if (Colwise) {
71 row_index_name = "0";
72 }
73 if (Rowwise) {
74 col_index_name = "0";
75 }
76 }
77
83 inline int rows() const {
84 return Colwise ? base::dynamic : this->template get_arg<0>().rows();
85 }
86
92 inline int cols() const {
93 return Rowwise ? base::dynamic : this->template get_arg<0>().cols();
94 }
95
100 inline std::pair<int, int> extreme_diagonals() const {
101 std::pair<int, int> arg_diags
102 = this->template get_arg<0>().extreme_diagonals();
103 return {Colwise ? std::numeric_limits<int>::min() : arg_diags.first,
104 Rowwise ? std::numeric_limits<int>::max() : arg_diags.second};
105 }
106};
107
123template <bool Colwise, bool Rowwise, typename T,
125inline auto broadcast(T&& a) {
126 auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy();
127 return broadcast_<std::remove_reference_t<decltype(a_operation)>, Colwise,
128 Rowwise>(std::move(a_operation));
129}
130
143template <typename T,
145inline auto rowwise_broadcast(T&& a) {
146 return broadcast<false, true>(std::forward<T>(a));
147}
148
161template <typename T,
163inline auto colwise_broadcast(T&& a) {
164 return broadcast<true, false>(std::forward<T>(a));
165}
167} // namespace math
168} // namespace stan
169#endif
170#endif
auto deep_copy() const
Creates a deep copy of this expression.
Definition broadcast.hpp:57
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Definition broadcast.hpp:92
broadcast_(T &&a)
Constructor.
Definition broadcast.hpp:43
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Definition broadcast.hpp:83
typename std::remove_reference_t< T >::Scalar Scalar
Definition broadcast.hpp:35
void modify_argument_indices(std::string &row_index_name, std::string &col_index_name) const
Sets index/indices along broadcasted dimmension(s) to 0.
Definition broadcast.hpp:68
Represents a broadcasting operation in kernel generator expressions.
Definition broadcast.hpp:33
static constexpr int dynamic
Base for all kernel generator operations.
auto rowwise_broadcast(T &&a)
Broadcast an expression in rowwise dimmension.
auto colwise_broadcast(T &&a)
Broadcast an expression in colwise dimmension.
auto broadcast(T &&a)
Broadcast an expression in specified dimension(s).
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.