Automatic Differentiation
 
Loading...
Searching...
No Matches
log_mix.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_FWD_FUN_LOG_MIX_HPP
2#define STAN_MATH_FWD_FUN_LOG_MIX_HPP
3
8#include <cmath>
9#include <type_traits>
10
11namespace stan {
12namespace math {
13
14/* Returns an array of size N with partials of log_mix wrt to its
15 * parameters instantiated as fvar<T>
16 *
17 * @tparam T_theta theta scalar type
18 * @tparam T_lambda1 lambda_1 scalar type
19 * @tparam T_lambda2 lambda_2 scalar type
20 *
21 * @param[in] N output array size
22 * @param[in] theta_d mixing proportion theta
23 * @param[in] lambda1_d log_density with mixing proportion theta
24 * @param[in] lambda2_d log_density with mixing proportion 1.0 - theta
25 * @param[out] partials_array array of partials derivatives
26 */
27template <typename T_theta, typename T_lambda1, typename T_lambda2, int N>
29 const T_theta& theta, const T_lambda1& lambda1, const T_lambda2& lambda2,
31 using std::exp;
32 using partial_return_type = promote_args_t<T_theta, T_lambda1, T_lambda2>;
33 auto lam2_m_lam1 = lambda2 - lambda1;
34 auto exp_lam2_m_lam1 = exp(lam2_m_lam1);
35 auto one_m_exp_lam2_m_lam1 = 1.0 - exp_lam2_m_lam1;
36 auto one_m_t = 1.0 - theta;
37 auto one_m_t_prod_exp_lam2_m_lam1 = one_m_t * exp_lam2_m_lam1;
38 auto t_plus_one_m_t_prod_exp_lam2_m_lam1
39 = theta + one_m_t_prod_exp_lam2_m_lam1;
40 auto one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1
41 = 1.0 / t_plus_one_m_t_prod_exp_lam2_m_lam1;
42
43 unsigned int offset = 0;
44 if (std::is_same<T_theta, partial_return_type>::value) {
45 partials_array[offset]
46 = one_m_exp_lam2_m_lam1 * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
47 ++offset;
48 }
49 if (std::is_same<T_lambda1, partial_return_type>::value) {
50 partials_array[offset] = theta * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
51 ++offset;
52 }
53 if (std::is_same<T_lambda2, partial_return_type>::value) {
54 partials_array[offset] = one_m_t_prod_exp_lam2_m_lam1
55 * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
56 }
57}
58
97template <typename T>
98inline fvar<T> log_mix(const fvar<T>& theta, const fvar<T>& lambda1,
99 const fvar<T>& lambda2) {
100 if (lambda1.val_ > lambda2.val_) {
101 fvar<T> partial_deriv_array[3];
102 log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
103 return fvar<T>(log_mix(theta.val_, lambda1.val_, lambda2.val_),
104 theta.d_ * value_of(partial_deriv_array[0])
105 + lambda1.d_ * value_of(partial_deriv_array[1])
106 + lambda2.d_ * value_of(partial_deriv_array[2]));
107 } else {
108 fvar<T> partial_deriv_array[3];
109 log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
110 return fvar<T>(log_mix(theta.val_, lambda1.val_, lambda2.val_),
111 -theta.d_ * value_of(partial_deriv_array[0])
112 + lambda1.d_ * value_of(partial_deriv_array[2])
113 + lambda2.d_ * value_of(partial_deriv_array[1]));
114 }
115}
116
117template <typename T, typename P, require_all_arithmetic_t<P>* = nullptr>
118inline fvar<T> log_mix(const fvar<T>& theta, const fvar<T>& lambda1,
119 P lambda2) {
120 if (lambda1.val_ > lambda2) {
121 fvar<T> partial_deriv_array[2];
122 log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
123 return fvar<T>(log_mix(theta.val_, lambda1.val_, lambda2),
124 theta.d_ * value_of(partial_deriv_array[0])
125 + lambda1.d_ * value_of(partial_deriv_array[1]));
126 } else {
127 fvar<T> partial_deriv_array[2];
128 log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
129 return fvar<T>(log_mix(theta.val_, lambda1.val_, lambda2),
130 -theta.d_ * value_of(partial_deriv_array[0])
131 + lambda1.d_ * value_of(partial_deriv_array[1]));
132 }
133}
134
135template <typename T, typename P, require_all_arithmetic_t<P>* = nullptr>
136inline fvar<T> log_mix(const fvar<T>& theta, P lambda1,
137 const fvar<T>& lambda2) {
138 if (lambda1 > lambda2.val_) {
139 fvar<T> partial_deriv_array[2];
140 log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
141 return fvar<T>(log_mix(theta.val_, lambda1, lambda2.val_),
142 theta.d_ * value_of(partial_deriv_array[0])
143 + lambda2.d_ * value_of(partial_deriv_array[1]));
144 } else {
145 fvar<T> partial_deriv_array[2];
146 log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
147 return fvar<T>(log_mix(theta.val_, lambda1, lambda2.val_),
148 -theta.d_ * value_of(partial_deriv_array[0])
149 + lambda2.d_ * value_of(partial_deriv_array[1]));
150 }
151}
152
153template <typename T, typename P, require_all_arithmetic_t<P>* = nullptr>
154inline fvar<T> log_mix(P theta, const fvar<T>& lambda1,
155 const fvar<T>& lambda2) {
156 if (lambda1.val_ > lambda2.val_) {
157 fvar<T> partial_deriv_array[2];
158 log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
159 return fvar<T>(log_mix(theta, lambda1.val_, lambda2.val_),
160 lambda1.d_ * value_of(partial_deriv_array[0])
161 + lambda2.d_ * value_of(partial_deriv_array[1]));
162 } else {
163 fvar<T> partial_deriv_array[2];
164 log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
165 return fvar<T>(log_mix(theta, lambda1.val_, lambda2.val_),
166 lambda1.d_ * value_of(partial_deriv_array[1])
167 + lambda2.d_ * value_of(partial_deriv_array[0]));
168 }
169}
170
171template <typename T, typename P1, typename P2,
173inline fvar<T> log_mix(const fvar<T>& theta, P1 lambda1, P2 lambda2) {
174 if (lambda1 > lambda2) {
175 fvar<T> partial_deriv_array[1];
176 log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
177 return fvar<T>(log_mix(theta.val_, lambda1, lambda2),
178 theta.d_ * value_of(partial_deriv_array[0]));
179 } else {
180 fvar<T> partial_deriv_array[1];
181 log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
182 return fvar<T>(log_mix(theta.val_, lambda1, lambda2),
183 -theta.d_ * value_of(partial_deriv_array[0]));
184 }
185}
186
187template <typename T, typename P1, typename P2,
189inline fvar<T> log_mix(P1 theta, const fvar<T>& lambda1, P2 lambda2) {
190 if (lambda1.val_ > lambda2) {
191 fvar<T> partial_deriv_array[1];
192 log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
193 return fvar<T>(log_mix(theta, lambda1.val_, lambda2),
194 lambda1.d_ * value_of(partial_deriv_array[0]));
195 } else {
196 fvar<T> partial_deriv_array[1];
197 log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
198 return fvar<T>(log_mix(theta, lambda1.val_, lambda2),
199 lambda1.d_ * value_of(partial_deriv_array[0]));
200 }
201}
202
203template <typename T, typename P1, typename P2,
205inline fvar<T> log_mix(P1 theta, P2 lambda1, const fvar<T>& lambda2) {
206 if (lambda1 > lambda2.val_) {
207 fvar<T> partial_deriv_array[1];
208 log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
209 return fvar<T>(log_mix(theta, lambda1, lambda2.val_),
210 lambda2.d_ * value_of(partial_deriv_array[0]));
211 } else {
212 fvar<T> partial_deriv_array[1];
213 log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
214 return fvar<T>(log_mix(theta, lambda1, lambda2.val_),
215 lambda2.d_ * value_of(partial_deriv_array[0]));
216 }
217}
218} // namespace math
219} // namespace stan
220#endif
require_all_t< std::is_arithmetic< std::decay_t< Types > >... > require_all_arithmetic_t
Require all of the types satisfy std::is_arithmetic.
typename boost::math::tools::promote_args< Args... >::type promote_args_t
Convenience alias for boost tools promote_args.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log_mix(const fvar< T > &theta, const fvar< T > &lambda1, const fvar< T > &lambda2)
Return the log mixture density with specified mixing proportion and log densities and its derivative ...
Definition log_mix.hpp:98
void log_mix_partial_helper(const T_theta &theta, const T_lambda1 &lambda1, const T_lambda2 &lambda2, promote_args_t< T_theta, T_lambda1, T_lambda2 >(&partials_array)[N])
Definition log_mix.hpp:28
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Scalar val_
The value of this variable.
Definition fvar.hpp:49
Scalar d_
The tangent (derivative) of this variable.
Definition fvar.hpp:61
This template class represents scalars used in forward-mode automatic differentiation,...
Definition fvar.hpp:40