Automatic Differentiation
 
Loading...
Searching...
No Matches
accumulator.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_ACCUMULATOR_HPP
2#define STAN_MATH_PRIM_FUN_ACCUMULATOR_HPP
3
7#include <vector>
8#include <type_traits>
9
10namespace stan {
11namespace math {
12
23template <typename T, typename = void>
25 private:
26 std::vector<T> buf_;
27
28 public:
39 template <typename S, typename = require_stan_scalar_t<S>>
40 inline void add(S x) {
41 buf_.push_back(x);
42 }
43
51 template <typename S, require_matrix_t<S>* = nullptr>
52 inline void add(const S& m) {
53 buf_.push_back(stan::math::sum(m));
54 }
55
65 template <typename S>
66 inline void add(const std::vector<S>& xs) {
67 for (size_t i = 0; i < xs.size(); ++i) {
68 this->add(xs[i]);
69 }
70 }
71
72#ifdef STAN_OPENCL
73
79 template <typename S,
81 inline void add(const S& xs) {
82 buf_.push_back(stan::math::sum(xs));
83 }
84
85#endif
86
92 inline T sum() const { return stan::math::sum(buf_); }
93};
94
95} // namespace math
96} // namespace stan
97
98#endif
void add(S x)
Add the specified arithmetic type value to the buffer after static casting it to the class type T.
void add(const std::vector< S > &xs)
Recursively add each entry in the specified standard vector to the buffer.
std::vector< T > buf_
void add(const S &m)
Add each entry in the specified matrix, vector, or row vector of values to the buffer.
void add(const S &xs)
Sum each entry and then push to the buffer.
T sum() const
Return the sum of the accumulated values.
Class to accumulate values and eventually return their sum.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9