Automatic Differentiation
 
Loading...
Searching...
No Matches
sqrt.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_SQRT_HPP
2#define STAN_MATH_REV_FUN_SQRT_HPP
3
11#include <cmath>
12#include <complex>
13
14namespace stan {
15namespace math {
16
45inline var sqrt(const var& a) {
46 return make_callback_var(std::sqrt(a.val()), [a](auto& vi) mutable {
47 if (vi.val() != 0.0) {
48 a.adj() += vi.adj() / (2.0 * vi.val());
49 }
50 });
51}
52
60template <typename T, require_var_matrix_t<T>* = nullptr>
61inline auto sqrt(const T& a) {
62 return make_callback_var(
63 a.val().array().sqrt().matrix(), [a](auto& vi) mutable {
64 a.adj().array()
65 += (vi.val_op().array() == 0.0)
66 .select(0.0, vi.adj().array() / (2.0 * vi.val_op().array()));
67 });
68}
69
76inline std::complex<var> sqrt(const std::complex<var>& z) {
77 return internal::complex_sqrt(z);
78}
79
80} // namespace math
81} // namespace stan
82#endif
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:18
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...