1#ifndef STAN_MATH_PRIM_FUN_GRAD_F32_HPP
2#define STAN_MATH_PRIM_FUN_GRAD_F32_HPP
51template <
bool grad_a1 =
true,
bool grad_a2 =
true,
bool grad_a3 =
true,
52 bool grad_b1 =
true,
bool grad_b2 =
true,
bool grad_z =
true,
53 typename T1,
typename T2,
typename T3,
typename T4,
typename T5,
54 typename T6,
typename T7,
typename T8 =
double>
55void grad_F32(T1* g,
const T2& a1,
const T3& a2,
const T4& a3,
const T5& b1,
56 const T6& b2,
const T7& z,
const T8& precision = 1
e-6,
57 int max_steps = 1e5) {
60 for (
int i = 0; i < 6; ++i) {
65 for (
auto& x : log_g_old) {
74 T1 log_t_new_sign = 1.0;
75 T1 log_t_old_sign = 1.0;
77 for (
int i = 0; i < 6; ++i) {
78 log_g_old_sign[i] = 1.0;
80 std::array<T1, 6> term{0};
81 for (
int k = 0; k <= max_steps; ++k) {
82 T1 p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
87 log_t_new +=
log(
fabs(p)) + log_z;
88 log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
89 if constexpr (grad_a1) {
91 = log_g_old_sign[0] * log_t_old_sign *
exp(log_g_old[0] - log_t_old)
93 log_g_old[0] = log_t_new +
log(
fabs(term[0]));
94 log_g_old_sign[0] = term[0] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
95 g[0] += log_g_old_sign[0] *
exp(log_g_old[0]);
98 if constexpr (grad_a2) {
100 = log_g_old_sign[1] * log_t_old_sign *
exp(log_g_old[1] - log_t_old)
102 log_g_old[1] = log_t_new +
log(
fabs(term[1]));
103 log_g_old_sign[1] = term[1] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
104 g[1] += log_g_old_sign[1] *
exp(log_g_old[1]);
107 if constexpr (grad_a3) {
109 = log_g_old_sign[2] * log_t_old_sign *
exp(log_g_old[2] - log_t_old)
111 log_g_old[2] = log_t_new +
log(
fabs(term[2]));
112 log_g_old_sign[2] = term[2] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
113 g[2] += log_g_old_sign[2] *
exp(log_g_old[2]);
116 if constexpr (grad_b1) {
118 = log_g_old_sign[3] * log_t_old_sign *
exp(log_g_old[3] - log_t_old)
120 log_g_old[3] = log_t_new +
log(
fabs(term[3]));
121 log_g_old_sign[3] = term[3] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
122 g[3] += log_g_old_sign[3] *
exp(log_g_old[3]);
125 if constexpr (grad_b2) {
127 = log_g_old_sign[4] * log_t_old_sign *
exp(log_g_old[4] - log_t_old)
129 log_g_old[4] = log_t_new +
log(
fabs(term[4]));
130 log_g_old_sign[4] = term[4] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
131 g[4] += log_g_old_sign[4] *
exp(log_g_old[4]);
134 if constexpr (grad_z) {
136 = log_g_old_sign[5] * log_t_old_sign *
exp(log_g_old[5] - log_t_old)
138 log_g_old[5] = log_t_new +
log(
fabs(term[5]));
139 log_g_old_sign[5] = term[5] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
140 g[5] += log_g_old_sign[5] *
exp(log_g_old[5]);
143 if (log_t_new <=
log(precision)) {
147 log_t_old = log_t_new;
148 log_t_old_sign = log_t_new_sign;
151 " iterations, hypergeometric function gradient "
152 "did not converge.");
void check_3F2_converges(const char *function, const T_a1 &a1, const T_a2 &a2, const T_a3 &a3, const T_b1 &b1, const T_b2 &b2, const T_z &z)
Check if the hypergeometric function (3F2) called with supplied arguments will converge,...
static constexpr double e()
Return the base of the natural logarithm.
void grad_F32(T1 *g, const T2 &a1, const T3 &a2, const T4 &a3, const T5 &b1, const T6 &b2, const T7 &z, const T8 &precision=1e-6, int max_steps=1e5)
Gradients of the hypergeometric function, 3F2.
fvar< T > log(const fvar< T > &x)
static constexpr double NEGATIVE_INFTY
Negative infinity.
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
fvar< T > inv(const fvar< T > &x)
fvar< T > fabs(const fvar< T > &x)
fvar< T > exp(const fvar< T > &x)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...