1#ifndef STAN_MATH_REV_CORE_VAR_HPP
2#define STAN_MATH_REV_CORE_VAR_HPP
22template <
typename Vari>
81 template <
typename S, require_convertible_t<S&, value_type>* =
nullptr>
95 inline const auto&
val() const noexcept {
return vi_->val(); }
105 inline auto&
adj() const noexcept {
return vi_->adj(); }
115 inline auto&
adj() noexcept {
return vi_->adj_; }
135 for (
size_t i = 0; i < x.size(); ++i) {
136 g[i] = x[i].vi_->adj_;
291 if (v.vi_ ==
nullptr) {
292 return os <<
"uninitialized";
294 return os << v.val();
320class var_value<T, internal::require_matrix_var_value<T>> {
323 using vari_type = std::conditional_t<is_plain_type<value_type>::value,
326 static constexpr int RowsAtCompileTime{vari_type::RowsAtCompileTime};
327 static constexpr int ColsAtCompileTime{vari_type::ColsAtCompileTime};
365 template <
typename S, require_assignable_t<value_type, S>* =
nullptr>
375 template <
typename S,
392 template <
typename S,
399 other_vi->adj_ += this_vi->adj_;
410 template <
typename S,
typename T_ = T,
416 [this_vi = this->vi_, other_vi = other.vi_]()
mutable {
417 other_vi->adj_ += this_vi->adj_;
428 template <
typename S,
typename T_ = T,
444 inline const auto&
val() const noexcept {
return vi_->val(); }
445 inline auto&
val_op() noexcept {
return vi_->val_op(); }
455 inline auto&
adj() noexcept {
return vi_->adj(); }
456 inline auto&
adj() const noexcept {
return vi_->adj(); }
457 inline auto&
adj_op() noexcept {
return vi_->adj(); }
459 inline Eigen::Index
rows() const noexcept {
return vi_->rows(); }
460 inline Eigen::Index
cols() const noexcept {
return vi_->cols(); }
461 inline Eigen::Index
size() const noexcept {
return vi_->size(); }
528 template <
typename S, require_st_var<S>* =
nullptr>
542 template <
typename S, require_st_arithmetic<S>* =
nullptr>
603 inline auto block(Eigen::Index start_row, Eigen::Index start_col,
604 Eigen::Index num_rows, Eigen::Index num_cols)
const {
606 =
decltype(vi_->block(start_row, start_col, num_rows, num_cols));
609 new vari_sub(vi_->block(start_row, start_col, num_rows, num_cols)));
611 inline auto block(Eigen::Index start_row, Eigen::Index start_col,
612 Eigen::Index num_rows, Eigen::Index num_cols) {
614 =
decltype(vi_->block(start_row, start_col, num_rows, num_cols));
617 new vari_sub(vi_->block(start_row, start_col, num_rows, num_cols)));
624 using vari_sub =
decltype(vi_->transpose());
626 return var_sub(
new vari_sub(vi_->transpose()));
629 using vari_sub =
decltype(vi_->transpose());
631 return var_sub(
new vari_sub(vi_->transpose()));
638 inline auto head(Eigen::Index n)
const {
639 using vari_sub =
decltype(vi_->head(n));
641 return var_sub(
new vari_sub(vi_->head(n)));
643 inline auto head(Eigen::Index n) {
644 using vari_sub =
decltype(vi_->head(n));
646 return var_sub(
new vari_sub(vi_->head(n)));
653 inline auto tail(Eigen::Index n)
const {
654 using vari_sub =
decltype(vi_->tail(n));
656 return var_sub(
new vari_sub(vi_->tail(n)));
658 inline auto tail(Eigen::Index n) {
659 using vari_sub =
decltype(vi_->tail(n));
661 return var_sub(
new vari_sub(vi_->tail(n)));
669 inline auto segment(Eigen::Index i, Eigen::Index n)
const {
670 using vari_sub =
decltype(vi_->segment(i, n));
672 return var_sub(
new vari_sub(vi_->segment(i, n)));
674 inline auto segment(Eigen::Index i, Eigen::Index n) {
675 using vari_sub =
decltype(vi_->segment(i, n));
677 return var_sub(
new vari_sub(vi_->segment(i, n)));
684 inline auto row(Eigen::Index i)
const {
685 using vari_sub =
decltype(vi_->row(i));
687 return var_sub(
new vari_sub(vi_->row(i)));
689 inline auto row(Eigen::Index i) {
690 using vari_sub =
decltype(vi_->row(i));
692 return var_sub(
new vari_sub(vi_->row(i)));
699 inline auto col(Eigen::Index i)
const {
700 using vari_sub =
decltype(vi_->col(i));
702 return var_sub(
new vari_sub(vi_->col(i)));
704 inline auto col(Eigen::Index i) {
705 using vari_sub =
decltype(vi_->col(i));
707 return var_sub(
new vari_sub(vi_->col(i)));
715 using vari_sub =
decltype(vi_->diagonal());
717 return var_sub(
new vari_sub(vi_->diagonal()));
720 using vari_sub =
decltype(vi_->diagonal());
722 return var_sub(
new vari_sub(vi_->diagonal()));
729 using vari_sub =
decltype(vi_->as_column_vector_or_scalar());
731 return var_sub(
new vari_sub(vi_->as_column_vector_or_scalar()));
734 using vari_sub =
decltype(vi_->as_column_vector_or_scalar());
736 return var_sub(
new vari_sub(vi_->as_column_vector_or_scalar()));
746 inline auto coeff(Eigen::Index i)
const {
747 using vari_sub =
decltype(vi_->coeff(i));
748 vari_sub* vari_coeff =
new vari_sub(vi_->coeff(i));
750 this_vi->adj_.coeffRef(i) += vari_coeff->adj_;
755 using vari_sub =
decltype(vi_->coeff(i));
756 vari_sub* vari_coeff =
new vari_sub(vi_->coeff(i));
758 this_vi->adj_.coeffRef(i) += vari_coeff->adj_;
771 inline auto coeff(Eigen::Index i, Eigen::Index j)
const {
772 using vari_sub =
decltype(vi_->coeff(i, j));
773 vari_sub* vari_coeff =
new vari_sub(vi_->coeff(i, j));
775 this_vi->adj_.coeffRef(i, j) += vari_coeff->adj_;
779 inline auto coeff(Eigen::Index i, Eigen::Index j) {
780 using vari_sub =
decltype(vi_->coeff(i, j));
781 vari_sub* vari_coeff =
new vari_sub(vi_->coeff(i, j));
783 this_vi->adj_.coeffRef(i, j) += vari_coeff->adj_;
795 inline auto operator()(Eigen::Index i)
const {
return this->coeff(i); }
796 inline auto operator()(Eigen::Index i) {
return this->coeff(i); }
806 inline auto operator()(Eigen::Index i, Eigen::Index j)
const {
807 return this->coeff(i, j);
810 return this->coeff(i, j);
820 inline auto coeffRef(Eigen::Index i)
const {
return this->coeff(i); }
821 inline auto coeffRef(Eigen::Index i) {
return this->coeff(i); }
831 inline auto coeffRef(Eigen::Index i, Eigen::Index j)
const {
832 return this->coeff(i, j);
834 inline auto coeffRef(Eigen::Index i, Eigen::Index j) {
835 return this->coeff(i, j);
842 using vari_sub =
decltype(vi_->rowwise_reverse());
844 return var_sub(
new vari_sub(vi_->rowwise_reverse()));
847 using vari_sub =
decltype(vi_->rowwise_reverse());
849 return var_sub(
new vari_sub(vi_->rowwise_reverse()));
856 using vari_sub =
decltype(vi_->colwise_reverse());
858 return var_sub(
new vari_sub(vi_->colwise_reverse()));
861 using vari_sub =
decltype(vi_->colwise_reverse());
863 return var_sub(
new vari_sub(vi_->colwise_reverse()));
871 using vari_sub =
decltype(vi_->reverse());
873 return var_sub(
new vari_sub(vi_->reverse()));
876 using vari_sub =
decltype(vi_->reverse());
878 return var_sub(
new vari_sub(vi_->reverse()));
886 using vari_sub =
decltype(vi_->topRows(n));
888 return var_sub(
new vari_sub(vi_->topRows(n)));
891 using vari_sub =
decltype(vi_->topRows(n));
893 return var_sub(
new vari_sub(vi_->topRows(n)));
901 using vari_sub =
decltype(vi_->bottomRows(n));
903 return var_sub(
new vari_sub(vi_->bottomRows(n)));
906 using vari_sub =
decltype(vi_->bottomRows(n));
908 return var_sub(
new vari_sub(vi_->bottomRows(n)));
916 inline auto middleRows(Eigen::Index start_row, Eigen::Index n)
const {
917 using vari_sub =
decltype(vi_->middleRows(start_row, n));
919 return var_sub(
new vari_sub(vi_->middleRows(start_row, n)));
921 inline auto middleRows(Eigen::Index start_row, Eigen::Index n) {
922 using vari_sub =
decltype(vi_->middleRows(start_row, n));
924 return var_sub(
new vari_sub(vi_->middleRows(start_row, n)));
932 using vari_sub =
decltype(vi_->leftCols(n));
934 return var_sub(
new vari_sub(vi_->leftCols(n)));
937 using vari_sub =
decltype(vi_->leftCols(n));
939 return var_sub(
new vari_sub(vi_->leftCols(n)));
947 using vari_sub =
decltype(vi_->rightCols(n));
949 return var_sub(
new vari_sub(vi_->rightCols(n)));
952 using vari_sub =
decltype(vi_->rightCols(n));
954 return var_sub(
new vari_sub(vi_->rightCols(n)));
962 inline auto middleCols(Eigen::Index start_col, Eigen::Index n)
const {
963 using vari_sub =
decltype(vi_->middleCols(start_col, n));
965 return var_sub(
new vari_sub(vi_->middleCols(start_col, n)));
967 inline auto middleCols(Eigen::Index start_col, Eigen::Index n) {
968 using vari_sub =
decltype(vi_->middleCols(start_col, n));
970 return var_sub(
new vari_sub(vi_->middleCols(start_col, n)));
977 using vari_sub =
decltype(vi_->array());
979 return var_sub(
new vari_sub(vi_->array()));
982 using vari_sub =
decltype(vi_->array());
984 return var_sub(
new vari_sub(vi_->array()));
991 using vari_sub =
decltype(vi_->matrix());
993 return var_sub(
new vari_sub(vi_->matrix()));
996 using vari_sub =
decltype(vi_->matrix());
998 return var_sub(
new vari_sub(vi_->matrix()));
1010 if (v.vi_ ==
nullptr) {
1011 return os <<
"uninitialized";
1013 return os << v.val();
1020 template <
typename U = T,
1022 inline auto rows() const noexcept {
1030 template <
typename U = T,
1032 inline auto cols() const noexcept {
1043 template <
typename S, require_assignable_t<value_type, S>* =
nullptr,
1044 require_all_plain_type_t<T, S>* =
nullptr,
1045 require_same_t<plain_type_t<T>, plain_type_t<S>>* =
nullptr>
1060 template <
typename S,
typename T_ = T,
1066 EIGEN_PREDICATE_SAME_MATRIX_SIZE(T, S),
1067 "You mixed matrices of different sizes that are not assignable.");
1079 template <
typename S,
typename T_ = T,
1090 prev_val.deep_copy(vi_->val_);
1091 vi_->val_.deep_copy(other.val());
1096 [this_vi = this->vi_, other_vi = other.vi_, prev_val]()
mutable {
1097 this_vi->val_.deep_copy(prev_val);
1103 prev_val.deep_copy(this_vi->adj_);
1104 this_vi->adj_.setZero();
1105 other_vi->adj_ += prev_val;
1119 template <
typename S,
typename T_ = T,
1126 throw std::domain_error(
1127 "var_value<matrix>::operator=(var_value<expression>):"
1128 " Internal Bug! Please report this with an example"
1129 " of your model to the Stan math github repository.");
1133 vi_->val_ = other.val();
1138 [this_vi = this->vi_, other_vi = other.vi_, prev_val]()
mutable {
1139 this_vi->val_ = prev_val;
1145 prev_val = this_vi->adj_;
1146 this_vi->adj_.setZero();
1147 other_vi->adj_ += prev_val;
1155 template <
typename T_ = T, require_plain_type_t<T_>* =
nullptr>
1159 template <
typename T_ = T, require_plain_type_t<T_>* =
nullptr>
1160 inline const auto&
eval() const noexcept {
1167 template <
typename T_ = T, require_not_plain_type_t<T_>* =
nullptr>
1171 template <
typename T_ = T, require_not_plain_type_t<T_>* =
nullptr>
1182 return operator=<T>(other);
1198template <
typename T>
auto reverse() const
Return an expression an expression to reverse the order of the coefficients inside of a vari matrix.
auto segment(Eigen::Index i, Eigen::Index n) const
View block of N elements starting at position i
auto matrix() const
Return an Matrix.
auto coeff(Eigen::Index i, Eigen::Index j) const
View element of eigen matrices.
var_value< T > & operator*=(const var_value< T > &b)
The compound multiply/assignment operator for variables (C++).
auto coeff(Eigen::Index i) const
View element of eigen matrices.
var_value< T > & operator*=(T b)
The compound multiply/assignment operator for scalars (C++).
auto as_column_vector_or_scalar() const
View a matrix_cl as a column vector.
var_value(const var_value< S > &other)
Construct from a var_value with different inner vari_type
auto & adj() noexcept
Return a reference to the derivative of the root expression with respect to this expression.
auto operator()(Eigen::Index i)
std::conditional_t< is_plain_type< value_type >::value, vari_value< value_type >, vari_view< T > > vari_type
vari_type * operator->()
Return a pointer to the underlying implementation of this variable.
auto operator()(Eigen::Index i) const
View element of eigen matrices.
auto rows() const noexcept
Returns number of rows.
vari_type * vi_
Pointer to the implementation of this variable.
auto diagonal() const
View diagonal of eigen matrices.
const auto & eval() const noexcept
auto segment(Eigen::Index i, Eigen::Index n)
const auto & val() const noexcept
Return a constant reference to the value of this variable.
auto middleCols(Eigen::Index start_col, Eigen::Index n) const
Return a block consisting of columns in the middle.
var_value(const var_value< S > &other)
Construct a var_value with a plain type from another var_value containing an expression.
Eigen::Index size() const noexcept
auto leftCols(Eigen::Index n) const
Return a block consisting of the left-most columns.
auto coeff(Eigen::Index i)
var_value< T > & operator+=(const var_value< T > &b)
The compound add/assignment operator for variables (C++).
auto cols() const noexcept
Returns number of columns.
auto leftCols(Eigen::Index n)
auto array() const
Return an Array.
auto as_column_vector_or_scalar()
auto row(Eigen::Index i) const
View row of eigen matrices.
bool is_uninitialized() noexcept
Return true if this variable has been declared, but not been defined.
auto operator()(Eigen::Index i, Eigen::Index j)
auto eval()
For non-plain types evaluate to the plain type.
auto & adj() const noexcept
auto colwise_reverse() const
Return an expression that operates on the columns of the matrix vari
auto block(Eigen::Index start_row, Eigen::Index start_col, Eigen::Index num_rows, Eigen::Index num_cols)
auto coeffRef(Eigen::Index i)
auto col(Eigen::Index i) const
View column of eigen matrices.
var_value< T > & operator=(const var_value< S > &other)
Assignment of one plain type to another when one sides compile time columns differ from the other.
auto bottomRows(Eigen::Index n) const
Return a block consisting of the bottom rows.
auto head(Eigen::Index n) const
View of the head of Eigen vector types.
auto tail(Eigen::Index n)
auto tail(Eigen::Index n) const
View of the tail of the Eigen vector types.
auto middleRows(Eigen::Index start_row, Eigen::Index n) const
Return a block consisting of rows in the middle.
auto coeff(Eigen::Index i, Eigen::Index j)
auto head(Eigen::Index n)
vari_type & operator*()
Return a reference to underlying implementation of this variable.
auto rightCols(Eigen::Index n) const
Return a block consisting of the right-most columns.
var_value(const S &val, const S &adj)
Construct a var_value with premade arena_matrix types.
auto rightCols(Eigen::Index n)
var_value(S &&x)
Construct a variable from the specified floating point argument by constructing a new vari_value<valu...
friend std::ostream & operator<<(std::ostream &os, const var_value< T > &v)
Write the value of this autodiff variable and its adjoint to the specified output stream.
auto middleRows(Eigen::Index start_row, Eigen::Index n)
auto operator()(Eigen::Index i, Eigen::Index j) const
View element of eigen matrices.
auto topRows(Eigen::Index n)
Eigen::Index rows() const noexcept
var_value< T > & operator-=(const S &b)
The compound subtract/assignment operator for variables (C++).
Eigen::Index cols() const noexcept
auto coeffRef(Eigen::Index i, Eigen::Index j)
var_value< T > & operator=(const var_value< S > &other)
Assignment of another plain var value, when this also contains a plain type.
auto bottomRows(Eigen::Index n)
var_value< T > & operator=(const var_value< T > &other)
Copy assignment operator delegates to general assignment operator.
auto block(Eigen::Index start_row, Eigen::Index start_col, Eigen::Index num_rows, Eigen::Index num_cols) const
A block view of the underlying Eigen matrices.
auto & eval() noexcept
No-op to match with Eigen methods which call eval.
var_value(vari_type *vi)
Construct a variable from a pointer to a variable implementation.
var_value()
Construct a variable for later assignment.
var_value(const var_value< S > &other)
Copy constructor for var_val when the vari_type from other is directly assignable.
auto coeffRef(Eigen::Index i) const
View element of eigen matrices.
auto coeffRef(Eigen::Index i, Eigen::Index j) const
View element of eigen matrices.
auto rowwise_reverse() const
Return an expression that operates on the rows of the matrix vari
var_value< T > & operator+=(T b)
The compound add/assignment operator for scalars (C++).
auto middleCols(Eigen::Index start_col, Eigen::Index n)
auto topRows(Eigen::Index n) const
Return a block consisting of the top rows.
auto transpose() const
View transpose of eigen matrix.
void grad(std::vector< var_value< T > > &x, std::vector< value_type > &g)
Compute the gradient of this (dependent) variable with respect to the specified vector of (independen...
auto & adj() const noexcept
Return a reference of the derivative of the root expression with respect to this expression.
var_value(vari_type *vi)
Construct a variable from a pointer to a variable implementation.
vari_type * operator->()
Return a pointer to the underlying implementation of this variable.
vari_type * vi_
Pointer to the implementation of this variable.
bool is_uninitialized()
Return true if this variable has been declared, but not been defined.
var_value()
Construct a variable for later assignment.
friend std::ostream & operator<<(std::ostream &os, const var_value< T > &v)
Write the value of this autodiff variable and its adjoint to the specified output stream.
vari_type & operator*()
Return a reference to underlying implementation of this variable.
void grad()
Compute the gradient of this (dependent) variable with respect to all (independent) variables.
std::decay_t< T > value_type
var_value(S x)
Construct a variable from the specified floating point argument by constructing a new vari_value<valu...
const auto & val() const noexcept
Return a constant reference to the value of this variable.
auto & adj() noexcept
Return a reference to the derivative of the root expression with respect to this expression.
A vari_view is used to read from a slice of a vari_value with an inner eigen type.
require_t< is_arena_matrix< std::decay_t< T > > > require_arena_matrix_t
Require type satisfies is_arena_matrix.
require_t< std::is_assignable< std::decay_t< T >, std::decay_t< S > > > require_assignable_t
Require types T and S satisfies std::is_assignable.
require_not_t< std::is_assignable< std::decay_t< T >, std::decay_t< S > > > require_not_assignable_t
Require types T and S does not satisfy std::is_assignable.
require_t< std::is_constructible< std::decay_t< T >, std::decay_t< S > > > require_constructible_t
Require types T and S satisfies std::is_constructible.
require_t< std::is_floating_point< std::decay_t< T > > > require_floating_point_t
Require type satisfies std::is_floating_point.
require_not_t< is_plain_type< std::decay_t< T > > > require_not_plain_type_t
Require type does not satisfy is_plain_type.
require_all_t< is_plain_type< std::decay_t< Types > >... > require_all_plain_type_t
Require all of the types satisfy is_plain_type.
require_t< is_plain_type< std::decay_t< T > > > require_plain_type_t
Require type satisfies is_plain_type.
require_not_t< std::is_same< std::decay_t< T >, std::decay_t< S > > > require_not_same_t
Require types T and S does not satisfy std::is_same.
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
std::integral_constant< bool, B > bool_constant
Alias for structs used for wraps a static constant of bool.
require_t< bool_constant<(is_eigen< T >::value||is_kernel_expression_and_not_scalar< T >::value) &&std::is_floating_point< value_type_t< T > >::value > > require_matrix_var_value
T1 operator+=(T1 &&a, T2 &&b)
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T1 operator-=(T1 &&a, T2 &&b)
T1 operator*=(T1 &&a, T2 &&b)
static void grad()
Compute the gradient for all variables starting from the end of the AD tape.
std::enable_if_t< math::disjunction< Checks... >::value > require_any_t
If any condition is true, template is enabled.
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Check if type derives from EigenBase
Determines whether a type is non-scalar type that is a valid kernel generator expression.
Checks if the decayed type of T is a matrix_cl.
Metaprogram structure to determine the base scalar type of a template argument.