Automatic Differentiation
 
Loading...
Searching...
No Matches
wiener4_lcdf_defective.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_PROB_WIENER4_LCDF_DEFECTIVE_HPP
2#define STAN_MATH_PRIM_PROB_WIENER4_LCDF_DEFECTIVE_HPP
3
5
6namespace stan {
7namespace math {
8namespace internal {
9
16template <typename T_x>
17inline auto make_finite(const T_x& x) {
18 if (x < std::numeric_limits<T_x>::max()) {
19 return x;
20 } else {
21 return std::numeric_limits<T_x>::max();
22 }
23}
24
33template <typename T_a, typename T_w, typename T_v>
34inline auto log_probability_distribution(const T_a& a, const T_v& v,
35 const T_w& w) {
36 using ret_t = return_type_t<T_a, T_w, T_v>;
37 if (fabs(v) == 0.0) {
38 return ret_t(log1m(w));
39 }
40 auto two_va = 2.0 * v * a;
41 auto minus_two_va_one_minus_w = -two_va * (1.0 - w);
42 // This split prevents abort errors
43 if (minus_two_va_one_minus_w < 0) {
44 const auto exp_arg = exp(minus_two_va_one_minus_w);
45 auto two_vaw = two_va * w;
46 if (two_vaw > minus_two_va_one_minus_w) {
47 return log1m(exp_arg) - log_diff_exp(two_vaw, minus_two_va_one_minus_w);
48 } else if (two_vaw < minus_two_va_one_minus_w) {
49 return log1m(exp_arg) - log_diff_exp(minus_two_va_one_minus_w, two_vaw);
50 } else {
51 return log1m(exp_arg) - NEGATIVE_INFTY;
52 }
53 } else {
54 return log1m_exp(-minus_two_va_one_minus_w) - log1m_exp(two_va);
55 }
56}
57
66template <typename T_a, typename T_w, typename T_v>
67inline auto log_probability_GradAV(const T_a& a, const T_v& v, const T_w& w) {
68 using ret_t = return_type_t<T_a, T_w, T_v>;
69 if (fabs(v) == 0.0) {
70 return ret_t(-w);
71 }
72 auto nearly_one = ret_t(1.0 - 1.1 * 1.0e-5);
73 ret_t log_prob;
74 if (v < 0) {
75 const auto two_av = 2.0 * a * v;
76 const auto two_va_one_minus_w = (two_av * (1.0 - w));
77 const auto two_avw = two_av * w;
78 const auto exp_two_va_one_minus_w = exp(two_va_one_minus_w);
79 const auto exp_two_avw = exp(two_avw);
80 const auto exp_two_av = exp(two_av);
81 if (((exp_two_va_one_minus_w >= nearly_one) || (exp_two_avw >= nearly_one))
82 || (exp_two_av >= nearly_one)) {
83 return ret_t(-w);
84 }
85 log_prob = LOG_TWO + two_va_one_minus_w - log1m(exp_two_va_one_minus_w);
86 auto log_quotient = log1m(exp_two_avw) - log1m(exp_two_av);
87 if (log(w) > log_quotient) {
88 return exp(log_prob) * (w - exp(log_quotient));
89 } else {
90 return -exp(log_prob) * (exp(log_quotient) - w);
91 }
92 } else {
93 const auto minus_two_av = -2.0 * a * v;
94 const auto minus_two_va_one_minus_w = minus_two_av * (1.0 - w);
95 const auto exp_minus_two_va_one_minus_w = exp(minus_two_va_one_minus_w);
96 const auto exp_minus_two_av = exp(minus_two_av);
97 if ((exp_minus_two_va_one_minus_w >= nearly_one)
98 || (exp_minus_two_av >= nearly_one)) {
99 return ret_t(-w);
100 }
101 log_prob = LOG_TWO - log1m(exp_minus_two_va_one_minus_w);
102 ret_t log_quotient;
103 if (minus_two_va_one_minus_w > minus_two_av) {
104 log_quotient = log_diff_exp(minus_two_va_one_minus_w, minus_two_av)
105 - log1m(exp_minus_two_av);
106 } else if (minus_two_va_one_minus_w < minus_two_av) {
107 log_quotient = log_diff_exp(minus_two_av, minus_two_va_one_minus_w)
108 - log1m(exp_minus_two_av);
109 } else {
110 log_quotient = NEGATIVE_INFTY;
111 }
112 if (log(w) > log_quotient) {
113 return -exp(log_prob + log_diff_exp(log(w), log_quotient));
114 } else {
115 return exp(log_prob + log_diff_exp(log_quotient, log(w)));
116 }
117 }
118}
119
126template <typename T_x>
127inline auto logMill(T_x&& x) {
128 return std_normal_lcdf(-x) - std_normal_lpdf(x);
129}
130
145template <bool NaturalScale = false, typename T_y, typename T_a, typename T_w,
146 typename T_v, typename T_err>
147inline auto wiener4_distribution(const T_y& y, const T_a& a, const T_v& v,
148 const T_w& w, T_err log_err = log(1e-12)) {
150 const auto neg_v = -v;
151 const auto one_m_w = 1.0 - w;
152
153 const auto one_m_w_a_neg_v = one_m_w * a * neg_v;
154
155 const auto K1 = 0.5 * (fabs(neg_v) / a * y - one_m_w);
156 const auto arg = fmax(
157 0.0, fmin(1.0, exp(one_m_w_a_neg_v + square(neg_v) * y / 2.0 + log_err)
158 / 2.0));
159 const auto K2 = (arg == 0) ? INFTY
160 : (arg == 1) ? NEGATIVE_INFTY
161 : -sqrt(y) / 2.0 / a * inv_Phi(arg);
162 const auto K_small_value = ceil(fmax(K1, K1 + K2));
163
164 const auto api = a / pi();
165 const auto v_square = square(neg_v);
166 const auto sqrtL1 = sqrt(1.0 / y) * api;
167 const auto sqrtL2 = sqrt(fmax(
168 1.0, -2.0 / y * square(api)
169 * (log_err + log(pi() * y / 2.0 * (v_square + square(pi() / a)))
170 + one_m_w_a_neg_v + v_square * y / 2.0)));
171 const auto K_large_value = ceil(fmax(sqrtL1, sqrtL2));
172
173 auto lg = LOG_TWO + LOG_PI - 2.0 * log(a);
174
175 if (3 * K_small_value < K_large_value) {
176 const auto sqrt_y = sqrt(y);
177 const auto neg_vy = neg_v * y;
178 ret_t fplus = NEGATIVE_INFTY;
179 ret_t fminus = NEGATIVE_INFTY;
180 for (auto k = K_small_value; k >= 0; --k) {
181 auto rj = a * (2.0 * k + one_m_w);
182 auto dj = std_normal_lpdf(rj / sqrt_y);
183 auto pos1 = dj + logMill((rj - neg_vy) / sqrt_y);
184 auto pos2 = dj + logMill((rj + neg_vy) / sqrt_y);
185 fplus = log_sum_exp(fplus, log_sum_exp(pos1, pos2));
186 rj = a * (2.0 * k + 2.0 - one_m_w);
187 dj = std_normal_lpdf(rj / sqrt_y);
188 auto neg1 = dj + logMill((rj - neg_vy) / sqrt_y);
189 auto neg2 = dj + logMill((rj + neg_vy) / sqrt_y);
190 fminus = log_sum_exp(fminus, log_sum_exp(neg1, neg2));
191 }
192 auto ans = ret_t(0.0);
193 ans = fplus > fminus ? log_diff_exp(fplus, fminus)
194 : log_diff_exp(fminus, fplus);
195 ret_t log_distribution = ans - one_m_w_a_neg_v - square(neg_v) * y / 2;
196 return NaturalScale ? exp(log_distribution) : log_distribution;
197 }
198 const auto log_a = log(a);
199 const auto log_v = log(fabs(neg_v));
200 ret_t fplus = NEGATIVE_INFTY;
201 ret_t fminus = NEGATIVE_INFTY;
202 for (auto k = K_large_value; k > 0; --k) {
203 auto log_k = log(k);
204 auto k_pi = k * pi();
205 auto sin_k_pi_w = sin(k_pi * one_m_w);
206 if (sin_k_pi_w > 0) {
207 fplus = log_sum_exp(
208 fplus, log_k
209 - log_sum_exp(2.0 * log_v, 2.0 * (log_k + LOG_PI - log_a))
210 - 0.5 * square(k_pi / a) * y + log(sin_k_pi_w));
211 } else if (sin_k_pi_w < 0) {
212 fminus = log_sum_exp(
213 fminus, log_k
214 - log_sum_exp(2.0 * log_v, 2.0 * (log_k + LOG_PI - log_a))
215 - 0.5 * square(k_pi / a) * y + log(-sin_k_pi_w));
216 }
217 }
218 ret_t ans = NEGATIVE_INFTY;
219 ans = fplus > fminus ? log_diff_exp(fplus, fminus)
220 : log_diff_exp(fminus, fplus);
221 auto summand_1 = log_probability_distribution(a, neg_v, one_m_w);
222 auto summand_2 = lg + (ans - one_m_w_a_neg_v - 0.5 * square(neg_v) * y);
223 ret_t log_distribution = NEGATIVE_INFTY;
224 if (summand_1 > summand_2) {
225 log_distribution = log_diff_exp(summand_1, summand_2);
226 } else if (summand_1 < summand_2) {
227 log_distribution = log_diff_exp(summand_2, summand_1);
228 }
229 return NaturalScale ? exp(log_distribution) : log_distribution;
230}
231
244template <typename T_y, typename T_a, typename T_v, typename T_w,
245 typename T_cdf, typename T_err>
246inline auto wiener4_cdf_grad_a(const T_y& y, const T_a& a, const T_v& v,
247 const T_w& w, T_cdf&& cdf,
248 T_err log_err = log(1e-12)) {
250 const auto neg_v = -v;
251 const auto one_m_w = 1 - w;
252
253 const auto one_m_w_neg_v = one_m_w * neg_v;
254 const auto one_m_w_a_neg_v = one_m_w_neg_v * a;
255
256 const auto log_y = log(y);
257 const auto log_a = log(a);
258 auto C1 = ret_t(
259 LOG_TWO - log_sum_exp(2.0 * log(fabs(neg_v)), 2.0 * (LOG_PI - log_a)));
260 C1 = log_sum_exp(C1, log_y);
261 const auto factor = one_m_w_a_neg_v + square(neg_v) * y / 2.0 + log_err;
262 const auto alphK = fmin(factor + LOG_PI + log_y + log_a - LOG_TWO - C1, 0.0);
263 const auto K = a / pi() / sqrt(y);
264 const auto K_large_value
265 = ceil(fmax(fmax(sqrt(-2.0 * alphK / y) * a / pi(), K), ret_t(1.0)));
266
267 const auto sqrt_y = sqrt(y);
268 const auto wdash = fmin(one_m_w, w);
269 const auto ueps
270 = fmin(-1.0, 2.0 * (factor + log(a) - log1p(square(neg_v) * y)) + LOG_PI);
271 const auto K_small
272 = (sqrt_y * sqrt(-(ueps - sqrt(-2.0 * ueps - 2.0))) - a * wdash) / a;
273 const auto K_large = sqrt_y / a - wdash;
274 const auto K_small_value = ceil(fmax(fmax(K_small, K_large), ret_t(1.0)));
275
276 // Depending on the Ks use formula for small reaction times or large
277 // reaction times (see Navarro & Fuss, 2009)
278 if (K_large_value > 4 * K_small_value) {
279 const auto neg_vy = neg_v * y;
280 auto ans = ret_t(0.0);
281 auto F_k = ret_t(0.0);
282 for (auto k = K_small_value; k >= 0; --k) {
283 auto r_k_1 = a * (2.0 * k + one_m_w);
284 auto x_1 = r_k_1 - neg_vy;
285 auto x_over_sqrt_y_1 = x_1 / sqrt_y;
286 auto d_k_1 = std_normal_lpdf(r_k_1 / sqrt_y);
287 auto temp_1 = make_finite(exp(d_k_1 + logMill(x_over_sqrt_y_1)));
288 auto exp_d_k_1 = exp(d_k_1);
289 auto ans_1 = -temp_1 * neg_vy - sqrt_y * exp_d_k_1;
290
291 auto x_2 = r_k_1 + neg_vy;
292 auto x_over_sqrt_y_2 = x_2 / sqrt_y;
293 auto temp_2 = make_finite(exp(d_k_1 + logMill(x_over_sqrt_y_2)));
294 auto ans_2 = temp_2 * neg_vy - sqrt_y * exp_d_k_1;
295 auto r_k_2 = a * (2.0 * k + 1.0 + w);
296 auto d_k_2 = std_normal_lpdf(r_k_2 / sqrt_y);
297
298 auto x_3 = r_k_2 - neg_vy;
299 auto x_over_sqrt_y_3 = x_3 / sqrt_y;
300 auto temp_3 = make_finite(exp(d_k_2 + logMill(x_over_sqrt_y_3)));
301 auto exp_d_k_2 = exp(d_k_2);
302 auto ans_3 = -temp_3 * neg_vy - sqrt_y * exp_d_k_2;
303
304 auto x_4 = r_k_2 + neg_vy;
305 auto x_over_sqrt_y_4 = x_4 / sqrt_y;
306 auto temp_4 = make_finite(exp(d_k_2 + logMill(x_over_sqrt_y_4)));
307 auto ans_4 = temp_4 * neg_vy - sqrt_y * exp_d_k_2;
308
309 ans += (ans_1 + ans_2 + ans_3 - ans_4) * (2.0 * k + one_m_w)
310 + ans_3 * one_m_w;
311 }
312 F_k = make_finite(exp(one_m_w_a_neg_v + 0.5 * square(neg_v) * y));
313 const auto summands_small_y = ans / (y * F_k);
314 return -one_m_w_neg_v * cdf + summands_small_y;
315 }
316 ret_t ans = 0.0;
317 for (auto k = K_large_value; k > 0; --k) {
318 const auto kpi = k * pi();
319 const auto kpia2 = square(kpi / a);
320 const auto denom = square(neg_v) + kpia2;
321 auto last = (square(kpi) / pow(a, 3) * (y + 2.0 / denom)) * k / denom
322 * exp(-0.5 * kpia2 * y);
323 ans -= last * sin(kpi * one_m_w);
324 }
325 const ret_t prob
326 = make_finite(exp(log_probability_distribution(a, neg_v, one_m_w)));
327 const auto dav = log_probability_GradAV(a, neg_v, one_m_w);
328 auto dav_neg_v = dav * neg_v;
329 auto prob_deriv = fabs(neg_v) == 0
330 ? ret_t(0.0)
331 : is_inf(dav_neg_v) ? NEGATIVE_INFTY : dav_neg_v * prob;
332 ans = (-2.0 / a - one_m_w_neg_v) * (cdf - prob)
333 + ans * (2.0 * pi() / square(a))
334 * exp(-one_m_w_a_neg_v - 0.5 * square(neg_v) * y);
335 return prob_deriv + ans;
336}
337
350template <typename T_y, typename T_a, typename T_v, typename T_w,
351 typename T_cdf, typename T_err>
352inline auto wiener4_cdf_grad_v(const T_y& y, const T_a& a, const T_v& v,
353 const T_w& w, T_cdf&& cdf,
354 T_err log_err = log(1e-12)) {
356 const auto neg_v = -v;
357 const auto one_m_w = 1.0 - w;
358
359 const auto one_m_w_a = one_m_w * a;
360 const auto one_m_w_a_neg_v = one_m_w_a * neg_v;
361
362 const auto log_y = log(y);
363 const auto factor = one_m_w_a_neg_v + square(neg_v) * y / 2.0 + log_err;
364
365 const auto log_a = log(a);
366 auto K_large_value = ret_t(1.0);
367 if (neg_v != 0) {
368 const auto temp = -make_finite(exp(log_a - LOG_PI - 0.5 * log_y));
369 const auto log_v = log(fabs(neg_v));
370 auto alphK_large = fmin(exp(factor + 0.5 * (7 * LOG_PI + log_y)
371 - 2.5 * LOG_TWO - 3 * log_a - log_v),
372 1.0);
373 K_large_value
374 = fmax(ceil((alphK_large == 0)
375 ? ret_t(INFTY)
376 : (alphK_large == 1) ? ret_t(NEGATIVE_INFTY)
377 : temp * inv_Phi(alphK_large)),
378 ret_t(1.0));
379 }
380
381 const auto sqrt_y = sqrt(y);
382 const auto wdash = fmin(one_m_w, w);
383 auto K_large = fabs(neg_v) / a * y - wdash;
384 const auto alphK_small = factor + 0.5 * (LOG_TWO - log_y + LOG_PI);
385 const auto K_small
386 = (alphK_small < 0) ? sqrt_y * sqrt(-2.0 * alphK_small) / a - wdash : 0;
387 const auto K_small_value = ceil(fmax(fmax(K_small, K_large), ret_t(1.0)));
388 if (K_large_value > 4 * K_small_value) {
389 const auto sqrt_y = sqrt(y);
390 const auto neg_vy = neg_v * y;
391 auto ans = ret_t(0.0);
392 auto F_k = ret_t(0.0);
393 for (auto k = K_small_value; k >= 0; --k) {
394 auto r_k_1 = a * (2.0 * k + one_m_w);
395 auto d_k_1 = std_normal_lpdf(r_k_1 / sqrt_y);
396 auto x_1 = r_k_1 - neg_vy;
397 auto x_over_sqrt_y_1 = x_1 / sqrt_y;
398 auto ans_1 = make_finite(exp(d_k_1 + logMill(x_over_sqrt_y_1)));
399
400 auto x_2 = r_k_1 + neg_vy;
401 auto x_over_sqrt_y_2 = x_2 / sqrt_y;
402 auto ans_2 = make_finite(exp(d_k_1 + logMill(x_over_sqrt_y_2)));
403 auto r_k_2 = a * (2.0 * k + 1.0 + w);
404 auto d_k_2 = std_normal_lpdf(r_k_2 / sqrt_y);
405
406 auto x_3 = r_k_2 - neg_vy;
407 auto x_over_sqrt_y_3 = x_3 / sqrt_y;
408 auto ans_3 = make_finite(exp(d_k_2 + logMill(x_over_sqrt_y_3)));
409
410 auto x_4 = r_k_2 + neg_vy;
411 auto x_over_sqrt_y_4 = x_4 / sqrt_y;
412 auto ans_4 = make_finite(exp(d_k_2 + logMill(x_over_sqrt_y_4)));
413 ans += -ans_1 * x_1 + ans_2 * x_2 + ans_3 * x_3 - ans_4 * x_4;
414 }
415 F_k = make_finite(exp(one_m_w_a_neg_v + 0.5 * square(neg_v) * y));
416 const auto summands_small_y = ans / F_k;
417 return (one_m_w_a + neg_vy) * cdf - summands_small_y;
418 }
419 ret_t ans = 0.0;
420 for (auto k = K_large_value; k > 0; --k) {
421 const auto kpi = k * pi();
422 const auto kpia2 = square(kpi / a);
423 const auto ekpia2y = exp(-0.5 * kpia2 * y);
424 const auto denom = square(neg_v) + kpia2;
425 const auto denomk = k / denom;
426 auto last = denomk * ekpia2y / denom;
427 ans -= last * sin(kpi * one_m_w);
428 }
429 const ret_t prob
430 = make_finite(exp(log_probability_distribution(a, neg_v, one_m_w)));
431 const auto dav = log_probability_GradAV(a, neg_v, one_m_w);
432 auto dav_a = dav * a;
433 auto prob_deriv = is_inf(dav_a) ? ret_t(NEGATIVE_INFTY) : dav_a * prob;
434 ans = (-one_m_w_a + v * y) * (cdf - prob)
435 + ans * 4.0 * v * pi() / square(a)
436 * exp(-one_m_w_a_neg_v - 0.5 * square(neg_v) * y);
437 return -(prob_deriv + ans);
438}
439
452template <typename T_y, typename T_a, typename T_v, typename T_w,
453 typename T_cdf, typename T_err>
454inline auto wiener4_cdf_grad_w(const T_y& y, const T_a& a, const T_v& v,
455 const T_w& w, T_cdf&& cdf,
456 T_err log_err = log(1e-12)) {
458 const auto neg_v = -v;
459 const auto one_m_w = 1 - w;
460
461 const auto one_m_w_a_neg_v = one_m_w * a * neg_v;
462
463 const auto factor = one_m_w_a_neg_v + square(neg_v) * y / 2.0 + log_err;
464
465 const auto log_y = log(y);
466 const auto log_a = log(a);
467 const auto temp = -make_finite(exp(log_a - LOG_PI - 0.5 * log_y));
468 auto alphK_large
469 = fmin(exp(factor + 0.5 * (LOG_PI + log_y) - 1.5 * LOG_TWO - log_a), 1.0);
470 alphK_large = fmax(0.0, alphK_large);
471 const auto K_large_value
472 = fmax(ceil((alphK_large == 0)
473 ? ret_t(INFTY)
474 : (alphK_large == 1) ? ret_t(NEGATIVE_INFTY)
475 : temp * inv_Phi(alphK_large)),
476 ret_t(1.0));
477
478 const auto sqrt_y = sqrt(y);
479 const auto wdash = fmin(one_m_w, w);
480 const auto K_large = fabs(neg_v) / a * y - wdash;
481 const auto lv = log1p(square(neg_v) * y);
482 const auto alphK_small = factor - LOG_TWO - lv;
483 const auto arg = fmin(exp(alphK_small), 1.0);
484 const auto K_small
485 = (arg == 0)
486 ? INFTY
487 : (arg == 1) ? NEGATIVE_INFTY : -sqrt_y / a * inv_Phi(arg) - wdash;
488 const auto K_small_value = ceil(fmax(fmax(K_small, K_large), ret_t(1.0)));
489
490 if (K_large_value > 4 * K_small_value) {
491 const auto sqrt_y = sqrt(y);
492 const auto neg_vy = neg_v * y;
493 auto ans = ret_t(0.0);
494 auto F_k = ret_t(0.0);
495 for (auto k = K_small_value; k >= 0; --k) {
496 auto r_k_1 = a * (2.0 * k + one_m_w);
497 auto d_k_1 = std_normal_lpdf(r_k_1 / sqrt_y);
498 auto x_1 = r_k_1 - neg_vy;
499 auto x_over_sqrt_y_1 = x_1 / sqrt_y;
500 auto temp_1 = make_finite(exp(d_k_1 + logMill(x_over_sqrt_y_1)));
501 auto exp_d_k_1 = exp(d_k_1);
502 auto ans_1 = -temp_1 * neg_vy - sqrt_y * exp_d_k_1;
503
504 auto x_2 = r_k_1 + neg_vy;
505 auto x_over_sqrt_y_2 = x_2 / sqrt_y;
506 auto temp_2 = make_finite(exp(d_k_1 + logMill(x_over_sqrt_y_2)));
507 auto ans_2 = temp_2 * neg_vy - sqrt_y * exp_d_k_1;
508 auto r_k_2 = a * (2.0 * k + 1.0 + w);
509 auto d_k_2 = std_normal_lpdf(r_k_2 / sqrt_y);
510
511 auto x_3 = r_k_2 - neg_vy;
512 auto x_over_sqrt_y_3 = x_3 / sqrt_y;
513 auto temp_3 = make_finite(exp(d_k_2 + logMill(x_over_sqrt_y_3)));
514 auto exp_d_k_2 = exp(d_k_2);
515 auto ans_3 = -temp_3 * neg_vy - sqrt_y * exp_d_k_2;
516
517 auto x_4 = r_k_2 + neg_vy;
518 auto x_over_sqrt_y_4 = x_4 / sqrt_y;
519 auto temp_4 = make_finite(exp(d_k_2 + logMill(x_over_sqrt_y_4)));
520 auto ans_4 = temp_4 * neg_vy - sqrt_y * exp_d_k_2;
521
522 ans += (ans_1 + ans_2 + ans_3 - ans_4) * a;
523 }
524 F_k = make_finite(exp(one_m_w_a_neg_v + 0.5 * square(neg_v) * y));
525 const auto summands_small_y = ans / (y * F_k);
526 return neg_v * a * cdf - summands_small_y;
527 }
528 ret_t ans = 0.0;
529 for (auto k = K_large_value; k > 0; --k) {
530 const auto kpi = k * pi();
531 const auto kpia2 = square(kpi / a);
532 const auto ekpia2y = exp(-0.5 * kpia2 * y);
533 const auto denom = square(neg_v) + kpia2;
534 const auto denomk = k / denom;
535 auto last = kpi;
536 last *= denomk * ekpia2y;
537 ans -= last * cos(kpi * one_m_w);
538 }
539 const auto evaw = exp(-one_m_w_a_neg_v - 0.5 * square(neg_v) * y);
540 const ret_t prob
541 = make_finite(exp(log_probability_distribution(a, neg_v, one_m_w)));
542
543 // Calculate the probability term 'P' on log scale
544 auto dav = ret_t(-1 / w);
545 if (neg_v != 0) {
546 auto nearly_one = ret_t(1.0 - 1.0e-6);
547 const auto sign_v = (neg_v < 0) ? 1 : -1;
548 const auto sign_two_va_one_minus_w = sign_v * (2.0 * neg_v * a * w);
549 const auto exp_arg = exp(sign_two_va_one_minus_w);
550 if (exp_arg >= nearly_one) {
551 dav = -1.0 / w;
552 } else {
553 auto prob = LOG_TWO + log(fabs(neg_v)) + log(a) - log1m(exp_arg);
554 if (neg_v < 0) {
555 prob += sign_two_va_one_minus_w;
556 }
557 dav = -exp(prob);
558 }
559 }
560
561 const auto pia2 = 2.0 * pi() / square(a);
562 auto prob_deriv = dav * prob;
563 ans = v * a * (cdf - prob) + ans * pia2 * evaw;
564 return -(prob_deriv + ans);
565}
566} // namespace internal
567
587template <bool propto = false, typename T_y, typename T_a, typename T_t0,
588 typename T_w, typename T_v>
589inline auto wiener_lcdf_defective(const T_y& y, const T_a& a, const T_t0& t0,
590 const T_w& w, const T_v& v,
591 const double& precision_derivatives = 1e-4) {
592 using T_partials_return = partials_return_t<T_y, T_a, T_t0, T_w, T_v>;
593 using T_y_ref = ref_type_t<T_y>;
594 using T_a_ref = ref_type_t<T_a>;
595 using T_t0_ref = ref_type_t<T_t0>;
596 using T_w_ref = ref_type_t<T_w>;
597 using T_v_ref = ref_type_t<T_v>;
600
601 T_y_ref y_ref = y;
602 T_a_ref a_ref = a;
603 T_t0_ref t0_ref = t0;
604 T_w_ref w_ref = w;
605 T_v_ref v_ref = v;
606
607 auto y_val = to_ref(as_value_column_array_or_scalar(y_ref));
608 auto a_val = to_ref(as_value_column_array_or_scalar(a_ref));
609 auto v_val = to_ref(as_value_column_array_or_scalar(v_ref));
610 auto w_val = to_ref(as_value_column_array_or_scalar(w_ref));
611 auto t0_val = to_ref(as_value_column_array_or_scalar(t0_ref));
612
614 return ret_t(0.0);
615 }
616
617 static constexpr const char* function_name = "wiener4_lcdf";
618 if (size_zero(y, a, t0, w, v)) {
619 return ret_t(0.0);
620 }
621
622 check_consistent_sizes(function_name, "Random variable", y,
623 "Boundary separation", a, "Drift rate", v,
624 "A-priori bias", w, "Nondecision time", t0);
625 check_positive_finite(function_name, "Random variable", y_val);
626 check_positive_finite(function_name, "Boundary separation", a_val);
627 check_finite(function_name, "Drift rate", v_val);
628 check_less(function_name, "A-priori bias", w_val, 1);
629 check_greater(function_name, "A-priori bias", w_val, 0);
630 check_nonnegative(function_name, "Nondecision time", t0_val);
631 check_finite(function_name, "Nondecision time", t0_val);
632
633 const size_t N = max_size(y, a, t0, w, v);
634
635 scalar_seq_view<T_y_ref> y_vec(y_ref);
636 scalar_seq_view<T_a_ref> a_vec(a_ref);
637 scalar_seq_view<T_t0_ref> t0_vec(t0_ref);
638 scalar_seq_view<T_w_ref> w_vec(w_ref);
639 scalar_seq_view<T_v_ref> v_vec(v_ref);
640 const size_t N_y_t0 = max_size(y, t0);
641
642 for (size_t i = 0; i < N_y_t0; ++i) {
643 if (y_vec[i] <= t0_vec[i]) {
644 std::stringstream msg;
645 msg << ", but must be greater than nondecision time = " << t0_vec[i];
646 std::string msg_str(msg.str());
647 throw_domain_error(function_name, "Random variable", y_vec[i], " = ",
648 msg_str.c_str());
649 }
650 }
651
652 // for precs. 1e-6, 1e-12, see Hartmann et al. (2021), Henrich et al. (2023)
653 const auto log_error_cdf = log(1e-6);
654 const auto log_error_derivative = log(precision_derivatives);
655 const T_partials_return log_error_absolute = log(1e-12);
656 T_partials_return lcdf = 0.0;
657 auto ops_partials
658 = make_partials_propagator(y_ref, a_ref, t0_ref, w_ref, v_ref);
659
660 const double LOG_FOUR = std::log(4.0);
661
662 // calculate distribution and partials
663 for (size_t i = 0; i < N; i++) {
664 const auto y_value = y_vec.val(i);
665 const auto a_value = a_vec.val(i);
666 const auto t0_value = t0_vec.val(i);
667 const auto w_value = w_vec.val(i);
668 const auto v_value = v_vec.val(i);
669
670 const T_partials_return log_cdf
671 = internal::estimate_with_err_check<4, 0, GradientCalc::OFF,
672 GradientCalc::OFF>(
673 [](auto&&... args) {
674 return internal::wiener4_distribution<GradientCalc::OFF>(args...);
675 },
676 log_error_cdf - LOG_TWO, y_value - t0_value, a_value, v_value,
677 w_value, log_error_absolute);
678
679 const T_partials_return cdf = exp(log_cdf);
680
681 lcdf += log_cdf;
682
683 const auto new_est_err = log_cdf + log_error_derivative - LOG_FOUR;
684
686 const auto deriv_y
687 = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
688 GradientCalc::ON>(
689 [](auto&&... args) {
690 return internal::wiener5_density<GradientCalc::ON>(args...);
691 },
692 new_est_err, y_value - t0_value, a_value, v_value, w_value, 0,
693 log_error_absolute);
694
696 partials<0>(ops_partials)[i] = deriv_y / cdf;
697 }
699 partials<2>(ops_partials)[i] = -deriv_y / cdf;
700 }
701 }
703 partials<1>(ops_partials)[i]
704 = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
705 GradientCalc::ON>(
706 [](auto&&... args) {
707 return internal::wiener4_cdf_grad_a(args...);
708 },
709 new_est_err, y_value - t0_value, a_value, v_value, w_value, cdf,
710 log_error_absolute)
711 / cdf;
712 }
714 partials<3>(ops_partials)[i]
715 = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
716 GradientCalc::ON>(
717 [](auto&&... args) {
718 return internal::wiener4_cdf_grad_w(args...);
719 },
720 new_est_err, y_value - t0_value, a_value, v_value, w_value, cdf,
721 log_error_absolute)
722 / cdf;
723 }
725 partials<4>(ops_partials)[i]
726 = internal::wiener4_cdf_grad_v(y_value - t0_value, a_value, v_value,
727 w_value, cdf, log_error_absolute)
728 / cdf;
729 }
730 } // for loop
731 return ops_partials.build(lcdf);
732}
733} // namespace math
734} // namespace stan
735#endif
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
return_type_t< T_y_cl > std_normal_lcdf(const T_y_cl &y)
Returns the log standard normal complementary cumulative distribution function.
return_type_t< T_y_cl > std_normal_lpdf(const T_y_cl &y)
The log of the normal density for the specified scalar(s) given a location of 0 and a scale of 1.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
auto wiener4_cdf_grad_w(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_cdf &&cdf, T_err log_err=log(1e-12))
Calculate derivative of the wiener4 distribution w.r.t.
auto wiener4_distribution(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_err log_err=log(1e-12))
Calculate the wiener4 distribution.
auto wiener4_cdf_grad_v(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_cdf &&cdf, T_err log_err=log(1e-12))
Calculate derivative of the wiener4 distribution w.r.t.
auto estimate_with_err_check(F &&functor, T_err &&log_err, ArgsTupleT &&... args_tuple)
Utility function for estimating a function with a given set of arguments, checking the result against...
auto logMill(T_x &&x)
Log of Mill's ratio for the normal distribution.
auto make_finite(const T_x &x)
Make the expression finite.
auto wiener4_cdf_grad_a(const T_y &y, const T_a &a, const T_v &v, const T_w &w, T_cdf &&cdf, T_err log_err=log(1e-12))
Calculate derivative of the wiener4 distribution w.r.t.
auto log_probability_distribution(const T_a &a, const T_v &v, const T_w &w)
Calculate the probability term 'P' on log scale for distribution.
auto log_probability_GradAV(const T_a &a, const T_v &v, const T_w &w)
Calculate the probability term 'P' on log scale for grad_a and grad_v.
fvar< T > sin(const fvar< T > &x)
Definition sin.hpp:16
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition size_zero.hpp:19
fvar< T > fmin(const fvar< T > &x1, const fvar< T > &x2)
Definition fmin.hpp:14
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
fvar< T > log1m_exp(const fvar< T > &x)
Return the natural logarithm of one minus the exponentiation of the specified argument.
Definition log1m_exp.hpp:22
auto pow(const T1 &x1, const T2 &x2)
Definition pow.hpp:32
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
Definition arg.hpp:19
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
auto wiener_lcdf_defective(const T_y &y, const T_a &a, const T_t0 &t0, const T_w &w, const T_v &v, const double &precision_derivatives=1e-4)
Log-CDF function for the 4-parameter Wiener distribution.
static constexpr double LOG_TWO
The natural logarithm of 2, .
Definition constants.hpp:80
auto as_value_column_array_or_scalar(T &&a)
Extract the value from an object and for eigen vectors and std::vectors convert to an eigen column ar...
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:18
static constexpr double LOG_PI
The natural logarithm of , .
Definition constants.hpp:86
fvar< T > log1p(const fvar< T > &x)
Definition log1p.hpp:12
fvar< T > inv_Phi(const fvar< T > &p)
Definition inv_Phi.hpp:16
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
fvar< T > fmax(const fvar< T > &x1, const fvar< T > &x2)
Return the greater of the two specified arguments.
Definition fmax.hpp:23
fvar< T > log_diff_exp(const fvar< T > &x1, const fvar< T > &x2)
fvar< T > cos(const fvar< T > &x)
Definition cos.hpp:16
static constexpr double pi()
Return the value of pi.
Definition constants.hpp:36
int is_inf(const fvar< T > &x)
Returns 1 if the input's value is infinite and 0 otherwise.
Definition is_inf.hpp:21
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
void check_less(const char *function, const char *name, const T_y &y, const T_high &high, Idxs... idxs)
Throw an exception if y is not strictly less than high.
fvar< T > ceil(const fvar< T > &x)
Definition ceil.hpp:13
int64_t max_size(const T1 &x1, const Ts &... xs)
Calculate the size of the largest input.
Definition max_size.hpp:20
fvar< T > log1m(const fvar< T > &x)
Definition log1m.hpp:12
void check_greater(const char *function, const char *name, const T_y &y, const T_low &low, Idxs... idxs)
Throw an exception if y is not strictly greater than low.
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
static constexpr double INFTY
Positive infinity.
Definition constants.hpp:46
fvar< T > fabs(const fvar< T > &x)
Definition fabs.hpp:16
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
fvar< T > log_sum_exp(const fvar< T > &x1, const fvar< T > &x2)
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
typename ref_type_if< true, T >::type ref_type_t
Definition ref_type.hpp:56
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...