Automatic Differentiation
 
Loading...
Searching...
No Matches
beta.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_BETA_HPP
2#define STAN_MATH_REV_FUN_BETA_HPP
3
8
9namespace stan {
10namespace math {
11
37inline var beta(const var& a, const var& b) {
38 double digamma_ab = digamma(a.val() + b.val());
39 double digamma_a = digamma(a.val()) - digamma_ab;
40 double digamma_b = digamma(b.val()) - digamma_ab;
41 return make_callback_var(beta(a.val(), b.val()),
42 [a, b, digamma_a, digamma_b](auto& vi) mutable {
43 const double adj_val = vi.adj() * vi.val();
44 a.adj() += adj_val * digamma_a;
45 b.adj() += adj_val * digamma_b;
46 });
47}
48
67inline var beta(const var& a, double b) {
68 auto digamma_ab = digamma(a.val()) - digamma(a.val() + b);
69 return make_callback_var(beta(a.val(), b), [a, digamma_ab](auto& vi) mutable {
70 a.adj() += vi.adj() * digamma_ab * vi.val();
71 });
72}
73
92inline var beta(double a, const var& b) {
93 auto beta_val = beta(a, b.val());
94 auto digamma_ab = (digamma(b.val()) - digamma(a + b.val())) * beta_val;
95 return make_callback_var(beta_val, [b, digamma_ab](auto& vi) mutable {
96 b.adj() += vi.adj() * digamma_ab;
97 });
98}
99
100template <typename Mat1, typename Mat2,
103inline auto beta(const Mat1& a, const Mat2& b) {
107 auto beta_val = beta(arena_a.val(), arena_b.val());
108 auto digamma_ab
109 = to_arena(digamma(arena_a.val().array() + arena_b.val().array()));
110 return make_callback_var(
111 beta(arena_a.val(), arena_b.val()),
112 [arena_a, arena_b, digamma_ab](auto& vi) mutable {
113 const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
114 arena_a.adj().array()
115 += adj_val * (digamma(arena_a.val().array()) - digamma_ab);
116 arena_b.adj().array()
117 += adj_val * (digamma(arena_b.val().array()) - digamma_ab);
118 });
119 } else if (!is_constant<Mat1>::value) {
122 auto digamma_ab
123 = to_arena(digamma(arena_a.val()).array()
124 - digamma(arena_a.val().array() + arena_b.array()));
125 return make_callback_var(beta(arena_a.val(), arena_b),
126 [arena_a, arena_b, digamma_ab](auto& vi) mutable {
127 arena_a.adj().array() += vi.adj().array()
128 * digamma_ab
129 * vi.val().array();
130 });
131 } else if (!is_constant<Mat2>::value) {
134 auto beta_val = beta(arena_a, arena_b.val());
135 auto digamma_ab
136 = to_arena((digamma(arena_b.val()).array()
137 - digamma(arena_a.array() + arena_b.val().array()))
138 * beta_val.array());
139 return make_callback_var(
140 beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable {
141 arena_b.adj().array() += vi.adj().array() * digamma_ab.array();
142 });
143 }
144}
145
146template <typename Scalar, typename VarMat,
149inline auto beta(const Scalar& a, const VarMat& b) {
151 var arena_a = a;
153 auto beta_val = beta(arena_a.val(), arena_b.val());
154 auto digamma_ab = to_arena(digamma(arena_a.val() + arena_b.val().array()));
155 return make_callback_var(
156 beta(arena_a.val(), arena_b.val()),
157 [arena_a, arena_b, digamma_ab](auto& vi) mutable {
158 const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
159 arena_a.adj()
160 += (adj_val * (digamma(arena_a.val()) - digamma_ab)).sum();
161 arena_b.adj().array()
162 += adj_val * (digamma(arena_b.val().array()) - digamma_ab);
163 });
164 } else if (!is_constant<Scalar>::value) {
165 var arena_a = a;
167 auto digamma_ab = to_arena(digamma(arena_a.val())
168 - digamma(arena_a.val() + arena_b.array()));
169 return make_callback_var(
170 beta(arena_a.val(), arena_b),
171 [arena_a, arena_b, digamma_ab](auto& vi) mutable {
172 arena_a.adj()
173 += (vi.adj().array() * digamma_ab * vi.val().array()).sum();
174 });
175 } else if (!is_constant<VarMat>::value) {
176 double arena_a = value_of(a);
178 auto beta_val = beta(arena_a, arena_b.val());
179 auto digamma_ab = to_arena((digamma(arena_b.val()).array()
180 - digamma(arena_a + arena_b.val().array()))
181 * beta_val.array());
182 return make_callback_var(beta_val, [arena_b, digamma_ab](auto& vi) mutable {
183 arena_b.adj().array() += vi.adj().array() * digamma_ab.array();
184 });
185 }
186}
187
188template <typename VarMat, typename Scalar,
191inline auto beta(const VarMat& a, const Scalar& b) {
194 var arena_b = b;
195 auto beta_val = beta(arena_a.val(), arena_b.val());
196 auto digamma_ab = to_arena(digamma(arena_a.val().array() + arena_b.val()));
197 return make_callback_var(
198 beta(arena_a.val(), arena_b.val()),
199 [arena_a, arena_b, digamma_ab](auto& vi) mutable {
200 const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
201 arena_a.adj().array()
202 += adj_val * (digamma(arena_a.val().array()) - digamma_ab);
203 arena_b.adj()
204 += (adj_val * (digamma(arena_b.val()) - digamma_ab)).sum();
205 });
206 } else if (!is_constant<VarMat>::value) {
208 double arena_b = value_of(b);
209 auto digamma_ab = to_arena(digamma(arena_a.val()).array()
210 - digamma(arena_a.val().array() + arena_b));
211 return make_callback_var(
212 beta(arena_a.val(), arena_b), [arena_a, digamma_ab](auto& vi) mutable {
213 arena_a.adj().array()
214 += vi.adj().array() * digamma_ab * vi.val().array();
215 });
216 } else if (!is_constant<Scalar>::value) {
218 var arena_b = b;
219 auto beta_val = beta(arena_a, arena_b.val());
220 auto digamma_ab = to_arena(
221 (digamma(arena_b.val()) - digamma(arena_a.array() + arena_b.val()))
222 * beta_val.array());
223 return make_callback_var(
224 beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable {
225 arena_b.adj() += (vi.adj().array() * digamma_ab.array()).sum();
226 });
227 }
228}
229
230} // namespace math
231} // namespace stan
232#endif
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
require_t< is_var_matrix< std::decay_t< T > > > require_var_matrix_t
Require type satisfies is_var_matrix.
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
Definition beta.hpp:51
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
Definition digamma.hpp:23
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...