Automatic Differentiation
 
Loading...
Searching...
No Matches
grad_2F1.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_GRAD_2F1_HPP
2#define STAN_MATH_PRIM_FUN_GRAD_2F1_HPP
3
15#include <cmath>
16#include <boost/optional.hpp>
17
18namespace stan {
19namespace math {
20
21namespace internal {
48template <bool calc_a1, bool calc_a2, bool calc_b1, typename T1, typename T2,
49 typename T3, typename T_z,
50 typename ScalarT = return_type_t<T1, T2, T3, T_z>,
51 typename TupleT = std::tuple<ScalarT, ScalarT, ScalarT>>
52inline TupleT grad_2F1_impl_ab(const T1& a1, const T2& a2, const T3& b1,
53 const T_z& z, double precision = 1e-14,
54 int max_steps = 1e6) {
55 TupleT grad_tuple = TupleT(0, 0, 0);
56
57 if (z == 0) {
58 return grad_tuple;
59 }
60
61 using ScalarArrayT = Eigen::Array<ScalarT, 3, 1>;
62 ScalarArrayT log_g_old = ScalarArrayT::Constant(3, 1, NEGATIVE_INFTY);
63
64 ScalarT log_t_old = 0.0;
65 ScalarT log_t_new = 0.0;
66 int sign_z = sign(z);
67 auto log_z = log(abs(z));
68
69 int log_t_new_sign = 1.0;
70 int log_t_old_sign = 1.0;
71
72 Eigen::Array<int, 3, 1> log_g_old_sign = Eigen::Array<int, 3, 1>::Ones(3);
73
74 int sign_zk = sign_z;
75 int k = 0;
76 const int min_steps = 5;
77 ScalarT inner_diff = 1;
78 ScalarArrayT g_current = ScalarArrayT::Zero(3);
79
80 while ((inner_diff > precision || k < min_steps) && k < max_steps) {
81 ScalarT p = ((a1 + k) * (a2 + k) / ((b1 + k) * (1.0 + k)));
82 if (p == 0) {
83 return grad_tuple;
84 }
85 log_t_new += log(fabs(p)) + log_z;
86 log_t_new_sign = sign(value_of_rec(p)) * log_t_new_sign;
87
88 if (calc_a1) {
89 ScalarT term_a1
90 = log_g_old_sign(0) * log_t_old_sign * exp(log_g_old(0) - log_t_old)
91 + inv(a1 + k);
92 log_g_old(0) = log_t_new + log(abs(term_a1));
93 log_g_old_sign(0) = sign(value_of_rec(term_a1)) * log_t_new_sign;
94 g_current(0) = log_g_old_sign(0) * exp(log_g_old(0)) * sign_zk;
95 std::get<0>(grad_tuple) += g_current(0);
96 }
97
98 if (calc_a2) {
99 ScalarT term_a2
100 = log_g_old_sign(1) * log_t_old_sign * exp(log_g_old(1) - log_t_old)
101 + inv(a2 + k);
102 log_g_old(1) = log_t_new + log(abs(term_a2));
103 log_g_old_sign(1) = sign(value_of_rec(term_a2)) * log_t_new_sign;
104 g_current(1) = log_g_old_sign(1) * exp(log_g_old(1)) * sign_zk;
105 std::get<1>(grad_tuple) += g_current(1);
106 }
107
108 if (calc_b1) {
109 ScalarT term_b1
110 = log_g_old_sign(2) * log_t_old_sign * exp(log_g_old(2) - log_t_old)
111 + inv(-(b1 + k));
112 log_g_old(2) = log_t_new + log(abs(term_b1));
113 log_g_old_sign(2) = sign(value_of_rec(term_b1)) * log_t_new_sign;
114 g_current(2) = log_g_old_sign(2) * exp(log_g_old(2)) * sign_zk;
115 std::get<2>(grad_tuple) += g_current(2);
116 }
117
118 inner_diff = g_current.array().abs().maxCoeff();
119
120 log_t_old = log_t_new;
121 log_t_old_sign = log_t_new_sign;
122 sign_zk *= sign_z;
123 ++k;
124 }
125
126 if (k > max_steps) {
127 throw_domain_error("grad_2F1", "k (internal counter)", max_steps,
128 "exceeded ",
129 " iterations, hypergeometric function gradient "
130 "did not converge.");
131 }
132 return grad_tuple;
133}
134
167template <bool calc_a1, bool calc_a2, bool calc_b1, bool calc_z, typename T1,
168 typename T2, typename T3, typename T_z,
169 typename ScalarT = return_type_t<T1, T2, T3, T_z>,
170 typename TupleT = std::tuple<ScalarT, ScalarT, ScalarT, ScalarT>>
171inline TupleT grad_2F1_impl(const T1& a1, const T2& a2, const T3& b1,
172 const T_z& z, double precision = 1e-14,
173 int max_steps = 1e6) {
174 bool euler_transform = false;
175 try {
176 check_2F1_converges("hypergeometric_2F1", a1, a2, b1, z);
177 } catch (const std::exception& e) {
178 // Apply Euler's hypergeometric transformation if function
179 // will not converge with current arguments
180 check_2F1_converges("hypergeometric_2F1 (euler transform)", b1 - a1, a2, b1,
181 z / (z - 1));
182 euler_transform = true;
183 }
184
185 std::tuple<ScalarT, ScalarT, ScalarT> grad_tuple_ab;
186 TupleT grad_tuple_rtn = TupleT(0, 0, 0, 0);
187 if (euler_transform) {
188 ScalarT a1_euler = a2;
189 ScalarT a2_euler = b1 - a1;
190 ScalarT z_euler = z / (z - 1);
191 if (calc_z) {
192 auto hyper1 = hypergeometric_2F1(a1_euler, a2_euler, b1, z_euler);
193 auto hyper2 = hypergeometric_2F1(1 + a2, 1 - a1 + b1, 1 + b1, z_euler);
194 std::get<3>(grad_tuple_rtn)
195 = a2 * pow(1 - z, -1 - a2) * hyper1
196 + (a2 * (b1 - a1) * pow(1 - z, -a2)
197 * (inv(z - 1) - z / square(z - 1)) * hyper2)
198 / b1;
199 }
200 if (calc_a1 || calc_a2 || calc_b1) {
201 // 'a' gradients under Euler transform are constructed using the gradients
202 // of both elements, so need to compute both if any are required
203 constexpr bool calc_a1_euler = calc_a1 || calc_a2;
204 // 'b' gradients under Euler transform require gradients from 'a2'
205 constexpr bool calc_a2_euler = calc_a1 || calc_a2 || calc_b1;
206 grad_tuple_ab = grad_2F1_impl_ab<calc_a1_euler, calc_a2_euler, calc_b1>(
207 a1_euler, a2_euler, b1, z_euler);
208
209 auto pre_mult_ab = inv(pow(1.0 - z, a2));
210 if (calc_a1) {
211 std::get<0>(grad_tuple_rtn) = -pre_mult_ab * std::get<1>(grad_tuple_ab);
212 }
213 if (calc_a2) {
214 auto hyper_da2 = hypergeometric_2F1(a1_euler, a2, b1, z_euler);
215 std::get<1>(grad_tuple_rtn)
216 = -pre_mult_ab * hyper_da2 * log1m(z)
217 + pre_mult_ab * std::get<0>(grad_tuple_ab);
218 }
219 if (calc_b1) {
220 std::get<2>(grad_tuple_rtn)
221 = pre_mult_ab
222 * (std::get<1>(grad_tuple_ab) + std::get<2>(grad_tuple_ab));
223 }
224 }
225 } else {
226 if (calc_z) {
227 auto hyper_2f1_dz = hypergeometric_2F1(a1 + 1.0, a2 + 1.0, b1 + 1.0, z);
228 std::get<3>(grad_tuple_rtn) = (a1 * a2 * hyper_2f1_dz) / b1;
229 }
230 if (calc_a1 || calc_a2 || calc_b1) {
231 grad_tuple_ab
232 = grad_2F1_impl_ab<calc_a1, calc_a2, calc_b1>(a1, a2, b1, z);
233 if (calc_a1) {
234 std::get<0>(grad_tuple_rtn) = std::get<0>(grad_tuple_ab);
235 }
236 if (calc_a2) {
237 std::get<1>(grad_tuple_rtn) = std::get<1>(grad_tuple_ab);
238 }
239 if (calc_b1) {
240 std::get<2>(grad_tuple_rtn) = std::get<2>(grad_tuple_ab);
241 }
242 }
243 }
244 return grad_tuple_rtn;
245}
246} // namespace internal
247
271template <bool ReturnSameT, typename T1, typename T2, typename T3, typename T_z,
273inline auto grad_2F1(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
274 double precision = 1e-14, int max_steps = 1e6) {
275 return internal::grad_2F1_impl<is_autodiff_v<T1>, is_autodiff_v<T2>,
276 is_autodiff_v<T3>, is_autodiff_v<T_z>>(
277 value_of(a1), value_of(a2), value_of(b1), value_of(z), precision,
278 max_steps);
279}
280
304template <bool ReturnSameT, typename T1, typename T2, typename T3, typename T_z,
306inline auto grad_2F1(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
307 double precision = 1e-14, int max_steps = 1e6) {
308 return internal::grad_2F1_impl<true, true, true, true>(a1, a2, b1, z,
309 precision, max_steps);
310}
311
330template <typename T1, typename T2, typename T3, typename T_z>
331inline auto grad_2F1(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
332 double precision = 1e-14, int max_steps = 1e6) {
333 return grad_2F1<false>(a1, a2, b1, z, precision, max_steps);
334}
335
336} // namespace math
337} // namespace stan
338#endif
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
TupleT grad_2F1_impl(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Implementation function to calculate the gradients of the hypergeometric function,...
Definition grad_2F1.hpp:171
TupleT grad_2F1_impl_ab(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Implementation function to calculate the gradients of the hypergeometric function,...
Definition grad_2F1.hpp:52
double value_of_rec(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > abs(const fvar< T > &x)
Definition abs.hpp:15
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
auto sign(const T &x)
Returns signs of the arguments.
Definition sign.hpp:18
auto pow(const T1 &x1, const T2 &x2)
Definition pow.hpp:32
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
void check_2F1_converges(const char *function, const T_a1 &a1, const T_a2 &a2, const T_b1 &b1, const T_z &z)
Check if the hypergeometric function (2F1) called with supplied arguments will converge,...
return_type_t< Ta1, Ta2, Tb, Tz > hypergeometric_2F1(const Ta1 &a1, const Ta2 &a2, const Tb &b, const Tz &z)
Returns the Gauss hypergeometric function applied to the input arguments: .
fvar< T > log1m(const fvar< T > &x)
Definition log1m.hpp:12
fvar< T > inv(const fvar< T > &x)
Definition inv.hpp:13
fvar< T > fabs(const fvar< T > &x)
Definition fabs.hpp:16
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
auto grad_2F1(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Calculate the gradients of the hypergeometric function (2F1) as the power series stopping when the se...
Definition grad_2F1.hpp:273
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
std::enable_if_t<!Check::value > require_not_t
If condition is false, template is disabled.
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...