Automatic Differentiation
 
Loading...
Searching...
No Matches
grad_F32.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_GRAD_F32_HPP
2#define STAN_MATH_PRIM_FUN_GRAD_F32_HPP
3
11#include <cmath>
12
13namespace stan {
14namespace math {
15
51template <bool grad_a1 = true, bool grad_a2 = true, bool grad_a3 = true,
52 bool grad_b1 = true, bool grad_b2 = true, bool grad_z = true,
53 typename T1, typename T2, typename T3, typename T4, typename T5,
54 typename T6, typename T7, typename T8 = double>
55void grad_F32(T1* g, const T2& a1, const T3& a2, const T4& a3, const T5& b1,
56 const T6& b2, const T7& z, const T8& precision = 1e-6,
57 int max_steps = 1e5) {
58 check_3F2_converges("grad_F32", a1, a2, a3, b1, b2, z);
59
60 for (int i = 0; i < 6; ++i) {
61 g[i] = 0.0;
62 }
63
64 T1 log_g_old[6];
65 for (auto& x : log_g_old) {
67 }
68
69 T1 log_t_old = 0.0;
70 T1 log_t_new = 0.0;
71
72 T7 log_z = log(z);
73
74 T1 log_t_new_sign = 1.0;
75 T1 log_t_old_sign = 1.0;
76 T1 log_g_old_sign[6];
77 for (int i = 0; i < 6; ++i) {
78 log_g_old_sign[i] = 1.0;
79 }
80 std::array<T1, 6> term{0};
81 for (int k = 0; k <= max_steps; ++k) {
82 T1 p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
83 if (p == 0) {
84 return;
85 }
86
87 log_t_new += log(fabs(p)) + log_z;
88 log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
89 if constexpr (grad_a1) {
90 term[0]
91 = log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
92 + inv(a1 + k);
93 log_g_old[0] = log_t_new + log(fabs(term[0]));
94 log_g_old_sign[0] = term[0] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
95 g[0] += log_g_old_sign[0] * exp(log_g_old[0]);
96 }
97
98 if constexpr (grad_a2) {
99 term[1]
100 = log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
101 + inv(a2 + k);
102 log_g_old[1] = log_t_new + log(fabs(term[1]));
103 log_g_old_sign[1] = term[1] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
104 g[1] += log_g_old_sign[1] * exp(log_g_old[1]);
105 }
106
107 if constexpr (grad_a3) {
108 term[2]
109 = log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
110 + inv(a3 + k);
111 log_g_old[2] = log_t_new + log(fabs(term[2]));
112 log_g_old_sign[2] = term[2] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
113 g[2] += log_g_old_sign[2] * exp(log_g_old[2]);
114 }
115
116 if constexpr (grad_b1) {
117 term[3]
118 = log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
119 - inv(b1 + k);
120 log_g_old[3] = log_t_new + log(fabs(term[3]));
121 log_g_old_sign[3] = term[3] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
122 g[3] += log_g_old_sign[3] * exp(log_g_old[3]);
123 }
124
125 if constexpr (grad_b2) {
126 term[4]
127 = log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
128 - inv(b2 + k);
129 log_g_old[4] = log_t_new + log(fabs(term[4]));
130 log_g_old_sign[4] = term[4] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
131 g[4] += log_g_old_sign[4] * exp(log_g_old[4]);
132 }
133
134 if constexpr (grad_z) {
135 term[5]
136 = log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
137 + inv(z);
138 log_g_old[5] = log_t_new + log(fabs(term[5]));
139 log_g_old_sign[5] = term[5] >= 0.0 ? log_t_new_sign : -log_t_new_sign;
140 g[5] += log_g_old_sign[5] * exp(log_g_old[5]);
141 }
142
143 if (log_t_new <= log(precision)) {
144 return; // implicit abs
145 }
146
147 log_t_old = log_t_new;
148 log_t_old_sign = log_t_new_sign;
149 }
150 throw_domain_error("grad_F32", "k (internal counter)", max_steps, "exceeded ",
151 " iterations, hypergeometric function gradient "
152 "did not converge.");
153 return;
154}
155
156} // namespace math
157} // namespace stan
158#endif
void check_3F2_converges(const char *function, const T_a1 &a1, const T_a2 &a2, const T_a3 &a3, const T_b1 &b1, const T_b2 &b2, const T_z &z)
Check if the hypergeometric function (3F2) called with supplied arguments will converge,...
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
void grad_F32(T1 *g, const T2 &a1, const T3 &a2, const T4 &a3, const T5 &b1, const T6 &b2, const T7 &z, const T8 &precision=1e-6, int max_steps=1e5)
Gradients of the hypergeometric function, 3F2.
Definition grad_F32.hpp:55
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
fvar< T > inv(const fvar< T > &x)
Definition inv.hpp:13
fvar< T > fabs(const fvar< T > &x)
Definition fabs.hpp:16
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...