Automatic Differentiation
 
Loading...
Searching...
No Matches
csr_matrix_times_vector.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP
2#define STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP
3
9#include <vector>
10
11namespace stan {
12namespace math {
13
14namespace internal {
30template <typename Result_, typename WMat_, typename B_>
31struct csr_adjoint : public vari {
32 std::decay_t<Result_> res_;
33 std::decay_t<WMat_> w_mat_;
34 std::decay_t<B_> b_;
35
36 template <typename T1, typename T2, typename T3>
37 csr_adjoint(T1&& res, T2&& w_mat, T3&& b)
38 : vari(0.0),
39 res_(std::forward<T1>(res)),
40 w_mat_(std::forward<T2>(w_mat)),
41 b_(std::forward<T3>(b)) {}
42
44
57 template <typename Result, typename WMat, typename B,
59 require_rev_matrix_t<B>* = nullptr>
60 inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
61 w_mat.adj() += res.adj() * b.val().transpose();
62 b.adj() += w_mat.val().transpose() * res.adj();
63 }
64
77 template <typename Result, typename WMat, typename B,
80 inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
81 w_mat.adj() += res.adj() * b.transpose();
82 }
83
96 template <typename Result, typename WMat, typename B,
98 require_rev_matrix_t<B>* = nullptr>
99 inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
100 b.adj() += w_mat.transpose() * res.adj();
101 }
102};
103
118template <typename Result_, typename WMat_, typename B_>
119inline void make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) {
120 new csr_adjoint<std::decay_t<Result_>, std::decay_t<WMat_>, std::decay_t<B_>>(
121 std::forward<Result_>(res), std::forward<WMat_>(w_mat),
122 std::forward<B_>(b));
123 return;
124}
125} // namespace internal
126
157template <typename T1, typename T2, require_any_rev_matrix_t<T1, T2>* = nullptr>
158inline auto csr_matrix_times_vector(int m, int n, const T1& w,
159 const std::vector<int>& v,
160 const std::vector<int>& u, const T2& b) {
161 using sparse_val_mat
162 = Eigen::Map<const Eigen::SparseMatrix<double, Eigen::RowMajor>>;
163 using sparse_dense_mul_type
164 = decltype((std::declval<sparse_val_mat>() * value_of(b)).eval());
166
167 check_positive("csr_matrix_times_vector", "m", m);
168 check_positive("csr_matrix_times_vector", "n", n);
169 check_size_match("csr_matrix_times_vector", "n", n, "b", b.size());
170 check_size_match("csr_matrix_times_vector", "w", w.size(), "v", v.size());
171 check_size_match("csr_matrix_times_vector", "m", m, "u", u.size() - 1);
172 check_size_match("csr_matrix_times_vector", "u/z",
173 u[m - 1] + csr_u_to_z(u, m - 1) - 1, "v", v.size());
174 for (int i : v) {
175 check_range("csr_matrix_times_vector", "v[]", n, i);
176 }
177 std::vector<int, arena_allocator<int>> v_arena(v.size());
178 std::transform(v.begin(), v.end(), v_arena.begin(),
179 [](auto&& x) { return x - 1; });
180 std::vector<int, arena_allocator<int>> u_arena(u.size());
181 std::transform(u.begin(), u.end(), u_arena.begin(),
182 [](auto&& x) { return x - 1; });
183 using sparse_var_value_t
187 sparse_var_value_t w_mat_arena
188 = to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
189 arena_t<return_t> res = w_mat_arena.val() * value_of(b_arena);
190 stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena);
191 return return_t(res);
192 } else if (!is_constant<T2>::value) {
194 auto w_val_arena = to_arena(value_of(w));
195 sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
196 v_arena.data(), w_val_arena.data());
197 arena_t<return_t> res = w_val_mat * value_of(b_arena);
198 stan::math::internal::make_csr_adjoint(res, w_val_mat, b_arena);
199 return return_t(res);
200 } else {
201 sparse_var_value_t w_mat_arena
202 = to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
203 auto b_arena = to_arena(value_of(b));
204 arena_t<return_t> res = w_mat_arena.val() * b_arena;
205 stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena);
206 return return_t(res);
207 }
208}
209
210} // namespace math
211} // namespace stan
212
213#endif
int csr_u_to_z(const std::vector< int > &u, int i)
Return the z vector computed from the specified u vector at the index for the z vector.
require_t< is_rev_matrix< std::decay_t< T > > > require_rev_matrix_t
Require type satisfies is_rev_matrix.
require_not_t< is_rev_matrix< std::decay_t< T > > > require_not_rev_matrix_t
Require type does not satisfy is_rev_matrix.
void make_csr_adjoint(Result_ &&res, WMat_ &&w_mat, B_ &&b)
Helper function to construct the csr_adjoint struct.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
void check_range(const char *function, const char *name, int max, int index, int nested_level, const char *error_msg)
Check if specified index is within range.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
Eigen::Matrix< return_type_t< T1, T2 >, Eigen::Dynamic, 1 > csr_matrix_times_vector(int m, int n, const T1 &w, const std::vector< int > &v, const std::vector< int > &u, const T2 &b)
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
csr_adjoint(T1 &&res, T2 &&w_mat, T3 &&b)
void chain_internal(Result &&res, WMat &&w_mat, B &&b)
Overload for calculating adjoints of w_mat and b
vari for csr_matrix_times_vector