Automatic Differentiation
 
Loading...
Searching...
No Matches
wiener5_lpdf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_PROB_WIENER5_LPDF_HPP
2#define STAN_MATH_PRIM_PROB_WIENER5_LPDF_HPP
3
5
6namespace stan {
7namespace math {
8namespace internal {
9
10enum GradientCalc { OFF = 0, ON = 1 };
11
28template <typename T_y, typename T_a, typename T_v, typename T_w, typename T_sv>
29inline auto wiener5_compute_error_term(T_y&& y, T_a&& a, T_v&& v_value,
30 T_w&& w_value, T_sv&& sv) noexcept {
31 const auto w = 1.0 - w_value;
32 const auto v = -v_value;
33 const auto sv_sqr = square(sv);
34 const auto one_plus_svsqr_y = 1 + sv_sqr * y;
35 const auto two_avw = 2.0 * a * v * w;
36 const auto two_log_a = 2.0 * log(a);
37 return stan::math::eval((sv_sqr * square(a * w) - two_avw - square(v) * y)
38 / 2.0 / one_plus_svsqr_y
39 - two_log_a - 0.5 * log(one_plus_svsqr_y));
40}
41
58template <bool Density, GradientCalc GradW, typename T_y, typename T_a,
59 typename T_w_value, typename T_err>
60inline auto wiener5_n_terms_small_t(T_y&& y, T_a&& a, T_w_value&& w_value,
61 T_err&& error) noexcept {
62 const auto two_error = 2.0 * error;
63 const auto y_asq = y / square(a);
64 const auto two_log_a = 2 * log(a);
65 const auto log_y_asq = log(y) - two_log_a;
66 const auto w = 1.0 - w_value;
67
68 const auto n_1_factor = Density ? 2 : 3;
69 const auto n_1 = (sqrt(n_1_factor * y_asq) + w) / 2.0;
70 auto u_eps = (Density || GradW)
71 ? fmin(-1.0, LOG_TWO + LOG_PI + 2.0 * log_y_asq + two_error)
72 : fmin(-3.0, (log(8.0) - log(27.0) + LOG_PI + 4.0 * log_y_asq
73 + two_error));
74 const auto arg_mult = (Density || GradW) ? 1 : 3;
75 const auto arg = -arg_mult * y_asq * (u_eps - sqrt(-2.0 * u_eps - 2.0));
76
77 const auto n_2
78 = (arg > 0) ? GradW ? 0.5 * (sqrt(arg) + w) : 0.5 * (sqrt(arg) - w) : n_1;
79
80 return ceil(fmax(n_1, n_2));
81}
82
97template <typename T_y, typename T_a, typename T_w_value, typename T_err>
98inline auto wiener5_density_large_reaction_time_terms(T_y&& y, T_a&& a,
99 T_w_value&& w_value,
100 T_err&& error) noexcept {
101 const auto y_asq = y / square(a);
102 const auto log_y_asq = log(y) - 2 * log(a);
103 static constexpr double PI_SQUARED = pi() * pi();
104 auto n_1 = 1.0 / (pi() * sqrt(y_asq));
105 const auto two_log_piy = -2.0 * (LOG_PI + log_y_asq + error);
106 auto n_2
107 = (two_log_piy >= 0) ? sqrt(two_log_piy / (PI_SQUARED * y_asq)) : 0.0;
108 return ceil(fmax(n_1, n_2));
109}
110
126template <GradientCalc GradW, typename T_y, typename T_a, typename T_w_value,
127 typename T_err>
129 T_w_value&& w_value,
130 T_err&& error) noexcept {
131 const auto y_asq = y / square(a);
132 const auto log_y_asq = log(y) - 2 * log(a);
133 static constexpr double PI_SQUARED = pi() * pi();
134 const auto n_1_factor = GradW ? 2 : 3;
135 auto n_1 = sqrt(n_1_factor / y_asq) / pi();
136 const auto two_error = 2.0 * error;
137 const auto u_eps_arg
138 = GradW ? log(4.0) - log(9.0) + 2.0 * LOG_PI + 3.0 * log_y_asq + two_error
139 : log(3.0) - log(5.0) + LOG_PI + 2.0 * log_y_asq + error;
140 const auto u_eps = fmin(-1, u_eps_arg);
141 const auto arg_mult = GradW ? 1 : (2.0 / PI_SQUARED / y_asq);
142 const auto arg = -arg_mult * (u_eps - sqrt(-2.0 * u_eps - 2.0));
143 auto n_2 = GradW ? ((arg > 0) ? sqrt(arg / y_asq) / pi() : n_1)
144 : ((arg > 0) ? sqrt(arg) : n_1);
145 return ceil(fmax(n_1, n_2));
146}
147
166template <bool Density, GradientCalc GradW, typename T_y, typename T_a,
167 typename T_w, typename T_nsmall, typename T_nlarge>
168inline auto wiener5_log_sum_exp(T_y&& y, T_a&& a, T_w&& w_value,
169 T_nsmall&& n_terms_small_t,
170 T_nlarge&& n_terms_large_t) noexcept {
171 const auto y_asq = y / square(a);
172 const auto w = 1.0 - w_value;
173 const bool small_n_terms_small_t
174 = Density ? (2 * n_terms_small_t <= n_terms_large_t)
175 : (2 * n_terms_small_t < n_terms_large_t);
176 const auto scaling = small_n_terms_small_t ? inv(2.0 * y_asq) : y_asq / 2.0;
178 ret_t fplus = NEGATIVE_INFTY;
179 ret_t fminus = NEGATIVE_INFTY;
180 int current_sign;
181 if (small_n_terms_small_t) {
182 constexpr double mult = Density ? 1.0 : 3.0;
183 if (GradW) {
184 for (auto k = n_terms_small_t; k >= 1; k--) {
185 const auto w_plus_2_k = w + 2.0 * k;
186 const auto w_minus_2_k = w - 2.0 * k;
187 const auto square_w_plus_2_k_minus_offset = square(w_plus_2_k) - y_asq;
188 if (square_w_plus_2_k_minus_offset > 0) {
189 const auto summand_plus = log(square_w_plus_2_k_minus_offset)
190 - square(w_plus_2_k) * scaling;
191 fplus = log_sum_exp(fplus, summand_plus);
192 } else if (square_w_plus_2_k_minus_offset < 0) {
193 const auto summand_plus = log(-square_w_plus_2_k_minus_offset)
194 - square(w_plus_2_k) * scaling;
195 fminus = log_sum_exp(fminus, summand_plus);
196 }
197 const auto square_w_minus_2_k_minus_offset
198 = square(w_minus_2_k) - y_asq;
199 if (square_w_minus_2_k_minus_offset > 0) {
200 const auto summand_minus = log(square_w_minus_2_k_minus_offset)
201 - square(w_minus_2_k) * scaling;
202 fplus = log_sum_exp(fplus, summand_minus);
203 } else if (square_w_minus_2_k_minus_offset < 0) {
204 const auto summand_minus = log(-square_w_minus_2_k_minus_offset)
205 - square(w_minus_2_k) * scaling;
206 fminus = log_sum_exp(fminus, summand_minus);
207 }
208 }
209 const auto square_w_minus_offset = square(w) - y_asq;
210 if (square_w_minus_offset > 0) {
211 const auto new_val = log(square_w_minus_offset) - square(w) * scaling;
212 fplus = log_sum_exp(fplus, new_val);
213 } else if (square_w_minus_offset < 0) {
214 const auto new_val = log(-square_w_minus_offset) - square(w) * scaling;
215 fminus = log_sum_exp(fminus, new_val);
216 }
217 } else {
218 for (auto k = n_terms_small_t; k >= 1; k--) {
219 const auto w_plus_2_k = w + 2.0 * k;
220 const auto w_minus_2_k = w - 2.0 * k;
221 const auto summand_plus
222 = mult * log(w_plus_2_k) - square(w_plus_2_k) * scaling;
223 fplus = log_sum_exp(fplus, summand_plus);
224 const auto summand_minus
225 = mult * log(-w_minus_2_k) - square(w_minus_2_k) * scaling;
226 if (fminus <= NEGATIVE_INFTY) {
227 fminus = summand_minus;
228 } else if (summand_minus <= NEGATIVE_INFTY) {
229 continue;
230 } else if (fminus > summand_minus) {
231 fminus = fminus + log1p_exp(summand_minus - fminus);
232 } else {
233 fminus = summand_minus + log1p_exp(fminus - summand_minus);
234 }
235 }
236 const auto new_val = mult * log(w) - square(w) * scaling;
237 fplus = log_sum_exp(fplus, new_val);
238 }
239 } else { // for large t
240 constexpr double mult = (Density ? 1 : (GradW ? 2 : 3));
241 for (auto k = n_terms_large_t; k >= 1; k--) {
242 const auto pi_k = k * pi();
243 const auto check = (GradW) ? cos(pi_k * w) : sin(pi_k * w);
244 if (check > 0) {
245 fplus = log_sum_exp(
246 fplus, mult * log(k) - square(pi_k) * scaling + log(check));
247 } else if ((GradW && check < 0) || !GradW) {
248 fminus = log_sum_exp(
249 fminus, mult * log(k) - square(pi_k) * scaling + log(-check));
250 }
251 }
252 }
253 current_sign = (fplus < fminus) ? -1 : 1;
254 if (fplus == NEGATIVE_INFTY) {
255 return std::make_pair(fminus, current_sign);
256 } else if (fminus == NEGATIVE_INFTY) {
257 return std::make_pair(fplus, current_sign);
258 } else if (fplus > fminus) {
259 return std::make_pair(log_diff_exp(fplus, fminus), current_sign);
260 } else if (fplus < fminus) {
261 return std::make_pair(log_diff_exp(fminus, fplus), current_sign);
262 } else {
263 return std::make_pair(ret_t(NEGATIVE_INFTY), current_sign);
264 }
265}
266
287template <bool NaturalScale = false, typename T_y, typename T_a, typename T_w,
288 typename T_v, typename T_sv, typename T_err>
289inline auto wiener5_density(const T_y& y, const T_a& a, const T_v& v_value,
290 const T_w& w_value, const T_sv& sv,
291 T_err&& err = log(1e-12)) noexcept {
292 const auto error_term
293 = wiener5_compute_error_term(y, a, v_value, w_value, sv);
294 const auto error = (err - error_term);
295 const auto n_terms_small_t
296 = wiener5_n_terms_small_t<GradientCalc::ON, GradientCalc::OFF>(
297 y, a, w_value, error);
298 const auto n_terms_large_t
299 = wiener5_density_large_reaction_time_terms(y, a, w_value, error);
300
301 auto res = wiener5_log_sum_exp<GradientCalc::ON, GradientCalc::OFF>(
302 y, a, w_value, n_terms_small_t, n_terms_large_t)
303 .first;
304 if (2 * n_terms_small_t <= n_terms_large_t) {
305 auto log_density = error_term - 0.5 * LOG_TWO - LOG_SQRT_PI
306 - 1.5 * (log(y) - 2 * log(a)) + res;
307 return NaturalScale ? exp(log_density) : log_density;
308 } else {
309 auto log_density = error_term + res + LOG_PI;
310 return NaturalScale ? exp(log_density) : log_density;
311 }
312}
313
334template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
335 typename T_v, typename T_sv, typename T_err>
336inline auto wiener5_grad_t(const T_y& y, const T_a& a, const T_v& v_value,
337 const T_w& w_value, const T_sv& sv,
338 T_err&& err = log(1e-12)) noexcept {
339 const auto two_log_a = 2 * log(a);
340 const auto log_y_asq = log(y) - two_log_a;
341 const auto error_term
342 = wiener5_compute_error_term(y, a, v_value, w_value, sv);
343 const auto w = 1.0 - w_value;
344 const auto v = -v_value;
345 const auto sv_sqr = square(sv);
346 const auto one_plus_svsqr_y = 1 + sv_sqr * y;
347 const auto density_part_one
348 = -0.5
349 * (square(sv_sqr) * (y + square(a * w))
350 + sv_sqr * (1 - (2.0 * a * v * w)) + square(v))
351 / square(one_plus_svsqr_y);
352 const auto error = (err - error_term) + two_log_a;
353 const auto n_terms_small_t
354 = wiener5_n_terms_small_t<GradientCalc::OFF, GradientCalc::OFF>(
355 y, a, w_value, error);
356 const auto n_terms_large_t
357 = wiener5_gradient_large_reaction_time_terms<GradientCalc::OFF>(
358 y, a, w_value, error);
359 auto wiener_res = wiener5_log_sum_exp<GradientCalc::OFF, GradientCalc::OFF>(
360 y, a, w_value, n_terms_small_t, n_terms_large_t);
361 auto&& result = wiener_res.first;
362 auto&& newsign = wiener_res.second;
363 const auto error_log_density
364 = log(fmax(fabs(density_part_one - 1.5 / y), fabs(density_part_one)));
365 const auto log_density = wiener5_density<GradientCalc::OFF>(
366 y, a, v_value, w_value, sv, err - error_log_density);
367 if (2 * n_terms_small_t < n_terms_large_t) {
368 auto ans = density_part_one - 1.5 / y
369 + newsign
370 * exp(error_term - two_log_a - 1.5 * LOG_TWO - LOG_SQRT_PI
371 - 3.5 * log_y_asq + result - log_density);
372 return WrtLog ? ans * exp(log_density) : ans;
373 } else {
374 auto ans = density_part_one
375 - newsign
376 * exp(error_term - two_log_a + 3.0 * LOG_PI - LOG_TWO
377 + result - log_density);
378 return WrtLog ? ans * exp(log_density) : ans;
379 }
380}
381
402template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
403 typename T_v, typename T_sv, typename T_err>
404inline auto wiener5_grad_a(const T_y& y, const T_a& a, const T_v& v_value,
405 const T_w& w_value, const T_sv& sv,
406 T_err&& err = log(1e-12)) noexcept {
407 const auto two_log_a = 2 * log(a);
408 const auto error_term
409 = wiener5_compute_error_term(y, a, v_value, w_value, sv);
410 const auto w = 1.0 - w_value;
411 const auto v = -v_value;
412 const auto sv_sqr = square(sv);
413 const auto one_plus_svsqr_y = 1 + sv_sqr * y;
414 const auto density_part_one
415 = (-v * w + sv_sqr * square(w) * a) / one_plus_svsqr_y;
416 const auto error = err - error_term + 3 * log(a) - log(y) - LOG_TWO;
417
418 const auto n_terms_small_t
419 = wiener5_n_terms_small_t<GradientCalc::OFF, GradientCalc::OFF>(
420 y, a, w_value, error);
421 const auto n_terms_large_t
422 = wiener5_gradient_large_reaction_time_terms<GradientCalc::OFF>(
423 y, a, w_value, error);
424 auto wiener_res = wiener5_log_sum_exp<GradientCalc::OFF, GradientCalc::OFF>(
425 y, a, w_value, n_terms_small_t, n_terms_large_t);
426 auto&& result = wiener_res.first;
427 auto&& newsign = wiener_res.second;
428 const auto error_log_density = log(
429 fmax(fabs(density_part_one + 1.0 / a), fabs(density_part_one - 2.0 / a)));
430 const auto log_density = wiener5_density<GradientCalc::OFF>(
431 y, a, v_value, w_value, sv, err - error_log_density);
432 if (2 * n_terms_small_t < n_terms_large_t) {
433 auto ans
434 = density_part_one + 1.0 / a
435 - newsign
436 * exp(-0.5 * LOG_TWO - LOG_SQRT_PI - 2.5 * log(y)
437 + 2.0 * two_log_a + error_term + result - log_density);
438 return WrtLog ? ans * exp(log_density) : ans;
439 } else {
440 auto ans = density_part_one - 2.0 / a
441 + newsign
442 * exp(log(y) + error_term - 3 * (log(a) - LOG_PI) + result
443 - log_density);
444 return WrtLog ? ans * exp(log_density) : ans;
445 }
446}
447
468template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
469 typename T_v, typename T_sv, typename T_err>
470inline auto wiener5_grad_v(const T_y& y, const T_a& a, const T_v& v_value,
471 const T_w& w_value, const T_sv& sv,
472 T_err&& err = log(1e-12)) noexcept {
473 auto ans = (a * (1 - w_value) - v_value * y) / (1.0 + square(sv) * y);
474 if (WrtLog) {
475 return ans * wiener5_density<true>(y, a, v_value, w_value, sv, err);
476 } else {
477 return ans;
478 }
479}
480
501template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
502 typename T_v, typename T_sv, typename T_err>
503inline auto wiener5_grad_w(const T_y& y, const T_a& a, const T_v& v_value,
504 const T_w& w_value, const T_sv& sv,
505 T_err&& err = log(1e-12)) noexcept {
506 const auto two_log_a = 2 * log(a);
507 const auto log_y_asq = log(y) - two_log_a;
508 const auto error_term
509 = wiener5_compute_error_term(y, a, v_value, w_value, sv);
510 const auto w = 1.0 - w_value;
511 const auto v = -v_value;
512 const auto sv_sqr = square(sv);
513 const auto one_plus_svsqr_y = 1 + sv_sqr * y;
514 const auto density_part_one
515 = (-v * a + sv_sqr * square(a) * w) / one_plus_svsqr_y;
516 const auto error = (err - error_term);
517
518 const auto n_terms_small_t
519 = wiener5_n_terms_small_t<GradientCalc::OFF, GradientCalc::ON>(
520 y, a, w_value, error);
521 const auto n_terms_large_t
522 = wiener5_gradient_large_reaction_time_terms<GradientCalc::ON>(
523 y, a, w_value, error);
524 auto wiener_res = wiener5_log_sum_exp<GradientCalc::OFF, GradientCalc::ON>(
525 y, a, w_value, n_terms_small_t, n_terms_large_t);
526 auto&& result = wiener_res.first;
527 auto&& newsign = wiener_res.second;
528 const auto log_density = wiener5_density<GradientCalc::OFF>(
529 y, a, v_value, w_value, sv, err - log(fabs(density_part_one)));
530 if (2 * n_terms_small_t < n_terms_large_t) {
531 auto ans = -(density_part_one
532 - newsign
533 * exp(result - (log_density - error_term)
534 - 2.5 * log_y_asq - 0.5 * LOG_TWO - 0.5 * LOG_PI));
535 return WrtLog ? ans * exp(log_density) : ans;
536 } else {
537 auto ans
538 = -(density_part_one
539 + newsign * exp(result - (log_density - error_term) + 2 * LOG_PI));
540 return WrtLog ? ans * exp(log_density) : ans;
541 }
542}
543
564template <bool WrtLog = false, typename T_y, typename T_a, typename T_w,
565 typename T_v, typename T_sv, typename T_err>
566inline auto wiener5_grad_sv(const T_y& y, const T_a& a, const T_v& v_value,
567 const T_w& w_value, const T_sv& sv,
568 T_err&& err = log(1e-12)) noexcept {
569 const auto one_plus_svsqr_y = 1 + square(sv) * y;
570 const auto w = 1.0 - w_value;
571 const auto v = -v_value;
572 const auto t1 = -y / one_plus_svsqr_y;
573 const auto t2 = (square(a * w) + 2 * a * v * w * y + square(v * y))
574 / square(one_plus_svsqr_y);
575 const auto ans = sv * (t1 + t2);
576 return WrtLog ? ans * wiener5_density<true>(y, a, v_value, w_value, sv, err)
577 : ans;
578}
579
590template <size_t NestedIndex, typename Scalar1, typename Scalar2>
591inline void assign_err(Scalar1 arg, Scalar2 err) {
592 arg = err;
593}
594
606template <size_t NestedIndex, typename Scalar, typename... TArgs>
607inline void assign_err(std::tuple<TArgs...>& args_tuple, Scalar err) {
608 std::get<NestedIndex>(args_tuple) = err;
609}
610
628template <size_t ErrIndex, size_t NestedIndex = 0,
629 GradientCalc GradW7 = GradientCalc::OFF, bool LogResult = true,
630 typename F, typename T_err, typename... ArgsTupleT>
631inline auto estimate_with_err_check(F&& functor, T_err&& err,
632 ArgsTupleT&&... args_tuple) {
633 auto result = functor(args_tuple...);
634 auto log_fabs_result = LogResult ? log(fabs(result)) : fabs(result);
635 if (log_fabs_result < err) {
636 log_fabs_result = is_inf(log_fabs_result) ? 0 : log_fabs_result;
637 auto err_args_tuple = std::make_tuple(args_tuple...);
638 const auto new_error
639 = GradW7 ? err + log_fabs_result + LOG_TWO : err + log_fabs_result;
640 assign_err<NestedIndex>(std::get<ErrIndex>(err_args_tuple), new_error);
641 result
642 = math::apply([](auto&& func, auto&&... args) { return func(args...); },
643 err_args_tuple, functor);
644 }
645 return result;
646}
647} // namespace internal
648
671template <bool propto = false, typename T_y, typename T_a, typename T_t0,
672 typename T_w, typename T_v, typename T_sv>
673inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
674 const T_w& w, const T_v& v, const T_sv& sv,
675 const double& precision_derivatives = 1e-4) {
679 return ret_t(0.0);
680 }
681 using T_y_ref = ref_type_if_t<!is_constant<T_y>::value, T_y>;
682 using T_a_ref = ref_type_if_t<!is_constant<T_a>::value, T_a>;
683 using T_t0_ref = ref_type_if_t<!is_constant<T_t0>::value, T_t0>;
684 using T_w_ref = ref_type_if_t<!is_constant<T_w>::value, T_w>;
685 using T_v_ref = ref_type_if_t<!is_constant<T_v>::value, T_v>;
686 using T_sv_ref = ref_type_if_t<!is_constant<T_sv>::value, T_sv>;
687
688 static constexpr const char* function_name = "wiener5_lpdf";
689
690 check_consistent_sizes(function_name, "Random variable", y,
691 "Boundary separation", a, "Drift rate", v,
692 "A-priori bias", w, "Nondecision time", t0,
693 "Inter-trial variability in drift rate", sv);
694
695 T_y_ref y_ref = y;
696 T_a_ref a_ref = a;
697 T_t0_ref t0_ref = t0;
698 T_w_ref w_ref = w;
699 T_v_ref v_ref = v;
700 T_sv_ref sv_ref = sv;
701
702 decltype(auto) y_val = to_ref(as_value_column_array_or_scalar(y_ref));
703 decltype(auto) a_val = to_ref(as_value_column_array_or_scalar(a_ref));
704 decltype(auto) v_val = to_ref(as_value_column_array_or_scalar(v_ref));
705 decltype(auto) w_val = to_ref(as_value_column_array_or_scalar(w_ref));
706 decltype(auto) t0_val = to_ref(as_value_column_array_or_scalar(t0_ref));
707 decltype(auto) sv_val = to_ref(as_value_column_array_or_scalar(sv_ref));
708 check_positive_finite(function_name, "Random variable", y_val);
709 check_positive_finite(function_name, "Boundary separation", a_val);
710 check_finite(function_name, "Drift rate", v_val);
711 check_less(function_name, "A-priori bias", w_val, 1);
712 check_greater(function_name, "A-priori bias", w_val, 0);
713 check_nonnegative(function_name, "Nondecision time", t0_val);
714 check_finite(function_name, "Nondecision time", t0_val);
715 check_nonnegative(function_name, "Inter-trial variability in drift rate",
716 sv_val);
717 check_finite(function_name, "Inter-trial variability in drift rate", sv_val);
718
719 if (size_zero(y, a, t0, w, v, sv)) {
720 return ret_t(0.0);
721 }
722 const size_t N = max_size(y, a, t0, w, v, sv);
723 if (!N) {
724 return ret_t(0.0);
725 }
726
727 scalar_seq_view<T_y_ref> y_vec(y_ref);
728 scalar_seq_view<T_a_ref> a_vec(a_ref);
729 scalar_seq_view<T_t0_ref> t0_vec(t0_ref);
730 scalar_seq_view<T_w_ref> w_vec(w_ref);
731 scalar_seq_view<T_v_ref> v_vec(v_ref);
732 scalar_seq_view<T_sv_ref> sv_vec(sv_ref);
733 const size_t N_y_t0 = max_size(y, t0);
734
735 for (size_t i = 0; i < N_y_t0; ++i) {
736 if (y_vec[i] <= t0_vec[i]) {
737 std::stringstream msg;
738 msg << ", but must be greater than nondecision time = " << t0_vec[i];
739 std::string msg_str(msg.str());
740 throw_domain_error(function_name, "Random variable", y_vec[i], " = ",
741 msg_str.c_str());
742 }
743 }
744
745 const auto log_error_density = log(1e-6);
746 const auto log_error_derivative = log(precision_derivatives);
747 const double log_error_absolute_val = log(1e-12);
748 const T_partials_return log_error_absolute = log_error_absolute_val;
749 T_partials_return log_density = 0.0;
750 auto ops_partials
751 = make_partials_propagator(y_ref, a_ref, t0_ref, w_ref, v_ref, sv_ref);
752
753 static constexpr double LOG_FOUR = LOG_TWO + LOG_TWO;
754
755 // calculate density and partials
756 for (size_t i = 0; i < N; i++) {
757 // Calculate 4-parameter model without inter-trial variabilities (if
758 // sv_vec[i] == 0) or 5-parameter model with inter-trial variability in
759 // drift rate (if sv_vec[i] != 0)
760
761 const auto y_value = y_vec.val(i);
762 const auto a_value = a_vec.val(i);
763 const auto t0_value = t0_vec.val(i);
764 const auto w_value = w_vec.val(i);
765 const auto v_value = v_vec.val(i);
766 const auto sv_value = sv_vec.val(i);
768 auto l_density = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
769 GradientCalc::OFF>(
770 [](auto&&... args) {
771 return internal::wiener5_density<GradientCalc::OFF>(args...);
772 },
773 log_error_density - LOG_TWO, y_value - t0_value, a_value, v_value,
774 w_value, sv_value, log_error_absolute);
775
776 log_density += l_density;
777
778 const auto new_est_err = l_density + log_error_derivative - LOG_FOUR;
779
780 // computation of derivative for t and precision check in order to give
781 // the value as deriv_y to edge1 and as -deriv_y to edge5
782 const auto deriv_y
783 = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
784 GradientCalc::ON>(
785 [](auto&&... args) {
786 return internal::wiener5_grad_t<GradientCalc::OFF>(args...);
787 },
788 new_est_err, y_value - t0_value, a_value, v_value, w_value,
789 sv_value, log_error_absolute);
790
791 // computation of derivatives and precision checks
793 partials<0>(ops_partials)[i] = deriv_y;
794 }
796 partials<1>(ops_partials)[i]
797 = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
798 GradientCalc::ON>(
799 [](auto&&... args) {
800 return internal::wiener5_grad_a<GradientCalc::OFF>(args...);
801 },
802 new_est_err, y_value - t0_value, a_value, v_value, w_value,
803 sv_value, log_error_absolute);
804 }
806 partials<2>(ops_partials)[i] = -deriv_y;
807 }
809 partials<3>(ops_partials)[i]
810 = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
811 GradientCalc::ON>(
812 [](auto&&... args) {
813 return internal::wiener5_grad_w<GradientCalc::OFF>(args...);
814 },
815 new_est_err, y_value - t0_value, a_value, v_value, w_value,
816 sv_value, log_error_absolute);
817 }
819 partials<4>(ops_partials)[i]
820 = internal::wiener5_grad_v<GradientCalc::OFF>(
821 y_value - t0_value, a_value, v_value, w_value, sv_value,
822 log_error_absolute_val);
823 }
825 partials<5>(ops_partials)[i]
826 = internal::wiener5_grad_sv<GradientCalc::OFF>(
827 y_value - t0_value, a_value, v_value, w_value, sv_value,
828 log_error_absolute_val);
829 }
830 } // end for loop
831 return ops_partials.build(log_density);
832} // end wiener_lpdf
833
834// ToDo: delete old wiener_lpdf implementation to use this one
835// template <bool propto = false, typename T_y, typename T_a, typename T_t0,
836// typename T_w, typename T_v>
837// inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
838// const T_w& w, const T_v& v,
839// const double& precision_derivatives = 1e-4) {
840// return wiener_lpdf(y, a, t0, w, v, 0, precision_derivatives);
841//} // end wiener_lpdf
842
843} // namespace math
844} // namespace stan
845#endif
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
auto wiener5_grad_v(const T_y &y, const T_a &a, const T_v &v_value, const T_w &w_value, const T_sv &sv, T_err &&err=log(1e-12)) noexcept
Calculate the derivative of the wiener5 density w.r.t.
auto wiener5_density_large_reaction_time_terms(T_y &&y, T_a &&a, T_w_value &&w_value, T_err &&error) noexcept
Calculate the 'n_terms_large_t' term for a wiener5 density.
auto wiener5_grad_a(const T_y &y, const T_a &a, const T_v &v_value, const T_w &w_value, const T_sv &sv, T_err &&err=log(1e-12)) noexcept
Calculate the derivative of the wiener5 density w.r.t.
auto wiener5_grad_sv(const T_y &y, const T_a &a, const T_v &v_value, const T_w &w_value, const T_sv &sv, T_err &&err=log(1e-12)) noexcept
Calculate the derivative of the wiener5 density w.r.t.
void assign_err(Scalar1 arg, Scalar2 err)
Utility function for replacing a value with a specified error value.
auto wiener5_n_terms_small_t(T_y &&y, T_a &&a, T_w_value &&w_value, T_err &&error) noexcept
Calculate the 'n_terms_small_t' term for a wiener5 density or gradient.
auto wiener5_gradient_large_reaction_time_terms(T_y &&y, T_a &&a, T_w_value &&w_value, T_err &&error) noexcept
Calculate the 'n_terms_large_t' term for a wiener5 gradient.
auto wiener5_grad_t(const T_y &y, const T_a &a, const T_v &v_value, const T_w &w_value, const T_sv &sv, T_err &&err=log(1e-12)) noexcept
Calculate the derivative of the wiener5 density w.r.t.
auto wiener5_density(const T_y &y, const T_a &a, const T_v &v_value, const T_w &w_value, const T_sv &sv, T_err &&err=log(1e-12)) noexcept
Calculate the wiener5 density.
auto wiener5_grad_w(const T_y &y, const T_a &a, const T_v &v_value, const T_w &w_value, const T_sv &sv, T_err &&err=log(1e-12)) noexcept
Calculate the derivative of the wiener5 density w.r.t.
auto estimate_with_err_check(F &&functor, T_err &&err, ArgsTupleT &&... args_tuple)
Utility function for estimating a function with a given set of arguments, checking the result against...
auto wiener5_compute_error_term(T_y &&y, T_a &&a, T_v &&v_value, T_w &&w_value, T_sv &&sv) noexcept
Calculate the 'error_term' term for a wiener5 density or gradient.
auto wiener5_log_sum_exp(T_y &&y, T_a &&a, T_w &&w_value, T_nsmall &&n_terms_small_t, T_nlarge &&n_terms_large_t) noexcept
Calculate the 'result' term and its sign for a wiener5 density or gradient.
fvar< T > sin(const fvar< T > &x)
Definition sin.hpp:14
size_t max_size(const T1 &x1, const Ts &... xs)
Calculate the size of the largest input.
Definition max_size.hpp:19
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition size_zero.hpp:19
fvar< T > fmin(const fvar< T > &x1, const fvar< T > &x2)
Definition fmin.hpp:14
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
Definition arg.hpp:19
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
static constexpr double LOG_TWO
The natural logarithm of 2, .
Definition constants.hpp:80
fvar< T > log1p_exp(const fvar< T > &x)
Definition log1p_exp.hpp:13
auto as_value_column_array_or_scalar(T &&a)
Extract the value from an object and for eigen vectors and std::vectors convert to an eigen column ar...
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:17
static constexpr double LOG_SQRT_PI
The natural logarithm of the square root of , .
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
static constexpr double LOG_PI
The natural logarithm of , .
Definition constants.hpp:86
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
fvar< T > fmax(const fvar< T > &x1, const fvar< T > &x2)
Return the greater of the two specified arguments.
Definition fmax.hpp:23
fvar< T > log_diff_exp(const fvar< T > &x1, const fvar< T > &x2)
fvar< T > cos(const fvar< T > &x)
Definition cos.hpp:14
static constexpr double pi()
Return the value of pi.
Definition constants.hpp:36
auto wiener_lpdf(const T_y &y, const T_a &a, const T_t0 &t0, const T_w &w, const T_v &v, const T_sv &sv, const double &precision_derivatives=1e-4)
Log-density function for the 5-parameter Wiener density.
int is_inf(const fvar< T > &x)
Returns 1 if the input's value is infinite and 0 otherwise.
Definition is_inf.hpp:21
void check_less(const char *function, const char *name, const T_y &y, const T_high &high, Idxs... idxs)
Throw an exception if y is not strictly less than high.
fvar< T > ceil(const fvar< T > &x)
Definition ceil.hpp:12
void check_greater(const char *function, const char *name, const T_y &y, const T_low &low, Idxs... idxs)
Throw an exception if y is not strictly greater than low.
fvar< T > inv(const fvar< T > &x)
Definition inv.hpp:12
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
Definition apply.hpp:52
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
fvar< T > fabs(const fvar< T > &x)
Definition fabs.hpp:15
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
fvar< T > log_sum_exp(const fvar< T > &x1, const fvar< T > &x2)
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
typename ref_type_if< Condition, T >::type ref_type_if_t
Definition ref_type.hpp:58
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...