Automatic Differentiation
 
Loading...
Searching...
No Matches
softmax.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_SOFTMAX_HPP
2#define STAN_MATH_OPENCL_PRIM_SOFTMAX_HPP
3#ifdef STAN_OPENCL
11
12namespace stan {
13namespace math {
14
22template <typename T,
23 require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
24inline matrix_cl<double> softmax(const T& a) {
25 check_vector("softmax (OpenCL)", "a", a);
26 if (a.size() == 0) {
27 return matrix_cl<double>(a.rows(), a.cols());
28 }
31 matrix_cl<double> a_max = max_2d(a);
32 theta = exp(a - from_matrix_cl(a_max).maxCoeff());
33 } else {
34 matrix_cl<double> a_eval;
36 results(a_eval, a_max) = expressions(a, max_2d(a));
37 theta = exp(a_eval - from_matrix_cl(a_max).maxCoeff());
38 }
39 return elt_divide(theta, sum(theta));
40}
41
42} // namespace math
43} // namespace stan
44
45#endif
46#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
results_cl< T_results... > results(T_results &&... results)
Deduces types for constructing results_cl object.
elt_divide_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_divide(T_a &&a, T_b &&b)
expressions_cl< T_expressions... > expressions(T_expressions &&... expressions)
Deduces types for constructing expressions_cl object.
auto max_2d(T &&a)
Two dimensional max - reduction of a kernel generator expression.
auto from_matrix_cl(const T &src)
Copies the source matrix that is stored on the OpenCL device to the destination Eigen matrix.
Definition copy.hpp:61
void check_vector(const char *function, const char *name, const Mat &x)
Check the input is either a row vector or column vector or a matrix with a single row or column.
auto softmax(T &&x)
Return the softmax of each vector in a container of fvar values.
Definition softmax.hpp:23
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:23
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...