1#ifndef STAN_MATH_OPENCL_REV_ADJOINT_RESULTS_HPP
2#define STAN_MATH_OPENCL_REV_ADJOINT_RESULTS_HPP
18template <
typename... T_results>
37 template <
typename... T_expressions,
38 typename = std::enable_if_t<
sizeof...(T_results)
39 ==
sizeof...(T_expressions)>>
41 index_apply<
sizeof...(T_expressions)>([&](
auto... Is) {
44 auto nonscalars_tmp = std::tuple_cat(
45 select_nonscalar_assignments<assign_op_cl::plus_equals>(
49 index_apply<std::tuple_size<
decltype(nonscalars_tmp)>::value>(
50 [&](
auto... Is_nonscal) {
51 auto nonscalars = std::make_tuple(
52 std::make_pair(std::get<Is_nonscal>(nonscalars_tmp).first,
53 std::get<Is_nonscal>(nonscalars_tmp).second)...);
55 index_apply<std::tuple_size<
decltype(scalars)>::value>(
56 [&](
auto... Is_scal) {
62 std::get<2>(std::get<Is_scal>(scalars)),
63 sum_2d(std::get<1>(std::get<Is_scal>(scalars))))...));
67 std::tie(std::get<0>(std::get<Is_scal>(scalars))...)
68 = std::make_tuple(std::get<0>(std::get<Is_scal>(scalars))
70 std::get<Is_scal>(scalars))))...);
85 template <
typename T_expression>
88 result.adj(), std::forward<T_expression>(expression), {}));
99 template <
typename T_result,
typename T_expression,
102 return std::make_tuple();
116 template <
assign_op_cl AssignOp,
typename T_result,
typename T_expression,
120 T_expression&& expression) {
121 return results_cl<T_results...>::template make_assignment_pair<AssignOp>(
122 result.adj(), std::forward<T_expression>(expression));
136 assign_op_cl AssignOp,
typename T_result,
typename T_expression,
137 std::enable_if_t<is_stan_scalar<T_result>::value
140 T_expression&& expression) {
141 return std::make_tuple();
150template <
typename... T_results>
auto select_scalar_assignments(T_result &&result, T_expression &&expression)
Selects assignments that have scalar var results.
adjoint_results_cl(T_results &&... results)
Constructor.
auto select_nonscalar_assignments(T_result &&result, T_expression &&expression)
Selects assignments that have non-scalar var results.
auto select_scalar_assignments(const var &result, T_expression &&expression)
Selects assignments that have scalar var results.
void operator+=(const expressions_cl< T_expressions... > &exprs)
Incrementing adjoint_results_cl object by expressions_cl object executes one or two kernels that eval...
auto select_nonscalar_assignments(T_result &&result, T_expression &&expression)
Selects assignments that have non-scalar var results.
Represents results that are adjoints of vars in kernel generrator expressions.
std::tuple< T_expressions... > expressions_
Represents multiple expressions that will be calculated in same kernel.
Represents an arithmetic matrix on the OpenCL device.
static void assignment_impl(const std::tuple< std::pair< T_res, T_expressions >... > &assignment_pairs)
Implementation of assignments of expressions to results.
std::tuple< T_results... > results_
static auto make_assignment_pair(T_result &&result, T_expression &&expression)
Makes a std::pair of one result and one expression and wraps it into a tuple.
Represents results that will be calculated in same kernel.
auto sum_2d(T &&a)
Two dimensional sum - reduction of a kernel generator expression.
results_cl< T_results... > results(T_results &&... results)
Deduces types for constructing results_cl object.
auto from_matrix_cl(const T &src)
Copies the source matrix that is stored on the OpenCL device to the destination Eigen matrix.
require_all_not_t< std::is_same< std::decay_t< T >, std::decay_t< Types > >... > require_all_not_same_t
Require none of the Types and T satisfy std::is_same.
require_not_t< is_stan_scalar< std::decay_t< T > > > require_not_stan_scalar_t
Require type does not satisfy is_stan_scalar.
require_t< is_var< scalar_type_t< std::decay_t< T > > > > require_st_var
Require scalar_type satisfies is_var.
adjoint_results_cl< T_results... > adjoint_results(T_results &&... results)
Deduces types for constructing adjoint_results_cl object.
constexpr auto index_apply(F &&f)
Calls given callable with an index sequence.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
assign_op_cl
Ops that decide the type of assignment for LHS operations.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Defines a static member named value which is defined to be false as the primitive scalar types cannot...