Automatic Differentiation
 
Loading...
Searching...
No Matches
integrate_1d.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUNCTOR_integrate_1d_HPP
2#define STAN_MATH_REV_FUNCTOR_integrate_1d_HPP
3
13#include <cmath>
14#include <functional>
15#include <ostream>
16#include <string>
17#include <type_traits>
18#include <vector>
19
20namespace stan {
21namespace math {
22
39template <typename F, typename T_a, typename T_b, typename... Args,
40 require_any_st_var<T_a, T_b, Args...> * = nullptr>
41inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
42 const F &f, const T_a &a, const T_b &b, double relative_tolerance,
43 std::ostream *msgs, const Args &... args) {
44 static constexpr const char *function = "integrate_1d";
45 check_less_or_equal(function, "lower limit", a, b);
46
47 double a_val = value_of(a);
48 double b_val = value_of(b);
49
50 if (unlikely(a_val == b_val)) {
51 if (is_inf(a_val)) {
52 throw_domain_error(function, "Integration endpoints are both", a_val, "",
53 "");
54 }
55 return var(0.0);
56 } else {
57 auto args_val_tuple = std::make_tuple(value_of(args)...);
58
59 double integral = integrate(
60 [&](const auto &x, const auto &xc) {
61 return math::apply(
62 [&](auto &&... val_args) { return f(x, xc, msgs, val_args...); },
63 args_val_tuple);
64 },
65 a_val, b_val, relative_tolerance);
66
67 constexpr size_t num_vars_ab = is_var<T_a>::value + is_var<T_b>::value;
68 size_t num_vars_args = count_vars(args...);
70 num_vars_ab + num_vars_args);
72 num_vars_ab + num_vars_args);
73 // We move this pointer up based on whether we a or b is a var type.
74 double *partials_ptr = partials;
75
76 save_varis(varis, a, b, args...);
77
78 for (size_t i = 0; i < num_vars_ab + num_vars_args; ++i) {
79 partials[i] = 0.0;
80 }
81
82 if (is_var<T_a>::value && !is_inf(a)) {
83 *partials_ptr = math::apply(
84 [&f, a_val, msgs](auto &&... val_args) {
85 return -f(a_val, 0.0, msgs, val_args...);
86 },
87 args_val_tuple);
88 partials_ptr++;
89 }
90
91 if (!is_inf(b) && is_var<T_b>::value) {
92 *partials_ptr = math::apply(
93 [&f, b_val, msgs](auto &&... val_args) {
94 return f(b_val, 0.0, msgs, val_args...);
95 },
96 args_val_tuple);
97 partials_ptr++;
98 }
99
100 {
101 nested_rev_autodiff argument_nest;
102 // The arguments copy is used multiple times in the following nests, so
103 // do it once in a separate nest for efficiency
104 auto args_tuple_local_copy = std::make_tuple(deep_copy_vars(args)...);
105
106 // Save the varis so it's easy to efficiently access the nth adjoint
107 std::vector<vari *> local_varis(num_vars_args);
109 [&](const auto &... args) {
110 save_varis(local_varis.data(), args...);
111 },
112 args_tuple_local_copy);
113
114 for (size_t n = 0; n < num_vars_args; ++n) {
115 // This computes the integral of the gradient of f with respect to the
116 // nth parameter in args using a nested nested reverse mode autodiff
117 *partials_ptr = integrate(
118 [&](const auto &x, const auto &xc) {
119 argument_nest.set_zero_all_adjoints();
120
121 nested_rev_autodiff gradient_nest;
122 var fx = math::apply(
123 [&f, &x, &xc, msgs](auto &&... local_args) {
124 return f(x, xc, msgs, local_args...);
125 },
126 args_tuple_local_copy);
127 fx.grad();
128
129 double gradient = local_varis[n]->adj();
130
131 // Gradients that evaluate to NaN are set to zero if the function
132 // itself evaluates to zero. If the function is not zero and the
133 // gradient evaluates to NaN, a std::domain_error is thrown
134 if (is_nan(gradient)) {
135 if (fx.val() == 0) {
136 gradient = 0;
137 } else {
138 throw_domain_error("gradient_of_f", "The gradient of f", n,
139 "is nan for parameter ", "");
140 }
141 }
142 return gradient;
143 },
144 a_val, b_val, relative_tolerance);
145 partials_ptr++;
146 }
147 }
148
149 return make_callback_var(
150 integral,
151 [total_vars = num_vars_ab + num_vars_args, varis, partials](auto &vi) {
152 for (size_t i = 0; i < total_vars; ++i) {
153 varis[i]->adj_ += partials[i] * vi.adj();
154 }
155 });
156 }
157}
158
214template <typename F, typename T_a, typename T_b, typename T_theta,
215 typename = require_any_var_t<T_a, T_b, T_theta>>
216inline return_type_t<T_a, T_b, T_theta> integrate_1d(
217 const F &f, const T_a &a, const T_b &b, const std::vector<T_theta> &theta,
218 const std::vector<double> &x_r, const std::vector<int> &x_i,
219 std::ostream *msgs, const double relative_tolerance = std::sqrt(EPSILON)) {
220 return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance,
221 msgs, theta, x_r, x_i);
222}
223
224} // namespace math
225} // namespace stan
226
227#endif
T * alloc_array(size_t n)
Allocate an array on the arena of the specified size to hold values of the specified template paramet...
#define unlikely(x)
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
require_any_t< is_var< scalar_type_t< std::decay_t< Types > > >... > require_any_st_var
Require any of the scalar types satisfy is_var.
Definition is_var.hpp:131
void check_less_or_equal(const char *function, const char *name, const T_y &y, const T_high &high, Idxs... idxs)
Throw an exception if y is not less than high.
bool is_nan(T &&x)
Returns 1 if the input's value is NaN and 0 otherwise.
Definition is_nan.hpp:22
static constexpr double EPSILON
Smallest positive value.
Definition constants.hpp:41
void gradient(const F &f, const Eigen::Matrix< T, Eigen::Dynamic, 1 > &x, T &fx, Eigen::Matrix< T, Eigen::Dynamic, 1 > &grad_fx)
Calculate the value and the gradient of the specified function at the specified argument.
Definition gradient.hpp:40
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
vari_value< double > vari
Definition vari.hpp:197
void throw_domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
return_type_t< T_a, T_b, Args... > integrate_1d_impl(const F &f, const T_a &a, const T_b &b, double relative_tolerance, std::ostream *msgs, const Args &... args)
Return the integral of f from a to b to the given relative tolerance.
Arith deep_copy_vars(Arith &&arg)
Forward arguments that do not contain vars.
size_t count_vars(Pargs &&... args)
Count the number of vars in the input argument list.
vari ** save_varis(vari **dest, const var &x, Pargs &&... args)
Save the vari pointer in x into the memory pointed to by dest, increment the dest storage pointer,...
var_value< double > var
Definition var.hpp:1187
return_type_t< T_a, T_b, T_theta > integrate_1d(const F &f, const T_a &a, const T_b &b, const std::vector< T_theta > &theta, const std::vector< double > &x_r, const std::vector< int > &x_i, std::ostream *msgs, const double relative_tolerance)
Compute the integral of the single variable function f from a to b to within a specified relative tol...
constexpr auto & partials(internal::partials_propagator< Types... > &x) noexcept
Access the partials for an edge of an partials_propagator
double integrate(const F &f, double a, double b, double relative_tolerance)
Integrate a single variable function f from a to b to within a specified relative tolerance.
int is_inf(const fvar< T > &x)
Returns 1 if the input's value is infinite and 0 otherwise.
Definition is_inf.hpp:21
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
Definition apply.hpp:52
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Adapt the non-variadic integrate_1d arguments to the variadic integrate_1d_impl interface.
static thread_local AutodiffStackStorage * instance_