Automatic Differentiation
 
Loading...
Searching...
No Matches
wiener4_lccdf_defective.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_PROB_WIENER4_LCCDF_DEFECTIVE_HPP
2#define STAN_MATH_PRIM_PROB_WIENER4_LCCDF_DEFECTIVE_HPP
3
5
6namespace stan {
7namespace math {
8namespace internal {
9
22template <typename T_a, typename T_w, typename T_v>
23inline auto log_wiener_prob_hit_upper(const T_a& a, const T_v& v,
24 const T_w& w) {
25 using ret_t = return_type_t<T_a, T_w, T_v>;
26 const auto neg_v = -v;
27 const auto one_m_w = 1.0 - w;
28 if (fabs(v) == 0.0) {
29 return ret_t(log(w));
30 }
31 const auto exponent = 2.0 * v * a * w;
32 // This branch is for numeric stability
33 if (exponent < 0) {
34 return ret_t(log1m_exp(exponent)
35 - log_diff_exp(2.0 * neg_v * a * one_m_w, exponent));
36 } else {
37 return ret_t(log1m_exp(-exponent) - log1m_exp(2.0 * neg_v * a));
38 }
39}
40
54template <typename T_a, typename T_w, typename T_v>
55inline auto wiener_prob_derivative_term(const T_a& a, const T_v& v,
56 const T_w& w) noexcept {
57 using ret_t = return_type_t<T_a, T_w, T_v>;
58 const auto exponent_m1 = log1m(1.1 * 1.0e-8);
59 const auto neg_v = -v;
60 const auto one_m_w = 1 - w;
61 int sign_v = neg_v < 0 ? 1 : -1;
62 const auto two_a_neg_v = 2.0 * a * neg_v;
63 const auto exponent_with_1mw = sign_v * two_a_neg_v * w;
64 const auto exponent = sign_v * two_a_neg_v;
65 const auto exponent_with_w = two_a_neg_v * one_m_w;
66 // truncating longer calculations, for numerical stability
67 if (unlikely((exponent_with_1mw >= exponent_m1)
68 || ((exponent_with_w >= exponent_m1) && (sign_v == 1))
69 || (exponent >= exponent_m1) || neg_v == 0)) {
70 return ret_t(-one_m_w);
71 }
72 ret_t ans;
73 ret_t diff_term;
74 const auto log_w = log(one_m_w);
75 if (neg_v < 0) {
76 ans = LOG_TWO + exponent_with_1mw - log1m_exp(exponent_with_1mw);
77 diff_term = log1m_exp(exponent_with_w) - log1m_exp(exponent);
78 } else /* neg_v > 0 */ {
79 ans = LOG_TWO - log1m_exp(exponent_with_1mw);
80 diff_term = log_diff_exp(exponent_with_1mw, exponent) - log1m_exp(exponent);
81 }
82 if (log_w > diff_term) {
83 ans = sign_v * exp(ans + log_diff_exp(log_w, diff_term));
84 } else {
85 ans = -sign_v * exp(ans + log_diff_exp(diff_term, log_w));
86 }
87 if (unlikely(!is_scal_finite(ans))) {
88 return ret_t(NEGATIVE_INFTY);
89 }
90 return ans;
91}
92
104template <typename T_y, typename T_a, typename T_w, typename T_v,
105 typename T_err>
106inline auto wiener4_ccdf(const T_y& y, const T_a& a, const T_v& v, const T_w& w,
107 T_err log_err = log(1e-12)) noexcept {
108 const auto prob_hit_upper = exp(log_wiener_prob_hit_upper(a, v, w));
109 const auto cdf
110 = internal::wiener4_distribution<GradientCalc::ON>(y, a, v, w, log_err);
111 return prob_hit_upper - cdf;
112}
113
126template <typename T_y, typename T_a, typename T_w, typename T_v,
127 typename T_cdf, typename T_err>
128inline auto wiener4_ccdf_grad_a(const T_y& y, const T_a& a, const T_v& v,
129 const T_w& w, T_cdf&& cdf,
130 T_err log_err = log(1e-12)) noexcept {
131 using ret_t = return_type_t<T_a, T_w, T_v>;
132
133 // derivative of the wiener probability w.r.t. 'a' (on log-scale)
134 auto prob_grad_a = -wiener_prob_derivative_term(a, v, w) * v;
135 if (!is_scal_finite(prob_grad_a)) {
136 prob_grad_a = ret_t(NEGATIVE_INFTY);
137 }
138 const auto log_prob_hit_upper = log_wiener_prob_hit_upper(a, v, w);
139 const auto cdf_grad_a = wiener4_cdf_grad_a(y, a, v, w, cdf, log_err);
140 return prob_grad_a * exp(log_prob_hit_upper) - cdf_grad_a;
141}
142
155template <typename T_y, typename T_a, typename T_w, typename T_v,
156 typename T_cdf, typename T_err>
157inline auto wiener4_ccdf_grad_v(const T_y& y, const T_a& a, const T_v& v,
158 const T_w& w, T_cdf&& cdf,
159 T_err log_err = log(1e-12)) noexcept {
160 using ret_t = return_type_t<T_a, T_w, T_v>;
161 const auto log_prob_hit_upper = log_wiener_prob_hit_upper(a, v, w);
162 // derivative of the wiener probability w.r.t. 'v' (on log-scale)
163 auto prob_grad_v = -wiener_prob_derivative_term(a, v, w) * a;
164 if (!is_scal_finite(fabs(prob_grad_v))) {
165 prob_grad_v = ret_t(NEGATIVE_INFTY);
166 }
167
168 const auto cdf_grad_v = wiener4_cdf_grad_v(y, a, v, w, cdf, log_err);
169 return prob_grad_v * exp(log_prob_hit_upper) - cdf_grad_v;
170}
171
184template <typename T_y, typename T_a, typename T_w, typename T_v,
185 typename T_cdf, typename T_err>
186inline auto wiener4_ccdf_grad_w(const T_y& y, const T_a& a, const T_v& v,
187 const T_w& w, T_cdf&& cdf,
188 T_err log_err = log(1e-12)) noexcept {
189 using ret_t = return_type_t<T_a, T_w, T_v>;
190 const auto log_prob_hit_upper = log_wiener_prob_hit_upper(a, v, w);
191 // derivative of the wiener probability w.r.t. 'v' (on log-scale)
192 const auto exponent = -sign(v) * 2.0 * v * a * w;
193 auto prob_grad_w
194 = (v != 0) ? exp(LOG_TWO + log(fabs(v)) + log(a) - log1m_exp(exponent))
195 : ret_t(1 / w);
196 if (v > 0) {
197 prob_grad_w *= exp(exponent);
198 }
199
200 const auto cdf_grad_w = wiener4_cdf_grad_w(y, a, v, w, cdf, log_err);
201 return prob_grad_w * exp(log_prob_hit_upper) - cdf_grad_w;
202}
203
204} // namespace internal
205
225template <bool propto = false, typename T_y, typename T_a, typename T_t0,
226 typename T_w, typename T_v>
227inline auto wiener_lccdf_defective(const T_y& y, const T_a& a, const T_t0& t0,
228 const T_w& w, const T_v& v,
229 const double& precision_derivatives = 1e-4) {
230 using T_partials_return = partials_return_t<T_y, T_a, T_t0, T_w, T_v>;
232 using T_y_ref = ref_type_t<T_y>;
233 using T_a_ref = ref_type_t<T_a>;
234 using T_t0_ref = ref_type_t<T_t0>;
235 using T_w_ref = ref_type_t<T_w>;
236 using T_v_ref = ref_type_t<T_v>;
238
239 T_y_ref y_ref = y;
240 T_a_ref a_ref = a;
241 T_t0_ref t0_ref = t0;
242 T_w_ref w_ref = w;
243 T_v_ref v_ref = v;
244
245 auto y_val = to_ref(as_value_column_array_or_scalar(y_ref));
246 auto a_val = to_ref(as_value_column_array_or_scalar(a_ref));
247 auto v_val = to_ref(as_value_column_array_or_scalar(v_ref));
248 auto w_val = to_ref(as_value_column_array_or_scalar(w_ref));
249 auto t0_val = to_ref(as_value_column_array_or_scalar(t0_ref));
250
251 static constexpr const char* function_name = "wiener4_lccdf";
252 if (size_zero(y, a, t0, w, v)) {
253 return ret_t(0.0);
254 }
255
257 return ret_t(0.0);
258 }
259
260 check_consistent_sizes(function_name, "Random variable", y,
261 "Boundary separation", a, "Drift rate", v,
262 "A-priori bias", w, "Nondecision time", t0);
263 check_positive_finite(function_name, "Random variable", y_val);
264 check_positive_finite(function_name, "Boundary separation", a_val);
265 check_finite(function_name, "Drift rate", v_val);
266 check_less(function_name, "A-priori bias", w_val, 1);
267 check_greater(function_name, "A-priori bias", w_val, 0);
268 check_nonnegative(function_name, "Nondecision time", t0_val);
269 check_finite(function_name, "Nondecision time", t0_val);
270
271 const size_t N = max_size(y, a, t0, w, v);
272
273 scalar_seq_view<T_y_ref> y_vec(y_ref);
274 scalar_seq_view<T_a_ref> a_vec(a_ref);
275 scalar_seq_view<T_t0_ref> t0_vec(t0_ref);
276 scalar_seq_view<T_w_ref> w_vec(w_ref);
277 scalar_seq_view<T_v_ref> v_vec(v_ref);
278 const size_t N_y_t0 = max_size(y, t0);
279
280 for (size_t i = 0; i < N_y_t0; ++i) {
281 if (y_vec[i] <= t0_vec[i]) {
282 std::stringstream msg;
283 msg << ", but must be greater than nondecision time = " << t0_vec[i];
284 std::string msg_str(msg.str());
285 throw_domain_error(function_name, "Random variable", y_vec[i], " = ",
286 msg_str.c_str());
287 }
288 }
289
290 // for precs. 1e-6, 1e-12, see Hartmann et al. (2021), Henrich et al. (2023)
291 const auto log_error_cdf = log(1e-6);
292 const auto log_error_derivative = log(precision_derivatives);
293 const T_partials_return log_error_absolute = log(1e-12);
294 T_partials_return lccdf = 0.0;
295 auto ops_partials
296 = make_partials_propagator(y_ref, a_ref, t0_ref, w_ref, v_ref);
297
298 const double LOG_FOUR = std::log(4.0);
299
300 // calculate distribution and partials
301 for (size_t i = 0; i < N; i++) {
302 const auto y_value = y_vec.val(i);
303 const auto a_value = a_vec.val(i);
304 const auto t0_value = t0_vec.val(i);
305 const auto w_value = w_vec.val(i);
306 const auto v_value = v_vec.val(i);
307
308 const T_partials_return cdf
309 = internal::estimate_with_err_check<4, 0, GradientCalc::OFF,
310 GradientCalc::OFF>(
311 [](auto&&... args) {
312 return internal::wiener4_distribution<GradientCalc::ON>(args...);
313 },
314 log_error_cdf - LOG_TWO, y_value - t0_value, a_value, v_value,
315 w_value, log_error_absolute);
316
317 const auto prob_hit_upper
318 = exp(internal::log_wiener_prob_hit_upper(a_value, v_value, w_value));
319 const auto ccdf = prob_hit_upper - cdf;
320 const auto log_ccdf_single_value = log(ccdf);
321
322 lccdf += log_ccdf_single_value;
323
324 const auto new_est_err
325 = log_ccdf_single_value + log_error_derivative - LOG_FOUR;
326
328 const auto deriv_y = internal::estimate_with_err_check<5, 0>(
329 [](auto&&... args) {
330 return internal::wiener5_density<GradientCalc::ON>(args...);
331 },
332 new_est_err, y_value - t0_value, a_value, v_value, w_value, 0.0,
333 log_error_absolute);
335 partials<0>(ops_partials)[i] = -deriv_y / ccdf;
336 }
338 partials<2>(ops_partials)[i] = deriv_y / ccdf;
339 }
340 }
342 partials<1>(ops_partials)[i]
343 = internal::estimate_with_err_check<5, 0>(
344 [](auto&&... args) {
345 return internal::wiener4_ccdf_grad_a(args...);
346 },
347 new_est_err, y_value - t0_value, a_value, v_value, w_value, cdf,
348 log_error_absolute)
349 / ccdf;
350 }
352 partials<3>(ops_partials)[i]
353 = internal::estimate_with_err_check<5, 0>(
354 [](auto&&... args) {
355 return internal::wiener4_ccdf_grad_w(args...);
356 },
357 new_est_err, y_value - t0_value, a_value, v_value, w_value, cdf,
358 log_error_absolute)
359 / ccdf;
360 }
362 partials<4>(ops_partials)[i]
363 = internal::wiener4_ccdf_grad_v(y_value - t0_value, a_value, v_value,
364 w_value, cdf, log_error_absolute)
365 / ccdf;
366 }
367 } // for loop
368 return ops_partials.build(lccdf);
369}
370} // namespace math
371} // namespace stan
372#endif
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
#define unlikely(x)
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
auto wiener4_cdf_grad_w(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_cdf &&cdf, T_err log_err=log(1e-12))
Calculate derivative of the wiener4 distribution w.r.t.
auto wiener_prob_derivative_term(const T_a &a, const T_v &v, const T_w &w) noexcept
Calculate parts of the partial derivatives for wiener_prob_grad_a and wiener_prob_grad_v (on log-scal...
auto wiener4_ccdf_grad_w(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_cdf &&cdf, T_err log_err=log(1e-12)) noexcept
Calculate derivative of the wiener4 ccdf w.r.t.
auto wiener4_cdf_grad_v(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_cdf &&cdf, T_err log_err=log(1e-12))
Calculate derivative of the wiener4 distribution w.r.t.
auto estimate_with_err_check(F &&functor, T_err &&log_err, ArgsTupleT &&... args_tuple)
Utility function for estimating a function with a given set of arguments, checking the result against...
auto log_wiener_prob_hit_upper(const T_a &a, const T_v &v, const T_w &w)
Log of probability of reaching the upper bound in diffusion process.
auto wiener4_ccdf(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_err log_err=log(1e-12)) noexcept
Calculate wiener4 ccdf (natural-scale)
auto wiener4_ccdf_grad_a(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_cdf &&cdf, T_err log_err=log(1e-12)) noexcept
Calculate derivative of the wiener4 ccdf w.r.t.
auto wiener4_cdf_grad_a(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_cdf &&cdf, T_err log_err=log(1e-12))
Calculate derivative of the wiener4 distribution w.r.t.
auto wiener4_ccdf_grad_v(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_cdf &&cdf, T_err log_err=log(1e-12)) noexcept
Calculate derivative of the wiener4 ccdf w.r.t.
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition size_zero.hpp:19
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
fvar< T > log1m_exp(const fvar< T > &x)
Return the natural logarithm of one minus the exponentiation of the specified argument.
Definition log1m_exp.hpp:22
auto sign(const T &x)
Returns signs of the arguments.
Definition sign.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
static constexpr double LOG_TWO
The natural logarithm of 2, .
Definition constants.hpp:80
auto as_value_column_array_or_scalar(T &&a)
Extract the value from an object and for eigen vectors and std::vectors convert to an eigen column ar...
bool is_scal_finite(const T_y &y)
Return true if y is finite.
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
fvar< T > log_diff_exp(const fvar< T > &x1, const fvar< T > &x2)
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
void check_less(const char *function, const char *name, const T_y &y, const T_high &high, Idxs... idxs)
Throw an exception if y is not strictly less than high.
int64_t max_size(const T1 &x1, const Ts &... xs)
Calculate the size of the largest input.
Definition max_size.hpp:20
fvar< T > log1m(const fvar< T > &x)
Definition log1m.hpp:12
auto wiener_lccdf_defective(const T_y &y, const T_a &a, const T_t0 &t0, const T_w &w, const T_v &v, const double &precision_derivatives=1e-4)
Log-CCDF for the 4-parameter Wiener distribution.
void check_greater(const char *function, const char *name, const T_y &y, const T_low &low, Idxs... idxs)
Throw an exception if y is not strictly greater than low.
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
fvar< T > fabs(const fvar< T > &x)
Definition fabs.hpp:16
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
typename ref_type_if< true, T >::type ref_type_t
Definition ref_type.hpp:56
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...