1#ifndef STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP
2#define STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP
31template <
typename Result_,
typename WMat_,
typename B_>
33 std::decay_t<Result_>
res_;
37 template <
typename T1,
typename T2,
typename T3>
42 b_(
std::forward<T3>(b)) {}
58 template <
typename Result,
typename WMat,
typename B,
62 w_mat.adj() += res.adj() * b.val().transpose();
63 b.adj() += w_mat.val().transpose() * res.adj();
78 template <
typename Result,
typename WMat,
typename B,
82 w_mat.adj() += res.adj() * b.transpose();
97 template <
typename Result,
typename WMat,
typename B,
101 b.adj() += w_mat.transpose() * res.adj();
119template <
typename Result_,
typename WMat_,
typename B_>
122 std::forward<Result_>(res), std::forward<WMat_>(w_mat),
123 std::forward<B_>(b));
158template <
typename T1,
typename T2, require_any_rev_matrix_t<T1, T2>* =
nullptr>
160 const std::vector<int>& v,
161 const std::vector<int>& u,
const T2& b) {
163 = Eigen::Map<const Eigen::SparseMatrix<double, Eigen::RowMajor>>;
164 using sparse_dense_mul_type
165 =
decltype((std::declval<sparse_val_mat>() *
value_of(b)).
eval());
174 u[m - 1] +
csr_u_to_z(u, m - 1) - 1,
"v", v.size());
176 check_range(
"csr_matrix_times_vector",
"v[]", n, i);
178 std::vector<int, arena_allocator<int>> v_arena(v.size());
179 std::transform(v.begin(), v.end(), v_arena.begin(),
180 [](
auto&& x) { return x - 1; });
181 std::vector<int, arena_allocator<int>> u_arena(u.size());
182 std::transform(u.begin(), u.end(), u_arena.begin(),
183 [](
auto&& x) { return x - 1; });
184 using sparse_var_value_t
188 sparse_var_value_t w_mat_arena
189 = to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
192 return return_t(res);
196 sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
197 v_arena.data(), w_val_arena.data());
200 return return_t(res);
202 sparse_var_value_t w_mat_arena
203 = to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
207 return return_t(res);
require_t< is_rev_matrix< std::decay_t< T > > > require_rev_matrix_t
Require type satisfies is_rev_matrix.
require_not_t< is_rev_matrix< std::decay_t< T > > > require_not_rev_matrix_t
Require type does not satisfy is_rev_matrix.
void make_csr_adjoint(Result_ &&res, WMat_ &&w_mat, B_ &&b)
Helper function to construct the csr_adjoint struct.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
void check_range(const char *function, const char *name, int max, int index, int nested_level, const char *error_msg)
Check if specified index is within range.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
Eigen::Matrix< return_type_t< T1, T2 >, Eigen::Dynamic, 1 > csr_matrix_times_vector(int m, int n, const T1 &w, const std::vector< int > &v, const std::vector< int > &u, const T2 &b)
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
std::decay_t< Result_ > res_
csr_adjoint(T1 &&res, T2 &&w_mat, T3 &&b)
void chain_internal(Result &&res, WMat &&w_mat, B &&b)
Overload for calculating adjoints of w_mat and b
std::decay_t< WMat_ > w_mat_
vari for csr_matrix_times_vector