Automatic Differentiation
 
Loading...
Searching...
No Matches
mdivide_left_tri_low.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_FWD_FUN_MDIVIDE_LEFT_TRI_LOW_HPP
2#define STAN_MATH_FWD_FUN_MDIVIDE_LEFT_TRI_LOW_HPP
3
14namespace stan {
15namespace math {
16
17template <typename T1, typename T2,
18 require_all_eigen_vt<is_fvar, T1, T2>* = nullptr,
19 require_vt_same<T1, T2>* = nullptr>
20inline Eigen::Matrix<value_type_t<T1>, std::decay_t<T1>::RowsAtCompileTime,
21 std::decay_t<T2>::ColsAtCompileTime>
22mdivide_left_tri_low(T1&& A, T2&& b) {
23 constexpr int S1 = std::decay_t<T1>::RowsAtCompileTime;
24 constexpr int C2 = std::decay_t<T2>::ColsAtCompileTime;
25
26 check_square("mdivide_left_tri_low", "A", A);
27 check_multiplicable("mdivide_left_tri_low", "A", A, "b", b);
28 if (A.size() == 0) {
29 return {0, b.cols()};
30 }
31 decltype(auto) b_ref = to_ref(std::forward<T2>(b));
32 decltype(auto) A_ref = to_ref(std::forward<T1>(A));
33 auto inv_A_mult_b
34 = eval(mdivide_left_tri<Eigen::Lower>(A_ref.val(), b_ref.val()));
35 return to_fvar(
36 inv_A_mult_b,
37 subtract(mdivide_left_tri<Eigen::Lower>(A_ref.val(), b_ref.d()),
38 multiply(mdivide_left_tri<Eigen::Lower>(
39 A_ref.val(),
40 A_ref.d().template triangularView<Eigen::Lower>()),
41 inv_A_mult_b)));
42}
43
44template <typename T1, typename T2, require_eigen_t<T1>* = nullptr,
45 require_vt_same<double, T1>* = nullptr,
46 require_eigen_vt<is_fvar, T2>* = nullptr>
47inline Eigen::Matrix<value_type_t<T2>, std::decay_t<T1>::RowsAtCompileTime,
48 std::decay_t<T2>::ColsAtCompileTime>
49mdivide_left_tri_low(T1&& A, T2&& b) {
50 constexpr int S1 = std::decay_t<T1>::RowsAtCompileTime;
51 check_square("mdivide_left_tri_low", "A", A);
52 check_multiplicable("mdivide_left_tri_low", "A", A, "b", b);
53 if (A.size() == 0) {
54 return {0, b.cols()};
55 }
56 decltype(auto) A_ref = to_ref(std::forward<T1>(A));
57 decltype(auto) b_ref = to_ref(std::forward<T2>(b));
58 return to_fvar(mdivide_left_tri<Eigen::Lower>(A_ref, b_ref.val()),
59 mdivide_left_tri<Eigen::Lower>(A_ref, b_ref.d()));
60}
61
62template <typename T1, typename T2, require_eigen_vt<is_fvar, T1>* = nullptr,
63 require_eigen_vt<std::is_floating_point, T2>* = nullptr>
64inline Eigen::Matrix<value_type_t<T1>, std::decay_t<T1>::RowsAtCompileTime,
65 std::decay_t<T2>::ColsAtCompileTime>
66mdivide_left_tri_low(T1&& A, T2&& b) {
67 constexpr int S1 = std::decay_t<T1>::RowsAtCompileTime;
68 constexpr int C2 = std::decay_t<T2>::ColsAtCompileTime;
69 check_square("mdivide_left_tri_low", "A", A);
70 check_multiplicable("mdivide_left_tri_low", "A", A, "b", b);
71 if (A.size() == 0) {
72 return {0, b.cols()};
73 }
74 decltype(auto) A_ref = to_ref(std::forward<T1>(A));
75 auto inv_A_mult_b
76 = eval(mdivide_left_tri<Eigen::Lower>(A_ref.val(), std::forward<T2>(b)));
77 return to_fvar(
78 inv_A_mult_b,
79 -multiply(
80 mdivide_left_tri<Eigen::Lower>(
81 A_ref.val(), A_ref.d().template triangularView<Eigen::Lower>()),
82 inv_A_mult_b));
83}
84
85} // namespace math
86} // namespace stan
87#endif
subtraction_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > subtract(T_a &&a, T_b &&b)
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
auto multiply(Mat1 &&m1, Mat2 &&m2)
Return the product of the specified matrices.
Definition multiply.hpp:20
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
fvar< T > to_fvar(const T &x)
Definition to_fvar.hpp:15
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
Eigen::Matrix< value_type_t< T1 >, std::decay_t< T1 >::RowsAtCompileTime, std::decay_t< T2 >::ColsAtCompileTime > mdivide_left_tri_low(T1 &&A, T2 &&b)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...