1#ifndef STAN_MATH_REV_FUN_LMULTPLY_HPP
2#define STAN_MATH_REV_FUN_LMULTPLY_HPP
102template <
typename T1,
typename T2, require_all_matrix_t<T1, T2>* =
nullptr,
103 require_any_var_matrix_t<T1, T2>* =
nullptr>
104inline auto lmultiply(
const T1& a,
const T2& b) {
112 [arena_a, arena_b](
const auto& res)
mutable {
113 arena_a.adj().array()
114 += res.adj().array() * arena_b.val().array().log();
115 arena_b.adj().array() += res.adj().array() * arena_a.val().array()
116 / arena_b.val().array();
118 }
else if (!is_constant<T1>::value) {
119 arena_t<promote_scalar_t<var, T1>> arena_a = a;
120 arena_t<promote_scalar_t<double, T2>> arena_b =
value_of(b);
123 [arena_a, arena_b](
const auto& res)
mutable {
124 arena_a.adj().array()
126 * arena_b.val().array().log();
129 arena_t<promote_scalar_t<double, T1>> arena_a =
value_of(a);
130 arena_t<promote_scalar_t<var, T2>> arena_b = b;
133 [arena_a, arena_b](
const auto& res)
mutable {
134 arena_b.adj().array() += res.adj().array()
135 * arena_a.val().array()
136 / arena_b.val().array();
150template <
typename T1,
typename T2, require_var_matrix_t<T1>* =
nullptr,
151 require_stan_scalar_t<T2>* =
nullptr>
152inline auto lmultiply(
const T1& a,
const T2& b) {
155 if (!is_constant<T1>::value && !is_constant<T2>::value) {
156 arena_t<promote_scalar_t<var, T1>> arena_a = a;
161 [arena_a, arena_b](
const auto& res)
mutable {
162 arena_a.adj().array() += res.adj().array() * log(arena_b.val());
163 arena_b.adj() += (res.adj().array() * arena_a.val().array()).sum()
166 }
else if (!is_constant<T1>::value) {
167 arena_t<promote_scalar_t<var, T1>> arena_a = a;
170 [arena_a, b](
const auto& res)
mutable {
171 arena_a.adj().array()
175 arena_t<promote_scalar_t<double, T1>> arena_a =
value_of(a);
180 [arena_a, arena_b](
const auto& res)
mutable {
182 += (res.adj().array() * arena_a.array()).sum() / arena_b.val();
196template <
typename T1,
typename T2, require_stan_scalar_t<T1>* =
nullptr,
197 require_var_matrix_t<T2>* =
nullptr>
198inline auto lmultiply(
const T1& a,
const T2& b) {
199 if (!is_constant<T1>::value && !is_constant<T2>::value) {
201 arena_t<promote_scalar_t<var, T2>> arena_b = b;
205 [arena_a, arena_b](
const auto& res)
mutable {
207 += (res.adj().array() * arena_b.val().array().log()).sum();
208 arena_b.adj().array()
209 += arena_a.val() * res.adj().array() / arena_b.val().array();
211 }
else if (!is_constant<T1>::value) {
213 arena_t<promote_scalar_t<double, T2>> arena_b =
value_of(b);
217 [arena_a, arena_b](
const auto& res)
mutable {
219 += (res.adj().array() * arena_b.val().array().log()).sum();
222 arena_t<promote_scalar_t<var, T2>> arena_b = b;
225 [a, arena_b](
const auto& res)
mutable {
226 arena_b.adj().array() += value_of(a)
228 / arena_b.val().array();
lmultiply_dv_vari(double a, vari *bvi)
lmultiply_vd_vari(vari *avi, double b)
lmultiply_vv_vari(vari *avi, vari *bvi)
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > log(const fvar< T > &x)
fvar< T > lmultiply(const fvar< T > &x1, const fvar< T > &x2)
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...