Automatic Differentiation
 
Loading...
Searching...
No Matches
append_col.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_APPEND_COL_HPP
2#define STAN_MATH_REV_FUN_APPEND_COL_HPP
3
9#include <vector>
10
11namespace stan {
12namespace math {
13
34template <typename T1, typename T2, require_any_var_matrix_t<T1, T2>* = nullptr>
35inline auto append_col(const T1& A, const T2& B) {
36 check_size_match("append_col", "columns of A", A.rows(), "columns of B",
37 B.rows());
38 if (!is_constant<T1>::value && !is_constant<T2>::value) {
39 arena_t<promote_scalar_t<var, T1>> arena_A = A;
40 arena_t<promote_scalar_t<var, T2>> arena_B = B;
41 return make_callback_var(
42 append_col(value_of(arena_A), value_of(arena_B)),
43 [arena_A, arena_B](auto& vi) mutable {
44 arena_A.adj() += vi.adj().leftCols(arena_A.cols());
45 arena_B.adj() += vi.adj().rightCols(arena_B.cols());
46 });
47 } else if (!is_constant<T1>::value) {
48 arena_t<promote_scalar_t<var, T1>> arena_A = A;
49 return make_callback_var(append_col(value_of(arena_A), value_of(B)),
50 [arena_A](auto& vi) mutable {
51 arena_A.adj()
52 += vi.adj().leftCols(arena_A.cols());
53 });
54 } else {
55 arena_t<promote_scalar_t<var, T2>> arena_B = B;
56 return make_callback_var(append_col(value_of(A), value_of(arena_B)),
57 [arena_B](auto& vi) mutable {
58 arena_B.adj()
59 += vi.adj().rightCols(arena_B.cols());
60 });
61 }
62}
63
78template <typename Scal, typename RowVec,
79 require_stan_scalar_t<Scal>* = nullptr,
80 require_t<is_eigen_row_vector<RowVec>>* = nullptr>
81inline auto append_col(const Scal& A, const var_value<RowVec>& B) {
83 var arena_A = A;
85 return make_callback_var(append_col(value_of(arena_A), value_of(arena_B)),
86 [arena_A, arena_B](auto& vi) mutable {
87 arena_A.adj() += vi.adj().coeff(0);
88 arena_B.adj() += vi.adj().tail(arena_B.size());
89 });
90 } else if (!is_constant<Scal>::value) {
91 var arena_A = A;
92 return make_callback_var(
93 append_col(value_of(arena_A), value_of(B)),
94 [arena_A](auto& vi) mutable { arena_A.adj() += vi.adj().coeff(0); });
95 } else {
97 return make_callback_var(append_col(value_of(A), value_of(arena_B)),
98 [arena_B](auto& vi) mutable {
99 arena_B.adj() += vi.adj().tail(arena_B.size());
100 });
101 }
102}
103
118template <typename RowVec, typename Scal,
121inline auto append_col(const var_value<RowVec>& A, const Scal& B) {
124 var arena_B = B;
125 return make_callback_var(append_col(value_of(arena_A), value_of(arena_B)),
126 [arena_A, arena_B](auto& vi) mutable {
127 arena_A.adj() += vi.adj().head(arena_A.size());
128 arena_B.adj()
129 += vi.adj().coeff(vi.adj().size() - 1);
130 });
131 } else if (!is_constant<RowVec>::value) {
133 return make_callback_var(append_col(value_of(arena_A), value_of(B)),
134 [arena_A](auto& vi) mutable {
135 arena_A.adj() += vi.adj().head(arena_A.size());
136 });
137 } else {
138 var arena_B = B;
139 return make_callback_var(append_col(value_of(A), value_of(arena_B)),
140 [arena_B](auto& vi) mutable {
141 arena_B.adj()
142 += vi.adj().coeff(vi.adj().size() - 1);
143 });
144 }
145}
146
147} // namespace math
148} // namespace stan
149
150#endif
auto append_col(Ta &&a, Tb &&b)
Stack the cols of the arguments.
Definition append.hpp:346
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...