Automatic Differentiation
 
Loading...
Searching...
No Matches
inv_inc_beta.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_INV_INC_BETA_HPP
2#define STAN_MATH_REV_FUN_INV_INC_BETA_HPP
3
17
18namespace stan {
19namespace math {
20
54template <typename T1, typename T2, typename T3,
55 require_all_stan_scalar_t<T1, T2, T3>* = nullptr,
56 require_any_var_t<T1, T2, T3>* = nullptr>
57inline var inv_inc_beta(const T1& a, const T2& b, const T3& p) {
58 double a_val = value_of(a);
59 double b_val = value_of(b);
60 double p_val = value_of(p);
61 double w = inv_inc_beta(a_val, b_val, p_val);
62 return make_callback_var(w, [a, b, p, a_val, b_val, w](auto& vi) {
63 double log_w = log(w);
64 double log1m_w = log1m(w);
65 double one_m_a = 1 - a_val;
66 double one_m_b = 1 - b_val;
67 double one_m_w = 1 - w;
68 double ap1 = a_val + 1;
69 double bp1 = b_val + 1;
70 double lbeta_ab = lbeta(a_val, b_val);
71 double digamma_apb = digamma(a_val + b_val);
72
74 double da1 = exp(one_m_b * log1m_w + one_m_a * log_w);
75 double da2
76 = a_val * log_w + 2 * lgamma(a_val)
77 + log(hypergeometric_3F2({a_val, a_val, one_m_b}, {ap1, ap1}, w))
78 - 2 * lgamma(ap1);
79 double da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab)
80 * (log_w - digamma(a_val) + digamma_apb);
81
82 forward_as<var>(a).adj() += vi.adj() * da1 * (exp(da2) - da3);
83 }
84
86 double db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w);
87 double db2 = 2 * lgamma(b_val)
88 + log(hypergeometric_3F2({b_val, b_val, one_m_a}, {bp1, bp1},
89 one_m_w))
90 - 2 * lgamma(bp1) + b_val * log1m_w;
91
92 double db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab)
93 * (log1m_w - digamma(b_val) + digamma_apb);
94
95 forward_as<var>(b).adj() += vi.adj() * db1 * (exp(db2) - db3);
96 }
97
99 forward_as<var>(p).adj()
100 += vi.adj() * exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab);
101 }
102 });
103}
104
105} // namespace math
106} // namespace stan
107#endif
auto hypergeometric_3F2(const Ta &a, const Tb &b, const Tz &z)
Hypergeometric function (3F2).
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
fvar< T > lbeta(const fvar< T > &x1, const fvar< T > &x2)
Definition lbeta.hpp:14
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
fvar< partials_return_t< T1, T2, T3 > > inv_inc_beta(const T1 &a, const T2 &b, const T3 &p)
The inverse of the normalized incomplete beta function of a, b, with probability p.
fvar< T > inc_beta(const fvar< T > &a, const fvar< T > &b, const fvar< T > &x)
Definition inc_beta.hpp:19
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
Definition lgamma.hpp:21
fvar< T > log1m(const fvar< T > &x)
Definition log1m.hpp:12
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
Definition digamma.hpp:23
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...