Automatic Differentiation
 
Loading...
Searching...
No Matches
lub_constrain.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_LUB_CONSTRAIN_HPP
2#define STAN_MATH_OPENCL_PRIM_LUB_CONSTRAIN_HPP
3#ifdef STAN_OPENCL
4
9
10namespace stan {
11namespace math {
12
32template <typename T, typename L, typename U,
33 require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr,
34 require_all_kernel_expressions_t<L, U>* = nullptr>
35inline matrix_cl<double> lub_constrain(const T& x, const L& lb, const U& ub) {
36 auto diff = ub - lb;
37 auto lb_inf = lb == NEGATIVE_INFTY;
38 auto ub_inf = ub == INFTY;
39
40 auto check
41 = check_cl("lub_constrain (OpenCL)", "(ub - lb)", diff, "positive");
43
44 results(check, res) = expressions(
45 diff > 0.0, select(lb_inf, select(ub_inf, x, ub - exp(x)),
46 select(ub_inf, exp(x) + lb,
47 elt_multiply(diff, inv_logit(x)) + lb)));
48 return res;
49}
50
71template <typename T, typename L, typename U,
74inline auto lub_constrain(const T& x, const L& lb, const U& ub,
76 auto diff = ub - lb;
77 auto lb_inf = lb == NEGATIVE_INFTY;
78 auto ub_inf = ub == INFTY;
79 auto abs_x = fabs(x);
80 auto check
81 = check_cl("lub_constrain (OpenCL)", "(ub - lb)", diff, "positive");
82
84 matrix_cl<double> lp_inc;
85
86 auto lp_inc_expr = sum_2d(
87 select(lb_inf, select(ub_inf, 0.0, x),
88 select(ub_inf, x, log(diff) - abs_x - 2.0 * log1p_exp(-abs_x))));
89 auto res_expr = select(
90 lb_inf, select(ub_inf, x, ub - exp(x)),
91 select(ub_inf, exp(x) + lb, elt_multiply(diff, inv_logit(x)) + lb));
92
93 results(check, res, lp_inc) = expressions(diff > 0.0, res_expr, lp_inc_expr);
94
95 lp += sum(from_matrix_cl(lp_inc));
96
97 return res;
98}
99
100} // namespace math
101} // namespace stan
102#endif
103#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
elt_multiply_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_multiply(T_a &&a, T_b &&b)
auto sum_2d(T &&a)
Two dimensional sum - reduction of a kernel generator expression.
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
auto check_cl(const char *function, const char *var_name, T &&y, const char *must_be)
Constructs a check on opencl matrix or expression.
Definition check_cl.hpp:219
results_cl< T_results... > results(T_results &&... results)
Deduces types for constructing results_cl object.
require_all_t< is_kernel_expression< Types >... > require_all_kernel_expressions_t
Enables a template if all given types are are a valid kernel generator expressions.
expressions_cl< T_expressions... > expressions(T_expressions &&... expressions)
Deduces types for constructing expressions_cl object.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
auto from_matrix_cl(const T &src)
Copies the source matrix that is stored on the OpenCL device to the destination Eigen matrix.
Definition copy.hpp:61
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
static constexpr double NEGATIVE_INFTY
Negative infinity.
Definition constants.hpp:51
fvar< T > log1p_exp(const fvar< T > &x)
Definition log1p_exp.hpp:13
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
matrix_cl< double > lub_constrain(const T &x, const L &lb, const U &ub)
Return the lower and upper-bounded matrix derived by transforming the specified free matrix given the...
fvar< T > inv_logit(const fvar< T > &x)
Returns the inverse logit function applied to the argument.
Definition inv_logit.hpp:20
static constexpr double INFTY
Positive infinity.
Definition constants.hpp:46
fvar< T > fabs(const fvar< T > &x)
Definition fabs.hpp:15
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9