Automatic Differentiation
 
Loading...
Searching...
No Matches
holder.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_HOLDER_HPP
2#define STAN_MATH_PRIM_FUN_HOLDER_HPP
3
7#include <memory>
8#include <type_traits>
9#include <utility>
10
64// This was implenmented following the tutorial on edding new expressions to
65// Eigen: https://eigen.tuxfamily.org/dox/TopicNewExpressionType.html
66
67namespace stan {
68namespace math {
69
70template <class ArgType, typename... Ptrs>
71class Holder;
72
73} // namespace math
74} // namespace stan
75
76namespace Eigen {
77namespace internal {
78
79template <class ArgType, typename... Ptrs>
80struct traits<stan::math::Holder<ArgType, Ptrs...>> {
81 typedef typename ArgType::StorageKind StorageKind;
82 typedef typename traits<ArgType>::XprKind XprKind;
83 typedef typename ArgType::StorageIndex StorageIndex;
84 typedef typename ArgType::Scalar Scalar;
85 enum {
86 // Possible flags are documented here:
87 // https://eigen.tuxfamily.org/dox/group__flags.html
88 Flags = (ArgType::Flags
89 & (RowMajorBit | LvalueBit | LinearAccessBit | DirectAccessBit
90 | PacketAccessBit | NoPreferredStorageOrderBit))
91 | NestByRefBit,
92 RowsAtCompileTime = ArgType::RowsAtCompileTime,
93 ColsAtCompileTime = ArgType::ColsAtCompileTime,
94 MaxRowsAtCompileTime = ArgType::MaxRowsAtCompileTime,
95 MaxColsAtCompileTime = ArgType::MaxColsAtCompileTime,
96 InnerStrideAtCompileTime = ArgType::InnerStrideAtCompileTime,
97 OuterStrideAtCompileTime = ArgType::OuterStrideAtCompileTime
98 };
99};
100
101} // namespace internal
102} // namespace Eigen
103
104namespace stan {
105namespace math {
106
114template <typename ArgType, typename... Ptrs>
116 : public Eigen::internal::dense_xpr_base<Holder<ArgType, Ptrs...>>::type {
117 public:
118 typedef typename Eigen::internal::ref_selector<Holder<ArgType, Ptrs...>>::type
120 typename Eigen::internal::ref_selector<ArgType>::non_const_type m_arg;
121 std::tuple<std::unique_ptr<Ptrs>...> m_unique_ptrs;
122
123 explicit Holder(ArgType&& arg, Ptrs*... pointers)
124 : m_arg(arg), m_unique_ptrs(std::unique_ptr<Ptrs>(pointers)...) {}
125
126 // we need to explicitely default copy and move constructors as we are
127 // defining copy and move assignment operators
130
131 // all these functions just call the same on the argument
132 Eigen::Index rows() const { return m_arg.rows(); }
133 Eigen::Index cols() const { return m_arg.cols(); }
134 Eigen::Index innerStride() const { return m_arg.innerStride(); }
135 Eigen::Index outerStride() const { return m_arg.outerStride(); }
136 auto* data() { return m_arg.data(); }
137
143 template <typename T, require_eigen_t<T>* = nullptr>
144 inline Holder<ArgType, Ptrs...>& operator=(const T& other) {
145 m_arg = other;
146 return *this;
147 }
148
149 // copy and move assignment operators need to be separately overloaded,
150 // otherwise defaults will be used.
151 inline Holder<ArgType, Ptrs...>& operator=(
152 const Holder<ArgType, Ptrs...>& other) {
153 m_arg = other;
154 return *this;
155 }
156 inline Holder<ArgType, Ptrs...>& operator=(Holder<ArgType, Ptrs...>&& other) {
157 m_arg = std::move(other.m_arg);
158 return *this;
159 }
160};
161
162} // namespace math
163} // namespace stan
164
165namespace Eigen {
166namespace internal {
167
168template <typename ArgType, typename... Ptrs>
169struct evaluator<stan::math::Holder<ArgType, Ptrs...>>
170 : evaluator_base<stan::math::Holder<ArgType, Ptrs...>> {
171 typedef stan::math::Holder<ArgType, Ptrs...> XprType;
172 typedef typename remove_all<ArgType>::type ArgTypeNestedCleaned;
173 typedef typename XprType::CoeffReturnType CoeffReturnType;
174 typedef typename XprType::Scalar Scalar;
175 enum {
176 CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
177 // Possible flags are documented here:
178 // https://eigen.tuxfamily.org/dox/group__flags.html
179 Flags = evaluator<ArgTypeNestedCleaned>::Flags,
180 Alignment = evaluator<ArgTypeNestedCleaned>::Alignment,
181 };
182
183 evaluator<ArgTypeNestedCleaned> m_argImpl;
184
185 explicit evaluator(const XprType& xpr) : m_argImpl(xpr.m_arg) {}
186
187 // all these functions just call the same on the argument
188 EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const {
189 return m_argImpl.coeff(row, col);
190 }
191 EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
192 return m_argImpl.coeff(index);
193 }
194
195 EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) {
196 return m_argImpl.coeffRef(row, col);
197 }
198 EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
199 return m_argImpl.coeffRef(index);
200 }
201
202 template <int LoadMode, typename PacketType>
203 EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const {
204 return m_argImpl.template packet<LoadMode, PacketType>(row, col);
205 }
206 template <int LoadMode, typename PacketType>
207 EIGEN_STRONG_INLINE PacketType packet(Index index) const {
208 return m_argImpl.template packet<LoadMode, PacketType>(index);
209 }
210
211 template <int StoreMode, typename PacketType>
212 EIGEN_STRONG_INLINE void writePacket(Index row, Index col,
213 const PacketType& x) {
214 return m_argImpl.template writePacket<StoreMode, PacketType>(row, col, x);
215 }
216 template <int StoreMode, typename PacketType>
217 EIGEN_STRONG_INLINE void writePacket(Index index, const PacketType& x) {
218 return m_argImpl.template writePacket<StoreMode, PacketType>(index, x);
219 }
220};
221
222} // namespace internal
223} // namespace Eigen
224
225namespace stan {
226namespace math {
227
238template <typename T, typename... Ptrs,
239 std::enable_if_t<sizeof...(Ptrs) >= 1>* = nullptr>
240Holder<T, Ptrs...> holder(T&& arg, Ptrs*... pointers) {
241 return Holder<T, Ptrs...>(std::forward<T>(arg), pointers...);
242}
243// trivial case with no pointers constructs no holder object
244template <typename T>
245T holder(T&& arg) {
246 return std::forward<T>(arg);
247}
248
249namespace internal {
250// the function holder_handle_element is also used in holder_cl
260template <typename T>
261auto holder_handle_element(T& a, T*& res) {
262 res = &a;
263 return std::make_tuple();
264}
265template <typename T,
266 std::enable_if_t<!(Eigen::internal::traits<std::decay_t<T>>::Flags
267 & Eigen::NestByRefBit)>* = nullptr>
268auto holder_handle_element(T&& a, std::remove_reference_t<T>*& res) {
269 res = &a;
270 return std::make_tuple();
271}
272
283template <typename T, require_t<std::is_rvalue_reference<T&&>>* = nullptr,
284 std::enable_if_t<
285 static_cast<bool>(Eigen::internal::traits<std::decay_t<T>>::Flags&
286 Eigen::NestByRefBit)>* = nullptr>
287auto holder_handle_element(T&& a, T*& res) {
288 res = new T(std::move(a));
289 return std::make_tuple(res);
290}
291template <typename T, require_t<std::is_rvalue_reference<T&&>>* = nullptr,
292 require_not_eigen_t<T>* = nullptr>
293auto holder_handle_element(T&& a, T*& res) {
294 res = new T(std::move(a));
295 return std::make_tuple(res);
296}
297
309template <typename T, std::size_t... Is, typename... Args>
310auto make_holder_impl_construct_object(T&& expr, std::index_sequence<Is...>,
311 const std::tuple<Args*...>& ptrs) {
312 return holder(std::forward<T>(expr), std::get<Is>(ptrs)...);
313}
314
324template <typename F, std::size_t... Is, typename... Args>
325auto make_holder_impl(const F& func, std::index_sequence<Is...>,
326 Args&&... args) {
327 std::tuple<std::remove_reference_t<Args>*...> res;
328 auto ptrs = std::tuple_cat(
329 holder_handle_element(std::forward<Args>(args), std::get<Is>(res))...);
331 func(*std::get<Is>(res)...),
332 std::make_index_sequence<std::tuple_size<decltype(ptrs)>::value>(), ptrs);
333}
334
335} // namespace internal
336
349template <typename F, typename... Args,
351 decltype(std::declval<F>()(std::declval<Args&>()...))>* = nullptr>
352auto make_holder(const F& func, Args&&... args) {
353 return internal::make_holder_impl(func,
354 std::make_index_sequence<sizeof...(Args)>(),
355 std::forward<Args>(args)...);
356}
357
368template <typename F, typename... Args,
370 decltype(std::declval<F>()(std::declval<Args&>()...))>* = nullptr>
371auto make_holder(const F& func, Args&&... args) {
372 return func(std::forward<Args>(args)...);
373}
374
375} // namespace math
376} // namespace stan
377
378#endif
Eigen::internal::ref_selector< Holder< ArgType, Ptrs... > >::type Nested
Definition holder.hpp:119
Holder< ArgType, Ptrs... > & operator=(const Holder< ArgType, Ptrs... > &other)
Definition holder.hpp:151
Eigen::Index outerStride() const
Definition holder.hpp:135
Eigen::Index rows() const
Definition holder.hpp:132
Eigen::Index innerStride() const
Definition holder.hpp:134
Holder< ArgType, Ptrs... > & operator=(Holder< ArgType, Ptrs... > &&other)
Definition holder.hpp:156
Holder(const Holder< ArgType, Ptrs... > &)=default
Holder< ArgType, Ptrs... > & operator=(const T &other)
Assignment operator assigns expresssions.
Definition holder.hpp:144
Eigen::Index cols() const
Definition holder.hpp:133
std::tuple< std::unique_ptr< Ptrs >... > m_unique_ptrs
Definition holder.hpp:121
Eigen::internal::ref_selector< ArgType >::non_const_type m_arg
Definition holder.hpp:120
Holder(Holder< ArgType, Ptrs... > &&)=default
Holder(ArgType &&arg, Ptrs *... pointers)
Definition holder.hpp:123
A no-op Eigen operation.
Definition holder.hpp:116
require_not_t< is_plain_type< std::decay_t< T > > > require_not_plain_type_t
Require type does not satisfy is_plain_type.
require_t< is_plain_type< std::decay_t< T > > > require_plain_type_t
Require type satisfies is_plain_type.
(Expert) Numerical traits for algorithmic differentiation variables.
auto make_holder_impl_construct_object(T &&expr, std::index_sequence< Is... >, const std::tuple< Args *... > &ptrs)
Second step in implementation of construction holder from a functor.
Definition holder.hpp:310
auto holder_handle_element(T &a, T *&res)
Handles single element (moving rvalue non-expressions to heap) for construction of holder or holder_c...
Definition holder.hpp:261
auto make_holder_impl(const F &func, std::index_sequence< Is... >, Args &&... args)
Implementation of construction holder from a functor.
Definition holder.hpp:325
auto make_holder(const F &func, Args &&... args)
Constructs an expression from given arguments using given functor.
Definition holder.hpp:352
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
Definition arg.hpp:19
Ptrs holder(T &&arg, Ptrs *... pointers)
Definition holder.hpp:240
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
void writePacket(Index row, Index col, const PacketType &x)
Definition holder.hpp:212
CoeffReturnType coeff(Index row, Index col) const
Definition holder.hpp:188