Automatic Differentiation
 
Loading...
Searching...
No Matches
wiener_lpdf.hpp
Go to the documentation of this file.
1// Original code from which Stan's code is derived:
2// Copyright (c) 2013, Joachim Vandekerckhove.
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted
7// provided that the following conditions are met:
8//
9// * Redistributions of source code must retain the above copyright notice,
10// * this list of conditions and the following disclaimer.
11// * Redistributions in binary form must reproduce the above copyright notice,
12// * this list of conditions and the following disclaimer in the
13// * documentation and/or other materials provided with the distribution.
14// * Neither the name of the University of California, Irvine nor the names
15// * of its contributors may be used to endorse or promote products derived
16// * from this software without specific prior written permission.
17//
18// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
22// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
28// THE POSSIBILITY OF SUCH DAMAGE.
29
30#ifndef STAN_MATH_PRIM_PROB_WIENER_LPDF_HPP
31#define STAN_MATH_PRIM_PROB_WIENER_LPDF_HPP
32
46#include <algorithm>
47#include <cmath>
48#include <string>
49
50namespace stan {
51namespace math {
52
77template <bool propto, typename T_y, typename T_alpha, typename T_tau,
78 typename T_beta, typename T_delta>
80 const T_y& y, const T_alpha& alpha, const T_tau& tau, const T_beta& beta,
81 const T_delta& delta) {
83 using T_y_ref = ref_type_t<T_y>;
84 using T_alpha_ref = ref_type_t<T_alpha>;
85 using T_tau_ref = ref_type_t<T_tau>;
86 using T_beta_ref = ref_type_t<T_beta>;
87 using T_delta_ref = ref_type_t<T_delta>;
88 static constexpr const char* function = "wiener_lpdf";
89 check_consistent_sizes(function, "Random variable", y, "Boundary separation",
90 alpha, "A-priori bias", beta, "Nondecision time", tau,
91 "Drift rate", delta);
92
93 T_y_ref y_ref = y;
94 T_alpha_ref alpha_ref = alpha;
95 T_tau_ref tau_ref = tau;
96 T_beta_ref beta_ref = beta;
97 T_delta_ref delta_ref = delta;
98
99 check_positive(function, "Random variable", value_of(y_ref));
100 check_positive_finite(function, "Boundary separation", value_of(alpha_ref));
101 check_positive_finite(function, "Nondecision time", value_of(tau_ref));
102 check_bounded(function, "A-priori bias", value_of(beta_ref), 0, 1);
103 check_finite(function, "Drift rate", value_of(delta_ref));
104
105 if (size_zero(y, alpha, beta, tau, delta)) {
106 return 0;
107 }
108
109 T_return_type lp(0.0);
110
111 size_t N = max_size(y, alpha, beta, tau, delta);
112 if (!N) {
113 return 0.0;
114 }
115
116 scalar_seq_view<T_y_ref> y_vec(y_ref);
117 scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
118 scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
119 scalar_seq_view<T_tau_ref> tau_vec(tau_ref);
120 scalar_seq_view<T_delta_ref> delta_vec(delta_ref);
121 size_t N_y_tau = max_size(y, tau);
122
123 for (size_t i = 0; i < N_y_tau; ++i) {
124 if (y_vec[i] <= tau_vec[i]) {
125 std::stringstream msg;
126 msg << ", but must be greater than nondecision time = " << tau_vec[i];
127 std::string msg_str(msg.str());
128 throw_domain_error(function, "Random variable", y_vec[i], " = ",
129 msg_str.c_str());
130 }
131 }
132
133 if constexpr (!include_summand<propto, T_y, T_alpha, T_tau, T_beta,
134 T_delta>::value) {
135 return 0;
136 }
137
138 static constexpr double WIENER_ERR = 0.000001;
139 static constexpr double PI_TIMES_WIENER_ERR = pi() * WIENER_ERR;
140 static constexpr double LOG_PI_LOG_WIENER_ERR = LOG_PI - 6 * LOG_TEN;
141 static constexpr double TWO_TIMES_SQRT_TWO_PI_TIMES_WIENER_ERR
142 = 2.0 * SQRT_TWO_PI * WIENER_ERR;
143 static constexpr double LOG_TWO_OVER_TWO_PLUS_LOG_SQRT_PI
144 = LOG_TWO / 2 + LOG_SQRT_PI;
145 // square(pi()) * 0.5
146 static constexpr double SQUARE_PI_OVER_TWO = (pi() * pi()) / 2;
147 static constexpr double TWO_TIMES_LOG_SQRT_PI = 2.0 * LOG_SQRT_PI;
148
149 for (size_t i = 0; i < N; i++) {
150 typename scalar_type<T_beta>::type one_minus_beta = 1.0 - beta_vec[i];
151 typename scalar_type<T_alpha>::type alpha2 = square(alpha_vec[i]);
152 T_return_type x = (y_vec[i] - tau_vec[i]) / alpha2;
153 T_return_type kl, ks, tmp = 0;
154 T_return_type k, K;
155 T_return_type sqrt_x = sqrt(x);
156 T_return_type log_x = log(x);
157 T_return_type one_over_pi_times_sqrt_x = 1.0 / pi() * sqrt_x;
158
159 // calculate number of terms needed for large t:
160 // if error threshold is set low enough
161 if (PI_TIMES_WIENER_ERR * x < 1) {
162 // compute bound
163 kl = sqrt(-2.0 * SQRT_PI * (LOG_PI_LOG_WIENER_ERR + log_x)) / sqrt_x;
164 // ensure boundary conditions met
165 kl = (kl > one_over_pi_times_sqrt_x) ? kl : one_over_pi_times_sqrt_x;
166 } else {
167 kl = one_over_pi_times_sqrt_x; // set to boundary condition
168 }
169 // calculate number of terms needed for small t:
170 // if error threshold is set low enough
171 T_return_type tmp_expr0 = TWO_TIMES_SQRT_TWO_PI_TIMES_WIENER_ERR * sqrt_x;
172 if (tmp_expr0 < 1) {
173 // compute bound
174 ks = 2.0 + sqrt_x * sqrt(-2 * log(tmp_expr0));
175 // ensure boundary conditions are met
176 T_return_type sqrt_x_plus_one = sqrt_x + 1.0;
177 ks = (ks > sqrt_x_plus_one) ? ks : sqrt_x_plus_one;
178 } else { // if error threshold was set too high
179 ks = 2.0; // minimal kappa for that case
180 }
181 if (ks < kl) { // small t
182 K = ceil(ks); // round to smallest integer meeting error
183 T_return_type tmp_expr1 = (K - 1.0) / 2.0;
184 T_return_type tmp_expr2 = ceil(tmp_expr1);
185 for (k = -floor(tmp_expr1); k <= tmp_expr2; k++) {
186 tmp += (one_minus_beta + 2.0 * k)
187 * exp(-(square(one_minus_beta + 2.0 * k)) * 0.5 / x);
188 }
189 tmp = log(tmp) - LOG_TWO_OVER_TWO_PLUS_LOG_SQRT_PI - 1.5 * log_x;
190 } else { // if large t is better...
191 K = ceil(kl); // round to smallest integer meeting error
192 for (k = 1; k <= K; ++k) {
193 tmp += k * exp(-(square(k)) * (SQUARE_PI_OVER_TWO * x))
194 * sin(k * pi() * one_minus_beta);
195 }
196 tmp = log(tmp) + TWO_TIMES_LOG_SQRT_PI;
197 }
198
199 // convert to f(t|v,a,w) and return result
200 lp += delta_vec[i] * alpha_vec[i] * one_minus_beta
201 - square(delta_vec[i]) * x * alpha2 / 2.0 - log(alpha2) + tmp;
202 }
203 return lp;
204}
205
206template <typename T_y, typename T_alpha, typename T_tau, typename T_beta,
207 typename T_delta>
209 const T_y& y, const T_alpha& alpha, const T_tau& tau, const T_beta& beta,
210 const T_delta& delta) {
211 return wiener_lpdf<false>(y, alpha, tau, beta, delta);
212}
213
214} // namespace math
215} // namespace stan
216#endif
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
fvar< T > sin(const fvar< T > &x)
Definition sin.hpp:16
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition size_zero.hpp:19
void check_bounded(const char *function, const char *name, const T_y &y, const T_low &low, const T_high &high)
Check if the value is between the low and high values, inclusively.
static constexpr double LOG_TEN
The natural logarithm of 10, .
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
static constexpr double SQRT_PI
The value of the square root of , .
static constexpr double LOG_TWO
The natural logarithm of 2, .
Definition constants.hpp:80
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:18
static constexpr double LOG_SQRT_PI
The natural logarithm of the square root of , .
static constexpr double LOG_PI
The natural logarithm of , .
Definition constants.hpp:86
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
static constexpr double SQRT_TWO_PI
The value of the square root of , .
fvar< T > floor(const fvar< T > &x)
Definition floor.hpp:13
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
static constexpr double pi()
Return the value of pi.
Definition constants.hpp:36
auto wiener_lpdf(const T_y &y, const T_a &a, const T_t0 &t0, const T_w &w, const T_v &v, const T_sv &sv, const double &precision_derivatives=1e-4)
Log-density function for the 5-parameter Wiener density.
fvar< T > ceil(const fvar< T > &x)
Definition ceil.hpp:13
int64_t max_size(const T1 &x1, const Ts &... xs)
Calculate the size of the largest input.
Definition max_size.hpp:20
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
Definition beta.hpp:51
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
typename ref_type_if< true, T >::type ref_type_t
Definition ref_type.hpp:56
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...
std::decay_t< T > type