Automatic Differentiation
 
Loading...
Searching...
No Matches
binomial_coefficient_log.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
2#define STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP
3
15
16namespace stan {
17namespace math {
18
79template <typename T_n, typename T_k,
80 require_all_stan_scalar_t<T_n, T_k>* = nullptr>
82 const T_k k) {
83 using T_partials_return = partials_return_t<T_n, T_k>;
84
85 if (is_any_nan(n, k)) {
86 return NOT_A_NUMBER;
87 }
88
89 // Choosing the more stable of the symmetric branches
90 if (n > -1 && k > value_of_rec(n) / 2.0 + 1e-8) {
91 return binomial_coefficient_log(n, n - k);
92 }
93
94 const T_partials_return n_dbl = value_of(n);
95 const T_partials_return k_dbl = value_of(k);
96 const T_partials_return n_plus_1 = n_dbl + 1;
97 const T_partials_return n_plus_1_mk = n_plus_1 - k_dbl;
98
99 static constexpr const char* function = "binomial_coefficient_log";
100 check_greater_or_equal(function, "first argument", n, -1);
101 check_greater_or_equal(function, "second argument", k, -1);
102 check_greater_or_equal(function, "(first argument - second argument + 1)",
103 n_plus_1_mk, 0.0);
104
105 auto ops_partials = make_partials_propagator(n, k);
106
107 T_partials_return value;
108 if (k_dbl == 0) {
109 value = 0;
110 } else if (n_plus_1 < lgamma_stirling_diff_useful) {
111 value = lgamma(n_plus_1) - lgamma(k_dbl + 1) - lgamma(n_plus_1_mk);
112 } else {
113 value = -lbeta(n_plus_1_mk, k_dbl + 1) - log1p(n_dbl);
114 }
115
117 // Branching on all the edge cases.
118 // In direct computation many of those would be NaN
119 // But one-sided limits from within the domain exist, all of the below
120 // follows from lim x->0 from above digamma(x) == -Inf
121 //
122 // Note that we have k < n / 2 (see the first branch in this function)
123 // se we can ignore the n == k - 1 edge case.
124 T_partials_return digamma_n_plus_1_mk = digamma(n_plus_1_mk);
125
127 if (n_dbl == -1.0) {
128 if (k_dbl == 0) {
129 partials<0>(ops_partials)[0] = 0;
130 } else {
131 partials<0>(ops_partials)[0] = NEGATIVE_INFTY;
132 }
133 } else {
134 partials<0>(ops_partials)[0]
135 = (digamma(n_plus_1) - digamma_n_plus_1_mk);
136 }
137 }
139 if (k_dbl == 0 && n_dbl == -1.0) {
140 partials<1>(ops_partials)[0] = NEGATIVE_INFTY;
141 } else if (k_dbl == -1) {
142 partials<1>(ops_partials)[0] = INFTY;
143 } else {
144 partials<1>(ops_partials)[0]
145 = (digamma_n_plus_1_mk - digamma(k_dbl + 1));
146 }
147 }
148 }
149
150 return ops_partials.build(value);
151}
152
163template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
164inline auto binomial_coefficient_log(const T1& a, const T2& b) {
165 return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
166 return binomial_coefficient_log(c, d);
167 });
168}
169
170} // namespace math
171} // namespace stan
172#endif
binomial_coefficient_log_< as_operation_cl_t< T1 >, as_operation_cl_t< T2 > > binomial_coefficient_log(T1 &&a, T2 &&b)
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
double value_of_rec(const fvar< T > &v)
Return the value of the specified variable.
static constexpr double NOT_A_NUMBER
(Quiet) not-a-number value.
Definition constants.hpp:56
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
fvar< T > lbeta(const fvar< T > &x1, const fvar< T > &x2)
Definition lbeta.hpp:14
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
void check_greater_or_equal(const char *function, const char *name, const T_y &y, const T_low &low, Idxs... idxs)
Throw an exception if y is not greater or equal than low.
fvar< T > log1p(const fvar< T > &x)
Definition log1p.hpp:12
fvar< T > lgamma(const fvar< T > &x)
Return the natural logarithm of the gamma function applied to the specified argument.
Definition lgamma.hpp:21
constexpr double lgamma_stirling_diff_useful
auto apply_scalar_binary(const T1 &x, const T2 &y, const F &f)
Base template function for vectorization of binary scalar functions defined by applying a functor to ...
bool is_any_nan(const T &x)
Returns true if the input is NaN and false otherwise.
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
static constexpr double INFTY
Positive infinity.
Definition constants.hpp:46
fvar< T > digamma(const fvar< T > &x)
Return the derivative of the log gamma function at the specified argument.
Definition digamma.hpp:23
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...