Automatic Differentiation
 
Loading...
Searching...
No Matches
grad_2F1.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_GRAD_2F1_HPP
2#define STAN_MATH_PRIM_FUN_GRAD_2F1_HPP
3
15#include <cmath>
16#include <boost/optional.hpp>
17
18namespace stan {
19namespace math {
20
21namespace internal {
48template <bool calc_a1, bool calc_a2, bool calc_b1, typename T1, typename T2,
49 typename T3, typename T_z,
50 typename ScalarT = return_type_t<T1, T2, T3, T_z>,
51 typename TupleT = std::tuple<ScalarT, ScalarT, ScalarT>>
52TupleT grad_2F1_impl_ab(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
53 double precision = 1e-14, int max_steps = 1e6) {
54 TupleT grad_tuple = TupleT(0, 0, 0);
55
56 if (z == 0) {
57 return grad_tuple;
58 }
59
60 using ScalarArrayT = Eigen::Array<ScalarT, 3, 1>;
61 ScalarArrayT log_g_old = ScalarArrayT::Constant(3, 1, NEGATIVE_INFTY);
62
63 ScalarT log_t_old = 0.0;
64 ScalarT log_t_new = 0.0;
65 int sign_z = sign(z);
66 auto log_z = log(abs(z));
67
68 int log_t_new_sign = 1.0;
69 int log_t_old_sign = 1.0;
70
71 Eigen::Array<int, 3, 1> log_g_old_sign = Eigen::Array<int, 3, 1>::Ones(3);
72
73 int sign_zk = sign_z;
74 int k = 0;
75 const int min_steps = 5;
76 ScalarT inner_diff = 1;
77 ScalarArrayT g_current = ScalarArrayT::Zero(3);
78
79 while ((inner_diff > precision || k < min_steps) && k < max_steps) {
80 ScalarT p = ((a1 + k) * (a2 + k) / ((b1 + k) * (1.0 + k)));
81 if (p == 0) {
82 return grad_tuple;
83 }
84 log_t_new += log(fabs(p)) + log_z;
85 log_t_new_sign = sign(value_of_rec(p)) * log_t_new_sign;
86
87 if (calc_a1) {
88 ScalarT term_a1
89 = log_g_old_sign(0) * log_t_old_sign * exp(log_g_old(0) - log_t_old)
90 + inv(a1 + k);
91 log_g_old(0) = log_t_new + log(abs(term_a1));
92 log_g_old_sign(0) = sign(value_of_rec(term_a1)) * log_t_new_sign;
93 g_current(0) = log_g_old_sign(0) * exp(log_g_old(0)) * sign_zk;
94 std::get<0>(grad_tuple) += g_current(0);
95 }
96
97 if (calc_a2) {
98 ScalarT term_a2
99 = log_g_old_sign(1) * log_t_old_sign * exp(log_g_old(1) - log_t_old)
100 + inv(a2 + k);
101 log_g_old(1) = log_t_new + log(abs(term_a2));
102 log_g_old_sign(1) = sign(value_of_rec(term_a2)) * log_t_new_sign;
103 g_current(1) = log_g_old_sign(1) * exp(log_g_old(1)) * sign_zk;
104 std::get<1>(grad_tuple) += g_current(1);
105 }
106
107 if (calc_b1) {
108 ScalarT term_b1
109 = log_g_old_sign(2) * log_t_old_sign * exp(log_g_old(2) - log_t_old)
110 + inv(-(b1 + k));
111 log_g_old(2) = log_t_new + log(abs(term_b1));
112 log_g_old_sign(2) = sign(value_of_rec(term_b1)) * log_t_new_sign;
113 g_current(2) = log_g_old_sign(2) * exp(log_g_old(2)) * sign_zk;
114 std::get<2>(grad_tuple) += g_current(2);
115 }
116
117 inner_diff = g_current.array().abs().maxCoeff();
118
119 log_t_old = log_t_new;
120 log_t_old_sign = log_t_new_sign;
121 sign_zk *= sign_z;
122 ++k;
123 }
124
125 if (k > max_steps) {
126 throw_domain_error("grad_2F1", "k (internal counter)", max_steps,
127 "exceeded ",
128 " iterations, hypergeometric function gradient "
129 "did not converge.");
130 }
131 return grad_tuple;
132}
133
166template <bool calc_a1, bool calc_a2, bool calc_b1, bool calc_z, typename T1,
167 typename T2, typename T3, typename T_z,
168 typename ScalarT = return_type_t<T1, T2, T3, T_z>,
169 typename TupleT = std::tuple<ScalarT, ScalarT, ScalarT, ScalarT>>
170TupleT grad_2F1_impl(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
171 double precision = 1e-14, int max_steps = 1e6) {
172 bool euler_transform = false;
173 try {
174 check_2F1_converges("hypergeometric_2F1", a1, a2, b1, z);
175 } catch (const std::exception& e) {
176 // Apply Euler's hypergeometric transformation if function
177 // will not converge with current arguments
178 check_2F1_converges("hypergeometric_2F1 (euler transform)", b1 - a1, a2, b1,
179 z / (z - 1));
180 euler_transform = true;
181 }
182
183 std::tuple<ScalarT, ScalarT, ScalarT> grad_tuple_ab;
184 TupleT grad_tuple_rtn = TupleT(0, 0, 0, 0);
185 if (euler_transform) {
186 ScalarT a1_euler = a2;
187 ScalarT a2_euler = b1 - a1;
188 ScalarT z_euler = z / (z - 1);
189 if (calc_z) {
190 auto hyper1 = hypergeometric_2F1(a1_euler, a2_euler, b1, z_euler);
191 auto hyper2 = hypergeometric_2F1(1 + a2, 1 - a1 + b1, 1 + b1, z_euler);
192 std::get<3>(grad_tuple_rtn)
193 = a2 * pow(1 - z, -1 - a2) * hyper1
194 + (a2 * (b1 - a1) * pow(1 - z, -a2)
195 * (inv(z - 1) - z / square(z - 1)) * hyper2)
196 / b1;
197 }
198 if (calc_a1 || calc_a2 || calc_b1) {
199 // 'a' gradients under Euler transform are constructed using the gradients
200 // of both elements, so need to compute both if any are required
201 constexpr bool calc_a1_euler = calc_a1 || calc_a2;
202 // 'b' gradients under Euler transform require gradients from 'a2'
203 constexpr bool calc_a2_euler = calc_a1 || calc_a2 || calc_b1;
204 grad_tuple_ab = grad_2F1_impl_ab<calc_a1_euler, calc_a2_euler, calc_b1>(
205 a1_euler, a2_euler, b1, z_euler);
206
207 auto pre_mult_ab = inv(pow(1.0 - z, a2));
208 if (calc_a1) {
209 std::get<0>(grad_tuple_rtn) = -pre_mult_ab * std::get<1>(grad_tuple_ab);
210 }
211 if (calc_a2) {
212 auto hyper_da2 = hypergeometric_2F1(a1_euler, a2, b1, z_euler);
213 std::get<1>(grad_tuple_rtn)
214 = -pre_mult_ab * hyper_da2 * log1m(z)
215 + pre_mult_ab * std::get<0>(grad_tuple_ab);
216 }
217 if (calc_b1) {
218 std::get<2>(grad_tuple_rtn)
219 = pre_mult_ab
220 * (std::get<1>(grad_tuple_ab) + std::get<2>(grad_tuple_ab));
221 }
222 }
223 } else {
224 if (calc_z) {
225 auto hyper_2f1_dz = hypergeometric_2F1(a1 + 1.0, a2 + 1.0, b1 + 1.0, z);
226 std::get<3>(grad_tuple_rtn) = (a1 * a2 * hyper_2f1_dz) / b1;
227 }
228 if (calc_a1 || calc_a2 || calc_b1) {
229 grad_tuple_ab
230 = grad_2F1_impl_ab<calc_a1, calc_a2, calc_b1>(a1, a2, b1, z);
231 if (calc_a1) {
232 std::get<0>(grad_tuple_rtn) = std::get<0>(grad_tuple_ab);
233 }
234 if (calc_a2) {
235 std::get<1>(grad_tuple_rtn) = std::get<1>(grad_tuple_ab);
236 }
237 if (calc_b1) {
238 std::get<2>(grad_tuple_rtn) = std::get<2>(grad_tuple_ab);
239 }
240 }
241 }
242 return grad_tuple_rtn;
243}
244} // namespace internal
245
269template <bool ReturnSameT, typename T1, typename T2, typename T3, typename T_z,
271auto grad_2F1(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
272 double precision = 1e-14, int max_steps = 1e6) {
276 value_of(z), precision, max_steps);
277}
278
302template <bool ReturnSameT, typename T1, typename T2, typename T3, typename T_z,
304auto grad_2F1(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
305 double precision = 1e-14, int max_steps = 1e6) {
306 return internal::grad_2F1_impl<true, true, true, true>(a1, a2, b1, z,
307 precision, max_steps);
308}
309
328template <typename T1, typename T2, typename T3, typename T_z>
329auto grad_2F1(const T1& a1, const T2& a2, const T3& b1, const T_z& z,
330 double precision = 1e-14, int max_steps = 1e6) {
331 return grad_2F1<false>(a1, a2, b1, z, precision, max_steps);
332}
333
334} // namespace math
335} // namespace stan
336#endif
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
TupleT grad_2F1_impl(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Implementation function to calculate the gradients of the hypergeometric function,...
Definition grad_2F1.hpp:170
TupleT grad_2F1_impl_ab(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Implementation function to calculate the gradients of the hypergeometric function,...
Definition grad_2F1.hpp:52
double value_of_rec(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > abs(const fvar< T > &x)
Definition abs.hpp:15
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
auto sign(const T &x)
Returns signs of the arguments.
Definition sign.hpp:18
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
return_type_t< Ta1, Ta1, Tb, Tz > hypergeometric_2F1(const Ta1 &a1, const Ta2 &a2, const Tb &b, const Tz &z)
Returns the Gauss hypergeometric function applied to the input arguments: .
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
fvar< T > pow(const fvar< T > &x1, const fvar< T > &x2)
Definition pow.hpp:19
void check_2F1_converges(const char *function, const T_a1 &a1, const T_a2 &a2, const T_b1 &b1, const T_z &z)
Check if the hypergeometric function (2F1) called with supplied arguments will converge,...
fvar< T > log1m(const fvar< T > &x)
Definition log1m.hpp:12
fvar< T > inv(const fvar< T > &x)
Definition inv.hpp:12
fvar< T > fabs(const fvar< T > &x)
Definition fabs.hpp:15
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
auto grad_2F1(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Calculate the gradients of the hypergeometric function (2F1) as the power series stopping when the se...
Definition grad_2F1.hpp:271
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
std::enable_if_t<!Check::value > require_not_t
If condition is false, template is disabled.
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...