1#ifndef STAN_MATH_PRIM_FUN_GRAD_2F1_HPP
2#define STAN_MATH_PRIM_FUN_GRAD_2F1_HPP
16#include <boost/optional.hpp>
48template <
bool calc_a1,
bool calc_a2,
bool calc_b1,
typename T1,
typename T2,
49 typename T3,
typename T_z,
50 typename ScalarT = return_type_t<T1, T2, T3, T_z>,
51 typename TupleT = std::tuple<ScalarT, ScalarT, ScalarT>>
53 const T_z& z,
double precision = 1
e-14,
54 int max_steps = 1e6) {
55 TupleT grad_tuple = TupleT(0, 0, 0);
61 using ScalarArrayT = Eigen::Array<ScalarT, 3, 1>;
62 ScalarArrayT log_g_old = ScalarArrayT::Constant(3, 1,
NEGATIVE_INFTY);
64 ScalarT log_t_old = 0.0;
65 ScalarT log_t_new = 0.0;
69 int log_t_new_sign = 1.0;
70 int log_t_old_sign = 1.0;
72 Eigen::Array<int, 3, 1> log_g_old_sign = Eigen::Array<int, 3, 1>::Ones(3);
76 const int min_steps = 5;
77 ScalarT inner_diff = 1;
78 ScalarArrayT g_current = ScalarArrayT::Zero(3);
80 while ((inner_diff > precision || k < min_steps) && k < max_steps) {
81 ScalarT p = ((a1 + k) * (a2 + k) / ((b1 + k) * (1.0 + k)));
85 log_t_new +=
log(
fabs(p)) + log_z;
90 = log_g_old_sign(0) * log_t_old_sign *
exp(log_g_old(0) - log_t_old)
92 log_g_old(0) = log_t_new +
log(
abs(term_a1));
94 g_current(0) = log_g_old_sign(0) *
exp(log_g_old(0)) * sign_zk;
95 std::get<0>(grad_tuple) += g_current(0);
100 = log_g_old_sign(1) * log_t_old_sign *
exp(log_g_old(1) - log_t_old)
102 log_g_old(1) = log_t_new +
log(
abs(term_a2));
104 g_current(1) = log_g_old_sign(1) *
exp(log_g_old(1)) * sign_zk;
105 std::get<1>(grad_tuple) += g_current(1);
110 = log_g_old_sign(2) * log_t_old_sign *
exp(log_g_old(2) - log_t_old)
112 log_g_old(2) = log_t_new +
log(
abs(term_b1));
114 g_current(2) = log_g_old_sign(2) *
exp(log_g_old(2)) * sign_zk;
115 std::get<2>(grad_tuple) += g_current(2);
118 inner_diff = g_current.array().abs().maxCoeff();
120 log_t_old = log_t_new;
121 log_t_old_sign = log_t_new_sign;
129 " iterations, hypergeometric function gradient "
130 "did not converge.");
167template <
bool calc_a1,
bool calc_a2,
bool calc_b1,
bool calc_z,
typename T1,
168 typename T2,
typename T3,
typename T_z,
170 typename TupleT = std::tuple<ScalarT, ScalarT, ScalarT, ScalarT>>
172 const T_z& z,
double precision = 1
e-14,
173 int max_steps = 1e6) {
174 bool euler_transform =
false;
177 }
catch (
const std::exception&
e) {
182 euler_transform =
true;
185 std::tuple<ScalarT, ScalarT, ScalarT> grad_tuple_ab;
186 TupleT grad_tuple_rtn = TupleT(0, 0, 0, 0);
187 if (euler_transform) {
188 ScalarT a1_euler = a2;
189 ScalarT a2_euler = b1 - a1;
190 ScalarT z_euler = z / (z - 1);
194 std::get<3>(grad_tuple_rtn)
195 = a2 *
pow(1 - z, -1 - a2) * hyper1
196 + (a2 * (b1 - a1) *
pow(1 - z, -a2)
197 * (
inv(z - 1) - z /
square(z - 1)) * hyper2)
200 if (calc_a1 || calc_a2 || calc_b1) {
203 constexpr bool calc_a1_euler = calc_a1 || calc_a2;
205 constexpr bool calc_a2_euler = calc_a1 || calc_a2 || calc_b1;
206 grad_tuple_ab = grad_2F1_impl_ab<calc_a1_euler, calc_a2_euler, calc_b1>(
207 a1_euler, a2_euler, b1, z_euler);
209 auto pre_mult_ab =
inv(
pow(1.0 - z, a2));
211 std::get<0>(grad_tuple_rtn) = -pre_mult_ab * std::get<1>(grad_tuple_ab);
215 std::get<1>(grad_tuple_rtn)
216 = -pre_mult_ab * hyper_da2 *
log1m(z)
217 + pre_mult_ab * std::get<0>(grad_tuple_ab);
220 std::get<2>(grad_tuple_rtn)
222 * (std::get<1>(grad_tuple_ab) + std::get<2>(grad_tuple_ab));
228 std::get<3>(grad_tuple_rtn) = (a1 * a2 * hyper_2f1_dz) / b1;
230 if (calc_a1 || calc_a2 || calc_b1) {
232 = grad_2F1_impl_ab<calc_a1, calc_a2, calc_b1>(a1, a2, b1, z);
234 std::get<0>(grad_tuple_rtn) = std::get<0>(grad_tuple_ab);
237 std::get<1>(grad_tuple_rtn) = std::get<1>(grad_tuple_ab);
240 std::get<2>(grad_tuple_rtn) = std::get<2>(grad_tuple_ab);
244 return grad_tuple_rtn;
271template <
bool ReturnSameT,
typename T1,
typename T2,
typename T3,
typename T_z,
273inline auto grad_2F1(
const T1& a1,
const T2& a2,
const T3& b1,
const T_z& z,
274 double precision = 1
e-14,
int max_steps = 1e6) {
275 return internal::grad_2F1_impl<is_autodiff_v<T1>, is_autodiff_v<T2>,
276 is_autodiff_v<T3>, is_autodiff_v<T_z>>(
304template <
bool ReturnSameT,
typename T1,
typename T2,
typename T3,
typename T_z,
306inline auto grad_2F1(
const T1& a1,
const T2& a2,
const T3& b1,
const T_z& z,
307 double precision = 1
e-14,
int max_steps = 1e6) {
308 return internal::grad_2F1_impl<true, true, true, true>(a1, a2, b1, z,
309 precision, max_steps);
330template <
typename T1,
typename T2,
typename T3,
typename T_z>
331inline auto grad_2F1(
const T1& a1,
const T2& a2,
const T3& b1,
const T_z& z,
332 double precision = 1
e-14,
int max_steps = 1e6) {
333 return grad_2F1<false>(a1, a2, b1, z, precision, max_steps);
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
TupleT grad_2F1_impl(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Implementation function to calculate the gradients of the hypergeometric function,...
TupleT grad_2F1_impl_ab(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Implementation function to calculate the gradients of the hypergeometric function,...
double value_of_rec(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > abs(const fvar< T > &x)
static constexpr double e()
Return the base of the natural logarithm.
auto sign(const T &x)
Returns signs of the arguments.
auto pow(const T1 &x1, const T2 &x2)
T value_of(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > log(const fvar< T > &x)
static constexpr double NEGATIVE_INFTY
Negative infinity.
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
void check_2F1_converges(const char *function, const T_a1 &a1, const T_a2 &a2, const T_b1 &b1, const T_z &z)
Check if the hypergeometric function (2F1) called with supplied arguments will converge,...
return_type_t< Ta1, Ta2, Tb, Tz > hypergeometric_2F1(const Ta1 &a1, const Ta2 &a2, const Tb &b, const Tz &z)
Returns the Gauss hypergeometric function applied to the input arguments: .
fvar< T > log1m(const fvar< T > &x)
fvar< T > inv(const fvar< T > &x)
fvar< T > fabs(const fvar< T > &x)
fvar< T > square(const fvar< T > &x)
auto grad_2F1(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Calculate the gradients of the hypergeometric function (2F1) as the power series stopping when the se...
fvar< T > exp(const fvar< T > &x)
std::enable_if_t<!Check::value > require_not_t
If condition is false, template is disabled.
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...