1#ifndef STAN_MATH_PRIM_FUN_GRAD_2F1_HPP
2#define STAN_MATH_PRIM_FUN_GRAD_2F1_HPP
16#include <boost/optional.hpp>
48template <
bool calc_a1,
bool calc_a2,
bool calc_b1,
typename T1,
typename T2,
49 typename T3,
typename T_z,
50 typename ScalarT = return_type_t<T1, T2, T3, T_z>,
51 typename TupleT = std::tuple<ScalarT, ScalarT, ScalarT>>
53 double precision = 1
e-14,
int max_steps = 1e6) {
54 TupleT grad_tuple = TupleT(0, 0, 0);
60 using ScalarArrayT = Eigen::Array<ScalarT, 3, 1>;
61 ScalarArrayT log_g_old = ScalarArrayT::Constant(3, 1,
NEGATIVE_INFTY);
63 ScalarT log_t_old = 0.0;
64 ScalarT log_t_new = 0.0;
68 int log_t_new_sign = 1.0;
69 int log_t_old_sign = 1.0;
71 Eigen::Array<int, 3, 1> log_g_old_sign = Eigen::Array<int, 3, 1>::Ones(3);
75 const int min_steps = 5;
76 ScalarT inner_diff = 1;
77 ScalarArrayT g_current = ScalarArrayT::Zero(3);
79 while ((inner_diff > precision || k < min_steps) && k < max_steps) {
80 ScalarT p = ((a1 + k) * (a2 + k) / ((b1 + k) * (1.0 + k)));
84 log_t_new +=
log(
fabs(p)) + log_z;
89 = log_g_old_sign(0) * log_t_old_sign *
exp(log_g_old(0) - log_t_old)
91 log_g_old(0) = log_t_new +
log(
abs(term_a1));
93 g_current(0) = log_g_old_sign(0) *
exp(log_g_old(0)) * sign_zk;
94 std::get<0>(grad_tuple) += g_current(0);
99 = log_g_old_sign(1) * log_t_old_sign *
exp(log_g_old(1) - log_t_old)
101 log_g_old(1) = log_t_new +
log(
abs(term_a2));
103 g_current(1) = log_g_old_sign(1) *
exp(log_g_old(1)) * sign_zk;
104 std::get<1>(grad_tuple) += g_current(1);
109 = log_g_old_sign(2) * log_t_old_sign *
exp(log_g_old(2) - log_t_old)
111 log_g_old(2) = log_t_new +
log(
abs(term_b1));
113 g_current(2) = log_g_old_sign(2) *
exp(log_g_old(2)) * sign_zk;
114 std::get<2>(grad_tuple) += g_current(2);
117 inner_diff = g_current.array().abs().maxCoeff();
119 log_t_old = log_t_new;
120 log_t_old_sign = log_t_new_sign;
128 " iterations, hypergeometric function gradient "
129 "did not converge.");
166template <
bool calc_a1,
bool calc_a2,
bool calc_b1,
bool calc_z,
typename T1,
167 typename T2,
typename T3,
typename T_z,
169 typename TupleT = std::tuple<ScalarT, ScalarT, ScalarT, ScalarT>>
171 double precision = 1
e-14,
int max_steps = 1e6) {
172 bool euler_transform =
false;
175 }
catch (
const std::exception&
e) {
180 euler_transform =
true;
183 std::tuple<ScalarT, ScalarT, ScalarT> grad_tuple_ab;
184 TupleT grad_tuple_rtn = TupleT(0, 0, 0, 0);
185 if (euler_transform) {
186 ScalarT a1_euler = a2;
187 ScalarT a2_euler = b1 - a1;
188 ScalarT z_euler = z / (z - 1);
192 std::get<3>(grad_tuple_rtn)
193 = a2 *
pow(1 - z, -1 - a2) * hyper1
194 + (a2 * (b1 - a1) *
pow(1 - z, -a2)
195 * (
inv(z - 1) - z /
square(z - 1)) * hyper2)
198 if (calc_a1 || calc_a2 || calc_b1) {
201 constexpr bool calc_a1_euler = calc_a1 || calc_a2;
203 constexpr bool calc_a2_euler = calc_a1 || calc_a2 || calc_b1;
204 grad_tuple_ab = grad_2F1_impl_ab<calc_a1_euler, calc_a2_euler, calc_b1>(
205 a1_euler, a2_euler, b1, z_euler);
207 auto pre_mult_ab =
inv(
pow(1.0 - z, a2));
209 std::get<0>(grad_tuple_rtn) = -pre_mult_ab * std::get<1>(grad_tuple_ab);
213 std::get<1>(grad_tuple_rtn)
214 = -pre_mult_ab * hyper_da2 *
log1m(z)
215 + pre_mult_ab * std::get<0>(grad_tuple_ab);
218 std::get<2>(grad_tuple_rtn)
220 * (std::get<1>(grad_tuple_ab) + std::get<2>(grad_tuple_ab));
226 std::get<3>(grad_tuple_rtn) = (a1 * a2 * hyper_2f1_dz) / b1;
228 if (calc_a1 || calc_a2 || calc_b1) {
230 = grad_2F1_impl_ab<calc_a1, calc_a2, calc_b1>(a1, a2, b1, z);
232 std::get<0>(grad_tuple_rtn) = std::get<0>(grad_tuple_ab);
235 std::get<1>(grad_tuple_rtn) = std::get<1>(grad_tuple_ab);
238 std::get<2>(grad_tuple_rtn) = std::get<2>(grad_tuple_ab);
242 return grad_tuple_rtn;
269template <
bool ReturnSameT,
typename T1,
typename T2,
typename T3,
typename T_z,
271auto grad_2F1(
const T1& a1,
const T2& a2,
const T3& b1,
const T_z& z,
272 double precision = 1
e-14,
int max_steps = 1e6) {
302template <
bool ReturnSameT,
typename T1,
typename T2,
typename T3,
typename T_z,
304auto grad_2F1(
const T1& a1,
const T2& a2,
const T3& b1,
const T_z& z,
305 double precision = 1
e-14,
int max_steps = 1e6) {
306 return internal::grad_2F1_impl<true, true, true, true>(a1, a2, b1, z,
307 precision, max_steps);
328template <
typename T1,
typename T2,
typename T3,
typename T_z>
329auto grad_2F1(
const T1& a1,
const T2& a2,
const T3& b1,
const T_z& z,
330 double precision = 1
e-14,
int max_steps = 1e6) {
331 return grad_2F1<false>(a1, a2, b1, z, precision, max_steps);
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
TupleT grad_2F1_impl(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Implementation function to calculate the gradients of the hypergeometric function,...
TupleT grad_2F1_impl_ab(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Implementation function to calculate the gradients of the hypergeometric function,...
double value_of_rec(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > abs(const fvar< T > &x)
static constexpr double e()
Return the base of the natural logarithm.
auto sign(const T &x)
Returns signs of the arguments.
auto pow(const T1 &x1, const T2 &x2)
T value_of(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > log(const fvar< T > &x)
return_type_t< Ta1, Ta1, Tb, Tz > hypergeometric_2F1(const Ta1 &a1, const Ta2 &a2, const Tb &b, const Tz &z)
Returns the Gauss hypergeometric function applied to the input arguments: .
static constexpr double NEGATIVE_INFTY
Negative infinity.
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
void check_2F1_converges(const char *function, const T_a1 &a1, const T_a2 &a2, const T_b1 &b1, const T_z &z)
Check if the hypergeometric function (2F1) called with supplied arguments will converge,...
fvar< T > log1m(const fvar< T > &x)
fvar< T > inv(const fvar< T > &x)
fvar< T > fabs(const fvar< T > &x)
fvar< T > square(const fvar< T > &x)
auto grad_2F1(const T1 &a1, const T2 &a2, const T3 &b1, const T_z &z, double precision=1e-14, int max_steps=1e6)
Calculate the gradients of the hypergeometric function (2F1) as the power series stopping when the se...
fvar< T > exp(const fvar< T > &x)
std::enable_if_t<!Check::value > require_not_t
If condition is false, template is disabled.
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...