Automatic Differentiation
 
Loading...
Searching...
No Matches
csr_matrix_times_vector.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP
2#define STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP
3
10#include <vector>
11
12namespace stan {
13namespace math {
14
15namespace internal {
31template <typename Result_, typename WMat_, typename B_>
32struct csr_adjoint : public vari {
33 std::decay_t<Result_> res_;
34 std::decay_t<WMat_> w_mat_;
35 std::decay_t<B_> b_;
36
37 template <typename T1, typename T2, typename T3>
38 csr_adjoint(T1&& res, T2&& w_mat, T3&& b)
39 : vari(0.0),
40 res_(std::forward<T1>(res)),
41 w_mat_(std::forward<T2>(w_mat)),
42 b_(std::forward<T3>(b)) {}
43
45
58 template <typename Result, typename WMat, typename B,
60 require_rev_matrix_t<B>* = nullptr>
61 inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
62 w_mat.adj() += res.adj() * b.val().transpose();
63 b.adj() += w_mat.val().transpose() * res.adj();
64 }
65
78 template <typename Result, typename WMat, typename B,
81 inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
82 w_mat.adj() += res.adj() * b.transpose();
83 }
84
97 template <typename Result, typename WMat, typename B,
99 require_rev_matrix_t<B>* = nullptr>
100 inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
101 b.adj() += w_mat.transpose() * res.adj();
102 }
103};
104
119template <typename Result_, typename WMat_, typename B_>
120inline void make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) {
121 new csr_adjoint<std::decay_t<Result_>, std::decay_t<WMat_>, std::decay_t<B_>>(
122 std::forward<Result_>(res), std::forward<WMat_>(w_mat),
123 std::forward<B_>(b));
124 return;
125}
126} // namespace internal
127
158template <typename T1, typename T2, require_any_rev_matrix_t<T1, T2>* = nullptr>
159inline auto csr_matrix_times_vector(int m, int n, const T1& w,
160 const std::vector<int>& v,
161 const std::vector<int>& u, const T2& b) {
162 using sparse_val_mat
163 = Eigen::Map<const Eigen::SparseMatrix<double, Eigen::RowMajor>>;
164 using sparse_dense_mul_type
165 = decltype((std::declval<sparse_val_mat>() * value_of(b)).eval());
167
168 check_positive("csr_matrix_times_vector", "m", m);
169 check_positive("csr_matrix_times_vector", "n", n);
170 check_size_match("csr_matrix_times_vector", "n", n, "b", b.size());
171 check_size_match("csr_matrix_times_vector", "w", w.size(), "v", v.size());
172 check_size_match("csr_matrix_times_vector", "m", m, "u", u.size() - 1);
173 check_size_match("csr_matrix_times_vector", "u/z",
174 u[m - 1] + csr_u_to_z(u, m - 1) - 1, "v", v.size());
175 for (int i : v) {
176 check_range("csr_matrix_times_vector", "v[]", n, i);
177 }
178 std::vector<int, arena_allocator<int>> v_arena(v.size());
179 std::transform(v.begin(), v.end(), v_arena.begin(),
180 [](auto&& x) { return x - 1; });
181 std::vector<int, arena_allocator<int>> u_arena(u.size());
182 std::transform(u.begin(), u.end(), u_arena.begin(),
183 [](auto&& x) { return x - 1; });
184 using sparse_var_value_t
188 sparse_var_value_t w_mat_arena
189 = to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
190 arena_t<return_t> res = w_mat_arena.val() * value_of(b_arena);
191 stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena);
192 return return_t(res);
193 } else if (!is_constant<T2>::value) {
195 auto w_val_arena = to_arena(value_of(w));
196 sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
197 v_arena.data(), w_val_arena.data());
198 arena_t<return_t> res = w_val_mat * value_of(b_arena);
199 stan::math::internal::make_csr_adjoint(res, w_val_mat, b_arena);
200 return return_t(res);
201 } else {
202 sparse_var_value_t w_mat_arena
203 = to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
204 auto b_arena = to_arena(value_of(b));
205 arena_t<return_t> res = w_mat_arena.val() * b_arena;
206 stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena);
207 return return_t(res);
208 }
209}
210
211} // namespace math
212} // namespace stan
213
214#endif
int csr_u_to_z(const std::vector< int > &u, int i)
Return the z vector computed from the specified u vector at the index for the z vector.
require_t< is_rev_matrix< std::decay_t< T > > > require_rev_matrix_t
Require type satisfies is_rev_matrix.
require_not_t< is_rev_matrix< std::decay_t< T > > > require_not_rev_matrix_t
Require type does not satisfy is_rev_matrix.
void make_csr_adjoint(Result_ &&res, WMat_ &&w_mat, B_ &&b)
Helper function to construct the csr_adjoint struct.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
void check_range(const char *function, const char *name, int max, int index, int nested_level, const char *error_msg)
Check if specified index is within range.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
Eigen::Matrix< return_type_t< T1, T2 >, Eigen::Dynamic, 1 > csr_matrix_times_vector(int m, int n, const T1 &w, const std::vector< int > &v, const std::vector< int > &u, const T2 &b)
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
csr_adjoint(T1 &&res, T2 &&w_mat, T3 &&b)
void chain_internal(Result &&res, WMat &&w_mat, B &&b)
Overload for calculating adjoints of w_mat and b
vari for csr_matrix_times_vector