1#ifndef STAN_MATH_PRIM_FUN_GRAD_PFQ_HPP
2#define STAN_MATH_PRIM_FUN_GRAD_PFQ_HPP
87template <
bool calc_a =
true,
bool calc_b =
true,
bool calc_z =
true,
88 typename TpFq,
typename Ta,
typename Tb,
typename Tz,
89 typename T_Rtn = return_type_t<Ta, Tb, Tz>,
90 typename Ta_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Ta>>,
91 typename Tb_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Tb>>>
92std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn>
grad_pFq(
const TpFq& pfq_val,
const Ta& a,
93 const Tb& b,
const Tz& z,
94 double precision = 1
e-14,
95 int max_steps = 1e6) {
97 using Ta_Array = Eigen::Array<return_type_t<Ta>, -1, 1>;
98 using Tb_Array = Eigen::Array<return_type_t<Tb>, -1, 1>;
103 std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> ret_tuple;
105 if (calc_a || calc_b) {
106 std::get<0>(ret_tuple).setConstant(a.size(), -pfq_val);
107 std::get<1>(ret_tuple).setConstant(b.size(), pfq_val);
108 Eigen::Array<T_Rtn, -1, 1> a_grad(a.size());
109 Eigen::Array<T_Rtn, -1, 1> b_grad(b.size());
114 static constexpr double dbl_min = std::numeric_limits<double>::min();
115 Ta_Array a_k = (a_array == 0.0).
select(dbl_min,
abs(a_array));
116 Tb_Array b_k = (b_array == 0.0).
select(dbl_min,
abs(b_array));
121 Eigen::ArrayXi a_pos_k = (a_array < 0.0)
123 .
template cast<int>();
124 int all_a_pos_k = a_pos_k.maxCoeff();
125 Eigen::ArrayXi b_pos_k = (b_array < 0.0)
127 .template cast<int>();
128 int all_b_pos_k = b_pos_k.maxCoeff();
129 Eigen::ArrayXi a_sign =
select(a_pos_k == 0, 1, -1);
130 Eigen::ArrayXi b_sign =
select(b_pos_k == 0, 1, -1);
134 Ta_Array digamma_a = Ta_Array::Ones(a.size());
135 Tb_Array digamma_b = Tb_Array::Ones(b.size());
139 while ((k < 10 || curr_log_prec >
log(precision)) && (k <= max_steps)) {
142 a_grad =
log(
abs(digamma_a)) + log_base;
143 std::get<0>(ret_tuple).array()
146 curr_log_prec =
max(curr_log_prec, a_grad.maxCoeff());
147 digamma_a +=
inv(a_k) * a_sign;
151 b_grad =
log(
abs(digamma_b)) + log_base;
152 std::get<1>(ret_tuple).array()
155 curr_log_prec =
max(curr_log_prec, b_grad.maxCoeff());
156 digamma_b +=
inv(b_k) * b_sign;
160 base_sign *= z_sign * a_sign.prod() * b_sign.prod();
164 if (k < all_a_pos_k) {
168 a_k = (a_k == 1.0 && a_sign == -1)
169 .
select(dbl_min, (a_k < 1.0 && a_sign == -1)
170 .select(1.0 - a_k, a_k + 1.0 * a_sign));
171 a_sign =
select(k == a_pos_k - 1, 1, a_sign);
175 if (k == all_a_pos_k) {
180 if (k < all_a_pos_k) {
181 b_k = (b_k == 1.0 && b_sign == -1)
182 .
select(dbl_min, (b_k < 1.0 && b_sign == -1)
183 .select(1.0 - b_k, b_k + 1.0 * b_sign));
184 b_sign =
select(k == b_pos_k - 1, 1, b_sign);
188 if (k == all_b_pos_k) {
197 std::get<2>(ret_tuple)
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
addition_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > add(T_a &&a, T_b &&b)
double value_of_rec(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > abs(const fvar< T > &x)
static constexpr double e()
Return the base of the natural logarithm.
value_type_t< T > prod(const T &m)
Calculates product of given kernel generator expression elements.
auto sign(const T &x)
Returns signs of the arguments.
fvar< T > log(const fvar< T > &x)
static constexpr double NEGATIVE_INFTY
Negative infinity.
auto max(T1 x, T2 y)
Returns the maximum value of the two specified scalar arguments.
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
fvar< T > log1p(const fvar< T > &x)
fvar< T > floor(const fvar< T > &x)
FvarT hypergeometric_pFq(const Ta &a, const Tb &b, const Tz &z)
Returns the generalized hypergeometric (pFq) function applied to the input arguments.
std::tuple< Ta_Rtn, Tb_Rtn, T_Rtn > grad_pFq(const TpFq &pfq_val, const Ta &a, const Tb &b, const Tz &z, double precision=1e-14, int max_steps=1e6)
Returns the gradient of generalized hypergeometric function wrt to the input arguments: .
fvar< T > inv(const fvar< T > &x)
fvar< T > exp(const fvar< T > &x)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...