Automatic Differentiation
 
Loading...
Searching...
No Matches
hypergeometric_2F1.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_2F1_HPP
2#define STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_2F1_HPP
3
20#include <boost/optional.hpp>
21
22namespace stan {
23namespace math {
24namespace internal {
25
45template <typename Ta1, typename Ta2, typename Tb, typename Tz,
46 typename RtnT = boost::optional<return_type_t<Ta1, Ta1, Tb, Tz>>,
47 require_all_arithmetic_t<Ta1, Ta2, Tb, Tz>* = nullptr>
48inline RtnT hyper_2F1_special_cases(const Ta1& a1, const Ta2& a2, const Tb& b,
49 const Tz& z) {
50 // https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric2F1/03/01/
51 // // NOLINT
52 if (z == 0.0) {
53 return 1.0;
54 }
55
56 // https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric2F1/03/06/01/0001/
57 // // NOLINT
58 if (a1 == b) {
59 return inv(pow(1.0 - z, a2));
60 }
61
62 // https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric2F1/03/06/01/0003/
63 // // NOLINT
64 if (b == (a2 - 1.0)) {
65 return (pow((1.0 - z), -a1 - 1.0) * (a2 + z * (a1 - a2 + 1.0) - 1.0))
66 / (a2 - 1);
67 }
68
69 if (a1 == a2) {
70 // https://www.wolframalpha.com/input?i=Hypergeometric2F1%281%2C+1%2C+2%2C+-z%29
71 // // NOLINT
72 if (a1 == 1.0 && b == 2.0 && z < 0) {
73 auto pos_z = abs(z);
74 return log1p(pos_z) / pos_z;
75 }
76
77 if (a1 == 0.5 && b == 1.5 && z < 1.0) {
78 auto sqrt_z = sqrt(abs(z));
79 auto numerator
80 = (z > 0.0)
81 // https://www.wolframalpha.com/input?i=Hypergeometric2F1%281%2F2%2C+1%2F2%2C+3%2F2%2C+z%29
82 // // NOLINT
83 ? asin(sqrt_z)
84 // https://www.wolframalpha.com/input?i=Hypergeometric2F1%281%2F2%2C+1%2F2%2C+3%2F2%2C+-z%29
85 // // NOLINT
86 : asinh(sqrt_z);
87 return numerator / sqrt_z;
88 }
89
90 // https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric2F1/03/04/03/
91 // // NOLINT
92 if (b == (a1 + 1) && z == 0.5) {
93 return pow(2, a1 - 1) * a1
94 * (digamma((a1 + 1) / 2.0) - digamma(a1 / 2.0));
95 }
96 }
97
98 if (z == 1.0) {
99 // https://www.wolframalpha.com/input?i=Hypergeometric2F1%28a1%2C+a2%2C+a1+%2B+a2+%2B+2%2C+1%29
100 // // NOLINT
101 if (b == (a1 + a2 + 2)) {
102 auto log_2f1 = lgamma(b) - (lgamma(a1 + 2) + lgamma(a2 + 2));
103 return exp(log_2f1);
104 // https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric2F1/03/02/0001/
105 // // NOLINT
106 } else if (b > (a1 + a2)) {
107 auto log_2f1 = (lgamma(b) + lgamma(b - a1 - a2))
108 - (lgamma(b - a1) + lgamma(b - a2));
109 return exp(log_2f1);
110 }
111 }
112
113 // https://www.wolframalpha.com/input?i=Hypergeometric2F1%283%2F2%2C+2%2C+3%2C+-z%29
114 // // NOLINT
115 if (a1 == 1.5 && a2 == 2.0 && b == 3.0 && z < 0.0) {
116 auto abs_z = abs(z);
117 auto sqrt_1pz = sqrt(1 + abs_z);
118 return -4 * (2 * sqrt_1pz + z - 2) / (sqrt_1pz * square(z));
119 }
120
121 return {};
122}
123} // namespace internal
124
150template <typename Ta1, typename Ta2, typename Tb, typename Tz,
151 typename ScalarT = return_type_t<Ta1, Ta1, Tb, Tz>,
152 typename OptT = boost::optional<ScalarT>,
155 const Ta2& a2,
156 const Tb& b,
157 const Tz& z) {
158 check_finite("hypergeometric_2F1", "a1", a1);
159 check_finite("hypergeometric_2F1", "a2", a2);
160 check_finite("hypergeometric_2F1", "b", b);
161 check_finite("hypergeometric_2F1", "z", z);
162
163 check_not_nan("hypergeometric_2F1", "a1", a1);
164 check_not_nan("hypergeometric_2F1", "a2", a2);
165 check_not_nan("hypergeometric_2F1", "b", b);
166 check_not_nan("hypergeometric_2F1", "z", z);
167
168 // Check whether value can be calculated by any special-case rules
169 // before estimating infinite sum
170 OptT special_case_a1a2 = internal::hyper_2F1_special_cases(a1, a2, b, z);
171 if (special_case_a1a2.is_initialized()) {
172 return special_case_a1a2.get();
173 }
174
175 // Check whether any special case rules apply with 'a' arguments reversed
176 // as 2F1(a1, a2, b, z) = 2F1(a2, a1, b, z)
177 OptT special_case_a2a1 = internal::hyper_2F1_special_cases(a2, a1, b, z);
178 if (special_case_a2a1.is_initialized()) {
179 return special_case_a2a1.get();
180 }
181
182 Eigen::Matrix<double, 2, 1> a_args(2);
183 Eigen::Matrix<double, 1, 1> b_args(1);
184
185 try {
186 check_2F1_converges("hypergeometric_2F1", a1, a2, b, z);
187
188 a_args << a1, a2;
189 b_args << b;
190 return hypergeometric_pFq(a_args, b_args, z);
191 } catch (const std::exception& e) {
192 // Apply Euler's hypergeometric transformation if function
193 // will not converge with current arguments
194 ScalarT a1_t = b - a1;
195 ScalarT a2_t = a2;
196 ScalarT b_t = b;
197 ScalarT z_t = z / (z - 1);
198
199 check_2F1_converges("hypergeometric_2F1", a1_t, a2_t, b_t, z_t);
200
201 a_args << a1_t, a2_t;
202 b_args << b_t;
203 return hypergeometric_pFq(a_args, b_args, z_t) / pow(1 - z, a2);
204 }
205}
206} // namespace math
207} // namespace stan
208#endif
require_all_t< std::is_arithmetic< std::decay_t< Types > >... > require_all_arithmetic_t
Require all of the types satisfy std::is_arithmetic.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
RtnT hyper_2F1_special_cases(const Ta1 &a1, const Ta2 &a2, const Tb &b, const Tz &z)
Calculate the Gauss Hypergeometric (2F1) function for special-case combinations of parameters which c...
fvar< T > abs(const fvar< T > &x)
Definition abs.hpp:15
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
fvar< T > asinh(const fvar< T > &x)
Definition asinh.hpp:15
return_type_t< Ta1, Ta1, Tb, Tz > hypergeometric_2F1(const Ta1 &a1, const Ta2 &a2, const Tb &b, const Tz &z)
Returns the Gauss hypergeometric function applied to the input arguments: .
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:17
fvar< T > log1p(const fvar< T > &x)
Definition log1p.hpp:12
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
Definition lgamma.hpp:21
void check_not_nan(const char *function, const char *name, const T_y &y)
Check if y is not NaN.
FvarT hypergeometric_pFq(const Ta &a, const Tb &b, const Tz &z)
Returns the generalized hypergeometric (pFq) function applied to the input arguments.
fvar< T > pow(const fvar< T > &x1, const fvar< T > &x2)
Definition pow.hpp:19
void check_2F1_converges(const char *function, const T_a1 &a1, const T_a2 &a2, const T_b1 &b1, const T_z &z)
Check if the hypergeometric function (2F1) called with supplied arguments will converge,...
fvar< T > asin(const fvar< T > &x)
Definition asin.hpp:15
fvar< T > inv(const fvar< T > &x)
Definition inv.hpp:12
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
Definition digamma.hpp:23
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9