1#ifndef STAN_MATH_REV_CORE_VARI_HPP
2#define STAN_MATH_REV_CORE_VARI_HPP
16template <
typename T,
typename =
void>
47 static inline void*
operator new(
size_t nbytes)
noexcept {
62 static inline void operator delete(
105 template <
typename S, require_convertible_t<S&, T>* =
nullptr>
125 template <
typename S, require_convertible_t<S&, T>* =
nullptr>
139 inline const auto&
val()
const {
return val_; }
149 inline auto&
adj()
const {
return adj_; }
159 inline auto&
adj() {
return adj_; }
188 return os << v->val_ <<
":" << v->adj_;
192 template <
typename,
typename>
205template <
typename T,
typename =
void>
214template <
typename Derived>
233 return static_cast<const Derived&
>(*this);
244 inline auto block(Eigen::Index start_row, Eigen::Index start_col,
245 Eigen::Index num_rows, Eigen::Index num_cols)
const {
246 using inner_type =
decltype(
247 derived().val_.block(start_row, start_col, num_rows, num_cols));
249 derived().val_.block(start_row, start_col, num_rows, num_cols),
250 derived().adj_.block(start_row, start_col, num_rows, num_cols));
252 inline auto block(Eigen::Index start_row, Eigen::Index start_col,
253 Eigen::Index num_rows, Eigen::Index num_cols) {
254 using inner_type =
decltype(
255 derived().val_.block(start_row, start_col, num_rows, num_cols));
257 derived().val_.block(start_row, start_col, num_rows, num_cols),
258 derived().adj_.block(start_row, start_col, num_rows, num_cols));
265 inline auto head(Eigen::Index n)
const {
266 using inner_type =
decltype(
derived().val_.head(n));
270 inline auto head(Eigen::Index n) {
271 using inner_type =
decltype(
derived().val_.head(n));
280 inline auto tail(Eigen::Index n)
const {
281 using inner_type =
decltype(
derived().val_.tail(n));
285 inline auto tail(Eigen::Index n) {
286 using inner_type =
decltype(
derived().val_.tail(n));
296 inline auto segment(Eigen::Index i, Eigen::Index n)
const {
297 using inner_type =
decltype(
derived().val_.segment(i, n));
301 inline auto segment(Eigen::Index i, Eigen::Index n) {
302 using inner_type =
decltype(
derived().val_.segment(i, n));
311 using inner_type =
decltype(
derived().val_.transpose());
316 using inner_type =
decltype(
derived().val_.transpose());
325 inline auto row(Eigen::Index i)
const {
326 using inner_type =
decltype(
derived().val_.row(i));
329 inline auto row(Eigen::Index i) {
330 using inner_type =
decltype(
derived().val_.row(i));
338 inline auto col(Eigen::Index i)
const {
339 using inner_type =
decltype(
derived().val_.col(i));
342 inline auto col(Eigen::Index i) {
343 using inner_type =
decltype(
derived().val_.col(i));
352 using inner_type =
decltype(
derived().val_.diagonal());
357 using inner_type =
decltype(
derived().val_.diagonal());
367 inline auto coeff(Eigen::Index i, Eigen::Index j)
const {
369 derived().adj_.coeffRef(i, j));
371 inline auto coeff(Eigen::Index i, Eigen::Index j) {
373 derived().adj_.coeffRef(i, j));
380 inline auto coeff(Eigen::Index i)
const {
401 inline auto operator()(Eigen::Index i, Eigen::Index j)
const {
402 return this->
coeff(i, j);
405 return this->
coeff(i, j);
412 using inner_type =
decltype(
derived().val_.rowwise().reverse());
414 derived().adj_.rowwise().reverse());
417 using inner_type =
decltype(
derived().val_.rowwise().reverse());
419 derived().adj_.rowwise().reverse());
426 using inner_type =
decltype(
derived().val_.colwise().reverse());
428 derived().adj_.colwise().reverse());
431 using inner_type =
decltype(
derived().val_.colwise().reverse());
433 derived().adj_.colwise().reverse());
441 using inner_type =
decltype(
derived().val_.reverse());
446 using inner_type =
decltype(
derived().val_.reverse());
456 using inner_type =
decltype(
derived().val_.topRows(n));
461 using inner_type =
decltype(
derived().val_.topRows(n));
471 using inner_type =
decltype(
derived().val_.bottomRows(n));
476 using inner_type =
decltype(
derived().val_.bottomRows(n));
486 inline auto middleRows(Eigen::Index start_row, Eigen::Index n)
const {
487 using inner_type =
decltype(
derived().val_.middleRows(start_row, n));
489 derived().adj_.middleRows(start_row, n));
491 inline auto middleRows(Eigen::Index start_row, Eigen::Index n) {
492 using inner_type =
decltype(
derived().val_.middleRows(start_row, n));
494 derived().adj_.middleRows(start_row, n));
502 using inner_type =
decltype(
derived().val_.leftCols(n));
507 using inner_type =
decltype(
derived().val_.leftCols(n));
517 using inner_type =
decltype(
derived().val_.rightCols(n));
522 using inner_type =
decltype(
derived().val_.rightCols(n));
532 inline auto middleCols(Eigen::Index start_col, Eigen::Index n)
const {
533 using inner_type =
decltype(
derived().val_.middleCols(start_col, n));
535 derived().adj_.middleCols(start_col, n));
537 inline auto middleCols(Eigen::Index start_col, Eigen::Index n) {
538 using inner_type =
decltype(
derived().val_.middleCols(start_col, n));
540 derived().adj_.middleCols(start_col, n));
547 using inner_type =
decltype(
derived().val_.array());
552 using inner_type =
decltype(
derived().val_.array());
561 using inner_type =
decltype(
derived().val_.matrix());
566 using inner_type =
decltype(
derived().val_.matrix());
574 inline Eigen::Index
rows()
const {
return derived().val_.rows(); }
578 inline Eigen::Index
cols()
const {
return derived().val_.cols(); }
582 inline Eigen::Index
size()
const {
return derived().val_.size(); }
596 static constexpr int RowsAtCompileTime = PlainObject::RowsAtCompileTime;
600 static constexpr int ColsAtCompileTime = PlainObject::ColsAtCompileTime;
604 template <
typename S,
typename K,
607 vari_view(
const S& val,
const K& adj) noexcept : val_(val), adj_(adj) {}
614 inline const auto&
val() const noexcept {
return val_; }
615 inline auto&
val_op() noexcept {
return val_; }
625 inline auto&
adj() noexcept {
return adj_; }
626 inline auto&
adj() const noexcept {
return adj_; }
627 inline auto&
adj_op() noexcept {
return adj_; }
648 T, require_all_t<is_plain_type<T>, is_eigen_dense_base<T>>>> {
660 static constexpr int RowsAtCompileTime = PlainObject::RowsAtCompileTime;
664 static constexpr int ColsAtCompileTime = PlainObject::ColsAtCompileTime;
688 template <
typename S, require_assignable_t<T, S>* =
nullptr>
691 adj_((RowsAtCompileTime == 1 && S::ColsAtCompileTime == 1)
692 || (ColsAtCompileTime == 1 && S::RowsAtCompileTime == 1)
695 (RowsAtCompileTime == 1 && S::ColsAtCompileTime == 1)
696 || (ColsAtCompileTime == 1 && S::RowsAtCompileTime == 1)
718 template <
typename S, require_assignable_t<T, S>* =
nullptr>
721 adj_((RowsAtCompileTime == 1 && S::ColsAtCompileTime == 1)
722 || (ColsAtCompileTime == 1 && S::RowsAtCompileTime == 1)
725 (RowsAtCompileTime == 1 && S::ColsAtCompileTime == 1)
726 || (ColsAtCompileTime == 1 && S::RowsAtCompileTime == 1)
749 template <
typename S,
typename K, require_assignable_t<T, S>* =
nullptr,
750 require_assignable_t<T, K>* =
nullptr>
751 explicit vari_value(
const S& val,
const K& adj) : val_(val), adj_(adj) {
756 template <
typename S, require_not_same_t<T, S>* =
nullptr>
765 inline const auto&
val() const noexcept {
return val_; }
766 inline auto&
val_op() noexcept {
return val_; }
776 inline auto&
adj() noexcept {
return adj_; }
777 inline auto&
adj() const noexcept {
return adj_; }
778 inline auto&
adj_op() noexcept {
return adj_; }
806 return os <<
"val: \n" << v->val_ <<
" \nadj: \n" << v->adj_;
810 template <
typename,
typename>
834 static constexpr int RowsAtCompileTime = T::RowsAtCompileTime;
838 static constexpr int ColsAtCompileTime = T::ColsAtCompileTime;
864 template <
typename S, require_convertible_t<S&, T>* =
nullptr>
866 : val_(
std::forward<S>(x)),
867 adj_(val_.
rows(), val_.
cols(), val_.nonZeros(), val_.outerIndexPtr(),
868 val_.innerIndexPtr(),
870 val_.innerNonZeroPtr()) {
876 : val_(val), adj_(adj) {
897 template <
typename S, require_convertible_t<S&, T>* =
nullptr>
899 : val_(
std::forward<S>(x)),
900 adj_(val_.
rows(), val_.
cols(), val_.nonZeros(), val_.outerIndexPtr(),
901 val_.innerIndexPtr(),
903 val_.innerNonZeroPtr()) {
914 Eigen::Index
rows()
const {
return val_.rows(); }
918 Eigen::Index
cols()
const {
return val_.cols(); }
922 Eigen::Index
size()
const {
return val_.size(); }
929 inline const auto&
val()
const {
return val_; }
940 inline auto&
adj() {
return adj_; }
941 inline auto&
adj()
const {
return adj_; }
952 std::fill(adj_.valuePtr(), adj_.valuePtr() + adj_.nonZeros(), 1.0);
961 std::fill(adj_.valuePtr(), adj_.valuePtr() + adj_.nonZeros(), 0.0);
974 return os <<
"val: \n" << v->val_ <<
" \nadj: \n" << v->adj_;
978 template <
typename,
typename>
Equivalent to Eigen::Matrix, except that the data is stored on AD stack.
void * alloc(size_t len)
Return a newly allocated block of memory of the appropriate size managed by the stack allocator.
virtual void chain()=0
Apply the chain rule to this variable based on the variables on which it depends.
virtual void set_zero_adjoint()=0
Abstract base class that all vari_value and it's derived classes inherit.
const auto & val() const noexcept
Return a constant reference to the value of this vari.
auto & adj() const noexcept
vari_value(const vari_value< S > *x)
arena_matrix< PlainObject > val_
The value of this variable.
friend std::ostream & operator<<(std::ostream &os, const vari_value< T > *v)
Insertion operator for vari.
vari_value(const S &val, const K &adj)
Construct a dense Eigen variable implementation from a preconstructed values and adjoints.
void init_dependent()
Initialize the adjoint for this (dependent) variable to 1.
arena_matrix< PlainObject > adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
plain_type_t< T > PlainObject
PlainObject represents a user constructible type such as Matrix or Array
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends.
vari_value(const S &x)
Construct a dense Eigen variable implementation from a value.
void set_zero_adjoint() final
Set the adjoint value of this variable to 0.
auto & adj() noexcept
Return a reference to the derivative of the root expression with respect to this expression.
vari_value(const S &x, bool stacked)
Construct a dense Eigen variable implementation from a value.
value_type_t< PlainObject > eigen_scalar
Eigen::Index rows() const
Return the number of rows for this class's val_ member.
Eigen::Index size() const
Return the size of this class's val_ member.
arena_matrix< PlainObject > val_
The value of this variable.
void init_dependent()
Initialize the adjoint for this (dependent) variable to 1.
auto & adj()
Return a reference to the derivative of the root expression with respect to this expression.
vari_value(const arena_matrix< PlainObject > &val, const arena_matrix< PlainObject > &adj)
vari_value(S &&x, bool stacked)
Construct an sparse Eigen variable implementation from a value.
Eigen::Index cols() const
Return the number of columns for this class's val_ member.
arena_matrix< PlainObject > adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
vari_value(S &&x)
Construct a variable implementation from a value.
const auto & val() const
Return a constant reference to the value of this vari.
typename arena_matrix< PlainObject >::InnerIterator InnerIterator
friend std::ostream & operator<<(std::ostream &os, const vari_value< T > *v)
Insertion operator for vari.
plain_type_t< T > PlainObject
void chain()
Apply the chain rule to this variable based on the variables on which it depends.
void set_zero_adjoint() noexcept final
Set the adjoint value of this variable to 0.
void init_dependent() noexcept
Initialize the adjoint for this (dependent) variable to 1.
void set_zero_adjoint() noexcept final
Set the adjoint value of this variable to 0.
const auto & val() const
Return a constant reference to the value of this vari.
void chain()
Apply the chain rule to this variable based on the variables on which it depends.
auto & adj() const
Return a reference of the derivative of the root expression with respect to this expression.
vari_value(S x) noexcept
Construct a variable implementation from a value.
friend std::ostream & operator<<(std::ostream &os, const vari_value< T > *v)
Insertion operator for vari.
const value_type val_
The value of this variable.
auto & adj()
Return a reference to the derivative of the root expression with respect to this expression.
std::decay_t< T > value_type
vari_value(S x, bool stacked) noexcept
Construct a variable implementation from a value.
std::decay_t< T > value_type
plain_type_t< T > PlainObject
auto & adj() const noexcept
auto & adj() noexcept
Return a reference to the derivative of the root expression with respect to this expression.
void chain()
Apply the chain rule to this variable based on the variables on which it depends.
vari_view(const S &val, const K &adj) noexcept
const auto & val() const noexcept
Return a constant reference to the value of this vari.
auto row(Eigen::Index i) const
View row of eigen matrices.
auto rowwise_reverse() const
Return an expression that operates on the rows of the matrix vari
auto middleRows(Eigen::Index start_row, Eigen::Index n) const
Return a block consisting of rows in the middle.
auto col(Eigen::Index i) const
View column of eigen matrices.
auto middleRows(Eigen::Index start_row, Eigen::Index n)
auto array() const
Return an Array expression.
auto diagonal() const
View diagonal of eigen matrices.
auto segment(Eigen::Index i, Eigen::Index n)
auto colwise_reverse() const
Return an expression that operates on the columns of the matrix vari
Eigen::Index rows() const
Return the number of rows for this class's val_ member.
auto topRows(Eigen::Index n)
const Derived & derived() const
Helper function to return a constant reference to the derived type.
auto tail(Eigen::Index n)
auto matrix() const
Return a Matrix expression.
auto coeff(Eigen::Index i) const
Get coefficient of eigen matrices.
auto head(Eigen::Index n)
auto bottomRows(Eigen::Index n)
auto block(Eigen::Index start_row, Eigen::Index start_col, Eigen::Index num_rows, Eigen::Index num_cols) const
A block view of the underlying Eigen matrices.
auto coeff(Eigen::Index i, Eigen::Index j) const
Get coefficient of eigen matrices.
auto transpose() const
View transpose of eigen matrix.
auto operator()(Eigen::Index i)
auto tail(Eigen::Index n) const
View of the tail of the Eigen vector types.
auto topRows(Eigen::Index n) const
Return a block consisting of the top rows.
auto head(Eigen::Index n) const
View of the head of Eigen vector types.
auto segment(Eigen::Index i, Eigen::Index n) const
View block of N elements starting at position i
auto rightCols(Eigen::Index n)
auto block(Eigen::Index start_row, Eigen::Index start_col, Eigen::Index num_rows, Eigen::Index num_cols)
auto leftCols(Eigen::Index n)
Eigen::Index size() const
Return the size of this class's val_ member.
auto bottomRows(Eigen::Index n) const
Return a block consisting of the bottom rows.
auto reverse() const
Return an expression to reverse the order of the coefficients inside of a vari matrix.
Eigen::Index cols() const
Return the number of columns for this class's val_ member.
auto operator()(Eigen::Index i, Eigen::Index j)
auto coeff(Eigen::Index i)
auto operator()(Eigen::Index i) const
Get coefficient of eigen matrices.
auto middleCols(Eigen::Index start_col, Eigen::Index n)
auto coeff(Eigen::Index i, Eigen::Index j)
vari_view_eigen()=default
Making the base constructor private while making the derived class a friend help's catch if derived t...
auto leftCols(Eigen::Index n) const
Return a block consisting of the left-most columns.
auto middleCols(Eigen::Index start_col, Eigen::Index n) const
Return a block consisting of columns in the middle.
auto operator()(Eigen::Index i, Eigen::Index j) const
Get coefficient of eigen matrices.
auto rightCols(Eigen::Index n) const
Return a block consisting of the right-most columns.
Derived & derived()
Helper function to return a reference to the derived type.
This struct is follows the CRTP for methods common to vari_view<> and vari_value<Matrix>.
A vari_view is used to read from a slice of a vari_value with an inner eigen type.
require_t< std::is_assignable< std::decay_t< T >, std::decay_t< S > > > require_assignable_t
Require types T and S satisfies std::is_assignable.
require_t< is_eigen_sparse_base< std::decay_t< T > > > require_eigen_sparse_base_t
Require type satisfies is_eigen_sparse_base.
int64_t cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
std::is_same< std::decay_t< S >, plain_type_t< S > > is_plain_type
Checks whether the template type T is an assignable type.
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
std::integral_constant< bool, B > bool_constant
Alias for structs used for wraps a static constant of bool.
(Expert) Numerical traits for algorithmic differentiation variables.
typename plain_type< T >::type plain_type_t
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
std::enable_if_t< math::conjunction< Checks... >::value > require_all_t
If all conditions are true, template is enabled Returns a type void if all conditions are true and ot...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Checks whether type T is derived from Eigen::DenseBase.
Check if type derives from EigenBase
std::vector< ChainableT * > var_stack_
std::vector< ChainableT * > var_nochain_stack_
static thread_local AutodiffStackStorage * instance_