Automatic Differentiation
 
Loading...
Searching...
No Matches
Eigen_NumTraits.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CORE_EIGEN_NUMTRAITS_HPP
2#define STAN_MATH_REV_CORE_EIGEN_NUMTRAITS_HPP
3
11#include <limits>
12
13namespace Eigen {
14
22template <>
23struct NumTraits<stan::math::var> : GenericNumTraits<stan::math::var> {
34 static inline Real dummy_precision() {
35 return NumTraits<double>::dummy_precision();
36 }
37
38 static inline Real epsilon() { return NumTraits<double>::epsilon(); }
39
40 static inline Real highest() { return NumTraits<double>::highest(); }
41 static inline Real lowest() { return NumTraits<double>::lowest(); }
42
43 enum {
47 IsComplex = 0,
48
52 IsInteger = 0,
53
57 IsSigned = 1,
58
62 RequireInitialization = 0,
63
67 ReadCost = 2 * NumTraits<double>::ReadCost,
68
73 AddCost = NumTraits<double>::AddCost,
74
79 MulCost = NumTraits<double>::MulCost
80 };
81
87 static int digits10() { return std::numeric_limits<double>::digits10; }
88};
89
97template <typename BinaryOp>
98struct ScalarBinaryOpTraits<stan::math::var, double, BinaryOp> {
100};
101
109template <typename BinaryOp>
110struct ScalarBinaryOpTraits<double, stan::math::var, BinaryOp> {
112};
113
121template <typename BinaryOp>
122struct ScalarBinaryOpTraits<stan::math::var, int, BinaryOp> {
124};
125
133template <typename BinaryOp>
134struct ScalarBinaryOpTraits<int, stan::math::var, BinaryOp> {
136};
137
146template <typename BinaryOp>
147struct ScalarBinaryOpTraits<stan::math::var, stan::math::var, BinaryOp> {
149};
150
158template <typename BinaryOp>
159struct ScalarBinaryOpTraits<double, std::complex<stan::math::var>, BinaryOp> {
160 using ReturnType = std::complex<stan::math::var>;
161};
162
170template <typename BinaryOp>
171struct ScalarBinaryOpTraits<std::complex<stan::math::var>, double, BinaryOp> {
172 using ReturnType = std::complex<stan::math::var>;
173};
174
182template <typename BinaryOp>
183struct ScalarBinaryOpTraits<int, std::complex<stan::math::var>, BinaryOp> {
184 using ReturnType = std::complex<stan::math::var>;
185};
186
194template <typename BinaryOp>
195struct ScalarBinaryOpTraits<std::complex<stan::math::var>, int, BinaryOp> {
196 using ReturnType = std::complex<stan::math::var>;
197};
198
206template <typename BinaryOp>
207struct ScalarBinaryOpTraits<stan::math::var, std::complex<double>, BinaryOp> {
208 using ReturnType = std::complex<stan::math::var>;
209};
210
218template <typename BinaryOp>
219struct ScalarBinaryOpTraits<std::complex<double>, stan::math::var, BinaryOp> {
220 using ReturnType = std::complex<stan::math::var>;
221};
222
230template <typename BinaryOp>
231struct ScalarBinaryOpTraits<std::complex<double>, std::complex<stan::math::var>,
232 BinaryOp> {
233 using ReturnType = std::complex<stan::math::var>;
234};
235
243template <typename BinaryOp>
244struct ScalarBinaryOpTraits<std::complex<stan::math::var>, std::complex<double>,
245 BinaryOp> {
246 using ReturnType = std::complex<stan::math::var>;
247};
248
249template <typename BinaryOp>
250struct ScalarBinaryOpTraits<stan::math::var, std::complex<stan::math::var>,
251 BinaryOp> {
252 using ReturnType = std::complex<stan::math::var>;
253};
254
255template <typename BinaryOp>
256struct ScalarBinaryOpTraits<std::complex<stan::math::var>, stan::math::var,
257 BinaryOp> {
258 using ReturnType = std::complex<stan::math::var>;
259};
260
261template <typename BinaryOp>
262struct ScalarBinaryOpTraits<std::complex<stan::math::var>,
263 std::complex<stan::math::var>, BinaryOp> {
264 using ReturnType = std::complex<stan::math::var>;
265};
266
267namespace internal {
268
272template <typename EigVar, typename EigVari, typename EigDbl>
273struct functor_has_linear_access<
274 stan::math::vi_val_adj_functor<EigVar, EigVari, EigDbl>> {
275 enum { ret = 1 };
276};
277
281template <typename EigVar, typename EigDbl>
282struct functor_has_linear_access<stan::math::val_adj_functor<EigVar, EigDbl>> {
283 enum { ret = 1 };
284};
285
289template <typename EigVar, typename EigVari>
290struct functor_has_linear_access<stan::math::vi_val_functor<EigVar, EigVari>> {
291 enum { ret = 1 };
292};
293
297template <typename EigVar, typename EigVari>
298struct functor_has_linear_access<stan::math::vi_adj_functor<EigVar, EigVari>> {
299 enum { ret = 1 };
300};
301
306template <>
307struct remove_all<stan::math::vari*> {
309};
310
322template <typename Index, typename LhsMapper, bool ConjugateLhs,
323 bool ConjugateRhs, typename RhsMapper, int Version>
324struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
325 ColMajor, ConjugateLhs, stan::math::var,
326 RhsMapper, ConjugateRhs, Version> {
330 enum { LhsStorageOrder = ColMajor };
331
332 EIGEN_DONT_INLINE static void run(Index rows, Index cols,
333 const LhsMapper& lhsMapper,
334 const RhsMapper& rhsMapper, ResScalar* res,
335 Index resIncr, const ResScalar& alpha) {
336 const LhsScalar* lhs = lhsMapper.data();
337 const Index lhsStride = lhsMapper.stride();
338 const RhsScalar* rhs = rhsMapper.data();
339 const Index rhsIncr = rhsMapper.stride();
340 run(rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha);
341 }
342
343 EIGEN_DONT_INLINE static void run(Index rows, Index cols,
344 const LhsScalar* lhs, Index lhsStride,
345 const RhsScalar* rhs, Index rhsIncr,
346 ResScalar* res, Index resIncr,
347 const ResScalar& alpha) {
349 using stan::math::var;
350 for (Index i = 0; i < rows; ++i) {
351 res[i * resIncr] += var(
352 new gevv_vvv_vari(&alpha, &lhs[i], lhsStride, rhs, rhsIncr, cols));
353 }
354 }
355};
356
357template <typename Index, typename LhsMapper, bool ConjugateLhs,
358 bool ConjugateRhs, typename RhsMapper, int Version>
359struct general_matrix_vector_product<Index, stan::math::var, LhsMapper,
360 RowMajor, ConjugateLhs, stan::math::var,
361 RhsMapper, ConjugateRhs, Version> {
365 enum { LhsStorageOrder = RowMajor };
366
367 EIGEN_DONT_INLINE static void run(Index rows, Index cols,
368 const LhsMapper& lhsMapper,
369 const RhsMapper& rhsMapper, ResScalar* res,
370 Index resIncr, const RhsScalar& alpha) {
371 const LhsScalar* lhs = lhsMapper.data();
372 const Index lhsStride = lhsMapper.stride();
373 const RhsScalar* rhs = rhsMapper.data();
374 const Index rhsIncr = rhsMapper.stride();
375 run(rows, cols, lhs, lhsStride, rhs, rhsIncr, res, resIncr, alpha);
376 }
377
378 EIGEN_DONT_INLINE static void run(Index rows, Index cols,
379 const LhsScalar* lhs, Index lhsStride,
380 const RhsScalar* rhs, Index rhsIncr,
381 ResScalar* res, Index resIncr,
382 const RhsScalar& alpha) {
383 for (Index i = 0; i < rows; i++) {
384 res[i * resIncr] += stan::math::var(new stan::math::gevv_vvv_vari(
385 &alpha,
386 (static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
387 ? (&lhs[i])
388 : (&lhs[i * lhsStride]),
389 (static_cast<int>(LhsStorageOrder) == static_cast<int>(ColMajor))
390 ? (lhsStride)
391 : (1),
392 rhs, rhsIncr, cols));
393 }
394 }
395};
396
397#if EIGEN_VERSION_AT_LEAST(3, 3, 8)
398template <typename Index, int LhsStorageOrder, bool ConjugateLhs,
399 int RhsStorageOrder, bool ConjugateRhs, int ResInnerStride>
400struct general_matrix_matrix_product<
401 Index, stan::math::var, LhsStorageOrder, ConjugateLhs, stan::math::var,
402 RhsStorageOrder, ConjugateRhs, ColMajor, ResInnerStride> {
403#else
404template <typename Index, int LhsStorageOrder, bool ConjugateLhs,
405 int RhsStorageOrder, bool ConjugateRhs>
406struct general_matrix_matrix_product<Index, stan::math::var, LhsStorageOrder,
407 ConjugateLhs, stan::math::var,
408 RhsStorageOrder, ConjugateRhs, ColMajor> {
409#endif
413
414 using Traits = gebp_traits<RhsScalar, LhsScalar>;
415
417 = const_blas_data_mapper<stan::math::var, Index, LhsStorageOrder>;
419 = const_blas_data_mapper<stan::math::var, Index, RhsStorageOrder>;
420
421 EIGEN_DONT_INLINE
422#if EIGEN_VERSION_AT_LEAST(3, 3, 8)
423 static void run(Index rows, Index cols, Index depth, const LhsScalar* lhs,
424 Index lhsStride, const RhsScalar* rhs, Index rhsStride,
425 ResScalar* res, Index resIncr, Index resStride,
426 const ResScalar& alpha,
427 level3_blocking<LhsScalar, RhsScalar>& /* blocking */,
428 GemmParallelInfo<Index>* /* info = 0 */)
429#else
430 static void run(Index rows, Index cols, Index depth, const LhsScalar* lhs,
431 Index lhsStride, const RhsScalar* rhs, Index rhsStride,
432 ResScalar* res, Index resStride, const ResScalar& alpha,
433 level3_blocking<LhsScalar, RhsScalar>& /* blocking */,
434 GemmParallelInfo<Index>* /* info = 0 */)
435#endif
436 {
437 for (Index i = 0; i < cols; i++) {
438 general_matrix_vector_product<
439 Index, LhsScalar, LhsMapper, LhsStorageOrder, ConjugateLhs, RhsScalar,
440 RhsMapper,
441 ConjugateRhs>::run(rows, depth, lhs, lhsStride,
442 &rhs[static_cast<int>(RhsStorageOrder)
443 == static_cast<int>(ColMajor)
444 ? i * rhsStride
445 : i],
446 static_cast<int>(RhsStorageOrder)
447 == static_cast<int>(ColMajor)
448 ? 1
449 : rhsStride,
450 &res[i * resStride], 1, alpha);
451 }
452 } // namespace internal
453
454 EIGEN_DONT_INLINE
455 static void run(Index rows, Index cols, Index depth,
456 const LhsMapper& lhsMapper, const RhsMapper& rhsMapper,
457 ResScalar* res, Index resStride, const ResScalar& alpha,
458 level3_blocking<LhsScalar, RhsScalar>& blocking,
459 GemmParallelInfo<Index>* info = 0) {
460 const LhsScalar* lhs = lhsMapper.data();
461 const Index lhsStride = lhsMapper.stride();
462 const RhsScalar* rhs = rhsMapper.data();
463 const Index rhsStride = rhsMapper.stride();
464
465 run(rows, cols, depth, lhs, lhsStride, rhs, rhsStride, res, resStride,
466 alpha, blocking, info);
467 }
468};
469} // namespace internal
470} // namespace Eigen
471#endif
Specialization of the standard library complex number type for reverse-mode autodiff type stan::math:...
(Expert) Numerical traits for algorithmic differentiation variables.
var_value< double > var
Definition var.hpp:1187
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
static Real dummy_precision()
Return the precision for stan::math::var delegates to precision for double.
static int digits10()
Return the number of decimal digits that can be represented without change.
static void run(Index rows, Index cols, Index depth, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsStride, ResScalar *res, Index resIncr, Index resStride, const ResScalar &alpha, level3_blocking< LhsScalar, RhsScalar > &, GemmParallelInfo< Index > *)
static void run(Index rows, Index cols, Index depth, const LhsMapper &lhsMapper, const RhsMapper &rhsMapper, ResScalar *res, Index resStride, const ResScalar &alpha, level3_blocking< LhsScalar, RhsScalar > &blocking, GemmParallelInfo< Index > *info=0)
static void run(Index rows, Index cols, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsIncr, ResScalar *res, Index resIncr, const RhsScalar &alpha)
static void run(Index rows, Index cols, const LhsMapper &lhsMapper, const RhsMapper &rhsMapper, ResScalar *res, Index resIncr, const RhsScalar &alpha)
static void run(Index rows, Index cols, const LhsMapper &lhsMapper, const RhsMapper &rhsMapper, ResScalar *res, Index resIncr, const ResScalar &alpha)
static void run(Index rows, Index cols, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsIncr, ResScalar *res, Index resIncr, const ResScalar &alpha)