Automatic Differentiation
 
Loading...
Searching...
No Matches
mdivide_left.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_MDIVIDE_LEFT_HPP
2#define STAN_MATH_REV_FUN_MDIVIDE_LEFT_HPP
3
12#include <vector>
13
14namespace stan {
15namespace math {
16
29template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
30 require_any_st_var<T1, T2>* = nullptr>
31inline auto mdivide_left(const T1& A, const T2& B) {
32 using ret_val_type = plain_type_t<decltype(value_of(A) * value_of(B))>;
34
35 check_square("mdivide_left", "A", A);
36 check_multiplicable("mdivide_left", "A", A, "B", B);
37
38 if (A.size() == 0) {
39 return ret_type(ret_val_type(0, B.cols()));
40 }
41
45
46 auto hqr_A_ptr = make_chainable_ptr(arena_A.val().householderQr());
47 arena_t<ret_type> res = hqr_A_ptr->solve(arena_B.val());
48 reverse_pass_callback([arena_A, arena_B, hqr_A_ptr, res]() mutable {
50 = hqr_A_ptr->householderQ()
51 * hqr_A_ptr->matrixQR()
52 .template triangularView<Eigen::Upper>()
53 .transpose()
54 .solve(res.adj());
55 arena_A.adj() -= adjB * res.val_op().transpose();
56 arena_B.adj() += adjB;
57 });
58
59 return ret_type(res);
60 } else if (!is_constant<T2>::value) {
62
63 auto hqr_A_ptr = make_chainable_ptr(value_of(A).householderQr());
64 arena_t<ret_type> res = hqr_A_ptr->solve(arena_B.val());
65 reverse_pass_callback([arena_B, hqr_A_ptr, res]() mutable {
66 arena_B.adj() += hqr_A_ptr->householderQ()
67 * hqr_A_ptr->matrixQR()
68 .template triangularView<Eigen::Upper>()
69 .transpose()
70 .solve(res.adj());
71 });
72 return ret_type(res);
73 } else {
75
76 auto hqr_A_ptr = make_chainable_ptr(arena_A.val().householderQr());
77 arena_t<ret_type> res = hqr_A_ptr->solve(value_of(B));
78 reverse_pass_callback([arena_A, hqr_A_ptr, res]() mutable {
79 arena_A.adj() -= hqr_A_ptr->householderQ()
80 * hqr_A_ptr->matrixQR()
81 .template triangularView<Eigen::Upper>()
82 .transpose()
83 .solve(res.adj())
84 * res.val_op().transpose();
85 });
86 return ret_type(res);
87 }
88}
89
90} // namespace math
91} // namespace stan
92#endif
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
auto make_chainable_ptr(T &&obj)
Store the given object in a chainable_object so it is destructed only when the chainable stack memory...
Eigen::Matrix< value_type_t< T1 >, T1::RowsAtCompileTime, T2::ColsAtCompileTime > mdivide_left(const T1 &A, const T2 &b)
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, ReturnType > >, stan::math::promote_scalar_t< stan::math::var_value< double >, ReturnType > > promote_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...