1#ifndef STAN_MATH_REV_CORE_EIGEN_NUMTRAITS_HPP
2#define STAN_MATH_REV_CORE_EIGEN_NUMTRAITS_HPP
35 return NumTraits<double>::dummy_precision();
38 static inline Real epsilon() {
return NumTraits<double>::epsilon(); }
40 static inline Real highest() {
return NumTraits<double>::highest(); }
41 static inline Real lowest() {
return NumTraits<double>::lowest(); }
62 RequireInitialization = 0,
67 ReadCost = 2 * NumTraits<double>::ReadCost,
73 AddCost = NumTraits<double>::AddCost,
79 MulCost = NumTraits<double>::MulCost
87 static int digits10() {
return std::numeric_limits<double>::digits10; }
97template <
typename BinaryOp>
98struct ScalarBinaryOpTraits<
stan::math::var, double, BinaryOp> {
109template <
typename BinaryOp>
110struct ScalarBinaryOpTraits<double,
stan::math::var, BinaryOp> {
121template <
typename BinaryOp>
122struct ScalarBinaryOpTraits<
stan::math::var, int, BinaryOp> {
133template <
typename BinaryOp>
134struct ScalarBinaryOpTraits<int,
stan::math::var, BinaryOp> {
146template <
typename BinaryOp>
158template <
typename BinaryOp>
159struct ScalarBinaryOpTraits<double,
std::complex<stan::math::var>, BinaryOp> {
170template <
typename BinaryOp>
171struct ScalarBinaryOpTraits<
std::complex<stan::math::var>, double, BinaryOp> {
182template <
typename BinaryOp>
183struct ScalarBinaryOpTraits<int,
std::complex<stan::math::var>, BinaryOp> {
194template <
typename BinaryOp>
195struct ScalarBinaryOpTraits<
std::complex<stan::math::var>, int, BinaryOp> {
206template <
typename BinaryOp>
207struct ScalarBinaryOpTraits<
stan::math::var, std::complex<double>, BinaryOp> {
218template <
typename BinaryOp>
230template <
typename BinaryOp>
231struct ScalarBinaryOpTraits<
std::complex<double>, std::complex<stan::math::var>,
243template <
typename BinaryOp>
244struct ScalarBinaryOpTraits<
std::complex<stan::math::var>, std::complex<double>,
249template <
typename BinaryOp>
250struct ScalarBinaryOpTraits<
stan::math::var, std::complex<stan::math::var>,
255template <
typename BinaryOp>
261template <
typename BinaryOp>
262struct ScalarBinaryOpTraits<
std::complex<stan::math::var>,
263 std::complex<stan::math::var>, BinaryOp> {
272template <
typename EigVar,
typename EigVari,
typename EigDbl>
273struct functor_has_linear_access<
274 stan::math::vi_val_adj_functor<EigVar, EigVari, EigDbl>> {
281template <
typename EigVar,
typename EigDbl>
282struct functor_has_linear_access<
stan::math::val_adj_functor<EigVar, EigDbl>> {
289template <
typename EigVar,
typename EigVari>
290struct functor_has_linear_access<
stan::math::vi_val_functor<EigVar, EigVari>> {
297template <
typename EigVar,
typename EigVari>
298struct functor_has_linear_access<
stan::math::vi_adj_functor<EigVar, EigVari>> {
307struct remove_all<
stan::math::vari*> {
322template <
typename Index,
typename LhsMapper,
bool ConjugateLhs,
323 bool ConjugateRhs,
typename RhsMapper,
int Version>
324struct general_matrix_vector_product<Index,
stan::math::var, LhsMapper,
326 RhsMapper, ConjugateRhs, Version> {
330 enum { LhsStorageOrder = ColMajor };
332 EIGEN_DONT_INLINE
static void run(Index rows, Index cols,
333 const LhsMapper& lhsMapper,
334 const RhsMapper& rhsMapper,
ResScalar* res,
337 const Index lhsStride = lhsMapper.stride();
339 const Index rhsIncr = rhsMapper.stride();
340 run(rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha);
343 EIGEN_DONT_INLINE
static void run(Index rows, Index cols,
350 for (Index i = 0; i < rows; ++i) {
351 res[i * resIncr] += var(
352 new gevv_vvv_vari(&alpha, &lhs[i], lhsStride, rhs, rhsIncr, cols));
357template <
typename Index,
typename LhsMapper,
bool ConjugateLhs,
358 bool ConjugateRhs,
typename RhsMapper,
int Version>
359struct general_matrix_vector_product<Index,
stan::math::var, LhsMapper,
361 RhsMapper, ConjugateRhs, Version> {
365 enum { LhsStorageOrder = RowMajor };
367 EIGEN_DONT_INLINE
static void run(Index rows, Index cols,
368 const LhsMapper& lhsMapper,
369 const RhsMapper& rhsMapper,
ResScalar* res,
372 const Index lhsStride = lhsMapper.stride();
374 const Index rhsIncr = rhsMapper.stride();
375 run(rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha);
378 EIGEN_DONT_INLINE
static void run(Index rows, Index cols,
383 for (Index i = 0; i < rows; i++) {
386 (
static_cast<int>(LhsStorageOrder) ==
static_cast<int>(ColMajor))
388 : (&lhs[i * lhsStride]),
389 (
static_cast<int>(LhsStorageOrder) ==
static_cast<int>(ColMajor))
392 rhs, rhsIncr, cols));
397#if EIGEN_VERSION_AT_LEAST(3, 3, 8)
398template <
typename Index,
int LhsStorageOrder,
bool ConjugateLhs,
399 int RhsStorageOrder,
bool ConjugateRhs,
int ResInnerStride>
400struct general_matrix_matrix_product<
402 RhsStorageOrder, ConjugateRhs, ColMajor, ResInnerStride> {
404template <
typename Index,
int LhsStorageOrder,
bool ConjugateLhs,
405 int RhsStorageOrder,
bool ConjugateRhs>
406struct general_matrix_matrix_product<Index,
stan::math::var, LhsStorageOrder,
408 RhsStorageOrder, ConjugateRhs, ColMajor> {
414 using Traits = gebp_traits<RhsScalar, LhsScalar>;
417 = const_blas_data_mapper<stan::math::var, Index, LhsStorageOrder>;
419 = const_blas_data_mapper<stan::math::var, Index, RhsStorageOrder>;
422#if EIGEN_VERSION_AT_LEAST(3, 3, 8)
423 static void run(Index rows, Index cols, Index depth,
const LhsScalar* lhs,
424 Index lhsStride,
const RhsScalar* rhs, Index rhsStride,
425 ResScalar* res, Index resIncr, Index resStride,
427 level3_blocking<LhsScalar, RhsScalar>& ,
428 GemmParallelInfo<Index>* )
430 static void run(Index rows, Index cols, Index depth,
const LhsScalar* lhs,
431 Index lhsStride,
const RhsScalar* rhs, Index rhsStride,
433 level3_blocking<LhsScalar, RhsScalar>& ,
434 GemmParallelInfo<Index>* )
437 for (Index i = 0; i < cols; i++) {
438 general_matrix_vector_product<
441 ConjugateRhs>::run(rows, depth, lhs, lhsStride,
442 &rhs[
static_cast<int>(RhsStorageOrder)
443 ==
static_cast<int>(ColMajor)
446 static_cast<int>(RhsStorageOrder)
447 ==
static_cast<int>(ColMajor)
450 &res[i * resStride], 1, alpha);
455 static void run(Index rows, Index cols, Index depth,
458 level3_blocking<LhsScalar, RhsScalar>& blocking,
459 GemmParallelInfo<Index>* info = 0) {
461 const Index lhsStride = lhsMapper.stride();
463 const Index rhsStride = rhsMapper.stride();
465 run(rows, cols, depth, lhs, lhsStride, rhs, rhsStride, res, resStride,
466 alpha, blocking, info);
Specialization of the standard library complex number type for reverse-mode autodiff type stan::math:...
(Expert) Numerical traits for algorithmic differentiation variables.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
static Real dummy_precision()
Return the precision for stan::math::var delegates to precision for double.
static int digits10()
Return the number of decimal digits that can be represented without change.
const_blas_data_mapper< stan::math::var, Index, LhsStorageOrder > LhsMapper
gebp_traits< RhsScalar, LhsScalar > Traits
const_blas_data_mapper< stan::math::var, Index, RhsStorageOrder > RhsMapper
static void run(Index rows, Index cols, Index depth, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsStride, ResScalar *res, Index resIncr, Index resStride, const ResScalar &alpha, level3_blocking< LhsScalar, RhsScalar > &, GemmParallelInfo< Index > *)
static void run(Index rows, Index cols, Index depth, const LhsMapper &lhsMapper, const RhsMapper &rhsMapper, ResScalar *res, Index resStride, const ResScalar &alpha, level3_blocking< LhsScalar, RhsScalar > &blocking, GemmParallelInfo< Index > *info=0)
static void run(Index rows, Index cols, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsIncr, ResScalar *res, Index resIncr, const RhsScalar &alpha)
static void run(Index rows, Index cols, const LhsMapper &lhsMapper, const RhsMapper &rhsMapper, ResScalar *res, Index resIncr, const RhsScalar &alpha)
static void run(Index rows, Index cols, const LhsMapper &lhsMapper, const RhsMapper &rhsMapper, ResScalar *res, Index resIncr, const ResScalar &alpha)
static void run(Index rows, Index cols, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsIncr, ResScalar *res, Index resIncr, const ResScalar &alpha)