Automatic Differentiation
 
Loading...
Searching...
No Matches
map_rect_combine.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUNCTOR_MAP_RECT_COMBINE_HPP
2#define STAN_MATH_PRIM_FUNCTOR_MAP_RECT_COMBINE_HPP
3
9
10#include <vector>
11
12namespace stan {
13namespace math {
14namespace internal {
15
16/* Template class for the combine step of map_rect which implements
17 * the CombineF concept. The concept requires that
18 *
19 * - A nullary constructor creates a null combiner (used on the
20 * children)
21 *
22 * - A constructor which takes the shared and job-specific parameters
23 * (used on the root/main process)
24 *
25 * - Provides an operator() which takes as two arguments: (i) the
26 * function outputs as a ragged matrix and (ii) as second argument
27 * the output sizes of each function evaluation.
28 *
29 * This functor inserts the concatenated outputs of all reduce
30 * operations into the autodiff stack. The concatenated results are
31 * stored in a double only matrix and is ragged according to the
32 * output sizes of each job.
33 *
34 * @tparam F type of user functor
35 * @tparam T_shared_param type of shared parameters
36 * @tparam T_job_param type of job specific parameters
37 */
38template <typename F, typename T_shared_param, typename T_job_param,
39 require_eigen_col_vector_t<T_shared_param>* = nullptr>
43 Eigen::Matrix<T_job_param, Eigen::Dynamic, 1>>;
44 std::vector<ops_partials_t> ops_partials_;
45
46 const std::size_t num_shared_operands_;
47 const std::size_t num_job_operands_;
48
49 public:
50 using result_t = Eigen::Matrix<return_type_t<T_shared_param, T_job_param>,
51 Eigen::Dynamic, 1>;
52
56 const T_shared_param& shared_params,
57 const std::vector<Eigen::Matrix<T_job_param, Eigen::Dynamic, 1>>&
58 job_params)
59 : ops_partials_(),
60 num_shared_operands_(shared_params.rows()),
61 num_job_operands_(dims(job_params)[1]) {
62 ops_partials_.reserve(job_params.size());
63 for (const auto& job_param : job_params) {
64 ops_partials_.emplace_back(shared_params, job_param);
65 }
66 }
67
68 result_t operator()(const matrix_d& world_result,
69 const std::vector<int>& world_f_out) {
70 const std::size_t num_jobs = world_f_out.size();
71 const std::size_t offset_job_params
73 const std::size_t size_world_f_out = sum(world_f_out);
74
75 result_t out(size_world_f_out);
76
77 for (std::size_t i = 0, ij = 0; i != num_jobs; ++i) {
78 for (int j = 0; j != world_f_out[i]; ++j, ++ij) {
80 edge<0>(ops_partials_[i]).partials_
81 = world_result.block(1, ij, num_shared_operands_, 1);
82 }
83
85 edge<1>(ops_partials_[i]).partials_
86 = world_result.block(offset_job_params, ij, num_job_operands_, 1);
87 }
88
89 out(ij) = ops_partials_[i].build(world_result(0, ij));
90 }
91 }
92
93 return out;
94 }
95};
96
97} // namespace internal
98} // namespace math
99} // namespace stan
100
101#endif
std::vector< ops_partials_t > ops_partials_
map_rect_combine(const T_shared_param &shared_params, const std::vector< Eigen::Matrix< T_job_param, Eigen::Dynamic, 1 > > &job_params)
Eigen::Matrix< return_type_t< T_shared_param, T_job_param >, Eigen::Dynamic, 1 > result_t
result_t operator()(const matrix_d &world_result, const std::vector< int > &world_f_out)
int rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:21
void dims(const T_x &x, std::vector< int > &result)
matrix_cl overload of the dims helper function in prim/fun/dims.hpp.
Definition dims.hpp:21
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > matrix_d
Type for matrix of double values.
Definition typedefs.hpp:19
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...