Automatic Differentiation
 
Loading...
Searching...
No Matches
lb_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_LB_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_LB_CONSTRAIN_HPP
3
16#include <cmath>
17
18namespace stan {
19namespace math {
20
40template <typename T, typename L, require_all_stan_scalar_t<T, L>* = nullptr,
41 require_any_var_t<T, L>* = nullptr>
42inline auto lb_constrain(const T& x, const L& lb) {
43 const auto lb_val = value_of(lb);
44 if (unlikely(lb_val == NEGATIVE_INFTY)) {
45 return identity_constrain(x, lb);
46 } else {
47 if (!is_constant<T>::value && !is_constant<L>::value) {
48 auto exp_x = std::exp(value_of(x));
49 return make_callback_var(
50 exp_x + lb_val,
51 [arena_x = var(x), arena_lb = var(lb), exp_x](auto& vi) mutable {
52 arena_x.adj() += vi.adj() * exp_x;
53 arena_lb.adj() += vi.adj();
54 });
55 } else if (!is_constant<T>::value) {
56 auto exp_x = std::exp(value_of(x));
57 return make_callback_var(exp_x + lb_val,
58 [arena_x = var(x), exp_x](auto& vi) mutable {
59 arena_x.adj() += vi.adj() * exp_x;
60 });
61 } else {
62 return make_callback_var(std::exp(value_of(x)) + lb_val,
63 [arena_lb = var(lb)](auto& vi) mutable {
64 arena_lb.adj() += vi.adj();
65 });
66 }
67 }
68}
69
90template <typename T, typename L, require_all_stan_scalar_t<T, L>* = nullptr,
91 require_any_var_t<T, L>* = nullptr>
92inline auto lb_constrain(const T& x, const L& lb, var& lp) {
93 const auto lb_val = value_of(lb);
94 if (unlikely(lb_val == NEGATIVE_INFTY)) {
95 return identity_constrain(x, lb);
96 } else {
97 lp += value_of(x);
99 auto exp_x = std::exp(value_of(x));
100 return make_callback_var(
101 exp_x + lb_val,
102 [lp, arena_x = var(x), arena_lb = var(lb), exp_x](auto& vi) mutable {
103 arena_x.adj() += vi.adj() * exp_x + lp.adj();
104 arena_lb.adj() += vi.adj();
105 });
106 } else if (!is_constant<T>::value) {
107 auto exp_x = std::exp(value_of(x));
108 return make_callback_var(exp_x + lb_val,
109 [lp, arena_x = var(x), exp_x](auto& vi) mutable {
110 arena_x.adj() += vi.adj() * exp_x + lp.adj();
111 });
112 } else {
113 return make_callback_var(std::exp(value_of(x)) + lb_val,
114 [arena_lb = var(lb)](auto& vi) mutable {
115 arena_lb.adj() += vi.adj();
116 });
117 }
118 }
119}
120
132template <typename T, typename L, require_matrix_t<T>* = nullptr,
133 require_stan_scalar_t<L>* = nullptr,
134 require_any_st_var<T, L>* = nullptr>
135inline auto lb_constrain(const T& x, const L& lb) {
136 using ret_type = return_var_matrix_t<T, T, L>;
137 const auto lb_val = value_of(lb);
138 if (unlikely(lb_val == NEGATIVE_INFTY)) {
139 return ret_type(identity_constrain(x, lb));
140 } else {
143 auto exp_x = to_arena(arena_x.val().array().exp());
144 arena_t<ret_type> ret = exp_x + lb_val;
146 [arena_x, ret, exp_x, arena_lb = var(lb)]() mutable {
147 arena_x.adj().array() += ret.adj().array() * exp_x;
148 arena_lb.adj() += ret.adj().sum();
149 });
150 return ret_type(ret);
151 } else if (!is_constant<T>::value) {
153 auto exp_x = to_arena(arena_x.val().array().exp());
154 arena_t<ret_type> ret = exp_x + lb_val;
155 reverse_pass_callback([arena_x, ret, exp_x]() mutable {
156 arena_x.adj().array() += ret.adj().array() * exp_x;
157 });
158 return ret_type(ret);
159 } else {
160 arena_t<ret_type> ret = value_of(x).array().exp() + lb_val;
161 reverse_pass_callback([ret, arena_lb = var(lb)]() mutable {
162 arena_lb.adj() += ret.adj().sum();
163 });
164 return ret_type(ret);
165 }
166 }
167}
168
181template <typename T, typename L, require_matrix_t<T>* = nullptr,
182 require_stan_scalar_t<L>* = nullptr,
183 require_any_st_var<T, L>* = nullptr>
184inline auto lb_constrain(const T& x, const L& lb, return_type_t<T, L>& lp) {
185 using ret_type = return_var_matrix_t<T, T, L>;
186 const auto lb_val = value_of(lb);
187 if (unlikely(lb_val == NEGATIVE_INFTY)) {
188 return ret_type(identity_constrain(x, lb));
189 } else {
190 if (!is_constant<T>::value && !is_constant<L>::value) {
191 arena_t<promote_scalar_t<var, T>> arena_x = x;
192 auto exp_x = to_arena(arena_x.val().array().exp());
193 arena_t<ret_type> ret = exp_x + lb_val;
194 lp += arena_x.val().sum();
196 [arena_x, ret, lp, arena_lb = var(lb), exp_x]() mutable {
197 arena_x.adj().array() += ret.adj().array() * exp_x + lp.adj();
198 arena_lb.adj() += ret.adj().sum();
199 });
200 return ret_type(ret);
201 } else if (!is_constant<T>::value) {
202 arena_t<promote_scalar_t<var, T>> arena_x = x;
203 auto exp_x = to_arena(arena_x.val().array().exp());
204 arena_t<ret_type> ret = exp_x + lb_val;
205 lp += arena_x.val().sum();
206 reverse_pass_callback([arena_x, ret, exp_x, lp]() mutable {
207 arena_x.adj().array() += ret.adj().array() * exp_x + lp.adj();
208 });
209 return ret_type(ret);
210 } else {
211 const auto& x_ref = to_ref(x);
212 lp += sum(x_ref);
213 arena_t<ret_type> ret = value_of(x_ref).array().exp() + lb_val;
214 reverse_pass_callback([ret, arena_lb = var(lb)]() mutable {
215 arena_lb.adj() += ret.adj().sum();
216 });
217 return ret_type(ret);
218 }
219 }
220}
221
234template <typename T, typename L, require_all_matrix_t<T, L>* = nullptr,
235 require_any_st_var<T, L>* = nullptr>
236inline auto lb_constrain(const T& x, const L& lb) {
237 check_matching_dims("lb_constrain", "x", x, "lb", lb);
238 using ret_type = return_var_matrix_t<T, T, L>;
239 if (!is_constant<T>::value && !is_constant<L>::value) {
240 arena_t<promote_scalar_t<var, T>> arena_x = x;
241 arena_t<promote_scalar_t<var, L>> arena_lb = lb;
242 auto is_not_inf_lb = to_arena((arena_lb.val().array() != NEGATIVE_INFTY));
243 auto precomp_x_exp = to_arena((arena_x.val().array()).exp());
244 arena_t<ret_type> ret = (is_not_inf_lb)
245 .select(precomp_x_exp + arena_lb.val().array(),
246 arena_x.val().array());
247 reverse_pass_callback([arena_x, arena_lb, ret, is_not_inf_lb,
248 precomp_x_exp]() mutable {
249 arena_x.adj().array()
250 += (is_not_inf_lb)
251 .select(ret.adj().array() * precomp_x_exp, ret.adj().array());
252 arena_lb.adj().array() += (is_not_inf_lb).select(ret.adj().array(), 0);
253 });
254 return ret_type(ret);
255 } else if (!is_constant<T>::value) {
256 arena_t<promote_scalar_t<var, T>> arena_x = x;
257 auto lb_ref = to_ref(value_of(lb));
258 auto is_not_inf_lb = to_arena((lb_ref.array() != NEGATIVE_INFTY));
259 auto precomp_x_exp = to_arena((arena_x.val().array()).exp());
260 arena_t<ret_type> ret
261 = (is_not_inf_lb)
262 .select(precomp_x_exp + lb_ref.array(), arena_x.val().array());
263 reverse_pass_callback([arena_x, ret, is_not_inf_lb,
264 precomp_x_exp]() mutable {
265 arena_x.adj().array()
266 += (is_not_inf_lb)
267 .select(ret.adj().array() * precomp_x_exp, ret.adj().array());
268 });
269 return ret_type(ret);
270 } else {
271 arena_t<promote_scalar_t<var, L>> arena_lb = lb;
272 const auto x_ref = to_ref(value_of(x));
273 auto is_not_inf_lb = to_arena((arena_lb.val().array() != NEGATIVE_INFTY));
274 arena_t<ret_type> ret
275 = (is_not_inf_lb)
276 .select(x_ref.array().exp() + arena_lb.val().array(),
277 x_ref.array());
278 reverse_pass_callback([arena_lb, ret, is_not_inf_lb]() mutable {
279 arena_lb.adj().array() += (is_not_inf_lb).select(ret.adj().array(), 0);
280 });
281 return ret_type(ret);
282 }
283}
284
298template <typename T, typename L, require_all_matrix_t<T, L>* = nullptr,
299 require_any_st_var<T, L>* = nullptr>
300inline auto lb_constrain(const T& x, const L& lb, return_type_t<T, L>& lp) {
301 check_matching_dims("lb_constrain", "x", x, "lb", lb);
302 using ret_type = return_var_matrix_t<T, T, L>;
303 if (!is_constant<T>::value && !is_constant<L>::value) {
304 arena_t<promote_scalar_t<var, T>> arena_x = x;
305 arena_t<promote_scalar_t<var, L>> arena_lb = lb;
306 auto is_not_inf_lb = to_arena((arena_lb.val().array() != NEGATIVE_INFTY));
307 auto exp_x = to_arena(arena_x.val().array().exp());
308 arena_t<ret_type> ret
309 = (is_not_inf_lb)
310 .select(exp_x + arena_lb.val().array(), arena_x.val().array());
311 lp += (is_not_inf_lb).select(arena_x.val(), 0).sum();
313 [arena_x, arena_lb, ret, lp, exp_x, is_not_inf_lb]() mutable {
314 const auto lp_adj = lp.adj();
315 for (size_t j = 0; j < arena_x.cols(); ++j) {
316 for (size_t i = 0; i < arena_x.rows(); ++i) {
317 double ret_adj = ret.adj().coeff(i, j);
318 if (likely(is_not_inf_lb.coeff(i, j))) {
319 arena_x.adj().coeffRef(i, j)
320 += ret_adj * exp_x.coeff(i, j) + lp_adj;
321 arena_lb.adj().coeffRef(i, j) += ret_adj;
322 } else {
323 arena_x.adj().coeffRef(i, j) += ret_adj;
324 }
325 }
326 }
327 });
328 return ret_type(ret);
329 } else if (!is_constant<T>::value) {
330 arena_t<promote_scalar_t<var, T>> arena_x = x;
331 auto lb_val = value_of(lb).array();
332 auto is_not_inf_lb = to_arena((lb_val != NEGATIVE_INFTY));
333 auto exp_x = to_arena(arena_x.val().array().exp());
334 arena_t<ret_type> ret
335 = (is_not_inf_lb).select(exp_x + lb_val, arena_x.val().array());
336 lp += (is_not_inf_lb).select(arena_x.val(), 0).sum();
337 reverse_pass_callback([arena_x, ret, exp_x, lp, is_not_inf_lb]() mutable {
338 const auto lp_adj = lp.adj();
339 for (size_t j = 0; j < arena_x.cols(); ++j) {
340 for (size_t i = 0; i < arena_x.rows(); ++i) {
341 if (likely(is_not_inf_lb.coeff(i, j))) {
342 const double ret_adj = ret.adj().coeff(i, j);
343 arena_x.adj().coeffRef(i, j)
344 += ret_adj * exp_x.coeff(i, j) + lp_adj;
345 } else {
346 arena_x.adj().coeffRef(i, j) += ret.adj().coeff(i, j);
347 }
348 }
349 }
350 });
351 return ret_type(ret);
352 } else {
353 auto x_val = to_ref(value_of(x)).array();
354 arena_t<promote_scalar_t<var, L>> arena_lb = lb;
355 auto is_not_inf_lb = to_arena((arena_lb.val().array() != NEGATIVE_INFTY));
356 arena_t<ret_type> ret
357 = (is_not_inf_lb).select(x_val.exp() + arena_lb.val().array(), x_val);
358 lp += (is_not_inf_lb).select(x_val, 0).sum();
359 reverse_pass_callback([arena_lb, ret, is_not_inf_lb]() mutable {
360 arena_lb.adj().array()
361 += ret.adj().array() * is_not_inf_lb.template cast<double>();
362 });
363
364 return ret_type(ret);
365 }
366}
367
368} // namespace math
369} // namespace stan
370
371#endif
#define likely(x)
#define unlikely(x)
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
auto lb_constrain(T &&x, L &&lb)
Return the lower-bounded value for the specified unconstrained input and specified lower bound.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:23
var_value< double > var
Definition var.hpp:1187
auto identity_constrain(T &&x, Types &&...)
Returns the result of applying the identity constraint transform to the input.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...