Automatic Differentiation
 
Loading...
Searching...
No Matches
Eigen_NumTraits.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CORE_EIGEN_NUMTRAITS_HPP
2#define STAN_MATH_REV_CORE_EIGEN_NUMTRAITS_HPP
3
12#include <limits>
13
14namespace Eigen {
15
23template <>
24struct NumTraits<stan::math::var> : GenericNumTraits<stan::math::var> {
35 static inline Real dummy_precision() {
36 return NumTraits<double>::dummy_precision();
37 }
38
39 static inline Real epsilon() { return NumTraits<double>::epsilon(); }
40
41 static inline Real highest() { return NumTraits<double>::highest(); }
42 static inline Real lowest() { return NumTraits<double>::lowest(); }
43
44 enum {
48 IsComplex = 0,
49
53 IsInteger = 0,
54
58 IsSigned = 1,
59
65 RequireInitialization = 0,
66
70 ReadCost = 2 * NumTraits<double>::ReadCost,
71
76 AddCost = NumTraits<double>::AddCost,
77
82 MulCost = NumTraits<double>::MulCost
83 };
84
90 static int digits10() { return std::numeric_limits<double>::digits10; }
91};
92
100template <typename BinaryOp>
101struct ScalarBinaryOpTraits<stan::math::var, double, BinaryOp> {
103};
104
112template <typename BinaryOp>
113struct ScalarBinaryOpTraits<double, stan::math::var, BinaryOp> {
115};
116
124template <typename BinaryOp>
125struct ScalarBinaryOpTraits<stan::math::var, int, BinaryOp> {
127};
128
136template <typename BinaryOp>
137struct ScalarBinaryOpTraits<int, stan::math::var, BinaryOp> {
139};
140
149template <typename BinaryOp>
150struct ScalarBinaryOpTraits<stan::math::var, stan::math::var, BinaryOp> {
152};
153
161template <typename BinaryOp>
162struct ScalarBinaryOpTraits<double, std::complex<stan::math::var>, BinaryOp> {
163 using ReturnType = std::complex<stan::math::var>;
164};
165
173template <typename BinaryOp>
174struct ScalarBinaryOpTraits<std::complex<stan::math::var>, double, BinaryOp> {
175 using ReturnType = std::complex<stan::math::var>;
176};
177
185template <typename BinaryOp>
186struct ScalarBinaryOpTraits<int, std::complex<stan::math::var>, BinaryOp> {
187 using ReturnType = std::complex<stan::math::var>;
188};
189
197template <typename BinaryOp>
198struct ScalarBinaryOpTraits<std::complex<stan::math::var>, int, BinaryOp> {
199 using ReturnType = std::complex<stan::math::var>;
200};
201
209template <typename BinaryOp>
210struct ScalarBinaryOpTraits<stan::math::var, std::complex<double>, BinaryOp> {
211 using ReturnType = std::complex<stan::math::var>;
212};
213
221template <typename BinaryOp>
222struct ScalarBinaryOpTraits<std::complex<double>, stan::math::var, BinaryOp> {
223 using ReturnType = std::complex<stan::math::var>;
224};
225
233template <typename BinaryOp>
234struct ScalarBinaryOpTraits<std::complex<double>, std::complex<stan::math::var>,
235 BinaryOp> {
236 using ReturnType = std::complex<stan::math::var>;
237};
238
246template <typename BinaryOp>
247struct ScalarBinaryOpTraits<std::complex<stan::math::var>, std::complex<double>,
248 BinaryOp> {
249 using ReturnType = std::complex<stan::math::var>;
250};
251
252template <typename BinaryOp>
253struct ScalarBinaryOpTraits<stan::math::var, std::complex<stan::math::var>,
254 BinaryOp> {
255 using ReturnType = std::complex<stan::math::var>;
256};
257
258template <typename BinaryOp>
259struct ScalarBinaryOpTraits<std::complex<stan::math::var>, stan::math::var,
260 BinaryOp> {
261 using ReturnType = std::complex<stan::math::var>;
262};
263
264template <typename BinaryOp>
265struct ScalarBinaryOpTraits<std::complex<stan::math::var>,
266 std::complex<stan::math::var>, BinaryOp> {
267 using ReturnType = std::complex<stan::math::var>;
268};
269
270namespace internal {
271
287template <typename CondScalar, typename Arg1, typename Arg2, typename Arg3>
288struct ternary_evaluator<
289 CwiseTernaryOp<
290 scalar_boolean_select_op<stan::math::var, stan::math::var, CondScalar>,
291 Arg1, Arg2, Arg3>,
292 IndexBased, IndexBased>
294 scalar_boolean_select_op<stan::math::var, stan::math::var,
295 CondScalar>,
296 Arg1, Arg2, Arg3> {
298 scalar_boolean_select_op<stan::math::var, stan::math::var, CondScalar>,
299 Arg1, Arg2, Arg3>;
300 using Base::Base;
301};
302
306template <typename EigVar, typename EigVari, typename EigDbl>
307struct functor_has_linear_access<
308 stan::math::vi_val_adj_functor<EigVar, EigVari, EigDbl>> {
309 enum { ret = 1 };
310};
311
315template <typename EigVar, typename EigDbl>
316struct functor_has_linear_access<stan::math::val_adj_functor<EigVar, EigDbl>> {
317 enum { ret = 1 };
318};
319
323template <typename EigVar, typename EigVari>
324struct functor_has_linear_access<stan::math::vi_val_functor<EigVar, EigVari>> {
325 enum { ret = 1 };
326};
327
331template <typename EigVar, typename EigVari>
332struct functor_has_linear_access<stan::math::vi_adj_functor<EigVar, EigVari>> {
333 enum { ret = 1 };
334};
335
340template <>
341struct remove_all<stan::math::vari*> {
343};
344
356template <typename Index, typename LhsMapper, bool ConjugateLhs,
357 bool ConjugateRhs, typename RhsMapper, int Version>
358struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
359 ColMajor, ConjugateLhs, stan::math::var,
360 RhsMapper, ConjugateRhs, Version> {
364 enum { LhsStorageOrder = ColMajor };
365
366 EIGEN_DONT_INLINE static void run(Index rows, Index cols,
367 const LhsMapper& lhsMapper,
368 const RhsMapper& rhsMapper, ResScalar* res,
369 Index resIncr, const ResScalar& alpha) {
370 const LhsScalar* lhs = lhsMapper.data();
371 const Index lhsStride = lhsMapper.stride();
372 const RhsScalar* rhs = rhsMapper.data();
373 const Index rhsIncr = rhsMapper.stride();
374 run(rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha);
375 }
376
377 EIGEN_DONT_INLINE static void run(Index rows, Index cols,
378 const LhsScalar* lhs, Index lhsStride,
379 const RhsScalar* rhs, Index rhsIncr,
380 ResScalar* res, Index resIncr,
381 const ResScalar& alpha) {
383 using stan::math::var;
384 for (Index i = 0; i < rows; ++i) {
385 res[i * resIncr] += var(
386 new gevv_vvv_vari(&alpha, &lhs[i], lhsStride, rhs, rhsIncr, cols));
387 }
388 }
389};
390
391template <typename Index, typename LhsMapper, bool ConjugateLhs,
392 bool ConjugateRhs, typename RhsMapper, int Version>
393struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
394 RowMajor, ConjugateLhs, stan::math::var,
395 RhsMapper, ConjugateRhs, Version> {
399 enum { LhsStorageOrder = RowMajor };
400
401 EIGEN_DONT_INLINE static void run(Index rows, Index cols,
402 const LhsMapper& lhsMapper,
403 const RhsMapper& rhsMapper, ResScalar* res,
404 Index resIncr, const RhsScalar& alpha) {
405 const LhsScalar* lhs = lhsMapper.data();
406 const Index lhsStride = lhsMapper.stride();
407 const RhsScalar* rhs = rhsMapper.data();
408 const Index rhsIncr = rhsMapper.stride();
409 run(rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha);
410 }
411
412 EIGEN_DONT_INLINE static void run(Index rows, Index cols,
413 const LhsScalar* lhs, Index lhsStride,
414 const RhsScalar* rhs, Index rhsIncr,
415 ResScalar* res, Index resIncr,
416 const RhsScalar& alpha) {
417 for (Index i = 0; i < rows; i++) {
418 res[i * resIncr] += stan::math::var(new stan::math::gevv_vvv_vari(
419 &alpha,
420 (static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
421 ? (&lhs[i])
422 : (&lhs[i * lhsStride]),
423 (static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
424 ? (lhsStride)
425 : (1),
426 rhs, rhsIncr, cols));
427 }
428 }
429};
430
431#if EIGEN_VERSION_AT_LEAST(3, 3, 8)
432template <typename Index, int LhsStorageOrder, bool ConjugateLhs,
433 int RhsStorageOrder, bool ConjugateRhs, int ResInnerStride>
434struct general_matrix_matrix_product<
435 Index, stan::math::var, LhsStorageOrder, ConjugateLhs, stan::math::var,
436 RhsStorageOrder, ConjugateRhs, ColMajor, ResInnerStride> {
437#else
438template <typename Index, int LhsStorageOrder, bool ConjugateLhs,
439 int RhsStorageOrder, bool ConjugateRhs>
440struct general_matrix_matrix_product<Index, stan::math::var, LhsStorageOrder,
441 ConjugateLhs, stan::math::var,
442 RhsStorageOrder, ConjugateRhs, ColMajor> {
443#endif
447
448 using Traits = gebp_traits<RhsScalar, LhsScalar>;
449
451 = const_blas_data_mapper<stan::math::var, Index, LhsStorageOrder>;
453 = const_blas_data_mapper<stan::math::var, Index, RhsStorageOrder>;
454
455 EIGEN_DONT_INLINE
456#if EIGEN_VERSION_AT_LEAST(3, 3, 8)
457 static void run(Index rows, Index cols, Index depth, const LhsScalar* lhs,
458 Index lhsStride, const RhsScalar* rhs, Index rhsStride,
459 ResScalar* res, Index resIncr, Index resStride,
460 const ResScalar& alpha,
461 level3_blocking<LhsScalar, RhsScalar>& /* blocking */,
462 GemmParallelInfo<Index>* /* info = 0 */)
463#else
464 static void run(Index rows, Index cols, Index depth, const LhsScalar* lhs,
465 Index lhsStride, const RhsScalar* rhs, Index rhsStride,
466 ResScalar* res, Index resStride, const ResScalar& alpha,
467 level3_blocking<LhsScalar, RhsScalar>& /* blocking */,
468 GemmParallelInfo<Index>* /* info = 0 */)
469#endif
470 {
471 for (Index i = 0; i < cols; i++) {
472 general_matrix_vector_product<
473 Index, LhsScalar, LhsMapper, LhsStorageOrder, ConjugateLhs, RhsScalar,
474 RhsMapper,
475 ConjugateRhs>::run(rows, depth, lhs, lhsStride,
476 &rhs[static_cast<int>(RhsStorageOrder)
477 == static_cast<int>(ColMajor)
478 ? i * rhsStride
479 : i],
480 static_cast<int>(RhsStorageOrder)
481 == static_cast<int>(ColMajor)
482 ? 1
483 : rhsStride,
484 &res[i * resStride], 1, alpha);
485 }
486 } // namespace internal
487
488 EIGEN_DONT_INLINE
489 static void run(Index rows, Index cols, Index depth,
490 const LhsMapper& lhsMapper, const RhsMapper& rhsMapper,
491 ResScalar* res, Index resStride, const ResScalar& alpha,
492 level3_blocking<LhsScalar, RhsScalar>& blocking,
493 GemmParallelInfo<Index>* info = 0) {
494 const LhsScalar* lhs = lhsMapper.data();
495 const Index lhsStride = lhsMapper.stride();
496 const RhsScalar* rhs = rhsMapper.data();
497 const Index rhsStride = rhsMapper.stride();
498
499 run(rows, cols, depth, lhs, lhsStride, rhs, rhsStride, res, resStride,
500 alpha, blocking, info);
501 }
502};
503} // namespace internal
504} // namespace Eigen
505#endif
Specialization of the standard library complex number type for reverse-mode autodiff type stan::math:...
(Expert) Numerical traits for algorithmic differentiation variables.
var_value< double > var
Definition var.hpp:1187
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
static Real dummy_precision()
Return the precision for stan::math::var delegates to precision for double.
static int digits10()
Return the number of decimal digits that can be represented without change.
static void run(Index rows, Index cols, Index depth, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsStride, ResScalar *res, Index resIncr, Index resStride, const ResScalar &alpha, level3_blocking< LhsScalar, RhsScalar > &, GemmParallelInfo< Index > *)
static void run(Index rows, Index cols, Index depth, const LhsMapper &lhsMapper, const RhsMapper &rhsMapper, ResScalar *res, Index resStride, const ResScalar &alpha, level3_blocking< LhsScalar, RhsScalar > &blocking, GemmParallelInfo< Index > *info=0)
static void run(Index rows, Index cols, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsIncr, ResScalar *res, Index resIncr, const RhsScalar &alpha)
static void run(Index rows, Index cols, const LhsMapper &lhsMapper, const RhsMapper &rhsMapper, ResScalar *res, Index resIncr, const RhsScalar &alpha)
static void run(Index rows, Index cols, const LhsMapper &lhsMapper, const RhsMapper &rhsMapper, ResScalar *res, Index resIncr, const ResScalar &alpha)
static void run(Index rows, Index cols, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsIncr, ResScalar *res, Index resIncr, const ResScalar &alpha)
Evaluator for .select() expressions that evaluates only the branch chosen by the condition,...