1#ifndef STAN_MATH_REV_FUN_BETA_HPP
2#define STAN_MATH_REV_FUN_BETA_HPP
38 double digamma_ab =
digamma(a.val() + b.val());
39 double digamma_a =
digamma(a.val()) - digamma_ab;
40 double digamma_b =
digamma(b.val()) - digamma_ab;
42 [a, b, digamma_a, digamma_b](
auto& vi)
mutable {
43 const double adj_val = vi.adj() * vi.val();
44 a.adj() += adj_val * digamma_a;
45 b.adj() += adj_val * digamma_b;
70 a.adj() += vi.adj() * digamma_ab * vi.val();
93 auto beta_val =
beta(a, b.val());
94 auto digamma_ab = (
digamma(b.val()) -
digamma(a + b.val())) * beta_val;
96 b.adj() += vi.adj() * digamma_ab;
100template <
typename Mat1,
typename Mat2,
103inline auto beta(
const Mat1& a,
const Mat2& b) {
107 auto beta_val =
beta(arena_a.val(), arena_b.val());
111 beta(arena_a.val(), arena_b.val()),
112 [arena_a, arena_b, digamma_ab](
auto& vi)
mutable {
113 const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
114 arena_a.adj().array()
115 += adj_val * (digamma(arena_a.val().array()) - digamma_ab);
116 arena_b.adj().array()
117 += adj_val * (digamma(arena_b.val().array()) - digamma_ab);
124 -
digamma(arena_a.val().array() + arena_b.array()));
126 [arena_a, arena_b, digamma_ab](
auto& vi)
mutable {
127 arena_a.adj().array() += vi.adj().array()
134 auto beta_val =
beta(arena_a, arena_b.val());
137 -
digamma(arena_a.array() + arena_b.val().array()))
140 beta_val, [arena_a, arena_b, digamma_ab](
auto& vi)
mutable {
141 arena_b.adj().array() += vi.adj().array() * digamma_ab.array();
146template <
typename Scalar,
typename VarMat,
149inline auto beta(
const Scalar& a,
const VarMat& b) {
153 auto beta_val =
beta(arena_a.val(), arena_b.val());
154 auto digamma_ab =
to_arena(
digamma(arena_a.val() + arena_b.val().array()));
156 beta(arena_a.val(), arena_b.val()),
157 [arena_a, arena_b, digamma_ab](
auto& vi)
mutable {
158 const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
160 += (adj_val * (digamma(arena_a.val()) - digamma_ab)).sum();
161 arena_b.adj().array()
162 += adj_val * (digamma(arena_b.val().array()) - digamma_ab);
168 -
digamma(arena_a.val() + arena_b.array()));
170 beta(arena_a.val(), arena_b),
171 [arena_a, arena_b, digamma_ab](
auto& vi)
mutable {
173 += (vi.adj().array() * digamma_ab * vi.val().array()).sum();
178 auto beta_val =
beta(arena_a, arena_b.val());
180 -
digamma(arena_a + arena_b.val().array()))
183 arena_b.adj().array() += vi.adj().array() * digamma_ab.array();
188template <
typename VarMat,
typename Scalar,
191inline auto beta(
const VarMat& a,
const Scalar& b) {
195 auto beta_val =
beta(arena_a.val(), arena_b.val());
196 auto digamma_ab =
to_arena(
digamma(arena_a.val().array() + arena_b.val()));
198 beta(arena_a.val(), arena_b.val()),
199 [arena_a, arena_b, digamma_ab](
auto& vi)
mutable {
200 const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
201 arena_a.adj().array()
202 += adj_val * (digamma(arena_a.val().array()) - digamma_ab);
204 += (adj_val * (digamma(arena_b.val()) - digamma_ab)).sum();
210 -
digamma(arena_a.val().array() + arena_b));
212 beta(arena_a.val(), arena_b), [arena_a, digamma_ab](
auto& vi)
mutable {
213 arena_a.adj().array()
214 += vi.adj().array() * digamma_ab * vi.val().array();
219 auto beta_val =
beta(arena_a, arena_b.val());
221 (
digamma(arena_b.val()) -
digamma(arena_a.array() + arena_b.val()))
224 beta_val, [arena_a, arena_b, digamma_ab](
auto& vi)
mutable {
225 arena_b.adj() += (vi.adj().array() * digamma_ab.array()).sum();
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
require_t< is_var_matrix< std::decay_t< T > > > require_var_matrix_t
Require type satisfies is_var_matrix.
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...