1#ifndef STAN_MATH_OPENCL_PRIM_CATEGORICAL_LOGIT_GLM_LPMF_HPP
2#define STAN_MATH_OPENCL_PRIM_CATEGORICAL_LOGIT_GLM_LPMF_HPP
45template <
bool propto,
typename T_y,
typename T_x,
typename T_alpha,
50 const T_y& y,
const T_x& x,
const T_alpha& alpha,
const T_beta&
beta) {
57 const size_t N_instances = x.rows();
58 const size_t N_attributes = x.cols();
59 const size_t N_classes =
beta.cols();
61 static constexpr const char* function =
"categorical_logit_glm_lpmf";
71 if (N_instances == 0 || N_classes <= 1) {
89 const int wgs = (N_instances + local_size - 1) / local_size;
98 need_alpha_derivative || need_beta_derivative ? N_instances : 0,
101 need_alpha_derivative ? wgs : 0);
105 cl::NDRange(local_size * wgs), cl::NDRange(local_size), logp_cl,
106 exp_lin_cl, inv_sum_exp_lin_cl, neg_softmax_lin_cl, alpha_derivative_cl,
107 y_val_cl, x_beta_cl, alpha_val, N_instances, N_attributes, N_classes,
108 is_y_vector, need_alpha_derivative, need_beta_derivative);
109 }
catch (
const cl::Error&
e) {
114 if (!std::isfinite(logp)) {
116 "between 0 and cols of beta"),
117 check_cl(function,
"Intercept", alpha_val,
"finite"))
118 =
expressions(y_val >= 0 && y_val <=
static_cast<int>(N_classes),
121 check_cl(function,
"Weight vector", beta_val,
"finite")
128 partials<0>(ops_partials)
134 partials<0>(ops_partials)
136 forward_as<int>(y_val) - 1)
143 partials<1>(ops_partials) = std::move(alpha_derivative_cl);
145 partials<1>(ops_partials) =
rowwise_sum(alpha_derivative_cl);
149 partials<2>(ops_partials) =
transpose(x_val) * neg_softmax_lin_cl;
153 cl::NDRange(local_size * N_attributes), cl::NDRange(local_size),
155 y_val_cl, x_val, N_instances, N_attributes, N_classes, is_y_vector);
156 }
catch (
const cl::Error&
e) {
160 return ops_partials.build(logp);
A variant of matrix_cl that schedules its destructor to be called, so it can be used on the AD stack.
Represents operation that determines column index.
Represents an arithmetic matrix on the OpenCL device.
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
elt_multiply_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_multiply(T_a &&a, T_b &&b)
isfinite_< as_operation_cl_t< T > > isfinite(T &&a)
auto check_cl(const char *function, const char *var_name, T &&y, const char *must_be)
Constructs a check on opencl matrix or expression.
results_cl< T_results... > results(T_results &&... results)
Deduces types for constructing results_cl object.
auto transpose(Arg &&a)
Transposes a kernel generator expression.
auto rowwise_broadcast(T &&a)
Broadcast an expression in rowwise dimmension.
auto rowwise_sum(T &&a)
Rowwise sum reduction of a kernel generator expression.
auto indexing(T_mat &&mat, T_row_index &&row_index, T_col_index &&col_index)
Index a kernel generator expression using two expressions for indices.
expressions_cl< T_expressions... > expressions(T_expressions &&... expressions)
Deduces types for constructing expressions_cl object.
const kernel_cl< out_buffer, out_buffer, out_buffer, out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, int, int, int, int, int, int > categorical_logit_glm("categorical_logit_glm", {categorical_logit_glm_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
See the docs for categorical_logit_glm() .
const kernel_cl< in_out_buffer, in_out_buffer, in_buffer, in_buffer, int, int, int, int > categorical_logit_glm_beta_derivative("categorical_logit_glm_beta_derivative", {categorical_logit_glm_beta_derivative_kernel_code})
See the docs for categorical_logit_glm_beta_derivative() .
matrix_cl< scalar_type_t< T > > to_matrix_cl(T &&src)
Copies the source Eigen matrix, std::vector or scalar to the destination matrix that is stored on the...
auto from_matrix_cl(const T &src)
Copies the source matrix that is stored on the OpenCL device to the destination Eigen matrix.
return_type_t< T_x, T_alpha, T_beta > categorical_logit_glm_lpmf(const T_y &y, const T_x &x, const T_alpha &alpha, const T_beta &beta)
Returns the log PMF of the Generalized Linear Model (GLM) with categorical distribution and logit (so...
require_all_t< is_prim_or_rev_kernel_expression< std::decay_t< Types > >... > require_all_prim_or_rev_kernel_expression_t
Require type satisfies is_prim_or_rev_kernel_expression.
T_actual && forward_as(T_actual &&a)
Assume which type we get.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
int64_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
static constexpr double e()
Return the base of the natural logarithm.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Checks if decayed type is a var, fvar, or arithmetic.
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...