Automatic Differentiation
wiener_lpdf.hpp
Go to the documentation of this file.
1 // Original code from which Stan's code is derived:
2 // Copyright (c) 2013, Joachim Vandekerckhove.
3 // All rights reserved.
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted
7 // provided that the following conditions are met:
8 //
9 // * Redistributions of source code must retain the above copyright notice,
10 // * this list of conditions and the following disclaimer.
11 // * Redistributions in binary form must reproduce the above copyright notice,
12 // * this list of conditions and the following disclaimer in the
13 // * documentation and/or other materials provided with the distribution.
14 // * Neither the name of the University of California, Irvine nor the names
15 // * of its contributors may be used to endorse or promote products derived
16 // * from this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
22 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
28 // THE POSSIBILITY OF SUCH DAMAGE.
29 
30 #ifndef STAN_MATH_PRIM_PROB_WIENER_LPDF_HPP
31 #define STAN_MATH_PRIM_PROB_WIENER_LPDF_HPP
32 
33 #include <stan/math/prim/meta.hpp>
34 #include <stan/math/prim/err.hpp>
44 #include <algorithm>
45 #include <cmath>
46 #include <string>
47 
48 namespace stan {
49 namespace math {
50 
75 template <bool propto, typename T_y, typename T_alpha, typename T_tau,
76  typename T_beta, typename T_delta>
78  const T_y& y, const T_alpha& alpha, const T_tau& tau, const T_beta& beta,
79  const T_delta& delta) {
81  using T_y_ref = ref_type_t<T_y>;
82  using T_alpha_ref = ref_type_t<T_alpha>;
83  using T_tau_ref = ref_type_t<T_tau>;
84  using T_beta_ref = ref_type_t<T_beta>;
85  using T_delta_ref = ref_type_t<T_delta>;
86  using std::ceil;
87  using std::exp;
88  using std::floor;
89  using std::log;
90  using std::sin;
91  using std::sqrt;
92  static const char* function = "wiener_lpdf";
93  check_consistent_sizes(function, "Random variable", y, "Boundary separation",
94  alpha, "A-priori bias", beta, "Nondecision time", tau,
95  "Drift rate", delta);
96 
97  T_y_ref y_ref = y;
98  T_alpha_ref alpha_ref = alpha;
99  T_tau_ref tau_ref = tau;
100  T_beta_ref beta_ref = beta;
101  T_delta_ref delta_ref = delta;
102 
103  check_positive(function, "Random variable", value_of(y_ref));
104  check_positive_finite(function, "Boundary separation", value_of(alpha_ref));
105  check_positive_finite(function, "Nondecision time", value_of(tau_ref));
106  check_bounded(function, "A-priori bias", value_of(beta_ref), 0, 1);
107  check_finite(function, "Drift rate", value_of(delta_ref));
108 
109  if (size_zero(y, alpha, beta, tau, delta)) {
110  return 0;
111  }
112 
113  T_return_type lp(0.0);
114 
115  size_t N = max_size(y, alpha, beta, tau, delta);
116  if (!N) {
117  return 0.0;
118  }
119 
120  scalar_seq_view<T_y_ref> y_vec(y_ref);
121  scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
122  scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
123  scalar_seq_view<T_tau_ref> tau_vec(tau_ref);
124  scalar_seq_view<T_delta_ref> delta_vec(delta_ref);
125  size_t N_y_tau = max_size(y, tau);
126 
127  for (size_t i = 0; i < N_y_tau; ++i) {
128  if (y_vec[i] <= tau_vec[i]) {
129  std::stringstream msg;
130  msg << ", but must be greater than nondecision time = " << tau_vec[i];
131  std::string msg_str(msg.str());
132  throw_domain_error(function, "Random variable", y_vec[i], " = ",
133  msg_str.c_str());
134  }
135  }
136 
138  return 0;
139  }
140 
141  static const double WIENER_ERR = 0.000001;
142  static const double PI_TIMES_WIENER_ERR = pi() * WIENER_ERR;
143  static const double LOG_PI_LOG_WIENER_ERR = LOG_PI + log(WIENER_ERR);
144  static const double TWO_TIMES_SQRT_TWO_PI_TIMES_WIENER_ERR
145  = 2.0 * SQRT_TWO_PI * WIENER_ERR;
146  static const double LOG_TWO_OVER_TWO_PLUS_LOG_SQRT_PI
147  = LOG_TWO / 2 + LOG_SQRT_PI;
148  static const double SQUARE_PI_OVER_TWO = square(pi()) * 0.5;
149  static const double TWO_TIMES_LOG_SQRT_PI = 2.0 * LOG_SQRT_PI;
150 
151  for (size_t i = 0; i < N; i++) {
152  typename scalar_type<T_beta>::type one_minus_beta = 1.0 - beta_vec[i];
153  typename scalar_type<T_alpha>::type alpha2 = square(alpha_vec[i]);
154  T_return_type x = (y_vec[i] - tau_vec[i]) / alpha2;
155  T_return_type kl, ks, tmp = 0;
156  T_return_type k, K;
157  T_return_type sqrt_x = sqrt(x);
158  T_return_type log_x = log(x);
159  T_return_type one_over_pi_times_sqrt_x = 1.0 / pi() * sqrt_x;
160 
161  // calculate number of terms needed for large t:
162  // if error threshold is set low enough
163  if (PI_TIMES_WIENER_ERR * x < 1) {
164  // compute bound
165  kl = sqrt(-2.0 * SQRT_PI * (LOG_PI_LOG_WIENER_ERR + log_x)) / sqrt_x;
166  // ensure boundary conditions met
167  kl = (kl > one_over_pi_times_sqrt_x) ? kl : one_over_pi_times_sqrt_x;
168  } else {
169  kl = one_over_pi_times_sqrt_x; // set to boundary condition
170  }
171  // calculate number of terms needed for small t:
172  // if error threshold is set low enough
173  T_return_type tmp_expr0 = TWO_TIMES_SQRT_TWO_PI_TIMES_WIENER_ERR * sqrt_x;
174  if (tmp_expr0 < 1) {
175  // compute bound
176  ks = 2.0 + sqrt_x * sqrt(-2 * log(tmp_expr0));
177  // ensure boundary conditions are met
178  T_return_type sqrt_x_plus_one = sqrt_x + 1.0;
179  ks = (ks > sqrt_x_plus_one) ? ks : sqrt_x_plus_one;
180  } else { // if error threshold was set too high
181  ks = 2.0; // minimal kappa for that case
182  }
183  if (ks < kl) { // small t
184  K = ceil(ks); // round to smallest integer meeting error
185  T_return_type tmp_expr1 = (K - 1.0) / 2.0;
186  T_return_type tmp_expr2 = ceil(tmp_expr1);
187  for (k = -floor(tmp_expr1); k <= tmp_expr2; k++) {
188  tmp += (one_minus_beta + 2.0 * k)
189  * exp(-(square(one_minus_beta + 2.0 * k)) * 0.5 / x);
190  }
191  tmp = log(tmp) - LOG_TWO_OVER_TWO_PLUS_LOG_SQRT_PI - 1.5 * log_x;
192  } else { // if large t is better...
193  K = ceil(kl); // round to smallest integer meeting error
194  for (k = 1; k <= K; ++k) {
195  tmp += k * exp(-(square(k)) * (SQUARE_PI_OVER_TWO * x))
196  * sin(k * pi() * one_minus_beta);
197  }
198  tmp = log(tmp) + TWO_TIMES_LOG_SQRT_PI;
199  }
200 
201  // convert to f(t|v,a,w) and return result
202  lp += delta_vec[i] * alpha_vec[i] * one_minus_beta
203  - square(delta_vec[i]) * x * alpha2 / 2.0 - log(alpha2) + tmp;
204  }
205  return lp;
206 }
207 
208 template <typename T_y, typename T_alpha, typename T_tau, typename T_beta,
209  typename T_delta>
211  const T_y& y, const T_alpha& alpha, const T_tau& tau, const T_beta& beta,
212  const T_delta& delta) {
213  return wiener_lpdf<false>(y, alpha, tau, beta, delta);
214 }
215 
216 } // namespace math
217 } // namespace stan
218 #endif
stan::return_type_t
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
Definition: return_type.hpp:206
stan::math::ceil
fvar< T > ceil(const fvar< T > &x)
Definition: ceil.hpp:12
max_size.hpp
stan::scalar_type::type
std::decay_t< T > type
Definition: scalar_type.hpp:21
meta.hpp
ceil.hpp
stan::math::LOG_TWO
static constexpr double LOG_TWO
The natural logarithm of 2, .
Definition: constants.hpp:86
stan::math::beta
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
Definition: beta.hpp:51
err.hpp
stan::math::log
fvar< T > log(const fvar< T > &x)
Definition: log.hpp:15
stan::math::wiener_lpdf
return_type_t< T_y, T_alpha, T_tau, T_beta, T_delta > wiener_lpdf(const T_y &y, const T_alpha &alpha, const T_tau &tau, const T_beta &beta, const T_delta &delta)
Definition: wiener_lpdf.hpp:77
stan::math::sin
fvar< T > sin(const fvar< T > &x)
Definition: sin.hpp:14
stan::scalar_seq_view
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
Definition: scalar_seq_view.hpp:18
stan::math::include_summand
Definition: include_summand.hpp:37
stan::math::check_bounded
void check_bounded(const char *function, const char *name, const T_y &y, const T_low &low, const T_high &high)
Check if the value is between the low and high values, inclusively.
Definition: check_bounded.hpp:75
stan::math::LOG_PI
const double LOG_PI
The natural logarithm of , .
Definition: constants.hpp:80
square.hpp
stan::math::square
fvar< T > square(const fvar< T > &x)
Definition: square.hpp:12
stan::math::check_positive_finite
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
Definition: check_positive_finite.hpp:22
stan::math::max_size
size_t max_size(const T1 &x1, const Ts &... xs)
Calculate the size of the largest input.
Definition: max_size.hpp:19
stan::math::throw_domain_error
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
Definition: throw_domain_error.hpp:26
stan::ref_type_t
typename ref_type_if< true, T >::type ref_type_t
Definition: ref_type.hpp:54
stan::math::exp
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:13
constants.hpp
stan::math::SQRT_PI
static constexpr double SQRT_PI
The value of the square root of , .
Definition: constants.hpp:128
stan::math::pi
static constexpr double pi()
Return the value of pi.
Definition: constants.hpp:36
stan::math::floor
fvar< T > floor(const fvar< T > &x)
Definition: floor.hpp:12
stan::math::SQRT_TWO_PI
static constexpr double SQRT_TWO_PI
The value of the square root of , .
Definition: constants.hpp:135
exp.hpp
stan::math::size_zero
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition: size_zero.hpp:19
stan::math::check_consistent_sizes
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
Definition: check_consistent_sizes.hpp:15
stan
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition: fvar.hpp:9
size_zero.hpp
stan::math::check_positive
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
Definition: check_positive.hpp:27
value_of.hpp
stan::math::check_finite
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
Definition: check_finite.hpp:28
log.hpp
scalar_seq_view.hpp
stan::math::sqrt
fvar< T > sqrt(const fvar< T > &x)
Definition: sqrt.hpp:17
stan::math::value_of
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:18
stan::math::LOG_SQRT_PI
const double LOG_SQRT_PI
The natural logarithm of the square root of , .
Definition: constants.hpp:110