1#ifndef STAN_MATH_PRIM_PROB_CATEGORICAL_LOGIT_GLM_LPMF_HPP
2#define STAN_MATH_PRIM_PROB_CATEGORICAL_LOGIT_GLM_LPMF_HPP
43template <
bool propto,
typename T_y,
typename T_x,
typename T_alpha,
44 typename T_beta, require_matrix_t<T_x>* =
nullptr,
45 require_col_vector_t<T_alpha>* =
nullptr,
46 require_matrix_t<T_beta>* =
nullptr>
48 const T_y& y,
const T_x& x,
const T_alpha& alpha,
const T_beta&
beta) {
61 constexpr int T_x_rows = T_x::RowsAtCompileTime;
64 const size_t N_attributes = x.cols();
65 const size_t N_classes =
beta.cols();
67 static constexpr const char* function =
"categorical_logit_glm_lpmf";
77 check_bounded(function,
"categorical outcome out of support", y_ref, 1,
85 T_alpha_ref alpha_ref = alpha;
86 T_beta_ref beta_ref =
beta;
88 const auto& x_val = to_ref_if<!is_constant<T_beta>::value>(
value_of(x_ref));
89 const auto& alpha_val =
value_of(alpha_ref);
91 = to_ref_if<!is_constant<T_x>::value>(
value_of(beta_ref));
95 Array<T_partials_return, T_x_rows, Dynamic> lin
96 = (x_val * beta_val).rowwise() + alpha_val_vec;
97 Array<T_partials_return, T_x_rows, 1> lin_max
98 = lin.rowwise().maxCoeff();
101 Array<T_partials_return, T_x_rows, Dynamic> exp_lin
102 =
exp(lin.colwise() - lin_max);
103 Array<T_partials_return, T_x_rows, 1> inv_sum_exp_lin
104 = 1 / exp_lin.rowwise().sum();
106 T_partials_return logp =
log(inv_sum_exp_lin).sum() - lin_max.sum();
111 for (
int i = 0; i < N_instances; i++) {
113 logp += lin(0, y_seq[i] - 1);
115 logp += lin(i, y_seq[i] - 1);
125 check_finite(function,
"Matrix of independent variables", x_ref);
133 Array<T_beta_partials, 1, Dynamic> beta_y = beta_val.col(y_seq[0] - 1);
134 for (
int i = 1; i < N_instances; i++) {
135 beta_y += beta_val.col(y_seq[i] - 1).array();
137 edge<0>(ops_partials).partials_
139 - (exp_lin.matrix() * beta_val.transpose()).array().colwise()
140 * inv_sum_exp_lin * N_instances;
142 Array<T_beta_partials, Dynamic, Dynamic> beta_y(N_instances,
144 for (
int i = 0; i < N_instances; i++) {
145 beta_y.row(i) = beta_val.col(y_seq[i] - 1);
147 edge<0>(ops_partials).partials_
149 - (exp_lin.matrix() * beta_val.transpose()).array().colwise()
158 Array<T_partials_return, T_x_rows, Dynamic> neg_softmax_lin
159 = exp_lin.colwise() * -inv_sum_exp_lin;
162 edge<1>(ops_partials).partials_
163 = neg_softmax_lin.colwise().sum() * N_instances;
165 partials<1>(ops_partials) = neg_softmax_lin.colwise().sum();
167 for (
int i = 0; i < N_instances; i++) {
168 partials<1>(ops_partials)[y_seq[i] - 1] += 1;
172 Matrix<T_partials_return, Dynamic, Dynamic> beta_derivative
173 = x_val.transpose().template cast<T_partials_return>()
174 * neg_softmax_lin.matrix();
176 beta_derivative *= N_instances;
179 for (
int i = 0; i < N_instances; i++) {
181 beta_derivative.col(y_seq[i] - 1) += x_val;
183 beta_derivative.col(y_seq[i] - 1) += x_val.row(i);
191 partials<2>(ops_partials) = std::move(beta_derivative);
194 return ops_partials.build(logp);
197template <
typename T_y,
typename T_x,
typename T_alpha,
typename T_beta>
199 const T_y& y,
const T_x& x,
const T_alpha& alpha,
const T_beta&
beta) {
200 return categorical_logit_glm_lpmf<false>(y, x, alpha,
beta);
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
isfinite_< as_operation_cl_t< T > > isfinite(T &&a)
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
return_type_t< T_x, T_alpha, T_beta > categorical_logit_glm_lpmf(const T_y &y, const T_x &x, const T_alpha &alpha, const T_beta &beta)
Returns the log PMF of the Generalized Linear Model (GLM) with categorical distribution and logit (so...
typename partials_type< T >::type partials_type_t
Helper alias for accessing the partial type.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
int64_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
void check_bounded(const char *function, const char *name, const T_y &y, const T_low &low, const T_high &high)
Check if the value is between the low and high values, inclusively.
void check_consistent_size(const char *function, const char *name, const T &x, size_t expected_size)
Check if x is consistent with size expected_size.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
fvar< T > log(const fvar< T > &x)
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
fvar< T > exp(const fvar< T > &x)
typename ref_type_if<!is_constant< T >::value, T >::type ref_type_if_not_constant_t
typename ref_type_if< true, T >::type ref_type_t
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...