1#ifndef STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_HPP
2#define STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_HPP
8#include <tbb/task_arena.h>
9#include <tbb/parallel_reduce.h>
10#include <tbb/blocked_range.h>
21template <
typename ReduceFunction,
typename Enable,
typename ReturnType,
22 typename Vec,
typename... Args>
23struct reduce_sum_impl;
33template <
typename ReduceFunction,
typename ReturnType,
typename Vec,
36 ReturnType, Vec, Args...> {
46 struct recursive_reducer {
53 : vmapped_(
std::forward<Vec>(vmapped)),
54 args_tuple_(
std::forward<Args>(args)...) {}
63 : vmapped_(other.vmapped_), args_tuple_(other.args_tuple_) {}
73 inline void operator()(
const tbb::blocked_range<size_t>& r) {
78 std::decay_t<Vec> sub_slice;
79 sub_slice.reserve(r.size());
80 for (
size_t i = r.begin(); i < r.end(); ++i) {
81 sub_slice.emplace_back(vmapped_[i]);
86 return ReduceFunction()(sub_slice, r.begin(), r.end() - 1, &msgs_,
97 inline void join(
const recursive_reducer& rhs) {
99 msgs_ << rhs.msgs_.str();
145 inline ReturnType
operator()(Vec&& vmapped,
bool auto_partitioning,
146 int grainsize, std::ostream* msgs,
147 Args&&... args)
const {
148 const std::size_t num_terms = vmapped.size();
149 if (vmapped.empty()) {
152 recursive_reducer worker(std::forward<Vec>(vmapped), msgs,
153 std::forward<Args>(args)...);
155 if (auto_partitioning) {
156 tbb::parallel_reduce(
157 tbb::blocked_range<std::size_t>(0, num_terms, grainsize), worker);
159 tbb::simple_partitioner partitioner;
160 tbb::parallel_deterministic_reduce(
161 tbb::blocked_range<std::size_t>(0, num_terms, grainsize), worker,
165 *msgs << worker.msgs_.str();
198template <
typename ReduceFunction,
typename Vec,
200inline auto reduce_sum(Vec&& vmapped,
int grainsize, std::ostream* msgs,
209 std::forward<Vec>(vmapped),
true, grainsize, msgs,
210 std::forward<Args>(args)...);
212 if (vmapped.empty()) {
216 return ReduceFunction()(std::forward<Vec>(vmapped), 0, vmapped.size() - 1,
217 msgs, std::forward<Args>(args)...);
require_t< std::is_arithmetic< std::decay_t< T > > > require_arithmetic_t
Require type satisfies std::is_arithmetic.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
require_t< is_vector_like< std::decay_t< T > > > require_vector_like_t
Require type satisfies is_vector_like.
auto reduce_sum(Vec &&vmapped, int grainsize, std::ostream *msgs, Args &&... args)
Call an instance of the function ReduceFunction on every element of an input sequence and sum these t...
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
typename ref_type_if< true, T >::type ref_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
ReturnType operator()(Vec &&vmapped, bool auto_partitioning, int grainsize, std::ostream *msgs, Args &&... args) const
Call an instance of the function ReduceFunction on every element of an input sequence and sum these t...
recursive_reducer(Vec &&vmapped, std::ostream *msgs, Args &&... args)
void operator()(const tbb::blocked_range< size_t > &r)
Compute the value and of ReduceFunction over the range defined by r and accumulate those in member va...
void join(const recursive_reducer &rhs)
Join reducers.
std::tuple< Args... > args_tuple_
recursive_reducer(recursive_reducer &other, tbb::split)
This is the copy operator as required for tbb::parallel_reduce Imperative form.
reduce_sum_impl implementation for any autodiff type.
Template metaprogram to calculate the base scalar return type resulting from promoting all the scalar...