Automatic Differentiation
 
Loading...
Searching...
No Matches
wiener_full_lpdf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_PROB_WIENER_FULL_LPDF_HPP
2#define STAN_MATH_PRIM_PROB_WIENER_FULL_LPDF_HPP
3
7
8namespace stan {
9namespace math {
10namespace internal {
11
32template <typename T_y, typename T_a, typename T_v, typename T_w, typename T_sv,
33 typename T_sw, typename T_err>
34inline auto wiener7_grad_sw(const T_y& y, const T_a& a, const T_v& v,
35 const T_w& w, const T_sv& sv, const T_sw& sw,
36 T_err log_error) {
37 auto low = w - sw / 2.0;
38 const auto lower_value
39 = wiener5_density<GradientCalc::ON>(y, a, v, low, sv, log_error);
40 auto high = w + sw / 2.0;
41 const auto upper_value
42 = wiener5_density<GradientCalc::ON>(y, a, v, high, sv, log_error);
43 return 0.5 * (lower_value + upper_value) / sw;
44}
45
74template <GradientCalc GradSW, typename F, typename T_y, typename T_a,
75 typename T_v, typename T_w, typename T_sv, typename T_sw,
76 typename T_err, std::enable_if_t<!GradSW>* = nullptr>
77inline auto conditionally_grad_sw(F&& functor, T_y&& y_diff, T_a&& a, T_v&& v,
78 T_w&& w, T_sv&& sv, T_sw&& sw,
79 T_err&& log_error) {
80 return functor(y_diff, a, v, w, sv, log_error);
81}
82
111template <GradientCalc GradSW, typename F, typename T_y, typename T_a,
112 typename T_v, typename T_w, typename T_sv, typename T_sw,
113 typename T_err, std::enable_if_t<GradSW>* = nullptr>
114inline auto conditionally_grad_sw(F&& functor, T_y&& y_diff, T_a&& a, T_v&& v,
115 T_w&& w, T_sv&& sv, T_sw&& sw,
116 T_err&& log_error) {
117 return functor(y_diff, a, v, w, sv, sw, log_error);
118}
119
136template <GradientCalc GradSW, GradientCalc GradW7 = GradientCalc::OFF,
137 typename Wiener7FunctorT, typename T_err, typename... TArgs>
138inline auto wiener7_integrate(const Wiener7FunctorT& wiener7_functor,
139 T_err&& hcubature_err, TArgs&&... args) {
140 const auto functor = [&wiener7_functor](auto&&... integration_args) {
141 return hcubature(
142 [&wiener7_functor](auto&& x, auto&& y, auto&& a, auto&& v, auto&& w,
143 auto&& t0, auto&& sv, auto&& sw, auto&& st0,
144 auto&& log_error) {
145 using ret_t = return_type_t<decltype(x), decltype(a), decltype(v),
146 decltype(w), decltype(t0), decltype(sv),
147 decltype(sw), decltype(st0),
148 decltype(st0), decltype(log_error)>;
149 scalar_seq_view<decltype(x)> x_vec(x);
150 const auto sw_val = GradSW ? 0 : sw;
151 const auto new_t0 = t0 + st0 * x_vec[(sw_val != 0) ? 1 : 0];
152 if (y - new_t0 <= 0) {
153 return ret_t(0.0);
154 } else {
155 const auto new_w = w + sw_val * (x_vec[0] - 0.5);
156 return ret_t(conditionally_grad_sw<GradSW>(
157 wiener7_functor, y - new_t0, a, v, new_w, sv, sw, log_error));
158 }
159 },
160 integration_args...);
161 };
162 return estimate_with_err_check<0, 8, GradW7, GradientCalc::ON>(
163 functor, hcubature_err, args...);
164}
165} // namespace internal
166
317template <bool propto = false, typename T_y, typename T_a, typename T_t0,
318 typename T_w, typename T_v, typename T_sv, typename T_sw,
319 typename T_st0>
320inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
321 const T_w& w, const T_v& v, const T_sv& sv,
322 const T_sw& sw, const T_st0& st0,
323 const double& precision_derivatives = 1e-4) {
325 if (!include_summand<propto, T_y, T_a, T_v, T_w, T_t0, T_sv, T_sw,
326 T_st0>::value) {
327 return ret_t(0);
328 }
329
330 using T_y_ref = ref_type_if_t<!is_constant<T_y>::value, T_y>;
331 using T_a_ref = ref_type_if_t<!is_constant<T_a>::value, T_a>;
332 using T_v_ref = ref_type_if_t<!is_constant<T_v>::value, T_v>;
333 using T_w_ref = ref_type_if_t<!is_constant<T_w>::value, T_w>;
334 using T_t0_ref = ref_type_if_t<!is_constant<T_t0>::value, T_t0>;
335 using T_sv_ref = ref_type_if_t<!is_constant<T_sv>::value, T_sv>;
336 using T_sw_ref = ref_type_if_t<!is_constant<T_sw>::value, T_sw>;
337 using T_st0_ref = ref_type_if_t<!is_constant<T_st0>::value, T_st0>;
338
339 using T_partials_return
341
342 static constexpr const char* function_name = "wiener_lpdf";
343 check_consistent_sizes(function_name, "Random variable", y,
344 "Boundary separation", a, "Drift rate", v,
345 "A-priori bias", w, "Nondecision time", t0,
346 "Inter-trial variability in drift rate", sv,
347 "Inter-trial variability in A-priori bias", sw,
348 "Inter-trial variability in Nondecision time", st0);
349
350 T_y_ref y_ref = y;
351 T_a_ref a_ref = a;
352 T_v_ref v_ref = v;
353 T_w_ref w_ref = w;
354 T_t0_ref t0_ref = t0;
355 T_sv_ref sv_ref = sv;
356 T_sw_ref sw_ref = sw;
357 T_st0_ref st0_ref = st0;
358
359 decltype(auto) y_val = to_ref(as_value_column_array_or_scalar(y_ref));
360 decltype(auto) a_val = to_ref(as_value_column_array_or_scalar(a_ref));
361 decltype(auto) v_val = to_ref(as_value_column_array_or_scalar(v_ref));
362 decltype(auto) w_val = to_ref(as_value_column_array_or_scalar(w_ref));
363 decltype(auto) t0_val = to_ref(as_value_column_array_or_scalar(t0_ref));
364 decltype(auto) sv_val = to_ref(as_value_column_array_or_scalar(sv_ref));
365 decltype(auto) sw_val = to_ref(as_value_column_array_or_scalar(sw_ref));
366 decltype(auto) st0_val = to_ref(as_value_column_array_or_scalar(st0_ref));
367 check_positive_finite(function_name, "Random variable", y_val);
368 check_positive_finite(function_name, "Boundary separation", a_val);
369 check_finite(function_name, "Drift rate", v_val);
370 check_less(function_name, "A-priori bias", w_val, 1);
371 check_greater(function_name, "A-priori bias", w_val, 0);
372 check_nonnegative(function_name, "Nondecision time", t0_val);
373 check_finite(function_name, "Nondecision time", t0_val);
374 check_nonnegative(function_name, "Inter-trial variability in drift rate",
375 sv_val);
376 check_finite(function_name, "Inter-trial variability in drift rate", sv_val);
377 check_bounded(function_name, "Inter-trial variability in A-priori bias",
378 sw_val, 0, 1);
379 check_nonnegative(function_name,
380 "Inter-trial variability in Nondecision time", st0_val);
381 check_finite(function_name, "Inter-trial variability in Nondecision time",
382 st0_val);
383
384 const size_t N = max_size(y, a, v, w, t0, sv, sw, st0);
385 if (N == 0) {
386 return ret_t(0);
387 }
388 scalar_seq_view<T_y_ref> y_vec(y_ref);
389 scalar_seq_view<T_a_ref> a_vec(a_ref);
390 scalar_seq_view<T_v_ref> v_vec(v_ref);
391 scalar_seq_view<T_w_ref> w_vec(w_ref);
392 scalar_seq_view<T_t0_ref> t0_vec(t0_ref);
393 scalar_seq_view<T_sv_ref> sv_vec(sv_ref);
394 scalar_seq_view<T_sw_ref> sw_vec(sw_ref);
395 scalar_seq_view<T_st0_ref> st0_vec(st0_ref);
396 const size_t N_y_t0 = max_size(y, t0, st0);
397
398 for (size_t i = 0; i < N_y_t0; ++i) {
399 if (y_vec[i] <= t0_vec[i]) {
400 [&]() STAN_COLD_PATH {
401 std::stringstream msg;
402 msg << ", but must be greater than nondecision time = " << t0_vec[i];
403 std::string msg_str(msg.str());
404 throw_domain_error(function_name, "Random variable", y_vec[i], " = ",
405 msg_str.c_str());
406 }();
407 }
408 }
409 size_t N_beta_sw = max_size(w, sw);
410 for (size_t i = 0; i < N_beta_sw; ++i) {
411 if (unlikely(w_vec[i] - .5 * sw_vec[i] <= 0)) {
412 [&]() STAN_COLD_PATH {
413 std::stringstream msg;
414 msg << ", but must be smaller than 2*(A-priori bias) = "
415 << 2 * w_vec[i];
416 std::string msg_str(msg.str());
417 throw_domain_error(function_name,
418 "Inter-trial variability in A-priori bias",
419 sw_vec[i], " = ", msg_str.c_str());
420 }();
421 }
422 if (unlikely(w_vec[i] + .5 * sw_vec[i] >= 1)) {
423 [&]() STAN_COLD_PATH {
424 std::stringstream msg;
425 msg << ", but must be smaller than 2*(1-A-priori bias) = "
426 << 2 * (1 - w_vec[i]);
427 std::string msg_str(msg.str());
428 throw_domain_error(function_name,
429 "Inter-trial variability in A-priori bias",
430 sw_vec[i], " = ", msg_str.c_str());
431 }();
432 }
433 }
434 // precision for density
435 const T_partials_return log_error_density = log(1e-6);
436 // precision for derivatives (controllable by user)
437 const auto error_bound = precision_derivatives;
438 const auto log_error_derivative = log(error_bound);
439 const T_partials_return absolute_error_hcubature = 0.0;
440 // eps_rel(Integration)
441 const T_partials_return relative_error_hcubature = .9 * error_bound;
442 const T_partials_return log_error_absolute = log(1e-12);
443 const int maximal_evaluations_hcubature = 6000;
444 T_partials_return log_density = 0.0;
445 auto ops_partials = make_partials_propagator(y_ref, a_ref, t0_ref, w_ref,
446 v_ref, sv_ref, sw_ref, st0_ref);
447 ret_t result = 0;
448
449 // calculate density and partials
450 for (size_t i = 0; i < N; i++) {
451 if (sw_vec[i] == 0 && st0_vec[i] == 0) {
452 result += wiener_lpdf<propto>(y_vec[i], a_vec[i], t0_vec[i], w_vec[i],
453 v_vec[i], sv_vec[i], precision_derivatives);
454 continue;
455 }
456 const T_partials_return y_value = y_vec.val(i);
457 const T_partials_return a_value = a_vec.val(i);
458 const T_partials_return v_value = v_vec.val(i);
459 const T_partials_return w_value = w_vec.val(i);
460 const T_partials_return t0_value = t0_vec.val(i);
461 const T_partials_return sv_value = sv_vec.val(i);
462 const T_partials_return sw_value = sw_vec.val(i);
463 const T_partials_return st0_value = st0_vec.val(i);
464 const int dim = (sw_value != 0) + (st0_value != 0);
465 check_positive(function_name,
466 "(Inter-trial variability in A-priori bias) + "
467 "(Inter-trial variability in nondecision time)",
468 dim);
469
470 Eigen::Matrix<T_partials_return, -1, 1> xmin = Eigen::VectorXd::Zero(dim);
471 Eigen::Matrix<T_partials_return, -1, 1> xmax = Eigen::VectorXd::Ones(dim);
472 if (st0_value != 0) {
473 xmax[dim - 1] = fmin(1.0, (y_value - t0_value) / st0_value);
474 }
475
476 T_partials_return hcubature_err
477 = log_error_absolute - log_error_density + LOG_TWO + 1;
479 const auto params = std::make_tuple(y_value, a_value, v_value, w_value,
480 t0_value, sv_value, sw_value, st0_value,
481 log_error_absolute - LOG_TWO);
482 T_partials_return density
483 = internal::wiener7_integrate<GradientCalc::OFF, GradientCalc::OFF>(
484 [](auto&&... args) {
485 return internal::wiener5_density<GradientCalc::ON>(args...);
486 },
487 hcubature_err, params, dim, xmin, xmax,
488 maximal_evaluations_hcubature, absolute_error_hcubature,
489 relative_error_hcubature / 2);
490 log_density += log(density);
491 hcubature_err = log_error_absolute - log_error_derivative
492 + log(fabs(density)) + LOG_TWO + 1;
493
494 // computation of derivative for t and precision check in order to give
495 // the value as deriv_t to edge1 and as -deriv_t to edge5
496 const T_partials_return deriv_t_7
497 = internal::wiener7_integrate<GradientCalc::OFF, GradientCalc::OFF>(
498 [](auto&&... args) {
499 return internal::wiener5_grad_t<GradientCalc::ON>(args...);
500 },
501 hcubature_err, params, dim, xmin, xmax,
502 maximal_evaluations_hcubature, absolute_error_hcubature,
503 relative_error_hcubature / 2)
504 / density;
505
506 // computation of derivatives and precision checks
507 T_partials_return derivative;
509 partials<0>(ops_partials)[i] = deriv_t_7;
510 }
512 partials<1>(ops_partials)[i]
513 = internal::wiener7_integrate<GradientCalc::OFF, GradientCalc::OFF>(
514 [](auto&&... args) {
515 return internal::wiener5_grad_a<GradientCalc::ON>(args...);
516 },
517 hcubature_err, params, dim, xmin, xmax,
518 maximal_evaluations_hcubature, absolute_error_hcubature,
519 relative_error_hcubature / 2)
520 / density;
521 }
523 partials<2>(ops_partials)[i] = -deriv_t_7;
524 }
526 partials<3>(ops_partials)[i]
527 = internal::wiener7_integrate<GradientCalc::OFF, GradientCalc::ON>(
528 [](auto&&... args) {
529 return internal::wiener5_grad_w<GradientCalc::ON>(args...);
530 },
531 hcubature_err, params, dim, xmin, xmax,
532 maximal_evaluations_hcubature, absolute_error_hcubature,
533 relative_error_hcubature / 2)
534 / density;
535 }
537 partials<4>(ops_partials)[i]
538 = internal::wiener7_integrate<GradientCalc::OFF, GradientCalc::OFF>(
539 [](auto&&... args) {
540 return internal::wiener5_grad_v<GradientCalc::ON>(args...);
541 },
542 hcubature_err, params, dim, xmin, xmax,
543 maximal_evaluations_hcubature, absolute_error_hcubature,
544 relative_error_hcubature / 2)
545 / density;
546 }
548 partials<5>(ops_partials)[i]
549 = internal::wiener7_integrate<GradientCalc::OFF, GradientCalc::OFF>(
550 [](auto&&... args) {
551 return internal::wiener5_grad_sv<GradientCalc::ON>(args...);
552 },
553 hcubature_err, params, dim, xmin, xmax,
554 maximal_evaluations_hcubature, absolute_error_hcubature,
555 relative_error_hcubature / 2)
556 / density;
557 }
559 if (sw_value == 0) {
560 partials<6>(ops_partials)[i] = 0;
561 } else {
562 if (st0_value == 0) {
564 6, 0, GradientCalc::OFF, GradientCalc::ON>(
565 [](auto&&... args) { return internal::wiener7_grad_sw(args...); },
566 hcubature_err, y_value - t0_value, a_value, v_value, w_value,
567 sv_value, sw_value, log_error_absolute - LOG_TWO);
568 } else {
569 derivative = internal::wiener7_integrate<GradientCalc::ON,
570 GradientCalc::OFF>(
571 [](auto&&... args) { return internal::wiener7_grad_sw(args...); },
572 hcubature_err, params, 1, xmin, xmax,
573 maximal_evaluations_hcubature, absolute_error_hcubature,
574 relative_error_hcubature / 2);
575 }
576 partials<6>(ops_partials)[i] = derivative / density - 1.0 / sw_value;
577 }
578 }
580 T_partials_return f;
581 if (st0_value == 0) {
582 partials<7>(ops_partials)[i] = 0;
583 } else if (y_value - (t0_value + st0_value) <= 0) {
584 partials<7>(ops_partials)[i] = -1 / st0_value;
585 } else {
586 const T_partials_return t0_st0 = t0_value + st0_value;
587 if (sw_value == 0) {
588 f = internal::estimate_with_err_check<5, 0, GradientCalc::OFF,
589 GradientCalc::ON>(
590 [](auto&&... args) {
591 return internal::wiener5_density<GradientCalc::ON>(args...);
592 },
593 log_error_derivative + log(st0_value), y_value - t0_st0, a_value,
594 v_value, w_value, sv_value, log_error_absolute - LOG_TWO);
595 } else {
596 const T_partials_return new_error = log_error_absolute - LOG_TWO;
597 auto params_st
598 = std::make_tuple(y_value, a_value, v_value, w_value, t0_st0,
599 sv_value, sw_value, 0.0, new_error);
600 f = internal::wiener7_integrate<GradientCalc::OFF, GradientCalc::OFF>(
601 [](auto&&... args) {
602 return internal::wiener5_density<GradientCalc::ON>(args...);
603 },
604 hcubature_err, params_st, 1, xmin, xmax,
605 maximal_evaluations_hcubature, absolute_error_hcubature,
606 relative_error_hcubature / 2.0);
607 }
608 partials<7>(ops_partials)[i] = -1 / st0_value + f / st0_value / density;
609 }
610 }
611 }
612 return result + ops_partials.build(log_density);
613}
614} // namespace math
615} // namespace stan
616#endif
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
#define STAN_COLD_PATH
#define unlikely(x)
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
auto wiener7_integrate(const Wiener7FunctorT &wiener7_functor, T_err &&hcubature_err, TArgs &&... args)
Implementation function for preparing arguments and functor to be passed to the hcubature() function ...
auto wiener7_grad_sw(const T_y &y, const T_a &a, const T_v &v, const T_w &w, const T_sv &sv, const T_sw &sw, T_err log_error)
Calculate the derivative of the wiener7 density w.r.t.
auto conditionally_grad_sw(F &&functor, T_y &&y_diff, T_a &&a, T_v &&v, T_w &&w, T_sv &&sv, T_sw &&sw, T_err &&log_error)
Helper function for agnostically calling wiener5 functions (to be integrated over) or directly callin...
auto estimate_with_err_check(F &&functor, T_err &&err, ArgsTupleT &&... args_tuple)
Utility function for estimating a function with a given set of arguments, checking the result against...
void derivative(const F &f, const T &x, T &fx, T &dfx_dx)
Return the derivative of the specified univariate function at the specified argument.
size_t max_size(const T1 &x1, const Ts &... xs)
Calculate the size of the largest input.
Definition max_size.hpp:19
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
fvar< T > fmin(const fvar< T > &x1, const fvar< T > &x2)
Definition fmin.hpp:14
void check_bounded(const char *function, const char *name, const T_y &y, const T_low &low, const T_high &high)
Check if the value is between the low and high values, inclusively.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
auto hcubature(const F &integrand, const ParsTuple &pars, const int dim, const Eigen::Matrix< T_a, Eigen::Dynamic, 1 > &a, const Eigen::Matrix< T_b, Eigen::Dynamic, 1 > &b, const int max_eval, const TAbsErr reqAbsError, const TRelErr reqRelError)
Compute the [dim]-dimensional integral of the function from to within specified relative and absol...
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
static constexpr double LOG_TWO
The natural logarithm of 2, .
Definition constants.hpp:80
auto as_value_column_array_or_scalar(T &&a)
Extract the value from an object and for eigen vectors and std::vectors convert to an eigen column ar...
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
auto wiener_lpdf(const T_y &y, const T_a &a, const T_t0 &t0, const T_w &w, const T_v &v, const T_sv &sv, const double &precision_derivatives=1e-4)
Log-density function for the 5-parameter Wiener density.
void check_less(const char *function, const char *name, const T_y &y, const T_high &high, Idxs... idxs)
Throw an exception if y is not strictly less than high.
void check_greater(const char *function, const char *name, const T_y &y, const T_low &low, Idxs... idxs)
Throw an exception if y is not strictly greater than low.
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
fvar< T > fabs(const fvar< T > &x)
Definition fabs.hpp:15
typename ref_type_if< Condition, T >::type ref_type_if_t
Definition ref_type.hpp:58
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...