Automatic Differentiation
 
Loading...
Searching...
No Matches
reduce_sum.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_HPP
2#define STAN_MATH_PRIM_FUNCTOR_REDUCE_SUM_HPP
3
7
8#include <tbb/task_arena.h>
9#include <tbb/parallel_reduce.h>
10#include <tbb/blocked_range.h>
11
12#include <algorithm>
13#include <tuple>
14#include <vector>
15
16namespace stan {
17namespace math {
18
19namespace internal {
20
21template <typename ReduceFunction, typename Enable, typename ReturnType,
22 typename Vec, typename... Args>
23struct reduce_sum_impl;
24
33template <typename ReduceFunction, typename ReturnType, typename Vec,
34 typename... Args>
35struct reduce_sum_impl<ReduceFunction, require_arithmetic_t<ReturnType>,
36 ReturnType, Vec, Args...> {
46 struct recursive_reducer {
48 std::stringstream msgs_;
49 std::tuple<Args...> args_tuple_;
50 return_type_t<Vec, Args...> sum_{0.0};
51
52 recursive_reducer(Vec&& vmapped, std::ostream* msgs, Args&&... args)
53 : vmapped_(std::forward<Vec>(vmapped)),
54 args_tuple_(std::forward<Args>(args)...) {}
55
62 recursive_reducer(recursive_reducer& other, tbb::split)
63 : vmapped_(other.vmapped_), args_tuple_(other.args_tuple_) {}
64
73 inline void operator()(const tbb::blocked_range<size_t>& r) {
74 if (r.empty()) {
75 return;
76 }
77
78 std::decay_t<Vec> sub_slice;
79 sub_slice.reserve(r.size());
80 for (size_t i = r.begin(); i < r.end(); ++i) {
81 sub_slice.emplace_back(vmapped_[i]);
82 }
83
84 sum_ += math::apply(
85 [&](auto&&... args) {
86 return ReduceFunction()(sub_slice, r.begin(), r.end() - 1, &msgs_,
87 args...);
88 },
89 args_tuple_);
90 }
91
97 inline void join(const recursive_reducer& rhs) {
98 sum_ += rhs.sum_;
99 msgs_ << rhs.msgs_.str();
100 }
101 };
102
145 inline ReturnType operator()(Vec&& vmapped, bool auto_partitioning,
146 int grainsize, std::ostream* msgs,
147 Args&&... args) const {
148 const std::size_t num_terms = vmapped.size();
149 if (vmapped.empty()) {
150 return 0.0;
151 }
152 recursive_reducer worker(std::forward<Vec>(vmapped), msgs,
153 std::forward<Args>(args)...);
154
155 if (auto_partitioning) {
156 tbb::parallel_reduce(
157 tbb::blocked_range<std::size_t>(0, num_terms, grainsize), worker);
158 } else {
159 tbb::simple_partitioner partitioner;
160 tbb::parallel_deterministic_reduce(
161 tbb::blocked_range<std::size_t>(0, num_terms, grainsize), worker,
162 partitioner);
163 }
164 if (msgs) {
165 *msgs << worker.msgs_.str();
166 }
167
168 return worker.sum_;
169 }
170};
171
172} // namespace internal
173
198template <typename ReduceFunction, typename Vec,
199 typename = require_vector_like_t<Vec>, typename... Args>
200inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs,
201 Args&&... args) {
202 using return_type = return_type_t<Vec, Args...>;
203
204 check_positive("reduce_sum", "grainsize", grainsize);
205
206#ifdef STAN_THREADS
207 return internal::reduce_sum_impl<ReduceFunction, void, return_type, Vec,
208 ref_type_t<Args&&>...>()(
209 std::forward<Vec>(vmapped), true, grainsize, msgs,
210 std::forward<Args>(args)...);
211#else
212 if (vmapped.empty()) {
213 return return_type(0.0);
214 }
215
216 return ReduceFunction()(std::forward<Vec>(vmapped), 0, vmapped.size() - 1,
217 msgs, std::forward<Args>(args)...);
218#endif
219}
220
221} // namespace math
222} // namespace stan
223
224#endif
require_t< std::is_arithmetic< std::decay_t< T > > > require_arithmetic_t
Require type satisfies std::is_arithmetic.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
require_t< is_vector_like< std::decay_t< T > > > require_vector_like_t
Require type satisfies is_vector_like.
auto reduce_sum(Vec &&vmapped, int grainsize, std::ostream *msgs, Args &&... args)
Call an instance of the function ReduceFunction on every element of an input sequence and sum these t...
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
typename ref_type_if< true, T >::type ref_type_t
Definition ref_type.hpp:55
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
ReturnType operator()(Vec &&vmapped, bool auto_partitioning, int grainsize, std::ostream *msgs, Args &&... args) const
Call an instance of the function ReduceFunction on every element of an input sequence and sum these t...
void operator()(const tbb::blocked_range< size_t > &r)
Compute the value and of ReduceFunction over the range defined by r and accumulate those in member va...
recursive_reducer(recursive_reducer &other, tbb::split)
This is the copy operator as required for tbb::parallel_reduce Imperative form.
reduce_sum_impl implementation for any autodiff type.
Template metaprogram to calculate the base scalar return type resulting from promoting all the scalar...