Automatic Differentiation
 
Loading...
Searching...
No Matches
categorical_logit_glm_lpmf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_PROB_CATEGORICAL_LOGIT_GLM_LPMF_HPP
2#define STAN_MATH_PRIM_PROB_CATEGORICAL_LOGIT_GLM_LPMF_HPP
3
18#include <cmath>
19
20namespace stan {
21namespace math {
22
43template <bool propto, typename T_y, typename T_x, typename T_alpha,
44 typename T_beta, require_matrix_t<T_x>* = nullptr,
45 require_col_vector_t<T_alpha>* = nullptr,
46 require_matrix_t<T_beta>* = nullptr>
48 const T_y& y, const T_x& x, const T_alpha& alpha, const T_beta& beta) {
49 using T_partials_return = partials_return_t<T_x, T_alpha, T_beta>;
50 using Eigen::Array;
51 using Eigen::Dynamic;
52 using Eigen::Matrix;
53 using std::exp;
54 using std::isfinite;
55 using std::log;
56 using T_y_ref = ref_type_t<T_y>;
57 using T_x_ref = ref_type_if_not_constant_t<T_x>;
58 using T_alpha_ref = ref_type_if_not_constant_t<T_alpha>;
59 using T_beta_ref = ref_type_if_not_constant_t<T_beta>;
60 using T_beta_partials = partials_type_t<scalar_type_t<T_beta>>;
61 constexpr int T_x_rows = T_x::RowsAtCompileTime;
62
63 const size_t N_instances = T_x_rows == 1 ? stan::math::size(y) : x.rows();
64 const size_t N_attributes = x.cols();
65 const size_t N_classes = beta.cols();
66
67 static constexpr const char* function = "categorical_logit_glm_lpmf";
68 check_consistent_size(function, "Vector of dependent variables", y,
69 N_instances);
70 check_consistent_size(function, "Intercept vector", alpha, N_classes);
71 check_size_match(function, "x.cols()", N_attributes, "beta.rows()",
72 beta.rows());
73 if (size_zero(y) || N_classes == 1) {
74 return 0;
75 }
76 T_y_ref y_ref = y;
77 check_bounded(function, "categorical outcome out of support", y_ref, 1,
78 N_classes);
79
81 return 0;
82 }
83
84 T_x_ref x_ref = x;
85 T_alpha_ref alpha_ref = alpha;
86 T_beta_ref beta_ref = beta;
87
88 const auto& x_val = to_ref_if<!is_constant<T_beta>::value>(value_of(x_ref));
89 const auto& alpha_val = value_of(alpha_ref);
90 const auto& beta_val
91 = to_ref_if<!is_constant<T_x>::value>(value_of(beta_ref));
92
93 const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val).transpose();
94
95 Array<T_partials_return, T_x_rows, Dynamic> lin
96 = (x_val * beta_val).rowwise() + alpha_val_vec;
97 Array<T_partials_return, T_x_rows, 1> lin_max
98 = lin.rowwise().maxCoeff(); // This is used to prevent overflow when
99 // calculating softmax/log_sum_exp and
100 // similar expressions
101 Array<T_partials_return, T_x_rows, Dynamic> exp_lin
102 = exp(lin.colwise() - lin_max);
103 Array<T_partials_return, T_x_rows, 1> inv_sum_exp_lin
104 = 1 / exp_lin.rowwise().sum();
105
106 T_partials_return logp = log(inv_sum_exp_lin).sum() - lin_max.sum();
107 if (T_x_rows == 1) {
108 logp *= N_instances;
109 }
110 scalar_seq_view<T_y_ref> y_seq(y_ref);
111 for (int i = 0; i < N_instances; i++) {
112 if (T_x_rows == 1) {
113 logp += lin(0, y_seq[i] - 1);
114 } else {
115 logp += lin(i, y_seq[i] - 1);
116 }
117 }
118 // TODO(Tadej) maybe we can replace previous block with the following line
119 // when we have newer Eigen T_partials_return logp =
120 // lin(Eigen::all,y-1).sum() + log(inv_sum_exp_lin).sum() - lin_max.sum();
121
122 if (!isfinite(logp)) {
123 check_finite(function, "Weight vector", beta_ref);
124 check_finite(function, "Intercept", alpha_ref);
125 check_finite(function, "Matrix of independent variables", x_ref);
126 }
127
128 // Compute the derivatives.
129 auto ops_partials = make_partials_propagator(x_ref, alpha_ref, beta_ref);
130
132 if (T_x_rows == 1) {
133 Array<T_beta_partials, 1, Dynamic> beta_y = beta_val.col(y_seq[0] - 1);
134 for (int i = 1; i < N_instances; i++) {
135 beta_y += beta_val.col(y_seq[i] - 1).array();
136 }
137 edge<0>(ops_partials).partials_
138 = beta_y
139 - (exp_lin.matrix() * beta_val.transpose()).array().colwise()
140 * inv_sum_exp_lin * N_instances;
141 } else {
142 Array<T_beta_partials, Dynamic, Dynamic> beta_y(N_instances,
143 N_attributes);
144 for (int i = 0; i < N_instances; i++) {
145 beta_y.row(i) = beta_val.col(y_seq[i] - 1);
146 }
147 edge<0>(ops_partials).partials_
148 = beta_y
149 - (exp_lin.matrix() * beta_val.transpose()).array().colwise()
150 * inv_sum_exp_lin;
151 // TODO(Tadej) maybe we can replace previous block with the following
152 // line when we have newer Eigen partials<0>(ops_partials) = beta_val(y
153 // - 1, all) - (exp_lin.matrix() * beta.transpose()).colwise() *
154 // inv_sum_exp_lin;
155 }
156 }
158 Array<T_partials_return, T_x_rows, Dynamic> neg_softmax_lin
159 = exp_lin.colwise() * -inv_sum_exp_lin;
161 if (T_x_rows == 1) {
162 edge<1>(ops_partials).partials_
163 = neg_softmax_lin.colwise().sum() * N_instances;
164 } else {
165 partials<1>(ops_partials) = neg_softmax_lin.colwise().sum();
166 }
167 for (int i = 0; i < N_instances; i++) {
168 partials<1>(ops_partials)[y_seq[i] - 1] += 1;
169 }
170 }
172 Matrix<T_partials_return, Dynamic, Dynamic> beta_derivative
173 = x_val.transpose().template cast<T_partials_return>()
174 * neg_softmax_lin.matrix();
175 if (T_x_rows == 1) {
176 beta_derivative *= N_instances;
177 }
178
179 for (int i = 0; i < N_instances; i++) {
180 if (T_x_rows == 1) {
181 beta_derivative.col(y_seq[i] - 1) += x_val;
182 } else {
183 beta_derivative.col(y_seq[i] - 1) += x_val.row(i);
184 }
185 }
186 // TODO(Tadej) maybe we can replace previous loop with the following
187 // line when we have newer Eigen partials<2>(ops_partials)(Eigen::all,
188 // y
189 // - 1) += x_val.colwise.sum().transpose();
190
191 partials<2>(ops_partials) = std::move(beta_derivative);
192 }
193 }
194 return ops_partials.build(logp);
195}
196
197template <typename T_y, typename T_x, typename T_alpha, typename T_beta>
199 const T_y& y, const T_x& x, const T_alpha& alpha, const T_beta& beta) {
200 return categorical_logit_glm_lpmf<false>(y, x, alpha, beta);
201}
202
203} // namespace math
204} // namespace stan
205
206#endif
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
isfinite_< as_operation_cl_t< T > > isfinite(T &&a)
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
return_type_t< T_x, T_alpha, T_beta > categorical_logit_glm_lpmf(const T_y &y, const T_x &x, const T_alpha &alpha, const T_beta &beta)
Returns the log PMF of the Generalized Linear Model (GLM) with categorical distribution and logit (so...
typename partials_type< T >::type partials_type_t
Helper alias for accessing the partial type.
size_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:18
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition size_zero.hpp:19
void check_bounded(const char *function, const char *name, const T_y &y, const T_low &low, const T_high &high)
Check if the value is between the low and high values, inclusively.
void check_consistent_size(const char *function, const char *name, const T &x, size_t expected_size)
Check if x is consistent with size expected_size.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:15
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
Definition beta.hpp:51
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
typename ref_type_if<!is_constant< T >::value, T >::type ref_type_if_not_constant_t
Definition ref_type.hpp:62
typename ref_type_if< true, T >::type ref_type_t
Definition ref_type.hpp:55
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Extends std::true_type when instantiated with zero or more template parameters, all of which extend t...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...