Automatic Differentiation
 
Loading...
Searching...
No Matches
holder.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_HOLDER_HPP
2#define STAN_MATH_PRIM_FUN_HOLDER_HPP
3
7#include <memory>
8#include <type_traits>
9#include <utility>
63// This was implenmented following the tutorial on edding new expressions to
64// Eigen: https://eigen.tuxfamily.org/dox/TopicNewExpressionType.html
65
66namespace stan {
67namespace math {
68
69template <class ArgType, typename... Ptrs>
70class Holder;
71
72} // namespace math
73} // namespace stan
74
75namespace Eigen {
76namespace internal {
77
78template <class ArgType, typename... Ptrs>
79struct traits<stan::math::Holder<ArgType, Ptrs...>> {
80 typedef typename ArgType::StorageKind StorageKind;
82 typedef typename ArgType::StorageIndex StorageIndex;
83 typedef typename ArgType::Scalar Scalar;
84 enum {
85 // Possible flags are documented here:
86 // https://eigen.tuxfamily.org/dox/group__flags.html
87 Flags = (ArgType::Flags
88 & (RowMajorBit | LvalueBit | LinearAccessBit | DirectAccessBit
89 | PacketAccessBit | NoPreferredStorageOrderBit))
90 | NestByRefBit,
91 RowsAtCompileTime = ArgType::RowsAtCompileTime,
92 ColsAtCompileTime = ArgType::ColsAtCompileTime,
93 MaxRowsAtCompileTime = ArgType::MaxRowsAtCompileTime,
94 MaxColsAtCompileTime = ArgType::MaxColsAtCompileTime,
95 InnerStrideAtCompileTime = ArgType::InnerStrideAtCompileTime,
96 OuterStrideAtCompileTime = ArgType::OuterStrideAtCompileTime
97 };
98};
99
100} // namespace internal
101} // namespace Eigen
102
103namespace stan {
104namespace internal {
105template <typename T>
106struct is_holder : std::false_type {};
107template <typename ArgType, typename... Ptrs>
108struct is_holder<stan::math::Holder<ArgType, Ptrs...>> : std::true_type {};
109} // namespace internal
110
111template <typename T>
112struct is_holder : internal::is_holder<std::decay_t<T>> {};
113
114template <typename T>
115inline constexpr bool is_holder_v = is_holder<T>::value;
116
117namespace math {
118
126template <typename ArgType, typename... Ptrs>
128 : public Eigen::internal::dense_xpr_base<Holder<ArgType, Ptrs...>>::type {
129 public:
130 typedef typename Eigen::internal::ref_selector<Holder<ArgType, Ptrs...>>::type
132 typename Eigen::internal::ref_selector<ArgType>::non_const_type m_arg;
133 std::tuple<std::unique_ptr<Ptrs>...> m_unique_ptrs;
134
135 explicit Holder(ArgType&& arg, Ptrs*... pointers)
136 : m_arg(arg), m_unique_ptrs(std::unique_ptr<Ptrs>(pointers)...) {}
137
138 // we need to explicitely default copy and move constructors as we are
139 // defining copy and move assignment operators
142
143 // all these functions just call the same on the argument
144 Eigen::Index rows() const { return m_arg.rows(); }
145 Eigen::Index cols() const { return m_arg.cols(); }
146 Eigen::Index innerStride() const { return m_arg.innerStride(); }
147 Eigen::Index outerStride() const { return m_arg.outerStride(); }
148 auto* data() { return m_arg.data(); }
149
155 template <typename T, require_eigen_t<T>* = nullptr>
156 inline Holder<ArgType, Ptrs...>& operator=(const T& other) {
157 m_arg = other;
158 return *this;
159 }
160
161 // copy and move assignment operators need to be separately overloaded,
162 // otherwise defaults will be used.
163 inline Holder<ArgType, Ptrs...>& operator=(
164 const Holder<ArgType, Ptrs...>& other) {
165 m_arg = other;
166 return *this;
167 }
168 inline Holder<ArgType, Ptrs...>& operator=(Holder<ArgType, Ptrs...>&& other) {
169 m_arg = std::move(other.m_arg);
170 return *this;
171 }
172};
173
174} // namespace math
175} // namespace stan
176
177namespace Eigen {
178namespace internal {
179
180template <typename ArgType, typename... Ptrs>
181struct evaluator<stan::math::Holder<ArgType, Ptrs...>>
182 : evaluator_base<stan::math::Holder<ArgType, Ptrs...>> {
183 typedef stan::math::Holder<ArgType, Ptrs...> XprType;
184 typedef typename remove_all<ArgType>::type ArgTypeNestedCleaned;
185 typedef typename XprType::CoeffReturnType CoeffReturnType;
186 typedef typename XprType::Scalar Scalar;
187 enum {
188 CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
189 // Possible flags are documented here:
190 // https://eigen.tuxfamily.org/dox/group__flags.html
191 Flags = evaluator<ArgTypeNestedCleaned>::Flags,
192 Alignment = evaluator<ArgTypeNestedCleaned>::Alignment,
193 };
194
195 evaluator<ArgTypeNestedCleaned> m_argImpl;
196
197 explicit evaluator(const XprType& xpr) : m_argImpl(xpr.m_arg) {}
198
199 // all these functions just call the same on the argument
200 EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const {
201 return m_argImpl.coeff(row, col);
202 }
203 EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
204 return m_argImpl.coeff(index);
205 }
206
207 EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) {
208 return m_argImpl.coeffRef(row, col);
209 }
210 EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
211 return m_argImpl.coeffRef(index);
212 }
213
214 template <int LoadMode, typename PacketType>
215 EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const {
216 return m_argImpl.template packet<LoadMode, PacketType>(row, col);
217 }
218 template <int LoadMode, typename PacketType>
219 EIGEN_STRONG_INLINE PacketType packet(Index index) const {
220 return m_argImpl.template packet<LoadMode, PacketType>(index);
221 }
222
223 template <int StoreMode, typename PacketType>
224 EIGEN_STRONG_INLINE void writePacket(Index row, Index col,
225 const PacketType& x) {
226 return m_argImpl.template writePacket<StoreMode, PacketType>(row, col, x);
227 }
228 template <int StoreMode, typename PacketType>
229 EIGEN_STRONG_INLINE void writePacket(Index index, const PacketType& x) {
230 return m_argImpl.template writePacket<StoreMode, PacketType>(index, x);
231 }
232};
233
234} // namespace internal
235} // namespace Eigen
236
237namespace stan {
238namespace math {
239
250template <typename T, typename... Ptrs,
251 std::enable_if_t<sizeof...(Ptrs) >= 1>* = nullptr>
252inline Holder<T, Ptrs...> holder(T&& arg, Ptrs*... pointers) {
253 return Holder<T, Ptrs...>(std::forward<T>(arg), pointers...);
254}
255// trivial case with no pointers constructs no holder object
256template <typename T>
257inline decltype(auto) holder(T&& arg) {
258 if constexpr (std::is_rvalue_reference<T&&>::value) {
259 return std::decay_t<T>(std::forward<T>(arg));
260 } else {
261 return std::forward<T>(arg);
262 }
263}
264
265namespace internal {
266// the function holder_handle_element is also used in holder_cl
276template <typename T>
277inline auto holder_handle_element(T& a, T*& res) {
278 res = &a;
279 return std::make_tuple();
280}
281template <typename T,
282 std::enable_if_t<!(Eigen::internal::traits<std::decay_t<T>>::Flags
283 & Eigen::NestByRefBit)>* = nullptr>
284inline auto holder_handle_element(T&& a, std::remove_reference_t<T>*& res) {
285 res = &a;
286 return std::make_tuple();
287}
288
299template <typename T, require_t<std::is_rvalue_reference<T&&>>* = nullptr,
300 std::enable_if_t<
301 static_cast<bool>(Eigen::internal::traits<std::decay_t<T>>::Flags&
302 Eigen::NestByRefBit)>* = nullptr>
303inline auto holder_handle_element(T&& a, T*& res) {
304 res = new T(std::move(a));
305 return std::make_tuple(res);
306}
307template <typename T, require_t<std::is_rvalue_reference<T&&>>* = nullptr,
308 require_not_eigen_t<T>* = nullptr>
309inline auto holder_handle_element(T&& a, T*& res) {
310 res = new T(std::move(a));
311 return std::make_tuple(res);
312}
313
325template <typename T, std::size_t... Is, typename... Args>
327 T&& expr, std::index_sequence<Is...>, const std::tuple<Args*...>& ptrs) {
328 return holder(std::forward<T>(expr), std::get<Is>(ptrs)...);
329}
330
340template <typename F, std::size_t... Is, typename... Args>
341inline auto make_holder_impl(F&& func, std::index_sequence<Is...>,
342 Args&&... args) {
343 std::tuple<std::remove_reference_t<Args>*...> res;
344 auto ptrs = std::tuple_cat(
345 holder_handle_element(std::forward<Args>(args), std::get<Is>(res))...);
347 std::forward<F>(func)(*std::get<Is>(res)...),
348 std::make_index_sequence<std::tuple_size<decltype(ptrs)>::value>(), ptrs);
349}
350
351} // namespace internal
352
365template <
366 typename F, typename... Args,
367 require_not_plain_type_t<std::invoke_result_t<F, Args&&...>>* = nullptr>
368inline auto make_holder(F&& func, Args&&... args) {
369 return internal::make_holder_impl(std::forward<F>(func),
370 std::make_index_sequence<sizeof...(Args)>(),
371 std::forward<Args>(args)...);
372}
373
384template <typename F, typename... Args,
385 require_plain_type_t<std::invoke_result_t<F, Args&&...>>* = nullptr>
386inline auto make_holder(F&& func, Args&&... args) {
387 return std::forward<F>(func)(std::forward<Args>(args)...);
388}
389
390} // namespace math
391} // namespace stan
392
393#endif
Eigen::internal::ref_selector< Holder< ArgType, Ptrs... > >::type Nested
Definition holder.hpp:131
Holder< ArgType, Ptrs... > & operator=(const Holder< ArgType, Ptrs... > &other)
Definition holder.hpp:163
Eigen::Index outerStride() const
Definition holder.hpp:147
Eigen::Index rows() const
Definition holder.hpp:144
Eigen::Index innerStride() const
Definition holder.hpp:146
Holder< ArgType, Ptrs... > & operator=(Holder< ArgType, Ptrs... > &&other)
Definition holder.hpp:168
Holder(const Holder< ArgType, Ptrs... > &)=default
Holder< ArgType, Ptrs... > & operator=(const T &other)
Assignment operator assigns expresssions.
Definition holder.hpp:156
Eigen::Index cols() const
Definition holder.hpp:145
std::tuple< std::unique_ptr< Ptrs >... > m_unique_ptrs
Definition holder.hpp:133
Eigen::internal::ref_selector< ArgType >::non_const_type m_arg
Definition holder.hpp:132
Holder(Holder< ArgType, Ptrs... > &&)=default
Holder(ArgType &&arg, Ptrs *... pointers)
Definition holder.hpp:135
A no-op Eigen operation.
Definition holder.hpp:128
require_not_t< is_plain_type< std::decay_t< T > > > require_not_plain_type_t
Require type does not satisfy is_plain_type.
require_t< is_plain_type< std::decay_t< T > > > require_plain_type_t
Require type satisfies is_plain_type.
(Expert) Numerical traits for algorithmic differentiation variables.
auto make_holder_impl_construct_object(T &&expr, std::index_sequence< Is... >, const std::tuple< Args *... > &ptrs)
Second step in implementation of construction holder from a functor.
Definition holder.hpp:326
auto holder_handle_element(T &a, T *&res)
Handles single element (moving rvalue non-expressions to heap) for construction of holder or holder_c...
Definition holder.hpp:277
auto make_holder_impl(F &&func, std::index_sequence< Is... >, Args &&... args)
Implementation of construction holder from a functor.
Definition holder.hpp:341
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
Definition arg.hpp:19
auto make_holder(F &&func, Args &&... args)
Constructs an expression from given arguments using given functor.
Definition holder.hpp:368
Ptrs holder(T &&arg, Ptrs *... pointers)
Definition holder.hpp:252
constexpr bool is_holder_v
Definition holder.hpp:115
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
void writePacket(Index row, Index col, const PacketType &x)
Definition holder.hpp:224
CoeffReturnType coeff(Index row, Index col) const
Definition holder.hpp:200