Automatic Differentiation
 
Loading...
Searching...
No Matches
sqrt.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_SQRT_HPP
2#define STAN_MATH_REV_FUN_SQRT_HPP
3
10#include <cmath>
11#include <complex>
12
13namespace stan {
14namespace math {
15
44inline var sqrt(const var& a) {
45 return make_callback_var(std::sqrt(a.val()), [a](auto& vi) mutable {
46 if (vi.val() != 0.0) {
47 a.adj() += vi.adj() / (2.0 * vi.val());
48 }
49 });
50}
51
59template <typename T, require_var_matrix_t<T>* = nullptr>
60inline auto sqrt(const T& a) {
61 return make_callback_var(
62 a.val().array().sqrt().matrix(), [a](auto& vi) mutable {
63 a.adj().array()
64 += (vi.val_op().array() == 0.0)
65 .select(0.0, vi.adj().array() / (2.0 * vi.val_op().array()));
66 });
67}
68
75inline std::complex<var> sqrt(const std::complex<var>& z) {
76 return internal::complex_sqrt(z);
77}
78
79} // namespace math
80} // namespace stan
81#endif
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:17
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...