Automatic Differentiation
 
Loading...
Searching...
No Matches
ub_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CONSTRAINT_UB_CONSTRAIN_HPP
2#define STAN_MATH_REV_CONSTRAINT_UB_CONSTRAIN_HPP
3
8#include <cmath>
9
10namespace stan {
11namespace math {
12
29template <typename T, typename U, require_all_stan_scalar_t<T, U>* = nullptr,
30 require_any_var_t<T, U>* = nullptr>
31inline auto ub_constrain(const T& x, const U& ub) {
32 const auto ub_val = value_of(ub);
33 if (unlikely(ub_val == INFTY)) {
34 return identity_constrain(x, ub);
35 } else {
36 if (!is_constant<T>::value && !is_constant<U>::value) {
37 auto neg_exp_x = -std::exp(value_of(x));
38 return make_callback_var(
39 ub_val + neg_exp_x,
40 [arena_x = var(x), arena_ub = var(ub), neg_exp_x](auto& vi) mutable {
41 const auto vi_adj = vi.adj();
42 arena_x.adj() += vi_adj * neg_exp_x;
43 arena_ub.adj() += vi_adj;
44 });
45 } else if (!is_constant<T>::value) {
46 auto neg_exp_x = -std::exp(value_of(x));
47 return make_callback_var(ub_val + neg_exp_x,
48 [arena_x = var(x), neg_exp_x](auto& vi) mutable {
49 arena_x.adj() += vi.adj() * neg_exp_x;
50 });
51 } else {
52 return make_callback_var(ub_val - std::exp(value_of(x)),
53 [arena_ub = var(ub)](auto& vi) mutable {
54 arena_ub.adj() += vi.adj();
55 });
56 }
57 }
58}
59
80template <typename T, typename U, require_all_stan_scalar_t<T, U>* = nullptr,
81 require_any_var_t<T, U>* = nullptr>
82inline auto ub_constrain(const T& x, const U& ub, return_type_t<T, U>& lp) {
83 const auto ub_val = value_of(ub);
84 const bool is_ub_inf = ub_val == INFTY;
85 if (!is_constant<T>::value && !is_constant<U>::value) {
86 if (unlikely(is_ub_inf)) {
87 return identity_constrain(x, ub);
88 } else {
89 lp += value_of(x);
90 auto neg_exp_x = -std::exp(value_of(x));
91 return make_callback_var(value_of(ub) + neg_exp_x,
92 [lp, arena_x = var(x), arena_ub = var(ub),
93 neg_exp_x](auto& vi) mutable {
94 const auto vi_adj = vi.adj();
95 arena_x.adj() += vi_adj * neg_exp_x + lp.adj();
96 arena_ub.adj() += vi_adj;
97 });
98 }
99 } else if (!is_constant<T>::value) {
100 if (unlikely(is_ub_inf)) {
101 return identity_constrain(x, ub);
102 } else {
103 lp += value_of(x);
104 auto neg_exp_x = -std::exp(value_of(x));
105 return make_callback_var(
106 value_of(ub) + neg_exp_x,
107 [lp, arena_x = var(x), neg_exp_x](auto& vi) mutable {
108 arena_x.adj() += vi.adj() * neg_exp_x + lp.adj();
109 });
110 }
111 } else {
112 if (unlikely(is_ub_inf)) {
113 return identity_constrain(x, ub);
114 } else {
115 lp += value_of(x);
116 return make_callback_var(value_of(ub) - std::exp(value_of(x)),
117 [arena_ub = var(ub)](auto& vi) mutable {
118 arena_ub.adj() += vi.adj();
119 });
120 }
121 }
122}
123
135template <typename T, typename U, require_matrix_t<T>* = nullptr,
136 require_stan_scalar_t<U>* = nullptr,
137 require_any_st_var<T, U>* = nullptr>
138inline auto ub_constrain(const T& x, const U& ub) {
139 using ret_type = return_var_matrix_t<T, T, U>;
140 const auto ub_val = value_of(ub);
141 if (unlikely(ub_val == INFTY)) {
142 return ret_type(identity_constrain(x, ub));
143 } else {
144 if (!is_constant<T>::value && !is_constant<U>::value) {
145 arena_t<promote_scalar_t<var, T>> arena_x = x;
146 auto arena_neg_exp_x = to_arena(-arena_x.val().array().exp());
147 arena_t<ret_type> ret = ub_val + arena_neg_exp_x;
149 [arena_x, arena_neg_exp_x, ret, arena_ub = var(ub)]() mutable {
150 arena_x.adj().array() += ret.adj().array() * arena_neg_exp_x;
151 arena_ub.adj() += ret.adj().sum();
152 });
153 return ret_type(ret);
154 } else if (!is_constant<T>::value) {
155 arena_t<promote_scalar_t<var, T>> arena_x = x;
156 auto arena_neg_exp_x = to_arena(-arena_x.val().array().exp());
157 arena_t<ret_type> ret = ub_val + arena_neg_exp_x;
158 reverse_pass_callback([arena_x, arena_neg_exp_x, ret]() mutable {
159 arena_x.adj().array() += ret.adj().array() * arena_neg_exp_x;
160 });
161 return ret_type(ret);
162 } else {
163 arena_t<ret_type> ret = ub_val - value_of(x).array().exp();
164 reverse_pass_callback([ret, arena_ub = var(ub)]() mutable {
165 arena_ub.adj() += ret.adj().sum();
166 });
167 return ret_type(ret);
168 }
169 }
170}
171
184template <typename T, typename U, require_matrix_t<T>* = nullptr,
185 require_stan_scalar_t<U>* = nullptr,
186 require_any_st_var<T, U>* = nullptr>
187inline auto ub_constrain(const T& x, const U& ub, return_type_t<T, U>& lp) {
188 using ret_type = return_var_matrix_t<T, T, U>;
189 const auto ub_val = value_of(ub);
190 if (unlikely(ub_val == INFTY)) {
191 return ret_type(identity_constrain(x, ub));
192 } else {
195 auto arena_neg_exp_x = to_arena(-arena_x.val().array().exp());
196 arena_t<ret_type> ret = ub_val + arena_neg_exp_x;
197 lp += arena_x.val().sum();
198 reverse_pass_callback([arena_x, arena_neg_exp_x, ret, lp,
199 arena_ub = var(ub)]() mutable {
200 arena_x.adj().array() += ret.adj().array() * arena_neg_exp_x + lp.adj();
201 arena_ub.adj() += ret.adj().sum();
202 });
203 return ret_type(ret);
204 } else if (!is_constant<T>::value) {
206 auto arena_neg_exp_x = to_arena(-arena_x.val().array().exp());
207 arena_t<ret_type> ret = ub_val + arena_neg_exp_x;
208 lp += arena_x.val().sum();
209 reverse_pass_callback([arena_x, arena_neg_exp_x, ret, lp]() mutable {
210 arena_x.adj().array() += ret.adj().array() * arena_neg_exp_x + lp.adj();
211 });
212 return ret_type(ret);
213 } else {
214 auto x_ref = to_ref(value_of(x));
215 arena_t<ret_type> ret = ub_val - x_ref.array().exp();
216 lp += x_ref.sum();
217 reverse_pass_callback([ret, arena_ub = var(ub)]() mutable {
218 arena_ub.adj() += ret.adj().sum();
219 });
220 return ret_type(ret);
221 }
222 }
223}
224
237template <typename T, typename U, require_all_matrix_t<T, U>* = nullptr,
238 require_any_st_var<T, U>* = nullptr>
239inline auto ub_constrain(const T& x, const U& ub) {
240 check_matching_dims("ub_constrain", "x", x, "ub", ub);
241 using ret_type = return_var_matrix_t<T, T, U>;
245 auto ub_val = to_ref(arena_ub.val());
246 auto is_not_inf_ub = to_arena((ub_val.array() != INFTY));
247 auto neg_exp_x = to_arena(-arena_x.val().array().exp());
249 = (is_not_inf_ub)
250 .select(ub_val.array() + neg_exp_x, arena_x.val().array());
251 reverse_pass_callback([arena_x, neg_exp_x, arena_ub, ret,
252 is_not_inf_ub]() mutable {
253 arena_x.adj().array()
254 += (is_not_inf_ub)
255 .select(ret.adj().array() * neg_exp_x, ret.adj().array());
256 arena_ub.adj().array() += (is_not_inf_ub).select(ret.adj().array(), 0.0);
257 });
258 return ret_type(ret);
259 } else if (!is_constant<T>::value) {
260 arena_t<promote_scalar_t<var, T>> arena_x = x;
261 auto ub_val = to_ref(value_of(ub));
262 auto is_not_inf_ub = to_arena((ub_val.array() != INFTY));
263 auto neg_exp_x = to_arena(-arena_x.val().array().exp());
264 arena_t<ret_type> ret
265 = (is_not_inf_ub)
266 .select(ub_val.array() + neg_exp_x, arena_x.val().array());
267 reverse_pass_callback([arena_x, neg_exp_x, ret, is_not_inf_ub]() mutable {
268 arena_x.adj().array()
269 += (is_not_inf_ub)
270 .select(ret.adj().array() * neg_exp_x, ret.adj().array());
271 });
272 return ret_type(ret);
273 } else {
274 arena_t<promote_scalar_t<var, U>> arena_ub = to_arena(ub);
275 auto is_not_inf_ub
276 = to_arena((arena_ub.val().array() != INFTY).template cast<double>());
277 auto&& x_ref = to_ref(value_of(x).array());
278 arena_t<ret_type> ret
279 = (is_not_inf_ub).select(arena_ub.val().array() - x_ref.exp(), x_ref);
280 reverse_pass_callback([arena_ub, ret, is_not_inf_ub]() mutable {
281 arena_ub.adj().array() += ret.adj().array() * is_not_inf_ub;
282 });
283 return ret_type(ret);
284 }
285}
286
300template <typename T, typename U, require_all_matrix_t<T, U>* = nullptr,
301 require_any_st_var<T, U>* = nullptr>
302inline auto ub_constrain(const T& x, const U& ub, return_type_t<T, U>& lp) {
303 check_matching_dims("ub_constrain", "x", x, "ub", ub);
304 using ret_type = return_var_matrix_t<T, T, U>;
305 if (!is_constant<T>::value && !is_constant<U>::value) {
306 arena_t<promote_scalar_t<var, T>> arena_x = x;
307 arena_t<promote_scalar_t<var, U>> arena_ub = ub;
308 auto ub_val = to_ref(arena_ub.val());
309 auto is_not_inf_ub = to_arena((ub_val.array() != INFTY));
310 auto neg_exp_x = to_arena(-arena_x.val().array().exp());
311 arena_t<ret_type> ret
312 = (is_not_inf_ub)
313 .select(ub_val.array() + neg_exp_x, arena_x.val().array());
314 lp += (is_not_inf_ub).select(arena_x.val().array(), 0).sum();
315 reverse_pass_callback([arena_x, neg_exp_x, arena_ub, ret, lp,
316 is_not_inf_ub]() mutable {
317 arena_x.adj().array()
318 += (is_not_inf_ub)
319 .select(ret.adj().array() * neg_exp_x + lp.adj(),
320 ret.adj().array());
321 arena_ub.adj().array() += (is_not_inf_ub).select(ret.adj().array(), 0.0);
322 });
323 return ret_type(ret);
324 } else if (!is_constant<T>::value) {
325 arena_t<promote_scalar_t<var, T>> arena_x = x;
326 auto ub_val = to_ref(value_of(ub));
327 auto is_not_inf_ub = to_arena((ub_val.array() != INFTY));
328 auto neg_exp_x = to_arena(-arena_x.val().array().exp());
329 arena_t<ret_type> ret
330 = (is_not_inf_ub)
331 .select(ub_val.array() + neg_exp_x, arena_x.val().array());
332 lp += (is_not_inf_ub).select(arena_x.val().array(), 0).sum();
334 [arena_x, neg_exp_x, ret, lp, is_not_inf_ub]() mutable {
335 arena_x.adj().array()
336 += (is_not_inf_ub)
337 .select(ret.adj().array() * neg_exp_x + lp.adj(),
338 ret.adj().array());
339 });
340 return ret_type(ret);
341 } else {
342 arena_t<promote_scalar_t<var, U>> arena_ub = to_arena(ub);
343 auto is_not_inf_ub
344 = to_arena((arena_ub.val().array() != INFTY).template cast<double>());
345 auto&& x_ref = to_ref(value_of(x).array());
346 arena_t<ret_type> ret
347 = (is_not_inf_ub).select(arena_ub.val().array() - x_ref.exp(), x_ref);
348 lp += (is_not_inf_ub).select(x_ref, 0).sum();
349 reverse_pass_callback([arena_ub, ret, is_not_inf_ub]() mutable {
350 arena_ub.adj().array() += ret.adj().array() * is_not_inf_ub;
351 });
352 return ret_type(ret);
353 }
354}
355
356} // namespace math
357} // namespace stan
358
359#endif
#define unlikely(x)
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
auto ub_constrain(T &&x, U &&ub)
Return the upper-bounded value for the specified unconstrained matrix and upper bound.
var_value< double > var
Definition var.hpp:1187
auto identity_constrain(T &&x, Types &&...)
Returns the result of applying the identity constraint transform to the input.
static constexpr double INFTY
Positive infinity.
Definition constants.hpp:46
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...