Automatic Differentiation
 
Loading...
Searching...
No Matches
inv_inc_beta.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_FWD_FUN_INV_INC_BETA_HPP
2#define STAN_MATH_FWD_FUN_INV_INC_BETA_HPP
3
15
16namespace stan {
17namespace math {
18
32template <typename T1, typename T2, typename T3,
33 require_all_stan_scalar_t<T1, T2, T3>* = nullptr,
34 require_any_fvar_t<T1, T2, T3>* = nullptr>
36 const T2& b,
37 const T3& p) {
38 using T_return = partials_return_t<T1, T2, T3>;
39 auto a_val = value_of(a);
40 auto b_val = value_of(b);
41 auto p_val = value_of(p);
42 T_return w = inv_inc_beta(a_val, b_val, p_val);
43 T_return log_w = log(w);
44 T_return log1m_w = log1m(w);
45 auto one_m_a = 1 - a_val;
46 auto one_m_b = 1 - b_val;
47 T_return one_m_w = 1 - w;
48 auto ap1 = a_val + 1;
49 auto bp1 = b_val + 1;
50 auto lbeta_ab = lbeta(a_val, b_val);
51 auto digamma_apb = digamma(a_val + b_val);
52
53 T_return inv_d_(0);
54
56 std::vector<T_return> da_a{a_val, a_val, one_m_b};
57 std::vector<T_return> da_b{ap1, ap1};
58 auto da1 = exp(one_m_b * log1m_w + one_m_a * log_w);
59 auto da2 = exp(a_val * log_w + 2 * lgamma(a_val)
60 + log(hypergeometric_3F2(da_a, da_b, w)) - 2 * lgamma(ap1));
61 auto da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab)
62 * (log_w - digamma(a_val) + digamma_apb);
63 inv_d_ += forward_as<fvar<T_return>>(a).d_ * da1 * (da2 - da3);
64 }
65
67 std::vector<T_return> db_a{b_val, b_val, one_m_a};
68 std::vector<T_return> db_b{bp1, bp1};
69 auto db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w);
70 auto db2 = 2 * lgamma(b_val) + log(hypergeometric_3F2(db_a, db_b, one_m_w))
71 - 2 * lgamma(bp1) + b_val * log1m_w;
72
73 auto db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab)
74 * (log1m_w - digamma(b_val) + digamma_apb);
75
76 inv_d_ += forward_as<fvar<T_return>>(b).d_ * db1 * (exp(db2) - db3);
77 }
78
80 inv_d_ += forward_as<fvar<T_return>>(p).d_
81 * exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab);
82 }
83
84 return fvar<T_return>(w, inv_d_);
85}
86
87} // namespace math
88} // namespace stan
89#endif
auto hypergeometric_3F2(const Ta &a, const Tb &b, const Tz &z)
Hypergeometric function (3F2).
fvar< T > lbeta(const fvar< T > &x1, const fvar< T > &x2)
Definition lbeta.hpp:14
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
fvar< partials_return_t< T1, T2, T3 > > inv_inc_beta(const T1 &a, const T2 &b, const T3 &p)
The inverse of the normalized incomplete beta function of a, b, with probability p.
fvar< T > inc_beta(const fvar< T > &a, const fvar< T > &b, const fvar< T > &x)
Definition inc_beta.hpp:19
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
Definition lgamma.hpp:21
fvar< T > log1m(const fvar< T > &x)
Definition log1m.hpp:12
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
Definition digamma.hpp:23
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Defines a static member function type which is defined to be false as the primitive scalar types cann...
Definition is_fvar.hpp:15
This template class represents scalars used in forward-mode automatic differentiation,...
Definition fvar.hpp:40