Automatic Differentiation
 
Loading...
Searching...
No Matches
reduce_sum.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUNCTOR_REDUCE_SUM_HPP
2#define STAN_MATH_REV_FUNCTOR_REDUCE_SUM_HPP
3
7
8#include <tbb/task_arena.h>
9#include <tbb/parallel_reduce.h>
10#include <tbb/blocked_range.h>
11
12#include <tuple>
13#include <memory>
14#include <utility>
15#include <vector>
16
17namespace stan {
18namespace math {
19namespace internal {
20
29template <typename ReduceFunction, typename ReturnType, typename Vec,
30 typename... Args>
31struct reduce_sum_impl<ReduceFunction, require_var_t<ReturnType>, ReturnType,
32 Vec, Args...> {
33 struct scoped_args_tuple {
36 = std::tuple<decltype(deep_copy_vars(std::declval<Args>()))...>;
37 std::unique_ptr<args_tuple_t> args_tuple_holder_;
38
39 scoped_args_tuple() : stack_(), args_tuple_holder_(nullptr) {}
40 };
41
51 struct recursive_reducer {
52 const size_t num_vars_per_term_;
53 const size_t num_vars_shared_terms_; // Number of vars in shared arguments
54 double* sliced_partials_; // Points to adjoints of the partial calculations
56 std::stringstream msgs_;
57 std::tuple<Args...> args_tuple_;
58 scoped_args_tuple local_args_tuple_scope_;
59 double sum_{0.0};
60 Eigen::VectorXd args_adjoints_{0};
61
62 template <typename VecT, typename... ArgsT>
63 recursive_reducer(size_t num_vars_per_term, size_t num_vars_shared_terms,
64 double* sliced_partials, VecT&& vmapped, ArgsT&&... args)
65 : num_vars_per_term_(num_vars_per_term),
66 num_vars_shared_terms_(num_vars_shared_terms),
67 sliced_partials_(sliced_partials),
68 vmapped_(std::forward<VecT>(vmapped)),
69 args_tuple_(std::forward<ArgsT>(args)...) {}
70
71 /*
72 * This is the copy operator as required for tbb::parallel_reduce
73 * Imperative form. This requires sum_ and arg_adjoints_ be reset
74 * to zero since the newly created reducer is used to accumulate
75 * an independent partial sum.
76 */
77 recursive_reducer(recursive_reducer& other, tbb::split)
78 : num_vars_per_term_(other.num_vars_per_term_),
79 num_vars_shared_terms_(other.num_vars_shared_terms_),
80 sliced_partials_(other.sliced_partials_),
81 vmapped_(other.vmapped_),
82 args_tuple_(other.args_tuple_) {}
83
97 inline void operator()(const tbb::blocked_range<size_t>& r) {
98 if (r.empty()) {
99 return;
100 }
101
102 if (args_adjoints_.size() == 0) {
103 args_adjoints_ = Eigen::VectorXd::Zero(num_vars_shared_terms_);
104 }
105
106 // Obtain reference to a local copy of all shared arguments that do
107 // not point
108 // back to main autodiff stack
109
110 if (!local_args_tuple_scope_.args_tuple_holder_) {
111 // shared arguments need to be copied to reducer-specific
112 // scope. In this case no need for zeroing adjoints, since the
113 // fresh copy has all adjoints set to zero.
114 local_args_tuple_scope_.stack_.execute([&]() {
115 math::apply(
116 [&](auto&&... args) {
117 local_args_tuple_scope_.args_tuple_holder_ = std::make_unique<
119 deep_copy_vars(args)...);
120 },
121 args_tuple_);
122 });
123 } else {
124 // set adjoints of shared arguments to zero
125 local_args_tuple_scope_.stack_.execute([] { set_zero_all_adjoints(); });
126 }
127
128 auto& args_tuple_local = *(local_args_tuple_scope_.args_tuple_holder_);
129
130 // Initialize nested autodiff stack
131 const nested_rev_autodiff begin_nest;
132
133 // Create nested autodiff copies of sliced argument that do not point
134 // back to main autodiff stack
135 std::decay_t<Vec> local_sub_slice;
136 local_sub_slice.reserve(r.size());
137 for (size_t i = r.begin(); i < r.end(); ++i) {
138 local_sub_slice.emplace_back(deep_copy_vars(vmapped_[i]));
139 }
140
141 // Perform calculation
142 var sub_sum_v = math::apply(
143 [&](auto&&... args) {
144 return ReduceFunction()(local_sub_slice, r.begin(), r.end() - 1,
145 &msgs_, args...);
146 },
147 args_tuple_local);
148
149 // Compute Jacobian
150 sub_sum_v.grad();
151
152 // Accumulate value of reduce_sum
153 sum_ += sub_sum_v.val();
154
155 // Accumulate adjoints of sliced_arguments
156 accumulate_adjoints(sliced_partials_ + r.begin() * num_vars_per_term_,
157 std::move(local_sub_slice));
158
159 // Accumulate adjoints of shared_arguments
160 math::apply(
161 [&](auto&&... args) {
162 accumulate_adjoints(args_adjoints_.data(), args...);
163 },
164 args_tuple_local);
165 }
166
173 inline void join(const recursive_reducer& rhs) {
174 sum_ += rhs.sum_;
175 if (args_adjoints_.size() != 0 && rhs.args_adjoints_.size() != 0) {
176 args_adjoints_ += rhs.args_adjoints_;
177 } else if (args_adjoints_.size() == 0 && rhs.args_adjoints_.size() != 0) {
178 args_adjoints_ = rhs.args_adjoints_;
179 }
180 msgs_ << rhs.msgs_.str();
181 }
182 };
183
224 inline var operator()(Vec&& vmapped, bool auto_partitioning, int grainsize,
225 std::ostream* msgs, Args&&... args) const {
226 if (vmapped.empty()) {
227 return var(0.0);
228 }
229
230 const std::size_t num_terms = vmapped.size();
231 const std::size_t num_vars_per_term = count_vars(vmapped[0]);
232 const std::size_t num_vars_sliced_terms = num_terms * num_vars_per_term;
233 const std::size_t num_vars_shared_terms = count_vars(args...);
234
236 num_vars_sliced_terms + num_vars_shared_terms);
238 num_vars_sliced_terms + num_vars_shared_terms);
239
240 save_varis(varis, vmapped);
241 save_varis(varis + num_vars_sliced_terms, args...);
242
243 for (size_t i = 0; i < num_vars_sliced_terms; ++i) {
244 partials[i] = 0.0;
245 }
246
247 recursive_reducer worker(num_vars_per_term, num_vars_shared_terms, partials,
248 std::forward<Vec>(vmapped),
249 std::forward<Args>(args)...);
250
251 // we must use task isolation as described here:
252 // https://software.intel.com/content/www/us/en/develop/documentation/tbb-documentation/top/intel-threading-building-blocks-developer-guide/task-isolation.html
253 // this is to ensure that the thread local AD tape ressource is
254 // not being modified from a different task which may happen
255 // whenever this function is being used itself in a parallel
256 // context (like running multiple chains for Stan)
257 tbb::this_task_arena::isolate([&] {
258 if (auto_partitioning) {
259 tbb::parallel_reduce(
260 tbb::blocked_range<std::size_t>(0, num_terms, grainsize), worker);
261 } else {
262 tbb::simple_partitioner partitioner;
263 tbb::parallel_deterministic_reduce(
264 tbb::blocked_range<std::size_t>(0, num_terms, grainsize), worker,
265 partitioner);
266 }
267 });
268
269 for (size_t i = 0; i < num_vars_shared_terms; ++i) {
270 partials[num_vars_sliced_terms + i] = worker.args_adjoints_.coeff(i);
271 }
272
273 if (msgs) {
274 *msgs << worker.msgs_.str();
275 }
276
278 worker.sum_, num_vars_sliced_terms + num_vars_shared_terms, varis,
279 partials));
280 }
281};
282} // namespace internal
283
284} // namespace math
285} // namespace stan
286
287#endif
The AD tape of reverse mode AD is by default stored globally within the process (or thread).
A class following the RAII idiom to start and recover nested autodiff scopes.
T * alloc_array(size_t n)
Allocate an array on the arena of the specified size to hold values of the specified template paramet...
require_t< is_var< std::decay_t< T > > > require_var_t
Require type satisfies is_var.
Definition is_var.hpp:24
Arith deep_copy_vars(Arith &&arg)
Forward arguments that do not contain vars.
size_t count_vars(Pargs &&... args)
Count the number of vars in the input argument list.
vari ** save_varis(vari **dest, const var &x, Pargs &&... args)
Save the vari pointer in x into the memory pointed to by dest, increment the dest storage pointer,...
var_value< double > var
Definition var.hpp:1187
constexpr auto & partials(internal::partials_propagator< Types... > &x) noexcept
Access the partials for an edge of an partials_propagator
precomputed_gradients_vari_template< std::tuple<>, std::tuple<> > precomputed_gradients_vari
double * accumulate_adjoints(double *dest, const var &x, Pargs &&... args)
Accumulate adjoints from x into storage pointed to by dest, increment the adjoint storage pointer,...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
static thread_local AutodiffStackStorage * instance_
var operator()(Vec &&vmapped, bool auto_partitioning, int grainsize, std::ostream *msgs, Args &&... args) const
Call an instance of the function ReduceFunction on every element of an input sequence and sum these t...
recursive_reducer(size_t num_vars_per_term, size_t num_vars_shared_terms, double *sliced_partials, VecT &&vmapped, ArgsT &&... args)
void operator()(const tbb::blocked_range< size_t > &r)
Compute, using nested autodiff, the value and Jacobian of ReduceFunction called over the range define...
reduce_sum_impl implementation for any autodiff type.