1#ifndef STAN_MATH_PRIM_FUN_GRAD_PFQ_HPP
2#define STAN_MATH_PRIM_FUN_GRAD_PFQ_HPP
87template <
bool calc_a =
true,
bool calc_b =
true,
bool calc_z =
true,
88 typename TpFq,
typename Ta,
typename Tb,
typename Tz,
89 typename T_Rtn = return_type_t<Ta, Tb, Tz>,
90 typename Ta_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Ta>>,
91 typename Tb_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Tb>>>
92inline std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn>
grad_pFq(
const TpFq& pfq_val,
93 const Ta& a,
const Tb& b,
95 double precision = 1
e-14,
96 int max_steps = 1e6) {
98 using Ta_Array = Eigen::Array<return_type_t<Ta>, -1, 1>;
99 using Tb_Array = Eigen::Array<return_type_t<Tb>, -1, 1>;
104 std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> ret_tuple;
106 if (calc_a || calc_b) {
107 std::get<0>(ret_tuple).setConstant(a.size(), -pfq_val);
108 std::get<1>(ret_tuple).setConstant(b.size(), pfq_val);
109 Eigen::Array<T_Rtn, -1, 1> a_grad(a.size());
110 Eigen::Array<T_Rtn, -1, 1> b_grad(b.size());
115 static constexpr double dbl_min = std::numeric_limits<double>::min();
116 Ta_Array a_k = (a_array == 0.0).
select(dbl_min,
abs(a_array));
117 Tb_Array b_k = (b_array == 0.0).
select(dbl_min,
abs(b_array));
122 Eigen::ArrayXi a_pos_k = (a_array < 0.0)
124 .
template cast<int>();
125 int all_a_pos_k = a_pos_k.maxCoeff();
126 Eigen::ArrayXi b_pos_k = (b_array < 0.0)
128 .template cast<int>();
129 int all_b_pos_k = b_pos_k.maxCoeff();
130 Eigen::ArrayXi a_sign =
select(a_pos_k == 0, 1, -1);
131 Eigen::ArrayXi b_sign =
select(b_pos_k == 0, 1, -1);
135 Ta_Array digamma_a = Ta_Array::Ones(a.size());
136 Tb_Array digamma_b = Tb_Array::Ones(b.size());
140 while ((k < 10 || curr_log_prec >
log(precision)) && (k <= max_steps)) {
143 a_grad =
log(
abs(digamma_a)) + log_base;
144 std::get<0>(ret_tuple).array()
147 curr_log_prec =
max(curr_log_prec, a_grad.maxCoeff());
148 digamma_a +=
inv(a_k) * a_sign;
152 b_grad =
log(
abs(digamma_b)) + log_base;
153 std::get<1>(ret_tuple).array()
156 curr_log_prec =
max(curr_log_prec, b_grad.maxCoeff());
157 digamma_b +=
inv(b_k) * b_sign;
161 base_sign *= z_sign * a_sign.prod() * b_sign.prod();
165 if (k < all_a_pos_k) {
169 a_k = (a_k == 1.0 && a_sign == -1)
170 .
select(dbl_min, (a_k < 1.0 && a_sign == -1)
171 .select(1.0 - a_k, a_k + 1.0 * a_sign));
172 a_sign =
select(k == a_pos_k - 1, 1, a_sign);
176 if (k == all_a_pos_k) {
181 if (k < all_a_pos_k) {
182 b_k = (b_k == 1.0 && b_sign == -1)
183 .
select(dbl_min, (b_k < 1.0 && b_sign == -1)
184 .select(1.0 - b_k, b_k + 1.0 * b_sign));
185 b_sign =
select(k == b_pos_k - 1, 1, b_sign);
189 if (k == all_b_pos_k) {
198 std::get<2>(ret_tuple)
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
addition_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > add(T_a &&a, T_b &&b)
double value_of_rec(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > abs(const fvar< T > &x)
static constexpr double e()
Return the base of the natural logarithm.
value_type_t< T > prod(const T &m)
Calculates product of given kernel generator expression elements.
auto sign(const T &x)
Returns signs of the arguments.
fvar< T > log(const fvar< T > &x)
static constexpr double NEGATIVE_INFTY
Negative infinity.
auto max(T1 x, T2 y)
Returns the maximum value of the two specified scalar arguments.
fvar< T > log1p(const fvar< T > &x)
fvar< T > floor(const fvar< T > &x)
FvarT hypergeometric_pFq(const Ta &a, const Tb &b, const Tz &z)
Returns the generalized hypergeometric (pFq) function applied to the input arguments.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
std::tuple< Ta_Rtn, Tb_Rtn, T_Rtn > grad_pFq(const TpFq &pfq_val, const Ta &a, const Tb &b, const Tz &z, double precision=1e-14, int max_steps=1e6)
Returns the gradient of generalized hypergeometric function wrt to the input arguments: .
fvar< T > inv(const fvar< T > &x)
fvar< T > exp(const fvar< T > &x)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...