1#ifndef STAN_MATH_REV_CONSTRAINT_LB_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_LB_CONSTRAIN_HPP
40template <
typename T,
typename L, require_all_stan_scalar_t<T, L>* =
nullptr,
41 require_any_var_t<T, L>* =
nullptr>
47 if (!is_constant<T>::value && !is_constant<L>::value) {
51 [arena_x =
var(x), arena_lb =
var(lb), exp_x](
auto& vi)
mutable {
52 arena_x.adj() += vi.adj() * exp_x;
53 arena_lb.adj() += vi.adj();
55 }
else if (!is_constant<T>::value) {
58 [arena_x =
var(x), exp_x](
auto& vi)
mutable {
59 arena_x.adj() += vi.adj() * exp_x;
63 [arena_lb =
var(lb)](
auto& vi)
mutable {
64 arena_lb.adj() += vi.adj();
90template <
typename T,
typename L, require_all_stan_scalar_t<T, L>* =
nullptr,
91 require_any_var_t<T, L>* =
nullptr>
102 [lp, arena_x =
var(x), arena_lb =
var(lb), exp_x](
auto& vi)
mutable {
103 arena_x.adj() += vi.adj() * exp_x + lp.adj();
104 arena_lb.adj() += vi.adj();
109 [lp, arena_x =
var(x), exp_x](
auto& vi)
mutable {
110 arena_x.adj() += vi.adj() * exp_x + lp.adj();
114 [arena_lb =
var(lb)](
auto& vi)
mutable {
115 arena_lb.adj() += vi.adj();
132template <
typename T,
typename L, require_matrix_t<T>* =
nullptr,
133 require_stan_scalar_t<L>* =
nullptr,
134 require_any_st_var<T, L>* =
nullptr>
143 auto exp_x =
to_arena(arena_x.val().array().exp());
146 [arena_x, ret, exp_x, arena_lb =
var(lb)]()
mutable {
147 arena_x.adj().array() += ret.adj().array() * exp_x;
148 arena_lb.adj() += ret.adj().sum();
150 return ret_type(ret);
153 auto exp_x =
to_arena(arena_x.val().array().exp());
156 arena_x.adj().array() += ret.adj().array() * exp_x;
158 return ret_type(ret);
162 arena_lb.adj() += ret.adj().sum();
164 return ret_type(ret);
181template <
typename T,
typename L, require_matrix_t<T>* =
nullptr,
182 require_stan_scalar_t<L>* =
nullptr,
183 require_any_st_var<T, L>* =
nullptr>
190 if (!is_constant<T>::value && !is_constant<L>::value) {
191 arena_t<promote_scalar_t<var, T>> arena_x = x;
192 auto exp_x =
to_arena(arena_x.val().array().exp());
193 arena_t<ret_type> ret = exp_x + lb_val;
194 lp += arena_x.val().sum();
196 [arena_x, ret, lp, arena_lb =
var(lb), exp_x]()
mutable {
197 arena_x.adj().array() += ret.adj().array() * exp_x + lp.adj();
198 arena_lb.adj() += ret.adj().sum();
200 return ret_type(ret);
201 }
else if (!is_constant<T>::value) {
202 arena_t<promote_scalar_t<var, T>> arena_x = x;
203 auto exp_x =
to_arena(arena_x.val().array().exp());
204 arena_t<ret_type> ret = exp_x + lb_val;
205 lp += arena_x.val().sum();
207 arena_x.adj().array() += ret.adj().array() * exp_x + lp.adj();
209 return ret_type(ret);
211 const auto& x_ref =
to_ref(x);
213 arena_t<ret_type> ret =
value_of(x_ref).array().exp() + lb_val;
215 arena_lb.adj() += ret.adj().sum();
217 return ret_type(ret);
234template <
typename T,
typename L, require_all_matrix_t<T, L>* =
nullptr,
235 require_any_st_var<T, L>* =
nullptr>
238 using ret_type = return_var_matrix_t<T, T, L>;
239 if (!is_constant<T>::value && !is_constant<L>::value) {
240 arena_t<promote_scalar_t<var, T>> arena_x = x;
241 arena_t<promote_scalar_t<var, L>> arena_lb = lb;
243 auto precomp_x_exp =
to_arena((arena_x.val().array()).exp());
244 arena_t<ret_type> ret = (is_not_inf_lb)
245 .
select(precomp_x_exp + arena_lb.val().array(),
246 arena_x.val().array());
248 precomp_x_exp]()
mutable {
249 arena_x.adj().array()
251 .
select(ret.adj().array() * precomp_x_exp, ret.adj().array());
252 arena_lb.adj().array() += (is_not_inf_lb).
select(ret.adj().array(), 0);
254 return ret_type(ret);
255 }
else if (!is_constant<T>::value) {
256 arena_t<promote_scalar_t<var, T>> arena_x = x;
259 auto precomp_x_exp =
to_arena((arena_x.val().array()).exp());
260 arena_t<ret_type> ret
262 .
select(precomp_x_exp + lb_ref.array(), arena_x.val().array());
264 precomp_x_exp]()
mutable {
265 arena_x.adj().array()
267 .
select(ret.adj().array() * precomp_x_exp, ret.adj().array());
269 return ret_type(ret);
271 arena_t<promote_scalar_t<var, L>> arena_lb = lb;
274 arena_t<ret_type> ret
276 .
select(x_ref.array().exp() + arena_lb.val().array(),
279 arena_lb.adj().array() += (is_not_inf_lb).
select(ret.adj().array(), 0);
281 return ret_type(ret);
298template <
typename T,
typename L, require_all_matrix_t<T, L>* =
nullptr,
299 require_any_st_var<T, L>* =
nullptr>
300inline auto lb_constrain(
const T& x,
const L& lb, return_type_t<T, L>& lp) {
302 using ret_type = return_var_matrix_t<T, T, L>;
303 if (!is_constant<T>::value && !is_constant<L>::value) {
304 arena_t<promote_scalar_t<var, T>> arena_x = x;
305 arena_t<promote_scalar_t<var, L>> arena_lb = lb;
307 auto exp_x =
to_arena(arena_x.val().array().exp());
308 arena_t<ret_type> ret
310 .
select(exp_x + arena_lb.val().array(), arena_x.val().array());
311 lp += (is_not_inf_lb).
select(arena_x.val(), 0).
sum();
313 [arena_x, arena_lb, ret, lp, exp_x, is_not_inf_lb]()
mutable {
314 const auto lp_adj = lp.adj();
315 for (
size_t j = 0; j < arena_x.cols(); ++j) {
316 for (
size_t i = 0; i < arena_x.rows(); ++i) {
317 double ret_adj = ret.adj().coeff(i, j);
318 if (
likely(is_not_inf_lb.coeff(i, j))) {
319 arena_x.adj().coeffRef(i, j)
320 += ret_adj * exp_x.coeff(i, j) + lp_adj;
321 arena_lb.adj().coeffRef(i, j) += ret_adj;
323 arena_x.adj().coeffRef(i, j) += ret_adj;
328 return ret_type(ret);
329 }
else if (!is_constant<T>::value) {
330 arena_t<promote_scalar_t<var, T>> arena_x = x;
333 auto exp_x =
to_arena(arena_x.val().array().exp());
334 arena_t<ret_type> ret
335 = (is_not_inf_lb).
select(exp_x + lb_val, arena_x.val().array());
336 lp += (is_not_inf_lb).
select(arena_x.val(), 0).
sum();
338 const auto lp_adj = lp.adj();
339 for (
size_t j = 0; j < arena_x.cols(); ++j) {
340 for (
size_t i = 0; i < arena_x.rows(); ++i) {
341 if (
likely(is_not_inf_lb.coeff(i, j))) {
342 const double ret_adj = ret.adj().coeff(i, j);
343 arena_x.adj().coeffRef(i, j)
344 += ret_adj * exp_x.coeff(i, j) + lp_adj;
346 arena_x.adj().coeffRef(i, j) += ret.adj().coeff(i, j);
351 return ret_type(ret);
354 arena_t<promote_scalar_t<var, L>> arena_lb = lb;
356 arena_t<ret_type> ret
357 = (is_not_inf_lb).
select(x_val.exp() + arena_lb.val().array(), x_val);
358 lp += (is_not_inf_lb).
select(x_val, 0).sum();
360 arena_lb.adj().array()
361 += ret.adj().array() * is_not_inf_lb.template cast<double>();
364 return ret_type(ret);
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
static constexpr double NEGATIVE_INFTY
Negative infinity.
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
auto lb_constrain(T &&x, L &&lb)
Return the lower-bounded value for the specified unconstrained input and specified lower bound.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
auto identity_constrain(T &&x, Types &&...)
Returns the result of applying the identity constraint transform to the input.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...