30#ifndef STAN_MATH_PRIM_PROB_WIENER_LPDF_HPP
31#define STAN_MATH_PRIM_PROB_WIENER_LPDF_HPP
75template <
bool propto,
typename T_y,
typename T_alpha,
typename T_tau,
76 typename T_beta,
typename T_delta>
78 const T_y& y,
const T_alpha& alpha,
const T_tau& tau,
const T_beta&
beta,
79 const T_delta& delta) {
92 static constexpr const char* function =
"wiener_lpdf";
94 alpha,
"A-priori bias",
beta,
"Nondecision time", tau,
98 T_alpha_ref alpha_ref = alpha;
99 T_tau_ref tau_ref = tau;
100 T_beta_ref beta_ref =
beta;
101 T_delta_ref delta_ref = delta;
113 T_return_type lp(0.0);
127 for (
size_t i = 0; i < N_y_tau; ++i) {
128 if (y_vec[i] <= tau_vec[i]) {
129 std::stringstream msg;
130 msg <<
", but must be greater than nondecision time = " << tau_vec[i];
131 std::string msg_str(msg.str());
141 static constexpr double WIENER_ERR = 0.000001;
142 static constexpr double PI_TIMES_WIENER_ERR =
pi() * WIENER_ERR;
143 static constexpr double LOG_PI_LOG_WIENER_ERR =
LOG_PI - 6 *
LOG_TEN;
144 static constexpr double TWO_TIMES_SQRT_TWO_PI_TIMES_WIENER_ERR
146 static constexpr double LOG_TWO_OVER_TWO_PLUS_LOG_SQRT_PI
149 static constexpr double SQUARE_PI_OVER_TWO = (
pi() *
pi()) / 2;
150 static constexpr double TWO_TIMES_LOG_SQRT_PI = 2.0 *
LOG_SQRT_PI;
152 for (
size_t i = 0; i < N; i++) {
155 T_return_type x = (y_vec[i] - tau_vec[i]) / alpha2;
156 T_return_type kl, ks, tmp = 0;
158 T_return_type sqrt_x =
sqrt(x);
159 T_return_type log_x =
log(x);
160 T_return_type one_over_pi_times_sqrt_x = 1.0 /
pi() * sqrt_x;
164 if (PI_TIMES_WIENER_ERR * x < 1) {
166 kl =
sqrt(-2.0 *
SQRT_PI * (LOG_PI_LOG_WIENER_ERR + log_x)) / sqrt_x;
168 kl = (kl > one_over_pi_times_sqrt_x) ? kl : one_over_pi_times_sqrt_x;
170 kl = one_over_pi_times_sqrt_x;
174 T_return_type tmp_expr0 = TWO_TIMES_SQRT_TWO_PI_TIMES_WIENER_ERR * sqrt_x;
177 ks = 2.0 + sqrt_x *
sqrt(-2 *
log(tmp_expr0));
179 T_return_type sqrt_x_plus_one = sqrt_x + 1.0;
180 ks = (ks > sqrt_x_plus_one) ? ks : sqrt_x_plus_one;
186 T_return_type tmp_expr1 = (K - 1.0) / 2.0;
187 T_return_type tmp_expr2 =
ceil(tmp_expr1);
188 for (k = -
floor(tmp_expr1); k <= tmp_expr2; k++) {
189 tmp += (one_minus_beta + 2.0 * k)
190 *
exp(-(
square(one_minus_beta + 2.0 * k)) * 0.5 / x);
192 tmp =
log(tmp) - LOG_TWO_OVER_TWO_PLUS_LOG_SQRT_PI - 1.5 * log_x;
195 for (k = 1; k <= K; ++k) {
196 tmp += k *
exp(-(
square(k)) * (SQUARE_PI_OVER_TWO * x))
197 *
sin(k *
pi() * one_minus_beta);
199 tmp =
log(tmp) + TWO_TIMES_LOG_SQRT_PI;
203 lp += delta_vec[i] * alpha_vec[i] * one_minus_beta
204 -
square(delta_vec[i]) * x * alpha2 / 2.0 -
log(alpha2) + tmp;
209template <
typename T_y,
typename T_alpha,
typename T_tau,
typename T_beta,
212 const T_y& y,
const T_alpha& alpha,
const T_tau& tau,
const T_beta&
beta,
213 const T_delta& delta) {
214 return wiener_lpdf<false>(y, alpha, tau,
beta, delta);
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
fvar< T > sin(const fvar< T > &x)
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
void check_bounded(const char *function, const char *name, const T_y &y, const T_low &low, const T_high &high)
Check if the value is between the low and high values, inclusively.
static constexpr double LOG_TEN
The natural logarithm of 10, .
T value_of(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > log(const fvar< T > &x)
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
static constexpr double SQRT_PI
The value of the square root of , .
static constexpr double LOG_TWO
The natural logarithm of 2, .
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
fvar< T > sqrt(const fvar< T > &x)
static constexpr double LOG_SQRT_PI
The natural logarithm of the square root of , .
static constexpr double LOG_PI
The natural logarithm of , .
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
static constexpr double SQRT_TWO_PI
The value of the square root of , .
fvar< T > floor(const fvar< T > &x)
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
static constexpr double pi()
Return the value of pi.
auto wiener_lpdf(const T_y &y, const T_a &a, const T_t0 &t0, const T_w &w, const T_v &v, const T_sv &sv, const double &precision_derivatives=1e-4)
Log-density function for the 5-parameter Wiener density.
fvar< T > ceil(const fvar< T > &x)
int64_t max_size(const T1 &x1, const Ts &... xs)
Calculate the size of the largest input.
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
fvar< T > square(const fvar< T > &x)
fvar< T > exp(const fvar< T > &x)
typename ref_type_if< true, T >::type ref_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...