Automatic Differentiation
 
Loading...
Searching...
No Matches
ordered_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_ORDERED_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_ORDERED_CONSTRAIN_HPP
3
9#include <cmath>
10#include <tuple>
11#include <vector>
12
13namespace stan {
14namespace math {
15
24template <typename T, require_rev_col_vector_t<T>* = nullptr>
25inline auto ordered_constrain(const T& x) {
26 using ret_type = plain_type_t<T>;
27
28 using std::exp;
29
30 size_t N = x.size();
31 if (unlikely(N == 0)) {
32 return ret_type(x);
33 }
34
35 Eigen::VectorXd y_val(N);
36 arena_t<T> arena_x = x;
37 arena_t<Eigen::VectorXd> exp_x(N - 1);
38
39 y_val.coeffRef(0) = arena_x.val().coeff(0);
40 for (Eigen::Index n = 1; n < N; ++n) {
41 exp_x.coeffRef(n - 1) = exp(arena_x.val().coeff(n));
42 y_val.coeffRef(n) = y_val.coeff(n - 1) + exp_x.coeff(n - 1);
43 }
44
45 arena_t<ret_type> y = y_val;
46
47 reverse_pass_callback([arena_x, y, exp_x]() mutable {
48 double rolling_adjoint_sum = 0.0;
49
50 for (int n = arena_x.size() - 1; n > 0; --n) {
51 rolling_adjoint_sum += y.adj().coeff(n);
52 arena_x.adj().coeffRef(n) += exp_x.coeff(n - 1) * rolling_adjoint_sum;
53 }
54 arena_x.adj().coeffRef(0) += rolling_adjoint_sum + y.adj().coeff(0);
55 });
56
57 return ret_type(y);
58}
59
72template <typename VarVec, require_var_col_vector_t<VarVec>* = nullptr>
73auto ordered_constrain(const VarVec& x, scalar_type_t<VarVec>& lp) {
74 if (x.size() > 1) {
75 lp += sum(x.tail(x.size() - 1));
76 }
77 return ordered_constrain(x);
78}
79
80} // namespace math
81} // namespace stan
82#endif
#define unlikely(x)
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
plain_type_t< EigVec > ordered_constrain(const EigVec &x)
Return an increasing ordered vector derived from the specified free vector.
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
typename plain_type< T >::type plain_type_t
typename scalar_type< T >::type scalar_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...