Automatic Differentiation
 
Loading...
Searching...
No Matches
grad_pFq.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_GRAD_PFQ_HPP
2#define STAN_MATH_PRIM_FUN_GRAD_PFQ_HPP
3
17
18namespace stan {
19namespace math {
20
87template <bool calc_a = true, bool calc_b = true, bool calc_z = true,
88 typename TpFq, typename Ta, typename Tb, typename Tz,
89 typename T_Rtn = return_type_t<Ta, Tb, Tz>,
90 typename Ta_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Ta>>,
91 typename Tb_Rtn = promote_scalar_t<T_Rtn, plain_type_t<Tb>>>
92std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> grad_pFq(const TpFq& pfq_val, const Ta& a,
93 const Tb& b, const Tz& z,
94 double precision = 1e-14,
95 int max_steps = 1e6) {
96 using std::max;
97 using Ta_Array = Eigen::Array<return_type_t<Ta>, -1, 1>;
98 using Tb_Array = Eigen::Array<return_type_t<Tb>, -1, 1>;
99
100 Ta_Array a_array = as_column_vector_or_scalar(a).array();
101 Tb_Array b_array = as_column_vector_or_scalar(b).array();
102
103 std::tuple<Ta_Rtn, Tb_Rtn, T_Rtn> ret_tuple;
104
105 if (calc_a || calc_b) {
106 std::get<0>(ret_tuple).setConstant(a.size(), -pfq_val);
107 std::get<1>(ret_tuple).setConstant(b.size(), pfq_val);
108 Eigen::Array<T_Rtn, -1, 1> a_grad(a.size());
109 Eigen::Array<T_Rtn, -1, 1> b_grad(b.size());
110
111 int k = 0;
112 int base_sign = 1;
113
114 static constexpr double dbl_min = std::numeric_limits<double>::min();
115 Ta_Array a_k = (a_array == 0.0).select(dbl_min, abs(a_array));
116 Tb_Array b_k = (b_array == 0.0).select(dbl_min, abs(b_array));
117 Tz log_z = log(abs(z));
118
119 // Identify the number of iterations to needed for each element to sign
120 // flip from negative to positive - rather than checking at each iteration
121 Eigen::ArrayXi a_pos_k = (a_array < 0.0)
122 .select(-floor(value_of_rec(a_array)), 0)
123 .template cast<int>();
124 int all_a_pos_k = a_pos_k.maxCoeff();
125 Eigen::ArrayXi b_pos_k = (b_array < 0.0)
126 .select(-floor(value_of_rec(b_array)), 0)
127 .template cast<int>();
128 int all_b_pos_k = b_pos_k.maxCoeff();
129 Eigen::ArrayXi a_sign = select(a_pos_k == 0, 1, -1);
130 Eigen::ArrayXi b_sign = select(b_pos_k == 0, 1, -1);
131
132 int z_sign = sign(value_of_rec(z));
133
134 Ta_Array digamma_a = Ta_Array::Ones(a.size());
135 Tb_Array digamma_b = Tb_Array::Ones(b.size());
136
137 T_Rtn curr_log_prec = NEGATIVE_INFTY;
138 T_Rtn log_base = 0;
139 while ((k < 10 || curr_log_prec > log(precision)) && (k <= max_steps)) {
140 curr_log_prec = NEGATIVE_INFTY;
141 if (calc_a) {
142 a_grad = log(abs(digamma_a)) + log_base;
143 std::get<0>(ret_tuple).array()
144 += exp(a_grad) * base_sign * sign(value_of_rec(digamma_a));
145
146 curr_log_prec = max(curr_log_prec, a_grad.maxCoeff());
147 digamma_a += inv(a_k) * a_sign;
148 }
149
150 if (calc_b) {
151 b_grad = log(abs(digamma_b)) + log_base;
152 std::get<1>(ret_tuple).array()
153 -= exp(b_grad) * base_sign * sign(value_of_rec(digamma_b));
154
155 curr_log_prec = max(curr_log_prec, b_grad.maxCoeff());
156 digamma_b += inv(b_k) * b_sign;
157 }
158
159 log_base += (sum(log(a_k)) + log_z) - (sum(log(b_k)) + log1p(k));
160 base_sign *= z_sign * a_sign.prod() * b_sign.prod();
161
162 // Wrap negative value handling in a conditional on iteration number so
163 // branch prediction likely to ignore once positive
164 if (k < all_a_pos_k) {
165 // Avoid log(0) and 1/0 in next iteration by using smallest double
166 // - This is smaller than EPSILON, so the following iteration will
167 // still be 1.0
168 a_k = (a_k == 1.0 && a_sign == -1)
169 .select(dbl_min, (a_k < 1.0 && a_sign == -1)
170 .select(1.0 - a_k, a_k + 1.0 * a_sign));
171 a_sign = select(k == a_pos_k - 1, 1, a_sign);
172 } else {
173 a_k += 1.0;
174
175 if (k == all_a_pos_k) {
176 a_sign.setOnes();
177 }
178 }
179
180 if (k < all_a_pos_k) {
181 b_k = (b_k == 1.0 && b_sign == -1)
182 .select(dbl_min, (b_k < 1.0 && b_sign == -1)
183 .select(1.0 - b_k, b_k + 1.0 * b_sign));
184 b_sign = select(k == b_pos_k - 1, 1, b_sign);
185 } else {
186 b_k += 1.0;
187
188 if (k == all_b_pos_k) {
189 b_sign.setOnes();
190 }
191 }
192
193 k += 1;
194 }
195 }
196 if (calc_z) {
197 std::get<2>(ret_tuple)
198 = hypergeometric_pFq(add(a, 1.0), add(b, 1.0), z) * prod(a) / prod(b);
199 }
200 return ret_tuple;
201}
202
203} // namespace math
204} // namespace stan
205#endif
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
addition_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > add(T_a &&a, T_b &&b)
double value_of_rec(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > abs(const fvar< T > &x)
Definition abs.hpp:15
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
value_type_t< T > prod(const T &m)
Calculates product of given kernel generator expression elements.
Definition prod.hpp:21
auto sign(const T &x)
Returns signs of the arguments.
Definition sign.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
auto max(T1 x, T2 y)
Returns the maximum value of the two specified scalar arguments.
Definition max.hpp:25
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
fvar< T > log1p(const fvar< T > &x)
Definition log1p.hpp:12
fvar< T > floor(const fvar< T > &x)
Definition floor.hpp:12
FvarT hypergeometric_pFq(const Ta &a, const Tb &b, const Tz &z)
Returns the generalized hypergeometric (pFq) function applied to the input arguments.
std::tuple< Ta_Rtn, Tb_Rtn, T_Rtn > grad_pFq(const TpFq &pfq_val, const Ta &a, const Tb &b, const Tz &z, double precision=1e-14, int max_steps=1e6)
Returns the gradient of generalized hypergeometric function wrt to the input arguments: .
Definition grad_pFq.hpp:92
fvar< T > inv(const fvar< T > &x)
Definition inv.hpp:12
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9