Automatic Differentiation
 
Loading...
Searching...
No Matches
wiener5_lpdf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_PROB_WIENER5_LPDF_HPP
2#define STAN_MATH_PRIM_PROB_WIENER5_LPDF_HPP
3
6
7namespace stan {
8namespace math {
9namespace internal {
10
11enum GradientCalc { OFF = 0, ON = 1 };
12
29template <typename T_y, typename T_a, typename T_v, typename T_w, typename T_sv>
30inline auto wiener5_compute_log_error_term(T_y&& y, T_a&& a, T_v&& v, T_w&& w,
31 T_sv&& sv) noexcept {
32 const auto one_m_w = 1.0 - w;
33 const auto neg_v = -v;
34 const auto sv_sqr = square(sv);
35 const auto one_plus_svsqr_y = 1 + sv_sqr * y;
36 const auto two_avw = 2.0 * a * neg_v * one_m_w;
37 const auto two_log_a = 2.0 * log(a);
38 return stan::math::eval(
39 (sv_sqr * square(a * one_m_w) - two_avw - square(neg_v) * y) / 2.0
40 / one_plus_svsqr_y
41 - two_log_a - 0.5 * log(one_plus_svsqr_y));
42}
43
60template <bool Density, GradientCalc GradW, typename T_y, typename T_a,
61 typename T_w, typename T_err>
62inline auto wiener5_n_terms_small_t(T_y&& y, T_a&& a, T_w&& w,
63 T_err error) noexcept {
64 const auto two_error = 2.0 * error;
65 const auto y_asq = y / square(a);
66 const auto two_log_a = 2.0 * log(a);
67 const auto log_y_asq = log(y) - two_log_a;
68 const auto one_m_w = 1.0 - w;
69
70 constexpr auto n_1_factor = Density ? 2.0 : 3.0;
71 const auto n_1 = (sqrt(n_1_factor * y_asq) + one_m_w) / 2.0;
72 auto u_eps = (Density || GradW)
73 ? fmin(-1.0, LOG_TWO + LOG_PI + 2.0 * log_y_asq + two_error)
74 : fmin(-3.0, (log(8.0) - log(27.0) + LOG_PI + 4.0 * log_y_asq
75 + two_error));
76 const auto arg_mult = (Density || GradW) ? 1 : 3;
77 const auto arg = -arg_mult * y_asq * (u_eps - sqrt(-2.0 * u_eps - 2.0));
78
79 const auto n_2 = (arg > 0) ? GradW ? 0.5 * (sqrt(arg) + one_m_w)
80 : 0.5 * (sqrt(arg) - one_m_w)
81 : n_1;
82
83 return ceil(fmax(n_1, n_2));
84}
85
100template <typename T_y, typename T_a, typename T_w, typename T_err>
101inline auto wiener5_density_large_reaction_time_terms(T_y&& y, T_a&& a, T_w&& w,
102 T_err error) noexcept {
103 const auto y_asq = y / square(a);
104 const auto log_y_asq = log(y) - 2.0 * log(a);
105 static constexpr double PI_SQUARED = pi() * pi();
106 auto n_1 = 1.0 / (pi() * sqrt(y_asq));
107 const auto two_log_piy = -2.0 * (LOG_PI + log_y_asq + error);
108 auto n_2
109 = (two_log_piy >= 0) ? sqrt(two_log_piy / (PI_SQUARED * y_asq)) : 0.0;
110 return ceil(fmax(n_1, n_2));
111}
112
128template <GradientCalc GradW, typename T_y, typename T_a, typename T_w,
129 typename T_err>
131 T_w&& w,
132 T_err error) noexcept {
133 const auto y_asq = y / square(a);
134 const auto log_y_asq = log(y) - 2.0 * log(a);
135 static constexpr double PI_SQUARED = pi() * pi();
136 static constexpr auto n_1_factor = GradW ? 2.0 : 3.0;
137 auto n_1 = sqrt(n_1_factor / y_asq) / pi();
138 const auto two_error = 2.0 * error;
139 const auto u_eps_arg
140 = GradW ? log(4.0) - log(9.0) + 2.0 * LOG_PI + 3.0 * log_y_asq + two_error
141 : log(3.0) - log(5.0) + LOG_PI + 2.0 * log_y_asq + error;
142 const auto u_eps = fmin(-1, u_eps_arg);
143 if constexpr (GradW) {
144 const auto arg = -(u_eps - sqrt(-2.0 * u_eps - 2.0));
145 auto n_2 = (arg > 0) ? sqrt(arg / y_asq) / pi() : n_1;
146 return ceil(fmax(n_1, n_2));
147 } else {
148 const auto arg
149 = -(2.0 / PI_SQUARED / y_asq) * (u_eps - sqrt(-2.0 * u_eps - 2.0));
150 auto n_2 = (arg > 0) ? sqrt(arg) : n_1;
151 return ceil(fmax(n_1, n_2));
152 }
153}
154
173template <bool Density, GradientCalc GradW, typename T_y, typename T_a,
174 typename T_w, typename T_nsmall, typename T_nlarge>
175inline auto wiener5_log_sum_exp(T_y&& y, T_a&& a, T_w&& w,
176 T_nsmall&& n_terms_small_t,
177 T_nlarge&& n_terms_large_t) noexcept {
179
180 const auto y_asq = y / square(a);
181 const auto one_m_w = 1.0 - w;
182 const bool small_n_terms_small_t
183 = Density ? (2.0 * n_terms_small_t <= n_terms_large_t)
184 : (2.0 * n_terms_small_t < n_terms_large_t);
185 const auto scaling = small_n_terms_small_t ? inv(2.0 * y_asq) : y_asq / 2.0;
186 ret_t fplus = NEGATIVE_INFTY;
187 ret_t fminus = NEGATIVE_INFTY;
188 int current_sign;
189 if (small_n_terms_small_t) {
190 constexpr double mult = Density ? 1.0 : 3.0;
191 if constexpr (GradW) {
192 for (auto k = n_terms_small_t; k >= 1; k--) {
193 const auto w_plus_2_k = one_m_w + 2.0 * k;
194 const auto w_minus_2_k = one_m_w - 2.0 * k;
195 const auto square_w_plus_2_k_minus_offset = square(w_plus_2_k) - y_asq;
196 if (square_w_plus_2_k_minus_offset > 0) {
197 const auto summand_plus = log(square_w_plus_2_k_minus_offset)
198 - square(w_plus_2_k) * scaling;
199 fplus = log_sum_exp(fplus, summand_plus);
200 } else if (square_w_plus_2_k_minus_offset < 0) {
201 const auto summand_plus = log(-square_w_plus_2_k_minus_offset)
202 - square(w_plus_2_k) * scaling;
203 fminus = log_sum_exp(fminus, summand_plus);
204 }
205 const auto square_w_minus_2_k_minus_offset
206 = square(w_minus_2_k) - y_asq;
207 if (square_w_minus_2_k_minus_offset > 0) {
208 const auto summand_minus = log(square_w_minus_2_k_minus_offset)
209 - square(w_minus_2_k) * scaling;
210 fplus = log_sum_exp(fplus, summand_minus);
211 } else if (square_w_minus_2_k_minus_offset < 0) {
212 const auto summand_minus = log(-square_w_minus_2_k_minus_offset)
213 - square(w_minus_2_k) * scaling;
214 fminus = log_sum_exp(fminus, summand_minus);
215 }
216 }
217 const auto square_w_minus_offset = square(one_m_w) - y_asq;
218 if (square_w_minus_offset > 0) {
219 const auto new_val
220 = log(square_w_minus_offset) - square(one_m_w) * scaling;
221 fplus = log_sum_exp(fplus, new_val);
222 } else if (square_w_minus_offset < 0) {
223 const auto new_val
224 = log(-square_w_minus_offset) - square(one_m_w) * scaling;
225 fminus = log_sum_exp(fminus, new_val);
226 }
227 } else {
228 for (auto k = n_terms_small_t; k >= 1; k--) {
229 const auto w_plus_2_k = one_m_w + 2.0 * k;
230 const auto w_minus_2_k = one_m_w - 2.0 * k;
231 const auto summand_plus
232 = mult * log(w_plus_2_k) - square(w_plus_2_k) * scaling;
233 fplus = log_sum_exp(fplus, summand_plus);
234 const auto summand_minus
235 = mult * log(-w_minus_2_k) - square(w_minus_2_k) * scaling;
236 if (fminus <= NEGATIVE_INFTY) {
237 fminus = summand_minus;
238 } else if (summand_minus <= NEGATIVE_INFTY) {
239 continue;
240 } else if (fminus > summand_minus) {
241 fminus = fminus + log1p_exp(summand_minus - fminus);
242 } else {
243 fminus = summand_minus + log1p_exp(fminus - summand_minus);
244 }
245 }
246 const auto new_val = mult * log(one_m_w) - square(one_m_w) * scaling;
247 fplus = log_sum_exp(fplus, new_val);
248 }
249 } else { // for large t
250 constexpr double mult = (Density ? 1.0 : (GradW ? 2.0 : 3.0));
251 for (auto k = n_terms_large_t; k >= 1; k--) {
252 const auto pi_k = k * pi();
253 const auto check = (GradW) ? cos(pi_k * one_m_w) : sin(pi_k * one_m_w);
254 if (check > 0) {
255 fplus = log_sum_exp(
256 fplus, mult * log(k) - square(pi_k) * scaling + log(check));
257 } else if ((GradW && check < 0) || !GradW) {
258 fminus = log_sum_exp(
259 fminus, mult * log(k) - square(pi_k) * scaling + log(-check));
260 }
261 }
262 }
263 current_sign = (fplus < fminus) ? -1 : 1;
264 if (fplus == NEGATIVE_INFTY) {
265 return std::make_pair(fminus, current_sign);
266 } else if (fminus == NEGATIVE_INFTY) {
267 return std::make_pair(fplus, current_sign);
268 } else if (fplus > fminus) {
269 return std::make_pair(log_diff_exp(fplus, fminus), current_sign);
270 } else if (fplus < fminus) {
271 return std::make_pair(log_diff_exp(fminus, fplus), current_sign);
272 } else {
273 return std::make_pair(ret_t(NEGATIVE_INFTY), current_sign);
274 }
275}
276
298template <bool NaturalScale = false, typename T_y, typename T_a, typename T_w,
299 typename T_v, typename T_sv, typename T_err>
300inline auto wiener5_density(const T_y& y, const T_a& a, const T_v& v,
301 const T_w& w, const T_sv& sv,
302 T_err log_err = log(1e-12)) noexcept {
303 const auto log_error_term = wiener5_compute_log_error_term(y, a, v, w, sv);
304 const auto log_error = (log_err - log_error_term);
305 const auto n_terms_small_t
306 = wiener5_n_terms_small_t<GradientCalc::ON, GradientCalc::OFF>(y, a, w,
307 log_error);
308 const auto n_terms_large_t
309 = wiener5_density_large_reaction_time_terms(y, a, w, log_error);
310
311 auto res = wiener5_log_sum_exp<GradientCalc::ON, GradientCalc::OFF>(
312 y, a, w, n_terms_small_t, n_terms_large_t)
313 .first;
314 if (2 * n_terms_small_t <= n_terms_large_t) {
315 auto log_density = log_error_term - 0.5 * LOG_TWO - LOG_SQRT_PI
316 - 1.5 * (log(y) - 2.0 * log(a)) + res;
317 return NaturalScale ? exp(log_density) : log_density;
318 } else {
319 auto log_density = log_error_term + res + LOG_PI;
320 return NaturalScale ? exp(log_density) : log_density;
321 }
322}
323
345template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
346 typename T_v, typename T_sv, typename T_err>
347inline auto wiener5_grad_t(const T_y& y, const T_a& a, const T_v& v,
348 const T_w& w, const T_sv& sv,
349 T_err log_err = log(1e-12)) noexcept {
350 const auto two_log_a = 2.0 * log(a);
351 const auto log_y_asq = log(y) - two_log_a;
352 const auto log_error_term = wiener5_compute_log_error_term(y, a, v, w, sv);
353 const auto one_m_w = 1.0 - w;
354 const auto neg_v = -v;
355 const auto sv_sqr = square(sv);
356 const auto one_plus_svsqr_y = 1 + sv_sqr * y;
357 const auto density_part_one
358 = -0.5
359 * (square(sv_sqr) * (y + square(a * one_m_w))
360 + sv_sqr * (1.0 - (2.0 * a * neg_v * one_m_w)) + square(neg_v))
361 / square(one_plus_svsqr_y);
362 const auto log_error = (log_err - log_error_term) + two_log_a;
363 const auto n_terms_small_t
364 = wiener5_n_terms_small_t<GradientCalc::OFF, GradientCalc::OFF>(
365 y, a, w, log_error);
366 const auto n_terms_large_t
367 = wiener5_gradient_large_reaction_time_terms<GradientCalc::OFF>(
368 y, a, w, log_error);
369 auto wiener_res = wiener5_log_sum_exp<GradientCalc::OFF, GradientCalc::OFF>(
370 y, a, w, n_terms_small_t, n_terms_large_t);
371 auto&& result = wiener_res.first;
372 auto&& newsign = wiener_res.second;
373 const auto error_log_density
374 = log(fmax(fabs(density_part_one - 1.5 / y), fabs(density_part_one)));
375 const auto log_density = wiener5_density<GradientCalc::OFF>(
376 y, a, v, w, sv, log_err - error_log_density);
377 if (2.0 * n_terms_small_t < n_terms_large_t) {
378 auto ans
379 = density_part_one - 1.5 / y
380 + newsign
381 * exp(log_error_term - two_log_a - 1.5 * LOG_TWO - LOG_SQRT_PI
382 - 3.5 * log_y_asq + result - log_density);
383 return WrtLog ? ans * exp(log_density) : ans;
384 } else {
385 auto ans = density_part_one
386 - newsign
387 * exp(log_error_term - two_log_a + 3.0 * LOG_PI - LOG_TWO
388 + result - log_density);
389 return WrtLog ? ans * exp(log_density) : ans;
390 }
391}
392
414template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
415 typename T_v, typename T_sv, typename T_err>
416inline auto wiener5_grad_a(const T_y& y, const T_a& a, const T_v& v,
417 const T_w& w, const T_sv& sv,
418 T_err log_err = log(1e-12)) noexcept {
419 const auto two_log_a = 2.0 * log(a);
420 const auto log_error_term = wiener5_compute_log_error_term(y, a, v, w, sv);
421 const auto one_m_w = 1.0 - w;
422 const auto sv_sqr = square(sv);
423 const auto one_plus_svsqr_y = 1.0 + sv_sqr * y;
424 const auto density_part_one
425 = (v * one_m_w + sv_sqr * square(one_m_w) * a) / one_plus_svsqr_y;
426 const auto log_error
427 = log_err - log_error_term + 3.0 * log(a) - log(y) - LOG_TWO;
428
429 const auto n_terms_small_t
430 = wiener5_n_terms_small_t<GradientCalc::OFF, GradientCalc::OFF>(
431 y, a, w, log_error);
432 const auto n_terms_large_t
433 = wiener5_gradient_large_reaction_time_terms<GradientCalc::OFF>(
434 y, a, w, log_error);
435 auto wiener_res = wiener5_log_sum_exp<GradientCalc::OFF, GradientCalc::OFF>(
436 y, a, w, n_terms_small_t, n_terms_large_t);
437 auto&& result = wiener_res.first;
438 auto&& newsign = wiener_res.second;
439 const auto log_error_log_density = log(
440 fmax(fabs(density_part_one + 1.0 / a), fabs(density_part_one - 2.0 / a)));
441 const auto log_density = wiener5_density<GradientCalc::OFF>(
442 y, a, v, w, sv, log_err - log_error_log_density);
443 if (2.0 * n_terms_small_t < n_terms_large_t) {
444 auto ans = density_part_one + 1.0 / a
445 - newsign
446 * exp(-0.5 * LOG_TWO - LOG_SQRT_PI - 2.5 * log(y)
447 + 2.0 * two_log_a + log_error_term + result
448 - log_density);
449 return WrtLog ? ans * exp(log_density) : ans;
450 } else {
451 auto ans = density_part_one - 2.0 / a
452 + newsign
453 * exp(log(y) + log_error_term - 3.0 * (log(a) - LOG_PI)
454 + result - log_density);
455 return WrtLog ? ans * exp(log_density) : ans;
456 }
457}
458
480template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
481 typename T_v, typename T_sv, typename T_err>
482inline auto wiener5_grad_v(const T_y& y, const T_a& a, const T_v& v,
483 const T_w& w, const T_sv& sv,
484 T_err log_err = log(1e-12)) noexcept {
485 auto ans = (a * (1 - w) - v * y) / (1.0 + square(sv) * y);
486 if constexpr (WrtLog) {
487 return ans * wiener5_density<true>(y, a, v, w, sv, log_err);
488 } else {
489 return ans;
490 }
491}
492
514template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
515 typename T_v, typename T_sv, typename T_err>
516inline auto wiener5_grad_w(const T_y& y, const T_a& a, const T_v& v,
517 const T_w& w, const T_sv& sv,
518 T_err log_err = log(1e-12)) noexcept {
519 const auto two_log_a = 2.0 * log(a);
520 const auto log_y_asq = log(y) - two_log_a;
521 const auto log_error_term = wiener5_compute_log_error_term(y, a, v, w, sv);
522 const auto one_m_w = 1.0 - w;
523 const auto sv_sqr = square(sv);
524 const auto one_plus_svsqr_y = 1.0 + sv_sqr * y;
525 const auto density_part_one
526 = (v * a + sv_sqr * square(a) * one_m_w) / one_plus_svsqr_y;
527 const auto log_error = (log_err - log_error_term);
528
529 const auto n_terms_small_t
530 = wiener5_n_terms_small_t<GradientCalc::OFF, GradientCalc::ON>(y, a, w,
531 log_error);
532 const auto n_terms_large_t
533 = wiener5_gradient_large_reaction_time_terms<GradientCalc::ON>(y, a, w,
534 log_error);
535 auto wiener_res = wiener5_log_sum_exp<GradientCalc::OFF, GradientCalc::ON>(
536 y, a, w, n_terms_small_t, n_terms_large_t);
537 auto&& result = wiener_res.first;
538 auto&& newsign = wiener_res.second;
539 const auto log_density = wiener5_density<GradientCalc::OFF>(
540 y, a, v, w, sv, log_err - log(fabs(density_part_one)));
541 if (2.0 * n_terms_small_t < n_terms_large_t) {
542 auto ans = -(density_part_one
543 - newsign
544 * exp(result - (log_density - log_error_term)
545 - 2.5 * log_y_asq - 0.5 * LOG_TWO - 0.5 * LOG_PI));
546 return WrtLog ? ans * exp(log_density) : ans;
547 } else {
548 auto ans = -(
549 density_part_one
550 + newsign
551 * exp(result - (log_density - log_error_term) + 2.0 * LOG_PI));
552 return WrtLog ? ans * exp(log_density) : ans;
553 }
554}
555
577template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
578 typename T_v, typename T_sv, typename T_err>
579inline auto wiener5_grad_sv(const T_y& y, const T_a& a, const T_v& v,
580 const T_w& w, const T_sv& sv,
581 T_err log_err = log(1e-12)) noexcept {
582 const auto one_plus_svsqr_y = 1.0 + square(sv) * y;
583 const auto one_m_w = 1.0 - w;
584 const auto neg_v = -v;
585 const auto t1 = -y / one_plus_svsqr_y;
586 const auto t2 = (square(a * one_m_w) + 2.0 * a * neg_v * one_m_w * y
587 + square(neg_v * y))
588 / square(one_plus_svsqr_y);
589 const auto ans = sv * (t1 + t2);
590 return WrtLog ? ans * wiener5_density<true>(y, a, v, w, sv, log_err) : ans;
591}
592
603template <size_t NestedIndex, typename Scalar1, typename Scalar2>
604inline void assign_err(Scalar1 arg, Scalar2 err) {
605 arg = err;
606}
607
619template <size_t NestedIndex, typename Scalar, typename... TArgs>
620inline void assign_err(std::tuple<TArgs...>& args_tuple, Scalar err) {
621 std::get<NestedIndex>(args_tuple) = err;
622}
623
641template <int ErrIndex, size_t NestedIndex = 0,
642 GradientCalc GradW7 = GradientCalc::OFF, bool LogResult = true,
643 typename F, typename T_err, typename... ArgsTupleT>
644inline auto estimate_with_err_check(F&& functor, T_err&& log_err,
645 ArgsTupleT&&... args_tuple) {
646 auto result = functor(args_tuple...);
647 auto log_fabs_result = LogResult ? log(fabs(result)) : fabs(result);
648 if (log_fabs_result < log_err) {
649 log_fabs_result = is_inf(log_fabs_result) ? 0 : log_fabs_result;
650 auto err_args_tuple = std::make_tuple(args_tuple...);
651 const auto new_error = GradW7 ? log_err + log_fabs_result + LOG_TWO
652 : log_err + log_fabs_result;
653 if constexpr (NestedIndex != -1) {
654 assign_err<NestedIndex>(std::get<ErrIndex>(err_args_tuple), new_error);
655 }
656 result
657 = math::apply([](auto&& func, auto&&... args) { return func(args...); },
658 err_args_tuple, functor);
659 }
660 return result;
661}
662} // namespace internal
663
686template <bool propto = false, typename T_y, typename T_a, typename T_t0,
687 typename T_w, typename T_v, typename T_sv>
688inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
689 const T_w& w, const T_v& v, const T_sv& sv,
690 const double& precision_derivatives = 1e-4) {
693 using T_y_ref = ref_type_t<T_y>;
694 using T_a_ref = ref_type_t<T_a>;
695 using T_t0_ref = ref_type_t<T_t0>;
696 using T_w_ref = ref_type_t<T_w>;
697 using T_v_ref = ref_type_t<T_v>;
698 using T_sv_ref = ref_type_t<T_sv>;
700
701 T_y_ref y_ref = y;
702 T_a_ref a_ref = a;
703 T_t0_ref t0_ref = t0;
704 T_w_ref w_ref = w;
705 T_v_ref v_ref = v;
706 T_sv_ref sv_ref = sv;
707
708 auto y_val = to_ref(as_value_column_array_or_scalar(y_ref));
709 auto a_val = to_ref(as_value_column_array_or_scalar(a_ref));
710 auto v_val = to_ref(as_value_column_array_or_scalar(v_ref));
711 auto w_val = to_ref(as_value_column_array_or_scalar(w_ref));
712 auto t0_val = to_ref(as_value_column_array_or_scalar(t0_ref));
713 auto sv_val = to_ref(as_value_column_array_or_scalar(sv_ref));
714
715 if constexpr (!include_summand<propto, T_y, T_a, T_t0, T_w, T_v,
716 T_sv>::value) {
717 return ret_t(0.0);
718 }
719
720 static constexpr const char* function_name = "wiener5_lpdf";
721
722 check_consistent_sizes(function_name, "Random variable", y,
723 "Boundary separation", a, "Drift rate", v,
724 "A-priori bias", w, "Nondecision time", t0,
725 "Inter-trial variability in drift rate", sv);
726 check_positive_finite(function_name, "Random variable", y_val);
727 check_positive_finite(function_name, "Boundary separation", a_val);
728 check_finite(function_name, "Drift rate", v_val);
729 check_less(function_name, "A-priori bias", w_val, 1);
730 check_greater(function_name, "A-priori bias", w_val, 0);
731 check_nonnegative(function_name, "Nondecision time", t0_val);
732 check_finite(function_name, "Nondecision time", t0_val);
733 check_nonnegative(function_name, "Inter-trial variability in drift rate",
734 sv_val);
735 check_finite(function_name, "Inter-trial variability in drift rate", sv_val);
736
737 if (size_zero(y, a, t0, w, v, sv)) {
738 return ret_t(0.0);
739 }
740 const size_t N = max_size(y, a, t0, w, v, sv);
741 if (N == 0) {
742 return ret_t(0.0);
743 }
744
745 scalar_seq_view<T_y_ref> y_vec(y_ref);
746 scalar_seq_view<T_a_ref> a_vec(a_ref);
747 scalar_seq_view<T_t0_ref> t0_vec(t0_ref);
748 scalar_seq_view<T_w_ref> w_vec(w_ref);
749 scalar_seq_view<T_v_ref> v_vec(v_ref);
750 scalar_seq_view<T_sv_ref> sv_vec(sv_ref);
751 const size_t N_y_t0 = max_size(y, t0);
752
753 for (size_t i = 0; i < N_y_t0; ++i) {
754 if (y_vec[i] <= t0_vec[i]) {
755 std::stringstream msg;
756 msg << ", but must be greater than nondecision time = " << t0_vec[i];
757 std::string msg_str(msg.str());
758 throw_domain_error(function_name, "Random variable", y_vec[i], " = ",
759 msg_str.c_str());
760 }
761 }
762
763 // for precs. 1e-6, 1e-12, see Hartmann et al. (2021), Henrich et al. (2023)
764 const auto log_error_density = log(1e-6);
765 const auto log_error_derivative = log(precision_derivatives);
766 const double log_error_absolute_val = log(1e-12);
767 const T_partials_return log_error_absolute = log_error_absolute_val;
768 T_partials_return log_density = 0.0;
769 auto ops_partials
770 = make_partials_propagator(y_ref, a_ref, t0_ref, w_ref, v_ref, sv_ref);
771
772 const double LOG_FOUR = std::log(4.0);
773
774 // calculate density and partials
775 for (size_t i = 0; i < N; i++) {
776 // Calculate 4-parameter model without inter-trial variabilities (if
777 // sv_vec[i] == 0) or 5-parameter model with inter-trial variability in
778 // drift rate (if sv_vec[i] != 0)
779
780 const auto y_value = y_vec.val(i);
781 const auto a_value = a_vec.val(i);
782 const auto t0_value = t0_vec.val(i);
783 const auto w_value = w_vec.val(i);
784 const auto v_value = v_vec.val(i);
785 const auto sv_value = sv_vec.val(i);
786 auto l_density = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
787 GradientCalc::OFF>(
788 [](auto&&... args) {
789 return internal::wiener5_density<GradientCalc::OFF>(args...);
790 },
791 log_error_density - LOG_TWO, y_value - t0_value, a_value, v_value,
792 w_value, sv_value, log_error_absolute);
793
794 log_density += l_density;
795
796 const auto new_est_err = l_density + log_error_derivative - LOG_FOUR;
797
798 // computation of derivatives and precision checks
799 // computation of derivative for t and precision check in order to give
800 // the value as deriv_y to edge1 and as -deriv_y to edge5
801 const auto deriv_y
802 = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
803 GradientCalc::ON>(
804 [](auto&&... args) {
805 return internal::wiener5_grad_t<GradientCalc::OFF>(args...);
806 },
807 new_est_err, y_value - t0_value, a_value, v_value, w_value,
808 sv_value, log_error_absolute);
809
810 // computation of derivatives and precision checks
811 if constexpr (is_autodiff_v<T_y>) {
812 partials<0>(ops_partials)[i] = deriv_y;
813 }
814 if constexpr (is_autodiff_v<T_a>) {
815 partials<1>(ops_partials)[i]
816 = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
817 GradientCalc::ON>(
818 [](auto&&... args) {
819 return internal::wiener5_grad_a<GradientCalc::OFF>(args...);
820 },
821 new_est_err, y_value - t0_value, a_value, v_value, w_value,
822 sv_value, log_error_absolute);
823 }
824 if constexpr (is_autodiff_v<T_t0>) {
825 partials<2>(ops_partials)[i] = -deriv_y;
826 }
827 if constexpr (is_autodiff_v<T_w>) {
828 partials<3>(ops_partials)[i]
829 = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
830 GradientCalc::ON>(
831 [](auto&&... args) {
832 return internal::wiener5_grad_w<GradientCalc::OFF>(args...);
833 },
834 new_est_err, y_value - t0_value, a_value, v_value, w_value,
835 sv_value, log_error_absolute);
836 }
837 if constexpr (is_autodiff_v<T_v>) {
838 partials<4>(ops_partials)[i]
839 = internal::wiener5_grad_v<GradientCalc::OFF>(
840 y_value - t0_value, a_value, v_value, w_value, sv_value,
841 log_error_absolute_val);
842 }
843 if constexpr (is_autodiff_v<T_sv>) {
844 partials<5>(ops_partials)[i]
845 = internal::wiener5_grad_sv<GradientCalc::OFF>(
846 y_value - t0_value, a_value, v_value, w_value, sv_value,
847 log_error_absolute_val);
848 }
849 } // end for loop
850 return ops_partials.build(log_density);
851} // end wiener_lpdf
852
853} // namespace math
854} // namespace stan
855#endif
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
auto wiener5_density(const T_y &y, const T_a &a, const T_v &v, const T_w &w, const T_sv &sv, T_err log_err=log(1e-12)) noexcept
Calculate the wiener5 density.
auto wiener5_grad_v(const T_y &y, const T_a &a, const T_v &v, const T_w &w, const T_sv &sv, T_err log_err=log(1e-12)) noexcept
Calculate the derivative of the wiener5 density w.r.t.
auto wiener5_log_sum_exp(T_y &&y, T_a &&a, T_w &&w, T_nsmall &&n_terms_small_t, T_nlarge &&n_terms_large_t) noexcept
Calculate the 'result' term and its sign for a wiener5 density or gradient.
auto estimate_with_err_check(F &&functor, T_err &&log_err, ArgsTupleT &&... args_tuple)
Utility function for estimating a function with a given set of arguments, checking the result against...
auto wiener5_gradient_large_reaction_time_terms(T_y &&y, T_a &&a, T_w &&w, T_err error) noexcept
Calculate the 'n_terms_large_t' term for a wiener5 gradient.
auto wiener5_density_large_reaction_time_terms(T_y &&y, T_a &&a, T_w &&w, T_err error) noexcept
Calculate the 'n_terms_large_t' term for a wiener5 density.
auto wiener5_grad_t(const T_y &y, const T_a &a, const T_v &v, const T_w &w, const T_sv &sv, T_err log_err=log(1e-12)) noexcept
Calculate the derivative of the wiener5 density w.r.t.
auto wiener5_grad_w(const T_y &y, const T_a &a, const T_v &v, const T_w &w, const T_sv &sv, T_err log_err=log(1e-12)) noexcept
Calculate the derivative of the wiener5 density w.r.t.
void assign_err(Scalar1 arg, Scalar2 err)
Utility function for replacing a value with a specified error value.
auto wiener5_grad_a(const T_y &y, const T_a &a, const T_v &v, const T_w &w, const T_sv &sv, T_err log_err=log(1e-12)) noexcept
Calculate the derivative of the wiener5 density w.r.t.
auto wiener5_grad_sv(const T_y &y, const T_a &a, const T_v &v, const T_w &w, const T_sv &sv, T_err log_err=log(1e-12)) noexcept
Calculate the derivative of the wiener5 density w.r.t.
auto wiener5_n_terms_small_t(T_y &&y, T_a &&a, T_w &&w, T_err error) noexcept
Calculate the 'n_terms_small_t' term for a wiener5 density or gradient.
auto wiener5_compute_log_error_term(T_y &&y, T_a &&a, T_v &&v, T_w &&w, T_sv &&sv) noexcept
Calculate the 'log_error_term' term for a wiener5 density or gradient.
fvar< T > sin(const fvar< T > &x)
Definition sin.hpp:16
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition size_zero.hpp:19
fvar< T > fmin(const fvar< T > &x1, const fvar< T > &x2)
Definition fmin.hpp:14
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
Definition arg.hpp:19
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
static constexpr double LOG_TWO
The natural logarithm of 2, .
Definition constants.hpp:80
fvar< T > log1p_exp(const fvar< T > &x)
Definition log1p_exp.hpp:14
auto as_value_column_array_or_scalar(T &&a)
Extract the value from an object and for eigen vectors and std::vectors convert to an eigen column ar...
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:18
static constexpr double LOG_SQRT_PI
The natural logarithm of the square root of , .
static constexpr double LOG_PI
The natural logarithm of , .
Definition constants.hpp:86
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
fvar< T > fmax(const fvar< T > &x1, const fvar< T > &x2)
Return the greater of the two specified arguments.
Definition fmax.hpp:23
fvar< T > log_diff_exp(const fvar< T > &x1, const fvar< T > &x2)
fvar< T > cos(const fvar< T > &x)
Definition cos.hpp:16
static constexpr double pi()
Return the value of pi.
Definition constants.hpp:36
auto wiener_lpdf(const T_y &y, const T_a &a, const T_t0 &t0, const T_w &w, const T_v &v, const T_sv &sv, const double &precision_derivatives=1e-4)
Log-density function for the 5-parameter Wiener density.
int is_inf(const fvar< T > &x)
Returns 1 if the input's value is infinite and 0 otherwise.
Definition is_inf.hpp:21
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
void check_less(const char *function, const char *name, const T_y &y, const T_high &high, Idxs... idxs)
Throw an exception if y is not strictly less than high.
fvar< T > ceil(const fvar< T > &x)
Definition ceil.hpp:13
int64_t max_size(const T1 &x1, const Ts &... xs)
Calculate the size of the largest input.
Definition max_size.hpp:20
void check_greater(const char *function, const char *name, const T_y &y, const T_low &low, Idxs... idxs)
Throw an exception if y is not strictly greater than low.
fvar< T > inv(const fvar< T > &x)
Definition inv.hpp:13
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
Definition apply.hpp:51
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
fvar< T > fabs(const fvar< T > &x)
Definition fabs.hpp:16
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
fvar< T > log_sum_exp(const fvar< T > &x1, const fvar< T > &x2)
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
typename ref_type_if< true, T >::type ref_type_t
Definition ref_type.hpp:56
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...