Automatic Differentiation
 
Loading...
Searching...
No Matches
integrate_1d.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUNCTOR_integrate_1d_HPP
2#define STAN_MATH_REV_FUNCTOR_integrate_1d_HPP
3
13#include <cmath>
14#include <functional>
15#include <ostream>
16#include <string>
17#include <type_traits>
18#include <vector>
19
20namespace stan {
21namespace math {
22
39template <typename F, typename T_a, typename T_b, typename... Args,
40 require_any_st_var<T_a, T_b, Args...> * = nullptr>
41inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
42 const F &f, const T_a &a, const T_b &b, double relative_tolerance,
43 std::ostream *msgs, const Args &... args) {
44 static constexpr const char *function = "integrate_1d";
45 check_less_or_equal(function, "lower limit", a, b);
46
47 double a_val = value_of(a);
48 double b_val = value_of(b);
49
50 if (unlikely(a_val == b_val)) {
51 if (is_inf(a_val)) {
52 throw_domain_error(function, "Integration endpoints are both", a_val, "",
53 "");
54 }
55 return var(0.0);
56 } else {
57 auto args_val_tuple = std::make_tuple(value_of(args)...);
58
59 double integral = integrate(
60 [&](const auto &x, const auto &xc) {
61 return math::apply(
62 [&](auto &&... val_args) { return f(x, xc, msgs, val_args...); },
63 args_val_tuple);
64 },
65 a_val, b_val, relative_tolerance);
66
67 constexpr size_t num_vars_ab = is_var<T_a>::value + is_var<T_b>::value;
68 size_t num_vars_args = count_vars(args...);
70 num_vars_ab + num_vars_args);
72 num_vars_ab + num_vars_args);
73 // We move this pointer up based on whether we a or b is a var type.
74 double *partials_ptr = partials;
75
76 save_varis(varis, a, b, args...);
77
78 for (size_t i = 0; i < num_vars_ab + num_vars_args; ++i) {
79 partials[i] = 0.0;
80 }
81
82 if constexpr (is_var<T_a>::value) {
83 if (!is_inf(a)) {
84 *partials_ptr = math::apply(
85 [&f, a_val, msgs](auto &&... val_args) {
86 return -f(a_val, 0.0, msgs, val_args...);
87 },
88 args_val_tuple);
89 partials_ptr++;
90 }
91 }
92
93 if constexpr (is_var<T_b>::value) {
94 if (!is_inf(b)) {
95 *partials_ptr = math::apply(
96 [&f, b_val, msgs](auto &&... val_args) {
97 return f(b_val, 0.0, msgs, val_args...);
98 },
99 args_val_tuple);
100 partials_ptr++;
101 }
102 }
103
104 {
105 nested_rev_autodiff argument_nest;
106 // The arguments copy is used multiple times in the following nests, so
107 // do it once in a separate nest for efficiency
108 auto args_tuple_local_copy = std::make_tuple(deep_copy_vars(args)...);
109
110 // Save the varis so it's easy to efficiently access the nth adjoint
111 std::vector<vari *> local_varis(num_vars_args);
113 [&](const auto &... args) {
114 save_varis(local_varis.data(), args...);
115 },
116 args_tuple_local_copy);
117
118 for (size_t n = 0; n < num_vars_args; ++n) {
119 // This computes the integral of the gradient of f with respect to the
120 // nth parameter in args using a nested nested reverse mode autodiff
121 *partials_ptr = integrate(
122 [&](const auto &x, const auto &xc) {
123 argument_nest.set_zero_all_adjoints();
124
125 nested_rev_autodiff gradient_nest;
126 var fx = math::apply(
127 [&f, &x, &xc, msgs](auto &&... local_args) {
128 return f(x, xc, msgs, local_args...);
129 },
130 args_tuple_local_copy);
131 fx.grad();
132
133 double gradient = local_varis[n]->adj();
134
135 // Gradients that evaluate to NaN are set to zero if the function
136 // itself evaluates to zero. If the function is not zero and the
137 // gradient evaluates to NaN, a std::domain_error is thrown
138 if (is_nan(gradient)) {
139 if (fx.val() == 0) {
140 gradient = 0;
141 } else {
142 throw_domain_error("gradient_of_f", "The gradient of f", n,
143 "is nan for parameter ", "");
144 }
145 }
146 return gradient;
147 },
148 a_val, b_val, relative_tolerance);
149 partials_ptr++;
150 }
151 }
152
153 return make_callback_var(
154 integral,
155 [total_vars = num_vars_ab + num_vars_args, varis, partials](auto &vi) {
156 for (size_t i = 0; i < total_vars; ++i) {
157 varis[i]->adj_ += partials[i] * vi.adj();
158 }
159 });
160 }
161}
162
218template <typename F, typename T_a, typename T_b, typename T_theta,
219 typename = require_any_var_t<T_a, T_b, T_theta>>
220inline return_type_t<T_a, T_b, T_theta> integrate_1d(
221 const F &f, const T_a &a, const T_b &b, const std::vector<T_theta> &theta,
222 const std::vector<double> &x_r, const std::vector<int> &x_i,
223 std::ostream *msgs, const double relative_tolerance = std::sqrt(EPSILON)) {
224 return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance,
225 msgs, theta, x_r, x_i);
226}
227
228} // namespace math
229} // namespace stan
230
231#endif
T * alloc_array(size_t n)
Allocate an array on the arena of the specified size to hold values of the specified template paramet...
#define unlikely(x)
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
require_any_t< is_var< scalar_type_t< std::decay_t< Types > > >... > require_any_st_var
Require any of the scalar types satisfy is_var.
Definition is_var.hpp:196
void check_less_or_equal(const char *function, const char *name, const T_y &y, const T_high &high, Idxs... idxs)
Throw an exception if y is not less than high.
bool is_nan(T &&x)
Returns 1 if the input's value is NaN and 0 otherwise.
Definition is_nan.hpp:22
static constexpr double EPSILON
Smallest positive value.
Definition constants.hpp:41
void gradient(const F &f, const Eigen::Matrix< T, Eigen::Dynamic, 1 > &x, T &fx, Eigen::Matrix< T, Eigen::Dynamic, 1 > &grad_fx)
Calculate the value and the gradient of the specified function at the specified argument.
Definition gradient.hpp:40
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
vari_value< double > vari
Definition vari.hpp:197
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
return_type_t< T_a, T_b, Args... > integrate_1d_impl(const F &f, const T_a &a, const T_b &b, double relative_tolerance, std::ostream *msgs, const Args &... args)
Return the integral of f from a to b to the given relative tolerance.
Arith deep_copy_vars(Arith &&arg)
Forward arguments that do not contain vars.
size_t count_vars(Pargs &&... args)
Count the number of vars in the input argument list.
vari ** save_varis(vari **dest, const var &x, Pargs &&... args)
Save the vari pointer in x into the memory pointed to by dest, increment the dest storage pointer,...
var_value< double > var
Definition var.hpp:1187
return_type_t< T_a, T_b, T_theta > integrate_1d(const F &f, const T_a &a, const T_b &b, const std::vector< T_theta > &theta, const std::vector< double > &x_r, const std::vector< int > &x_i, std::ostream *msgs, const double relative_tolerance)
Compute the integral of the single variable function f from a to b to within a specified relative tol...
constexpr auto & partials(internal::partials_propagator< Types... > &x) noexcept
Access the partials for an edge of an partials_propagator
double integrate(const F &f, double a, double b, double relative_tolerance)
Integrate a single variable function f from a to b to within a specified relative tolerance.
int is_inf(const fvar< T > &x)
Returns 1 if the input's value is infinite and 0 otherwise.
Definition is_inf.hpp:21
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
Definition apply.hpp:51
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Adapt the non-variadic integrate_1d arguments to the variadic integrate_1d_impl interface.
static thread_local AutodiffStackStorage * instance_