Automatic Differentiation
 
Loading...
Searching...
No Matches
holder.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_HOLDER_HPP
2#define STAN_MATH_PRIM_FUN_HOLDER_HPP
3
7#include <memory>
8#include <type_traits>
9#include <utility>
63// This was implemented following the tutorial on edding new expressions to
64// Eigen: https://eigen.tuxfamily.org/dox/TopicNewExpressionType.html
65
66namespace stan {
67namespace math {
68
69template <class ArgType, typename... Ptrs>
70class Holder;
71
72} // namespace math
73} // namespace stan
74
75namespace Eigen {
76namespace internal {
77
78template <class ArgType, typename... Ptrs>
79struct traits<stan::math::Holder<ArgType, Ptrs...>> {
80 typedef typename ArgType::StorageKind StorageKind;
82 typedef typename ArgType::StorageIndex StorageIndex;
83 typedef typename ArgType::Scalar Scalar;
84 enum {
85 // Possible flags are documented here:
86 // https://eigen.tuxfamily.org/dox/group__flags.html
87 Flags = (ArgType::Flags
88 & (RowMajorBit | LvalueBit | LinearAccessBit | DirectAccessBit
89 | PacketAccessBit | NoPreferredStorageOrderBit))
90 | NestByRefBit,
91 RowsAtCompileTime = ArgType::RowsAtCompileTime,
92 ColsAtCompileTime = ArgType::ColsAtCompileTime,
93 MaxRowsAtCompileTime = ArgType::MaxRowsAtCompileTime,
94 MaxColsAtCompileTime = ArgType::MaxColsAtCompileTime,
95 InnerStrideAtCompileTime = ArgType::InnerStrideAtCompileTime,
96 OuterStrideAtCompileTime = ArgType::OuterStrideAtCompileTime
97 };
98};
99
100} // namespace internal
101} // namespace Eigen
102
103namespace stan {
104namespace internal {
105template <typename T>
106struct is_holder : std::false_type {};
107template <typename ArgType, typename... Ptrs>
108struct is_holder<stan::math::Holder<ArgType, Ptrs...>> : std::true_type {};
109} // namespace internal
110
111template <typename T>
112struct is_holder : internal::is_holder<std::decay_t<T>> {};
113
114template <typename T>
115inline constexpr bool is_holder_v = is_holder<T>::value;
116
117template <typename T>
119
120namespace math {
121
122template <
123 typename F, typename... Args,
124 require_not_plain_type_t<std::invoke_result_t<F, Args&&...>>* = nullptr>
125inline auto make_holder(F&& func, Args&&... args);
126
137template <typename F, typename... Args,
138 require_plain_type_t<std::invoke_result_t<F, Args&&...>>* = nullptr>
139inline auto make_holder(F&& func, Args&&... args);
147template <typename ArgType, typename... Ptrs>
149 : public Eigen::internal::dense_xpr_base<Holder<ArgType, Ptrs...>>::type {
150 public:
151 typedef typename Eigen::internal::ref_selector<Holder<ArgType, Ptrs...>>::type
153 typename Eigen::internal::ref_selector<ArgType>::non_const_type m_arg;
154 std::tuple<std::unique_ptr<Ptrs>...> m_unique_ptrs;
155 explicit Holder(ArgType&& arg, Ptrs*... pointers)
156 : m_arg(std::forward<ArgType>(arg)),
157 m_unique_ptrs(std::unique_ptr<Ptrs>(pointers)...) {}
158
159 // we need to explicitly default copy and move constructors as we are
160 // defining copy and move assignment operators
163
164 // all these functions just call the same on the argument
165 Eigen::Index rows() const { return m_arg.rows(); }
166 Eigen::Index cols() const { return m_arg.cols(); }
167 Eigen::Index innerStride() const { return m_arg.innerStride(); }
168 Eigen::Index outerStride() const { return m_arg.outerStride(); }
169 auto* data() { return m_arg.data(); }
170 const auto* data() const { return m_arg.data(); }
171
177 template <typename T, require_eigen_t<T>* = nullptr>
178 inline Holder<ArgType, Ptrs...>& operator=(const T& other) {
179 m_arg = other;
180 return *this;
181 }
182
183 // copy and move assignment operators need to be separately overloaded,
184 // otherwise defaults will be used.
185 inline Holder<ArgType, Ptrs...>& operator=(
186 const Holder<ArgType, Ptrs...>& other) {
187 m_arg = other;
188 return *this;
189 }
190 inline Holder<ArgType, Ptrs...>& operator=(Holder<ArgType, Ptrs...>&& other) {
191 m_arg = std::move(other.m_arg);
192 return *this;
193 }
194};
195
196template <typename T, require_holder_t<T>* = nullptr>
197inline auto operator-(T&& h) {
198 return make_holder([](auto&& arg) { return -arg; }, std::forward<T>(h).m_arg);
199}
200template <typename T, require_holder_t<T>* = nullptr>
201inline auto operator+(T&& h) {
202 return make_holder([](auto&& arg) { return arg; }, std::forward<T>(h).m_arg);
203}
204
205template <typename T, typename Other, require_holder_t<T>* = nullptr,
206 require_holder_t<Other>* = nullptr>
207inline auto operator-(T&& h, Other&& other) {
208 return make_holder(
209 [](auto&& arg, auto&& other_) {
210 return arg - std::forward<decltype(other_)>(other_);
211 },
212 std::forward<T>(h).m_arg, std::forward<Other>(other));
213}
214template <typename T, typename Other, require_holder_t<T>* = nullptr,
215 require_holder_t<Other>* = nullptr>
216inline auto operator+(T&& h, Other&& other) {
217 return make_holder(
218 [](auto&& arg, auto&& other_) {
219 return arg + std::forward<decltype(other_)>(other_);
220 },
221 std::forward<T>(h).m_arg, std::forward<Other>(other));
222}
223template <typename T, typename Other, require_holder_t<T>* = nullptr,
224 require_holder_t<Other>* = nullptr>
225inline auto operator*(T&& h, Other&& other) {
226 return make_holder(
227 [](auto&& arg, auto&& other_) {
228 return arg * std::forward<decltype(other_)>(other_);
229 },
230 std::forward<T>(h).m_arg, std::forward<Other>(other));
231}
232template <typename T, typename Other, require_holder_t<T>* = nullptr,
233 require_holder_t<Other>* = nullptr>
234inline auto operator/(T&& h, Other&& other) {
235 return make_holder(
236 [](auto&& arg, auto&& other_) {
237 return arg / std::forward<decltype(other_)>(other_);
238 },
239 std::forward<T>(h).m_arg, std::forward<Other>(other));
240}
241
242} // namespace math
243} // namespace stan
244
245namespace Eigen {
246namespace internal {
247
248template <typename ArgType, typename... Ptrs>
249struct evaluator<stan::math::Holder<ArgType, Ptrs...>>
250 : evaluator_base<stan::math::Holder<ArgType, Ptrs...>> {
251 using PlainObjectType = stan::math::Holder<ArgType, Ptrs...>;
252 using XprType = stan::math::Holder<ArgType, Ptrs...>;
253 using ArgTypeNestedCleaned = typename remove_all<ArgType>::type;
254 using CoeffReturnType = typename XprType::CoeffReturnType;
255 using Scalar = typename XprType::Scalar;
256 enum {
257 IsRowMajor = XprType::IsRowMajor,
258 IsColMajor = !IsRowMajor,
259 IsVectorAtCompileTime = XprType::IsVectorAtCompileTime,
260 RowsAtCompileTime = XprType::RowsAtCompileTime,
261 ColsAtCompileTime = XprType::ColsAtCompileTime,
262
263 CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
264 Flags = evaluator<ArgTypeNestedCleaned>::Flags,
265 Alignment = evaluator<ArgTypeNestedCleaned>::Alignment,
266 };
267 enum {
268 // We do not need to know the outer stride for vectors
269 OuterStrideAtCompileTime
270 = IsVectorAtCompileTime
271 ? 0
272 : (IsRowMajor ? ColsAtCompileTime : RowsAtCompileTime)
273 };
274
275 evaluator<ArgTypeNestedCleaned> m_argImpl;
276
277 explicit evaluator(const XprType& xpr) : m_argImpl(xpr.m_arg) {}
278 explicit evaluator(XprType&& xpr)
279 : m_argImpl(std::forward<XprType>(xpr).m_arg) {}
280
281 // all these functions just call the same on the argument
282 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row,
283 Index col) const {
284 return m_argImpl.coeff(row, col);
285 }
286 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType
287 coeff(Index index) const {
288 return m_argImpl.coeff(index);
289 }
290
291 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) {
292 return m_argImpl.coeffRef(row, col);
293 }
294 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
295 return m_argImpl.coeffRef(index);
296 }
297
298 template <int LoadMode, typename PacketType>
299 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index row,
300 Index col) const {
301 return m_argImpl.template packet<LoadMode, PacketType>(row, col);
302 }
303 template <int LoadMode, typename PacketType>
304 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index index) const {
305 return m_argImpl.template packet<LoadMode, PacketType>(index);
306 }
307
308 template <int StoreMode, typename PacketType>
309 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writePacket(Index row, Index col,
310 const PacketType& x) {
311 return m_argImpl.template writePacket<StoreMode, PacketType>(row, col, x);
312 }
313 template <int StoreMode, typename PacketType>
314 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writePacket(Index index,
315 const PacketType& x) {
316 return m_argImpl.template writePacket<StoreMode, PacketType>(index, x);
317 }
318
319 template <int LoadMode, typename PacketType>
320 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType
321 packetSegment(Index row, Index col, Index begin, Index count) const {
322 return m_argImpl.template packetSegment<LoadMode, PacketType>(row, col,
323 begin, count);
324 }
325
326 template <int LoadMode, typename PacketType>
327 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType
328 packetSegment(Index index, Index begin, Index count) const {
329 return m_argImpl.template packetSegment<LoadMode, PacketType>(index, begin,
330 count);
331 }
332
333 template <int StoreMode, typename PacketType>
334 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writePacketSegment(
335 Index row, Index col, const PacketType& x, Index begin, Index count) {
336 return m_argImpl.template writePacketSegment<StoreMode, PacketType>(
337 row, col, x, begin, count);
338 }
339
340 template <int StoreMode, typename PacketType>
341 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writePacketSegment(
342 Index index, const PacketType& x, Index begin, Index count) {
343 return m_argImpl.template writePacketSegment<StoreMode, PacketType>(
344 index, x, begin, count);
345 }
346};
347
348} // namespace internal
349} // namespace Eigen
350
351namespace stan {
352namespace math {
353
364template <typename T, typename... Ptrs,
365 std::enable_if_t<sizeof...(Ptrs) >= 1>* = nullptr>
366inline Holder<T, Ptrs...> holder(T&& arg, Ptrs*... pointers) {
367 return Holder<T, Ptrs...>(std::forward<T>(arg), pointers...);
368}
369// trivial case with no pointers constructs no holder object
370template <typename T>
371inline decltype(auto) holder(T&& arg) {
372 if constexpr (std::is_rvalue_reference<T&&>::value) {
373 return std::decay_t<T>(std::forward<T>(arg));
374 } else {
375 return std::forward<T>(arg);
376 }
377}
378
379namespace internal {
380// the function holder_handle_element is also used in holder_cl
390template <typename T>
391inline auto holder_handle_element(T& a, T*& res) {
392 res = &a;
393 return std::make_tuple();
394}
395template <typename T,
396 std::enable_if_t<!(Eigen::internal::traits<std::decay_t<T>>::Flags
397 & Eigen::NestByRefBit)>* = nullptr>
398inline auto holder_handle_element(T&& a, std::remove_reference_t<T>*& res) {
399 res = &a;
400 return std::make_tuple();
401}
402
413template <typename T, require_t<std::is_rvalue_reference<T&&>>* = nullptr,
414 std::enable_if_t<
415 static_cast<bool>(Eigen::internal::traits<std::decay_t<T>>::Flags&
416 Eigen::NestByRefBit)>* = nullptr>
417inline auto holder_handle_element(T&& a, T*& res) {
418 res = new T(std::move(a));
419 return std::make_tuple(res);
420}
421template <typename T, require_t<std::is_rvalue_reference<T&&>>* = nullptr,
422 require_not_eigen_t<T>* = nullptr>
423inline auto holder_handle_element(T&& a, T*& res) {
424 res = new T(std::move(a));
425 return std::make_tuple(res);
426}
427
439template <typename T, std::size_t... Is, typename... Args>
441 T&& expr, std::index_sequence<Is...>, const std::tuple<Args*...>& ptrs) {
442 return holder(std::forward<T>(expr), std::get<Is>(ptrs)...);
443}
444
454template <typename F, std::size_t... Is, typename... Args>
455inline auto make_holder_impl(F&& func, std::index_sequence<Is...>,
456 Args&&... args) {
457 std::tuple<std::remove_reference_t<Args>*...> res;
458 auto ptrs = std::tuple_cat(
459 holder_handle_element(std::forward<Args>(args), std::get<Is>(res))...);
461 std::forward<F>(func)(*std::get<Is>(res)...),
462 std::make_index_sequence<std::tuple_size<decltype(ptrs)>::value>(), ptrs);
463}
464
465} // namespace internal
466
479template <typename F, typename... Args,
480 require_not_plain_type_t<std::invoke_result_t<F, Args&&...>>*>
481inline auto make_holder(F&& func, Args&&... args) {
482 return internal::make_holder_impl(std::forward<F>(func),
483 std::make_index_sequence<sizeof...(Args)>(),
484 std::forward<Args>(args)...);
485}
486
497template <typename F, typename... Args,
498 require_plain_type_t<std::invoke_result_t<F, Args&&...>>*>
499inline auto make_holder(F&& func, Args&&... args) {
500 return std::forward<F>(func)(std::forward<Args>(args)...);
501}
502
503} // namespace math
504} // namespace stan
505
506#endif
Eigen::internal::ref_selector< Holder< ArgType, Ptrs... > >::type Nested
Definition holder.hpp:152
Holder< ArgType, Ptrs... > & operator=(const Holder< ArgType, Ptrs... > &other)
Definition holder.hpp:185
Eigen::Index outerStride() const
Definition holder.hpp:168
Eigen::Index rows() const
Definition holder.hpp:165
Eigen::Index innerStride() const
Definition holder.hpp:167
Holder< ArgType, Ptrs... > & operator=(Holder< ArgType, Ptrs... > &&other)
Definition holder.hpp:190
Holder(const Holder< ArgType, Ptrs... > &)=default
Holder< ArgType, Ptrs... > & operator=(const T &other)
Assignment operator assigns expressions.
Definition holder.hpp:178
Eigen::Index cols() const
Definition holder.hpp:166
std::tuple< std::unique_ptr< Ptrs >... > m_unique_ptrs
Definition holder.hpp:154
const auto * data() const
Definition holder.hpp:170
Eigen::internal::ref_selector< ArgType >::non_const_type m_arg
Definition holder.hpp:153
Holder(Holder< ArgType, Ptrs... > &&)=default
Holder(ArgType &&arg, Ptrs *... pointers)
Definition holder.hpp:155
A no-op Eigen operation.
Definition holder.hpp:149
require_not_t< is_plain_type< std::decay_t< T > > > require_not_plain_type_t
Require type does not satisfy is_plain_type.
require_t< is_plain_type< std::decay_t< T > > > require_plain_type_t
Require type satisfies is_plain_type.
(Expert) Numerical traits for algorithmic differentiation variables.
auto make_holder_impl_construct_object(T &&expr, std::index_sequence< Is... >, const std::tuple< Args *... > &ptrs)
Second step in implementation of construction holder from a functor.
Definition holder.hpp:440
auto holder_handle_element(T &a, T *&res)
Handles single element (moving rvalue non-expressions to heap) for construction of holder or holder_c...
Definition holder.hpp:391
auto make_holder_impl(F &&func, std::index_sequence< Is... >, Args &&... args)
Implementation of construction holder from a functor.
Definition holder.hpp:455
fvar< T > operator/(const fvar< T > &x1, const fvar< T > &x2)
Return the result of dividing the first argument by the second.
fvar< T > operator-(const fvar< T > &x1, const fvar< T > &x2)
Return the difference of the specified arguments.
fvar< T > operator*(const fvar< T > &x, const fvar< T > &y)
Return the product of the two arguments.
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
Definition arg.hpp:19
auto make_holder(F &&func, Args &&... args)
Calls given function with given arguments.
Definition holder.hpp:481
fvar< T > operator+(const fvar< T > &x1, const fvar< T > &x2)
Return the sum of the specified forward mode addends.
Ptrs holder(T &&arg, Ptrs *... pointers)
Definition holder.hpp:366
require_t< is_holder< T > > require_holder_t
Definition holder.hpp:118
constexpr bool is_holder_v
Definition holder.hpp:115
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
void writePacket(Index row, Index col, const PacketType &x)
Definition holder.hpp:309
CoeffReturnType coeff(Index row, Index col) const
Definition holder.hpp:282
void writePacketSegment(Index row, Index col, const PacketType &x, Index begin, Index count)
Definition holder.hpp:334
void writePacketSegment(Index index, const PacketType &x, Index begin, Index count)
Definition holder.hpp:341
PacketType packetSegment(Index row, Index col, Index begin, Index count) const
Definition holder.hpp:321
PacketType packetSegment(Index index, Index begin, Index count) const
Definition holder.hpp:328