Automatic Differentiation
 
Loading...
Searching...
No Matches
mdivide_left.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_MDIVIDE_LEFT_HPP
2#define STAN_MATH_REV_FUN_MDIVIDE_LEFT_HPP
3
11#include <vector>
12
13namespace stan {
14namespace math {
15
28template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
29 require_any_st_var<T1, T2>* = nullptr>
30inline auto mdivide_left(const T1& A, const T2& B) {
31 using ret_val_type = plain_type_t<decltype(value_of(A) * value_of(B))>;
33
34 check_square("mdivide_left", "A", A);
35 check_multiplicable("mdivide_left", "A", A, "B", B);
36
37 if (A.size() == 0) {
38 return ret_type(ret_val_type(0, B.cols()));
39 }
40
44
45 auto hqr_A_ptr = make_chainable_ptr(arena_A.val().householderQr());
46 arena_t<ret_type> res = hqr_A_ptr->solve(arena_B.val());
47 reverse_pass_callback([arena_A, arena_B, hqr_A_ptr, res]() mutable {
49 = hqr_A_ptr->householderQ()
50 * hqr_A_ptr->matrixQR()
51 .template triangularView<Eigen::Upper>()
52 .transpose()
53 .solve(res.adj());
54 arena_A.adj() -= adjB * res.val_op().transpose();
55 arena_B.adj() += adjB;
56 });
57
58 return ret_type(res);
59 } else if (!is_constant<T2>::value) {
61
62 auto hqr_A_ptr = make_chainable_ptr(value_of(A).householderQr());
63 arena_t<ret_type> res = hqr_A_ptr->solve(arena_B.val());
64 reverse_pass_callback([arena_B, hqr_A_ptr, res]() mutable {
65 arena_B.adj() += hqr_A_ptr->householderQ()
66 * hqr_A_ptr->matrixQR()
67 .template triangularView<Eigen::Upper>()
68 .transpose()
69 .solve(res.adj());
70 });
71 return ret_type(res);
72 } else {
74
75 auto hqr_A_ptr = make_chainable_ptr(arena_A.val().householderQr());
76 arena_t<ret_type> res = hqr_A_ptr->solve(value_of(B));
77 reverse_pass_callback([arena_A, hqr_A_ptr, res]() mutable {
78 arena_A.adj() -= hqr_A_ptr->householderQ()
79 * hqr_A_ptr->matrixQR()
80 .template triangularView<Eigen::Upper>()
81 .transpose()
82 .solve(res.adj())
83 * res.val_op().transpose();
84 });
85 return ret_type(res);
86 }
87}
88
89} // namespace math
90} // namespace stan
91#endif
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
auto make_chainable_ptr(T &&obj)
Store the given object in a chainable_object so it is destructed only when the chainable stack memory...
Eigen::Matrix< value_type_t< T1 >, T1::RowsAtCompileTime, T2::ColsAtCompileTime > mdivide_left(const T1 &A, const T2 &b)
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, ReturnType > >, stan::math::promote_scalar_t< stan::math::var_value< double >, ReturnType > > promote_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...