Automatic Differentiation
 
Loading...
Searching...
No Matches
categorical_logit_glm_lpmf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_CATEGORICAL_LOGIT_GLM_LPMF_HPP
2#define STAN_MATH_OPENCL_PRIM_CATEGORICAL_LOGIT_GLM_LPMF_HPP
3#ifdef STAN_OPENCL
4
13
21
22namespace stan {
23namespace math {
24
45template <bool propto, typename T_y, typename T_x, typename T_alpha,
46 typename T_beta,
48 T_beta>* = nullptr>
50 const T_y& y, const T_x& x, const T_alpha& alpha, const T_beta& beta) {
51 using T_partials_return = partials_return_t<T_x, T_alpha, T_beta>;
52 constexpr bool is_y_vector = !is_stan_scalar<T_y>::value;
53 using Eigen::Array;
54 using Eigen::Dynamic;
55 using Eigen::Matrix;
56
57 const size_t N_instances = x.rows();
58 const size_t N_attributes = x.cols();
59 const size_t N_classes = beta.cols();
60
61 static constexpr const char* function = "categorical_logit_glm_lpmf";
62 if (is_y_vector) {
63 check_size_match(function, "Rows of ", "x", N_instances, "size of ", "y",
64 math::size(y));
65 }
66 check_size_match(function, "Columns of ", "beta", N_classes, "size of ",
67 "alpha", math::size(alpha));
68 check_size_match(function, "Columns of ", "x", N_attributes, "Rows of",
69 "beta", beta.rows());
70
71 if (N_instances == 0 || N_classes <= 1) {
72 return 0;
73 }
75 return 0;
76 }
77
78 const auto& y_val = eval(value_of(y));
79 const auto& x_val = eval(value_of(x));
80 const auto& alpha_val = eval(value_of(alpha));
81 const auto& beta_val = eval(value_of(beta));
82
83 const auto& y_val_cl = to_matrix_cl(y_val);
84
85 matrix_cl<double> x_beta_cl = x_val * beta_val;
86
87 const int local_size
88 = opencl_kernels::categorical_logit_glm.get_option("LOCAL_SIZE_");
89 const int wgs = (N_instances + local_size - 1) / local_size;
90
91 bool need_alpha_derivative = !is_constant_all<T_alpha>::value;
92 bool need_beta_derivative = !is_constant_all<T_beta>::value;
93
94 matrix_cl<double> logp_cl(wgs, 1);
95 matrix_cl<double> exp_lin_cl(N_instances, N_classes);
96 matrix_cl<double> inv_sum_exp_lin_cl(N_instances, 1);
97 matrix_cl<double> neg_softmax_lin_cl(
98 need_alpha_derivative || need_beta_derivative ? N_instances : 0,
99 N_classes);
100 matrix_cl<double> alpha_derivative_cl(N_classes,
101 need_alpha_derivative ? wgs : 0);
102
103 try {
105 cl::NDRange(local_size * wgs), cl::NDRange(local_size), logp_cl,
106 exp_lin_cl, inv_sum_exp_lin_cl, neg_softmax_lin_cl, alpha_derivative_cl,
107 y_val_cl, x_beta_cl, alpha_val, N_instances, N_attributes, N_classes,
108 is_y_vector, need_alpha_derivative, need_beta_derivative);
109 } catch (const cl::Error& e) {
110 check_opencl_error(function, e);
111 }
112 T_partials_return logp = sum(from_matrix_cl(logp_cl));
113
114 if (!std::isfinite(logp)) {
115 results(check_cl(function, "Vector of dependent variables", y_val,
116 "between 0 and cols of beta"),
117 check_cl(function, "Intercept", alpha_val, "finite"))
118 = expressions(y_val >= 0 && y_val <= static_cast<int>(N_classes),
119 isfinite(alpha_val));
120 check_cl(function, "Design matrix", x_val, "finite") = isfinite(x_val);
121 check_cl(function, "Weight vector", beta_val, "finite")
122 = isfinite(beta_val);
123 }
124
125 auto ops_partials = make_partials_propagator(x, alpha, beta);
127 if (is_y_vector) {
128 partials<0>(ops_partials)
129 = indexing(beta_val, col_index(x.rows(), x.cols()),
131 - elt_multiply(exp_lin_cl * transpose(beta_val),
132 rowwise_broadcast(inv_sum_exp_lin_cl));
133 } else {
134 partials<0>(ops_partials)
135 = indexing(beta_val, col_index(x.rows(), x.cols()),
136 forward_as<int>(y_val) - 1)
137 - elt_multiply(exp_lin_cl * transpose(beta_val),
138 rowwise_broadcast(inv_sum_exp_lin_cl));
139 }
140 }
142 if (wgs == 1) {
143 partials<1>(ops_partials) = std::move(alpha_derivative_cl);
144 } else {
145 partials<1>(ops_partials) = rowwise_sum(alpha_derivative_cl);
146 }
147 }
148 if (!is_constant_all<T_beta>::value && N_attributes != 0) {
149 partials<2>(ops_partials) = transpose(x_val) * neg_softmax_lin_cl;
150 matrix_cl<double> temp(N_classes, local_size * N_attributes);
151 try {
153 cl::NDRange(local_size * N_attributes), cl::NDRange(local_size),
154 forward_as<arena_matrix_cl<double>>(partials<2>(ops_partials)), temp,
155 y_val_cl, x_val, N_instances, N_attributes, N_classes, is_y_vector);
156 } catch (const cl::Error& e) {
157 check_opencl_error(function, e);
158 }
159 }
160 return ops_partials.build(logp);
161}
162
163} // namespace math
164} // namespace stan
165
166#endif
167#endif
A variant of matrix_cl that schedules its destructor to be called, so it can be used on the AD stack.
Represents operation that determines column index.
Definition index.hpp:80
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
elt_multiply_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_multiply(T_a &&a, T_b &&b)
isfinite_< as_operation_cl_t< T > > isfinite(T &&a)
auto check_cl(const char *function, const char *var_name, T &&y, const char *must_be)
Constructs a check on opencl matrix or expression.
Definition check_cl.hpp:219
results_cl< T_results... > results(T_results &&... results)
Deduces types for constructing results_cl object.
auto transpose(Arg &&a)
Transposes a kernel generator expression.
auto rowwise_broadcast(T &&a)
Broadcast an expression in rowwise dimmension.
auto rowwise_sum(T &&a)
Rowwise sum reduction of a kernel generator expression.
auto indexing(T_mat &&mat, T_row_index &&row_index, T_col_index &&col_index)
Index a kernel generator expression using two expressions for indices.
Definition indexing.hpp:304
expressions_cl< T_expressions... > expressions(T_expressions &&... expressions)
Deduces types for constructing expressions_cl object.
const kernel_cl< out_buffer, out_buffer, out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int, int, int > categorical_logit_glm("categorical_logit_glm", {categorical_logit_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
See the docs for categorical_logit_glm() .
const kernel_cl< in_out_buffer, in_out_buffer, in_buffer, in_buffer, int, int, int, int > categorical_logit_glm_beta_derivative("categorical_logit_glm_beta_derivative", {categorical_logit_glm_beta_derivative_kernel_code})
See the docs for categorical_logit_glm_beta_derivative() .
matrix_cl< scalar_type_t< T > > to_matrix_cl(T &&src)
Copies the source Eigen matrix, std::vector or scalar to the destination matrix that is stored on the...
Definition copy.hpp:45
auto from_matrix_cl(const T &src)
Copies the source matrix that is stored on the OpenCL device to the destination Eigen matrix.
Definition copy.hpp:61
return_type_t< T_x, T_alpha, T_beta > categorical_logit_glm_lpmf(const T_y &y, const T_x &x, const T_alpha &alpha, const T_beta &beta)
Returns the log PMF of the Generalized Linear Model (GLM) with categorical distribution and logit (so...
require_all_t< is_prim_or_rev_kernel_expression< std::decay_t< Types > >... > require_all_prim_or_rev_kernel_expression_t
Require type satisfies is_prim_or_rev_kernel_expression.
size_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:18
T_actual && forward_as(T_actual &&a)
Assume which type we get.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
Definition beta.hpp:51
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Checks if decayed type is a var, fvar, or arithmetic.
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...