Automatic Differentiation
 
Loading...
Searching...
No Matches
adjoint_results.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_REV_ADJOINT_RESULTS_HPP
2#define STAN_MATH_OPENCL_REV_ADJOINT_RESULTS_HPP
3#ifdef STAN_OPENCL
4
8#include <tuple>
9#include <utility>
10
11namespace stan {
12namespace math {
13
18template <typename... T_results>
19class adjoint_results_cl : protected results_cl<T_results...> {
20 public:
25 explicit adjoint_results_cl(T_results&&... results)
26 : results_cl<T_results...>(std::forward<T_results>(results)...) {}
27
37 template <typename... T_expressions,
38 typename = std::enable_if_t<sizeof...(T_results)
39 == sizeof...(T_expressions)>>
41 index_apply<sizeof...(T_expressions)>([&](auto... Is) {
42 auto scalars = std::tuple_cat(select_scalar_assignments(
43 std::get<Is>(this->results_), std::get<Is>(exprs.expressions_))...);
44 auto nonscalars_tmp = std::tuple_cat(
45 select_nonscalar_assignments<assign_op_cl::plus_equals>(
46 std::get<Is>(this->results_),
47 std::get<Is>(exprs.expressions_))...);
48
49 index_apply<std::tuple_size<decltype(nonscalars_tmp)>::value>(
50 [&](auto... Is_nonscal) {
51 auto nonscalars = std::make_tuple(
52 std::make_pair(std::get<Is_nonscal>(nonscalars_tmp).first,
53 std::get<Is_nonscal>(nonscalars_tmp).second)...);
54
55 index_apply<std::tuple_size<decltype(scalars)>::value>(
56 [&](auto... Is_scal) {
57 // evaluate all expressions
58 this->assignment_impl(std::tuple_cat(
59 nonscalars,
60 this->template make_assignment_pair<
62 std::get<2>(std::get<Is_scal>(scalars)),
63 sum_2d(std::get<1>(std::get<Is_scal>(scalars))))...));
64
65 // copy results from the OpenCL device and increment the
66 // adjoints
67 std::tie(std::get<0>(std::get<Is_scal>(scalars))...)
68 = std::make_tuple(std::get<0>(std::get<Is_scal>(scalars))
69 + sum(from_matrix_cl(std::get<2>(
70 std::get<Is_scal>(scalars))))...);
71 });
72 });
73 });
74 }
75
76 private:
85 template <typename T_expression>
86 auto select_scalar_assignments(const var& result, T_expression&& expression) {
87 return std::make_tuple(std::tuple<double&, T_expression, matrix_cl<double>>(
88 result.adj(), std::forward<T_expression>(expression), {}));
89 }
99 template <typename T_result, typename T_expression,
101 auto select_scalar_assignments(T_result&& result, T_expression&& expression) {
102 return std::make_tuple();
103 }
104
116 template <assign_op_cl AssignOp, typename T_result, typename T_expression,
118 require_st_var<T_result>* = nullptr>
119 auto select_nonscalar_assignments(T_result&& result,
120 T_expression&& expression) {
121 return results_cl<T_results...>::template make_assignment_pair<AssignOp>(
122 result.adj(), std::forward<T_expression>(expression));
123 }
135 template <
136 assign_op_cl AssignOp, typename T_result, typename T_expression,
137 std::enable_if_t<is_stan_scalar<T_result>::value
138 || !is_var<scalar_type_t<T_result>>::value>* = nullptr>
139 auto select_nonscalar_assignments(T_result&& result,
140 T_expression&& expression) {
141 return std::make_tuple();
142 }
143};
144
150template <typename... T_results>
151adjoint_results_cl<T_results...> adjoint_results(T_results&&... results) {
152 return adjoint_results_cl<T_results...>(std::forward<T_results>(results)...);
153}
154
155} // namespace math
156} // namespace stan
157
158#endif
159#endif
auto select_scalar_assignments(T_result &&result, T_expression &&expression)
Selects assignments that have scalar var results.
adjoint_results_cl(T_results &&... results)
Constructor.
auto select_nonscalar_assignments(T_result &&result, T_expression &&expression)
Selects assignments that have non-scalar var results.
auto select_scalar_assignments(const var &result, T_expression &&expression)
Selects assignments that have scalar var results.
void operator+=(const expressions_cl< T_expressions... > &exprs)
Incrementing adjoint_results_cl object by expressions_cl object executes one or two kernels that eval...
auto select_nonscalar_assignments(T_result &&result, T_expression &&expression)
Selects assignments that have non-scalar var results.
Represents results that are adjoints of vars in kernel generrator expressions.
std::tuple< T_expressions... > expressions_
Represents multiple expressions that will be calculated in same kernel.
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
static void assignment_impl(const std::tuple< std::pair< T_res, T_expressions >... > &assignment_pairs)
Implementation of assignments of expressions to results.
static auto make_assignment_pair(T_result &&result, T_expression &&expression)
Makes a std::pair of one result and one expression and wraps it into a tuple.
Represents results that will be calculated in same kernel.
auto sum_2d(T &&a)
Two dimensional sum - reduction of a kernel generator expression.
results_cl< T_results... > results(T_results &&... results)
Deduces types for constructing results_cl object.
auto from_matrix_cl(const T &src)
Copies the source matrix that is stored on the OpenCL device to the destination Eigen matrix.
Definition copy.hpp:61
require_all_not_t< std::is_same< std::decay_t< T >, std::decay_t< Types > >... > require_all_not_same_t
Require none of the Types and T satisfy std::is_same.
require_not_t< is_stan_scalar< std::decay_t< T > > > require_not_stan_scalar_t
Require type does not satisfy is_stan_scalar.
require_t< is_var< scalar_type_t< std::decay_t< T > > > > require_st_var
Require scalar_type satisfies is_var.
Definition is_var.hpp:111
adjoint_results_cl< T_results... > adjoint_results(T_results &&... results)
Deduces types for constructing adjoint_results_cl object.
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
constexpr auto index_apply(F &&f)
Calls given callable with an index sequence.
assign_op_cl
Ops that decide the type of assignment for LHS operations.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Defines a static member named value which is defined to be false as the primitive scalar types cannot...
Definition is_var.hpp:14