Automatic Differentiation
 
Loading...
Searching...
No Matches
rows_dot_product.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_ROWS_DOT_PRODUCT_HPP
2#define STAN_MATH_REV_FUN_ROWS_DOT_PRODUCT_HPP
3
11#include <type_traits>
12
13namespace stan {
14namespace math {
15
30template <typename Mat1, typename Mat2,
31 require_all_eigen_t<Mat1, Mat2>* = nullptr,
32 require_any_eigen_vt<is_var, Mat1, Mat2>* = nullptr>
33inline Eigen::Matrix<var, Mat1::RowsAtCompileTime, 1> rows_dot_product(
34 const Mat1& v1, const Mat2& v2) {
35 check_matching_sizes("dot_product", "v1", v1, "v2", v2);
36 Eigen::Matrix<var, Mat1::RowsAtCompileTime, 1> ret(v1.rows(), 1);
37 for (size_type j = 0; j < v1.rows(); ++j) {
38 ret.coeffRef(j) = dot_product(v1.row(j), v2.row(j));
39 }
40 return ret;
41}
42
60template <typename Mat1, typename Mat2,
63inline auto rows_dot_product(const Mat1& v1, const Mat2& v2) {
64 check_matching_sizes("rows_dot_product", "v1", v1, "v2", v2);
65
66 using return_t = return_var_matrix_t<
67 decltype((v1.val().array() * v2.val().array()).rowwise().sum().matrix()),
68 Mat1, Mat2>;
69
73
74 return_t res
75 = (arena_v1.val().array() * arena_v2.val().array()).rowwise().sum();
76
77 reverse_pass_callback([arena_v1, arena_v2, res]() mutable {
79 arena_v1.adj().noalias() += res.adj().asDiagonal() * arena_v2.val();
80 } else {
81 arena_v1.adj() += res.adj().asDiagonal() * arena_v2.val();
82 }
84 arena_v2.adj().noalias() += res.adj().asDiagonal() * arena_v1.val();
85 } else {
86 arena_v2.adj() += res.adj().asDiagonal() * arena_v1.val();
87 }
88 });
89
90 return res;
91 } else if (!is_constant<Mat2>::value) {
94
95 return_t res = (arena_v1.array() * arena_v2.val().array()).rowwise().sum();
96
97 reverse_pass_callback([arena_v1, arena_v2, res]() mutable {
99 arena_v2.adj().noalias() += res.adj().asDiagonal() * arena_v1;
100 } else {
101 arena_v2.adj() += res.adj().asDiagonal() * arena_v1;
102 }
103 });
104
105 return res;
106 } else {
109
110 return_t res = (arena_v1.val().array() * arena_v2.array()).rowwise().sum();
111
112 reverse_pass_callback([arena_v1, arena_v2, res]() mutable {
114 arena_v1.adj().noalias() += res.adj().asDiagonal() * arena_v2;
115 } else {
116 arena_v1.adj() += res.adj().asDiagonal() * arena_v2;
117 }
118 });
119
120 return res;
121 }
122}
123
124} // namespace math
125} // namespace stan
126#endif
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
auto rows_dot_product(T_a &&a, T_b &&b)
Returns the dot product of rows of the specified matrices.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_matching_sizes(const char *function, const char *name1, const T_y1 &y1, const char *name2, const T_y2 &y2)
Check if two structures at the same size.
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic >::Index size_type
Type for sizes and indexes in an Eigen matrix with double elements.
Definition typedefs.hpp:11
auto dot_product(const T_a &a, const T_b &b)
Returns the dot product of the specified vectors.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Check if a type is a var_value whose value_type is derived from Eigen::EigenBase