Automatic Differentiation
 
Loading...
Searching...
No Matches
grad_pFq.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_GRAD_PFQ_HPP
2#define STAN_MATH_PRIM_FUN_GRAD_PFQ_HPP
3
17
18namespace stan {
19namespace math {
20
87template <bool calc_a = true, bool calc_b = true, bool calc_z = true,
88 typename TpFq, typename Ta, typename Tb, typename Tz,
89 typename T_Rtn = return_type_t<Ta, Tb, Tz>,
90 typename Ta_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Ta>>,
91 typename Tb_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Tb>>>
92inline std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val,
93 const Ta& a, const Tb& b,
94 const Tz& z,
95 double precision = 1e-14,
96 int max_steps = 1e6) {
97 using std::max;
98 using Ta_Array = Eigen::Array<return_type_t<Ta>, -1, 1>;
99 using Tb_Array = Eigen::Array<return_type_t<Tb>, -1, 1>;
100
101 Ta_Array a_array = as_column_vector_or_scalar(a).array();
102 Tb_Array b_array = as_column_vector_or_scalar(b).array();
103
104 std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> ret_tuple;
105
106 if (calc_a || calc_b) {
107 std::get<0>(ret_tuple).setConstant(a.size(), -pfq_val);
108 std::get<1>(ret_tuple).setConstant(b.size(), pfq_val);
109 Eigen::Array<T_Rtn, -1, 1> a_grad(a.size());
110 Eigen::Array<T_Rtn, -1, 1> b_grad(b.size());
111
112 int k = 0;
113 int base_sign = 1;
114
115 static constexpr double dbl_min = std::numeric_limits<double>::min();
116 Ta_Array a_k = (a_array == 0.0).select(dbl_min, abs(a_array));
117 Tb_Array b_k = (b_array == 0.0).select(dbl_min, abs(b_array));
118 Tz log_z = log(abs(z));
119
120 // Identify the number of iterations to needed for each element to sign
121 // flip from negative to positive - rather than checking at each iteration
122 Eigen::ArrayXi a_pos_k = (a_array < 0.0)
123 .select(-floor(value_of_rec(a_array)), 0)
124 .template cast<int>();
125 int all_a_pos_k = a_pos_k.maxCoeff();
126 Eigen::ArrayXi b_pos_k = (b_array < 0.0)
127 .select(-floor(value_of_rec(b_array)), 0)
128 .template cast<int>();
129 int all_b_pos_k = b_pos_k.maxCoeff();
130 Eigen::ArrayXi a_sign = select(a_pos_k == 0, 1, -1);
131 Eigen::ArrayXi b_sign = select(b_pos_k == 0, 1, -1);
132
133 int z_sign = sign(value_of_rec(z));
134
135 Ta_Array digamma_a = Ta_Array::Ones(a.size());
136 Tb_Array digamma_b = Tb_Array::Ones(b.size());
137
138 T_Rtn curr_log_prec = NEGATIVE_INFTY;
139 T_Rtn log_base = 0;
140 while ((k < 10 || curr_log_prec > log(precision)) && (k <= max_steps)) {
141 curr_log_prec = NEGATIVE_INFTY;
142 if (calc_a) {
143 a_grad = log(abs(digamma_a)) + log_base;
144 std::get<0>(ret_tuple).array()
145 += exp(a_grad) * base_sign * sign(value_of_rec(digamma_a));
146
147 curr_log_prec = max(curr_log_prec, a_grad.maxCoeff());
148 digamma_a += inv(a_k) * a_sign;
149 }
150
151 if (calc_b) {
152 b_grad = log(abs(digamma_b)) + log_base;
153 std::get<1>(ret_tuple).array()
154 -= exp(b_grad) * base_sign * sign(value_of_rec(digamma_b));
155
156 curr_log_prec = max(curr_log_prec, b_grad.maxCoeff());
157 digamma_b += inv(b_k) * b_sign;
158 }
159
160 log_base += (sum(log(a_k)) + log_z) - (sum(log(b_k)) + log1p(k));
161 base_sign *= z_sign * a_sign.prod() * b_sign.prod();
162
163 // Wrap negative value handling in a conditional on iteration number so
164 // branch prediction likely to ignore once positive
165 if (k < all_a_pos_k) {
166 // Avoid log(0) and 1/0 in next iteration by using smallest double
167 // - This is smaller than EPSILON, so the following iteration will
168 // still be 1.0
169 a_k = (a_k == 1.0 && a_sign == -1)
170 .select(dbl_min, (a_k < 1.0 && a_sign == -1)
171 .select(1.0 - a_k, a_k + 1.0 * a_sign));
172 a_sign = select(k == a_pos_k - 1, 1, a_sign);
173 } else {
174 a_k += 1.0;
175
176 if (k == all_a_pos_k) {
177 a_sign.setOnes();
178 }
179 }
180
181 if (k < all_a_pos_k) {
182 b_k = (b_k == 1.0 && b_sign == -1)
183 .select(dbl_min, (b_k < 1.0 && b_sign == -1)
184 .select(1.0 - b_k, b_k + 1.0 * b_sign));
185 b_sign = select(k == b_pos_k - 1, 1, b_sign);
186 } else {
187 b_k += 1.0;
188
189 if (k == all_b_pos_k) {
190 b_sign.setOnes();
191 }
192 }
193
194 k += 1;
195 }
196 }
197 if (calc_z) {
198 std::get<2>(ret_tuple)
199 = hypergeometric_pFq(add(a, 1.0), add(b, 1.0), z) * prod(a) / prod(b);
200 }
201 return ret_tuple;
202}
203
204} // namespace math
205} // namespace stan
206#endif
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
addition_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > add(T_a &&a, T_b &&b)
double value_of_rec(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > abs(const fvar< T > &x)
Definition abs.hpp:15
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
value_type_t< T > prod(const T &m)
Calculates product of given kernel generator expression elements.
Definition prod.hpp:21
auto sign(const T &x)
Returns signs of the arguments.
Definition sign.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
auto max(T1 x, T2 y)
Returns the maximum value of the two specified scalar arguments.
Definition max.hpp:25
fvar< T > log1p(const fvar< T > &x)
Definition log1p.hpp:12
fvar< T > floor(const fvar< T > &x)
Definition floor.hpp:12
FvarT hypergeometric_pFq(const Ta &a, const Tb &b, const Tz &z)
Returns the generalized hypergeometric (pFq) function applied to the input arguments.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:23
std::tuple< Ta_Rtn, Tb_Rtn, T_Rtn > grad_pFq(const TpFq &pfq_val, const Ta &a, const Tb &b, const Tz &z, double precision=1e-14, int max_steps=1e6)
Returns the gradient of generalized hypergeometric function wrt to the input arguments: .
Definition grad_pFq.hpp:92
fvar< T > inv(const fvar< T > &x)
Definition inv.hpp:12
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...