Automatic Differentiation
 
Loading...
Searching...
No Matches
categorical_logit_glm_lpmf.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_PROB_CATEGORICAL_LOGIT_GLM_LPMF_HPP
2#define STAN_MATH_PRIM_PROB_CATEGORICAL_LOGIT_GLM_LPMF_HPP
3
18#include <cmath>
19
20namespace stan {
21namespace math {
22
43template <bool propto, typename T_y, typename T_x, typename T_alpha,
44 typename T_beta, require_matrix_t<T_x>* = nullptr,
45 require_col_vector_t<T_alpha>* = nullptr,
46 require_matrix_t<T_beta>* = nullptr>
48 const T_y& y, const T_x& x, const T_alpha& alpha, const T_beta& beta) {
49 using T_partials_return = partials_return_t<T_x, T_alpha, T_beta>;
50 using Eigen::Array;
51 using Eigen::Dynamic;
52 using Eigen::Matrix;
53 using std::exp;
54 using std::isfinite;
55 using std::log;
56 using T_y_ref = ref_type_t<T_y>;
57 using T_x_ref = ref_type_if_not_constant_t<T_x>;
58 using T_alpha_ref = ref_type_if_not_constant_t<T_alpha>;
59 using T_beta_ref = ref_type_if_not_constant_t<T_beta>;
60 using T_beta_partials = partials_type_t<scalar_type_t<T_beta>>;
61 constexpr int T_x_rows = T_x::RowsAtCompileTime;
62
63 const size_t N_instances = T_x_rows == 1 ? stan::math::size(y) : x.rows();
64 const size_t N_attributes = x.cols();
65 const size_t N_classes = beta.cols();
66
67 static constexpr const char* function = "categorical_logit_glm_lpmf";
68 check_consistent_size(function, "Vector of dependent variables", y,
69 N_instances);
70 check_consistent_size(function, "Intercept vector", alpha, N_classes);
71 check_size_match(function, "x.cols()", N_attributes, "beta.rows()",
72 beta.rows());
73 if (size_zero(y) || N_classes == 1) {
74 return 0;
75 }
76 T_y_ref y_ref = y;
77 check_bounded(function, "categorical outcome out of support", y_ref, 1,
78 N_classes);
79
81 return 0;
82 }
83
84 T_x_ref x_ref = x;
85 T_alpha_ref alpha_ref = alpha;
86 T_beta_ref beta_ref = beta;
87
88 const auto& x_val = to_ref_if<is_autodiff_v<T_beta>>(value_of(x_ref));
89 const auto& alpha_val = value_of(alpha_ref);
90 const auto& beta_val = to_ref_if<is_autodiff_v<T_x>>(value_of(beta_ref));
91
92 const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val).transpose();
93
94 Array<T_partials_return, T_x_rows, Dynamic> lin
95 = (x_val * beta_val).rowwise() + alpha_val_vec;
96 Array<T_partials_return, T_x_rows, 1> lin_max
97 = lin.rowwise().maxCoeff(); // This is used to prevent overflow when
98 // calculating softmax/log_sum_exp and
99 // similar expressions
100 Array<T_partials_return, T_x_rows, Dynamic> exp_lin
101 = exp(lin.colwise() - lin_max);
102 Array<T_partials_return, T_x_rows, 1> inv_sum_exp_lin
103 = 1 / exp_lin.rowwise().sum();
104
105 T_partials_return logp = log(inv_sum_exp_lin).sum() - lin_max.sum();
106 if constexpr (T_x_rows == 1) {
107 logp *= N_instances;
108 }
109 scalar_seq_view<T_y_ref> y_seq(y_ref);
110 for (int i = 0; i < N_instances; i++) {
111 if constexpr (T_x_rows == 1) {
112 logp += lin(0, y_seq[i] - 1);
113 } else {
114 logp += lin(i, y_seq[i] - 1);
115 }
116 }
117 // TODO(Tadej) maybe we can replace previous block with the following line
118 // when we have newer Eigen T_partials_return logp =
119 // lin(Eigen::all,y-1).sum() + log(inv_sum_exp_lin).sum() - lin_max.sum();
120
121 if (!isfinite(logp)) {
122 check_finite(function, "Weight vector", beta_ref);
123 check_finite(function, "Intercept", alpha_ref);
124 check_finite(function, "Matrix of independent variables", x_ref);
125 }
126
127 // Compute the derivatives.
128 auto ops_partials = make_partials_propagator(x_ref, alpha_ref, beta_ref);
129
130 if constexpr (is_autodiff_v<T_x>) {
131 if constexpr (T_x_rows == 1) {
132 Array<T_beta_partials, 1, Dynamic> beta_y = beta_val.col(y_seq[0] - 1);
133 for (int i = 1; i < N_instances; i++) {
134 beta_y += beta_val.col(y_seq[i] - 1).array();
135 }
136 edge<0>(ops_partials).partials_
137 = beta_y
138 - (exp_lin.matrix() * beta_val.transpose()).array().colwise()
139 * inv_sum_exp_lin * N_instances;
140 } else {
141 Array<T_beta_partials, Dynamic, Dynamic> beta_y(N_instances,
142 N_attributes);
143 for (int i = 0; i < N_instances; i++) {
144 beta_y.row(i) = beta_val.col(y_seq[i] - 1);
145 }
146 edge<0>(ops_partials).partials_
147 = beta_y
148 - (exp_lin.matrix() * beta_val.transpose()).array().colwise()
149 * inv_sum_exp_lin;
150 // TODO(Tadej) maybe we can replace previous block with the following
151 // line when we have newer Eigen partials<0>(ops_partials) = beta_val(y
152 // - 1, all) - (exp_lin.matrix() * beta.transpose()).colwise() *
153 // inv_sum_exp_lin;
154 }
155 }
156 if constexpr (is_any_autodiff_v<T_alpha, T_beta>) {
157 Array<T_partials_return, T_x_rows, Dynamic> neg_softmax_lin
158 = exp_lin.colwise() * -inv_sum_exp_lin;
159 if constexpr (is_autodiff_v<T_alpha>) {
160 if constexpr (T_x_rows == 1) {
161 edge<1>(ops_partials).partials_
162 = neg_softmax_lin.colwise().sum() * N_instances;
163 } else {
164 partials<1>(ops_partials) = neg_softmax_lin.colwise().sum();
165 }
166 for (int i = 0; i < N_instances; i++) {
167 partials<1>(ops_partials)[y_seq[i] - 1] += 1;
168 }
169 }
170 if constexpr (is_autodiff_v<T_beta>) {
171 Matrix<T_partials_return, Dynamic, Dynamic> beta_derivative
172 = x_val.transpose().template cast<T_partials_return>()
173 * neg_softmax_lin.matrix();
174 if constexpr (T_x_rows == 1) {
175 beta_derivative *= N_instances;
176 }
177
178 for (int i = 0; i < N_instances; i++) {
179 if constexpr (T_x_rows == 1) {
180 beta_derivative.col(y_seq[i] - 1) += x_val;
181 } else {
182 beta_derivative.col(y_seq[i] - 1) += x_val.row(i);
183 }
184 }
185 // TODO(Tadej) maybe we can replace previous loop with the following
186 // line when we have newer Eigen partials<2>(ops_partials)(Eigen::all,
187 // y
188 // - 1) += x_val.colwise.sum().transpose();
189
190 partials<2>(ops_partials) = std::move(beta_derivative);
191 }
192 }
193 return ops_partials.build(logp);
194}
195
196template <typename T_y, typename T_x, typename T_alpha, typename T_beta>
198 const T_y& y, const T_x& x, const T_alpha& alpha, const T_beta& beta) {
199 return categorical_logit_glm_lpmf<false>(y, x, alpha, beta);
200}
201
202} // namespace math
203} // namespace stan
204
205#endif
scalar_seq_view provides a uniform sequence-like wrapper around either a scalar or a sequence of scal...
isfinite_< as_operation_cl_t< T > > isfinite(T &&a)
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
return_type_t< T_x, T_alpha, T_beta > categorical_logit_glm_lpmf(const T_y &y, const T_x &x, const T_alpha &alpha, const T_beta &beta)
Returns the log PMF of the Generalized Linear Model (GLM) with categorical distribution and logit (so...
typename partials_type< T >::type partials_type_t
Helper alias for accessing the partial type.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
int64_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:19
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition size_zero.hpp:19
void check_bounded(const char *function, const char *name, const T_y &y, const T_low &low, const T_high &high)
Check if the value is between the low and high values, inclusively.
void check_consistent_size(const char *function, const char *name, const T &x, size_t expected_size)
Check if x is consistent with size expected_size.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > log(const fvar< T > &x)
Definition log.hpp:18
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
fvar< T > beta(const fvar< T > &x1, const fvar< T > &x2)
Return fvar with the beta function applied to the specified arguments and its gradient.
Definition beta.hpp:51
auto make_partials_propagator(Ops &&... ops)
Construct an partials_propagator.
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
typename ref_type_if< is_autodiff_v< T >, T >::type ref_type_if_not_constant_t
Definition ref_type.hpp:63
typename ref_type_if< true, T >::type ref_type_t
Definition ref_type.hpp:56
typename partials_return_type< Args... >::type partials_return_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Template metaprogram to calculate whether a summand needs to be included in a proportional (log) prob...