1#ifndef STAN_MATH_REV_FUNCTOR_REDUCE_SUM_HPP
2#define STAN_MATH_REV_FUNCTOR_REDUCE_SUM_HPP
8#include <tbb/task_arena.h>
9#include <tbb/parallel_reduce.h>
10#include <tbb/blocked_range.h>
29template <
typename ReduceFunction,
typename ReturnType,
typename Vec,
33 struct scoped_args_tuple {
36 = std::tuple<decltype(deep_copy_vars(std::declval<Args>()))...>;
51 struct recursive_reducer {
60 Eigen::VectorXd args_adjoints_{0};
62 template <
typename VecT,
typename... ArgsT>
64 double* sliced_partials, VecT&& vmapped, ArgsT&&... args)
65 : num_vars_per_term_(num_vars_per_term),
66 num_vars_shared_terms_(num_vars_shared_terms),
67 sliced_partials_(sliced_partials),
68 vmapped_(
std::forward<VecT>(vmapped)),
69 args_tuple_(
std::forward<ArgsT>(args)...) {}
78 : num_vars_per_term_(other.num_vars_per_term_),
79 num_vars_shared_terms_(other.num_vars_shared_terms_),
80 sliced_partials_(other.sliced_partials_),
81 vmapped_(other.vmapped_),
82 args_tuple_(other.args_tuple_) {}
97 inline void operator()(
const tbb::blocked_range<size_t>& r) {
102 if (args_adjoints_.size() == 0) {
103 args_adjoints_ = Eigen::VectorXd::Zero(num_vars_shared_terms_);
110 if (!local_args_tuple_scope_.args_tuple_holder_) {
114 local_args_tuple_scope_.stack_.execute([&]() {
116 [&](
auto&&... args) {
117 local_args_tuple_scope_.args_tuple_holder_ = std::make_unique<
125 local_args_tuple_scope_.stack_.execute([] { set_zero_all_adjoints(); });
128 auto& args_tuple_local = *(local_args_tuple_scope_.args_tuple_holder_);
135 std::decay_t<Vec> local_sub_slice;
136 local_sub_slice.reserve(r.size());
137 for (
size_t i = r.begin(); i < r.end(); ++i) {
142 var sub_sum_v = math::apply(
143 [&](
auto&&... args) {
144 return ReduceFunction()(local_sub_slice, r.begin(), r.end() - 1,
153 sum_ += sub_sum_v.val();
157 std::move(local_sub_slice));
161 [&](
auto&&... args) {
173 inline void join(
const recursive_reducer& rhs) {
175 if (args_adjoints_.size() != 0 && rhs.args_adjoints_.size() != 0) {
176 args_adjoints_ += rhs.args_adjoints_;
177 }
else if (args_adjoints_.size() == 0 && rhs.args_adjoints_.size() != 0) {
178 args_adjoints_ = rhs.args_adjoints_;
180 msgs_ << rhs.msgs_.str();
225 std::ostream* msgs, Args&&... args)
const {
226 if (vmapped.empty()) {
230 const std::size_t num_terms = vmapped.size();
231 const std::size_t num_vars_per_term =
count_vars(vmapped[0]);
232 const std::size_t num_vars_sliced_terms = num_terms * num_vars_per_term;
233 const std::size_t num_vars_shared_terms =
count_vars(args...);
236 num_vars_sliced_terms + num_vars_shared_terms);
238 num_vars_sliced_terms + num_vars_shared_terms);
241 save_varis(varis + num_vars_sliced_terms, args...);
243 for (
size_t i = 0; i < num_vars_sliced_terms; ++i) {
247 recursive_reducer worker(num_vars_per_term, num_vars_shared_terms,
partials,
248 std::forward<Vec>(vmapped),
249 std::forward<Args>(args)...);
257 tbb::this_task_arena::isolate([&] {
258 if (auto_partitioning) {
259 tbb::parallel_reduce(
260 tbb::blocked_range<std::size_t>(0, num_terms, grainsize), worker);
262 tbb::simple_partitioner partitioner;
263 tbb::parallel_deterministic_reduce(
264 tbb::blocked_range<std::size_t>(0, num_terms, grainsize), worker,
269 for (
size_t i = 0; i < num_vars_shared_terms; ++i) {
270 partials[num_vars_sliced_terms + i] = worker.args_adjoints_.coeff(i);
274 *msgs << worker.msgs_.str();
278 worker.sum_, num_vars_sliced_terms + num_vars_shared_terms, varis,
The AD tape of reverse mode AD is by default stored globally within the process (or thread).
A class following the RAII idiom to start and recover nested autodiff scopes.
T * alloc_array(size_t n)
Allocate an array on the arena of the specified size to hold values of the specified template paramet...
require_t< is_var< std::decay_t< T > > > require_var_t
Require type satisfies is_var.
Arith deep_copy_vars(Arith &&arg)
Forward arguments that do not contain vars.
size_t count_vars(Pargs &&... args)
Count the number of vars in the input argument list.
vari ** save_varis(vari **dest, const var &x, Pargs &&... args)
Save the vari pointer in x into the memory pointed to by dest, increment the dest storage pointer,...
constexpr auto & partials(internal::partials_propagator< Types... > &x) noexcept
Access the partials for an edge of an partials_propagator
precomputed_gradients_vari_template< std::tuple<>, std::tuple<> > precomputed_gradients_vari
double * accumulate_adjoints(double *dest, const var &x, Pargs &&... args)
Accumulate adjoints from x into storage pointed to by dest, increment the adjoint storage pointer,...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
static thread_local AutodiffStackStorage * instance_
var operator()(Vec &&vmapped, bool auto_partitioning, int grainsize, std::ostream *msgs, Args &&... args) const
Call an instance of the function ReduceFunction on every element of an input sequence and sum these t...
std::unique_ptr< args_tuple_t > args_tuple_holder_
ScopedChainableStack stack_
std::tuple< decltype(deep_copy_vars(std::declval< Args >()))... > args_tuple_t
const size_t num_vars_per_term_
recursive_reducer(size_t num_vars_per_term, size_t num_vars_shared_terms, double *sliced_partials, VecT &&vmapped, ArgsT &&... args)
scoped_args_tuple local_args_tuple_scope_
recursive_reducer(recursive_reducer &other, tbb::split)
void operator()(const tbb::blocked_range< size_t > &r)
Compute, using nested autodiff, the value and Jacobian of ReduceFunction called over the range define...
void join(const recursive_reducer &rhs)
Join reducers.
double * sliced_partials_
const size_t num_vars_shared_terms_
std::tuple< Args... > args_tuple_
reduce_sum_impl implementation for any autodiff type.