Automatic Differentiation
 
Loading...
Searching...
No Matches
pow.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_POW_HPP
2#define STAN_MATH_REV_FUN_POW_HPP
3
24
25#include <cmath>
26#include <complex>
27#include <type_traits>
28
29namespace stan {
30namespace math {
31
74template <typename Scal1, typename Scal2,
75 require_any_var_t<base_type_t<Scal1>, base_type_t<Scal2>>* = nullptr,
76 require_all_stan_scalar_t<Scal1, Scal2>* = nullptr>
77inline auto pow(const Scal1& base, const Scal2& exponent) {
79 return internal::complex_pow(base, exponent);
80 } else {
81 if constexpr (is_constant_v<Scal2>) {
82 if (exponent == 0.5) {
83 return sqrt(base);
84 } else if (exponent == 1.0) {
85 return base;
86 } else if (exponent == 2.0) {
87 return square(base);
88 } else if (exponent == -2.0) {
89 return inv_square(base);
90 } else if (exponent == -1.0) {
91 return inv(base);
92 } else if (exponent == -0.5) {
93 return inv_sqrt(base);
94 }
95 }
96 return make_callback_var(
97 std::pow(value_of(base), value_of(exponent)),
98 [base, exponent](auto&& vi) mutable {
99 if (value_of(base) == 0.0) {
100 return; // partials zero, avoids 0 & log(0)
101 }
102 const double vi_mul = vi.adj() * vi.val();
103
104 if constexpr (is_autodiff_v<Scal1>) {
105 base.adj() += vi_mul * value_of(exponent) / value_of(base);
106 }
107 if constexpr (is_autodiff_v<Scal2>) {
108 exponent.adj() += vi_mul * std::log(value_of(base));
109 }
110 });
111 }
112}
113
127template <typename Mat1, typename Mat2,
131inline auto pow(const Mat1& base, const Mat2& exponent) {
132 check_consistent_sizes("pow", "base", base, "exponent", exponent);
133 using expr_type = decltype(as_array_or_scalar(value_of(base))
134 .pow(as_array_or_scalar(value_of(exponent))));
135 using val_type = std::conditional_t<
137 decltype(std::declval<expr_type>().eval()),
138 decltype(std::declval<expr_type>().matrix().eval())>;
140 using base_t = decltype(as_array_or_scalar(base));
141 using exp_t = decltype(as_array_or_scalar(exponent));
142 using base_arena_t = arena_t<base_t>;
143 using exp_arena_t = arena_t<exp_t>;
144
145 base_arena_t arena_base = as_array_or_scalar(base);
146 exp_arena_t arena_exponent = as_array_or_scalar(exponent);
148 = value_of(arena_base).pow(value_of(arena_exponent)).matrix();
149
150 reverse_pass_callback([arena_base, arena_exponent, ret]() mutable {
151 const auto& are_vals_zero = to_ref(value_of(arena_base) != 0.0);
152 const auto& ret_mul = to_ref(ret.adj().array() * ret.val().array());
153 if constexpr (is_autodiff_v<Mat1>) {
154 arena_base.adj() += (are_vals_zero)
155 .select(ret_mul * value_of(arena_exponent)
156 / value_of(arena_base),
157 0);
158 }
159 if constexpr (is_autodiff_v<Mat2>) {
160 arena_exponent.adj()
161 += (are_vals_zero).select(ret_mul * value_of(arena_base).log(), 0);
162 }
163 });
164 return ret_type(ret);
165}
166
178template <typename Mat1, typename Scal1,
182inline auto pow(const Mat1& base, const Scal1& exponent) {
184
185 if constexpr (is_constant_v<Scal1>) {
186 if (exponent == 0.5) {
187 return ret_type(sqrt(base));
188 } else if (exponent == 1.0) {
189 return ret_type(base);
190 } else if (exponent == 2.0) {
191 return ret_type(square(base));
192 } else if (exponent == -2.0) {
193 return ret_type(inv_square(base));
194 } else if (exponent == -1.0) {
195 return ret_type(inv(base));
196 } else if (exponent == -0.5) {
197 return ret_type(inv_sqrt(base));
198 }
199 }
200
201 arena_t<plain_type_t<Mat1>> arena_base = base;
203 = value_of(arena_base).array().pow(value_of(exponent)).matrix();
204
205 reverse_pass_callback([arena_base, exponent, ret]() mutable {
206 const auto& are_vals_zero = to_ref(value_of(arena_base).array() != 0.0);
207 const auto& ret_mul = to_ref(ret.adj().array() * ret.val().array());
208 if constexpr (is_autodiff_v<Mat1>) {
209 arena_base.adj().array()
210 += (are_vals_zero)
211 .select(ret_mul * value_of(exponent)
212 / value_of(arena_base).array(),
213 0);
214 }
215 if constexpr (is_autodiff_v<Scal1>) {
216 exponent.adj()
217 += (are_vals_zero)
218 .select(ret_mul * value_of(arena_base).array().log(), 0)
219 .sum();
220 }
221 });
222
223 return ret_type(ret);
224}
225
243template <typename Scal1, typename Mat1,
247inline auto pow(Scal1 base, const Mat1& exponent) {
249 arena_t<Mat1> arena_exponent = exponent;
251 = Eigen::pow(value_of(base), value_of(arena_exponent).array());
252
253 reverse_pass_callback([base, arena_exponent, ret]() mutable {
254 if (unlikely(value_of(base) == 0.0)) {
255 return; // partials zero, avoids 0 & log(0)
256 }
257 const auto& ret_mul = to_ref(ret.adj().array() * ret.val().array());
258 if constexpr (is_autodiff_v<Scal1>) {
259 base.adj()
260 += (ret_mul * value_of(arena_exponent).array() / value_of(base))
261 .sum();
262 }
263 if constexpr (is_autodiff_v<Mat1>) {
264 arena_exponent.adj().array() += ret_mul * std::log(value_of(base));
265 }
266 });
267 return ret_type(ret);
268}
269
281template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
282 require_all_not_matrix_st<is_var, T1, T2>* = nullptr,
283 require_any_var_t<base_type_t<T1>, base_type_t<T2>>* = nullptr>
284inline auto pow(T1&& a, T2&& b) {
285 return apply_scalar_binary(
286 [](auto&& c, auto&& d) {
287 return stan::math::pow(std::forward<decltype(c)>(c),
288 std::forward<decltype(d)>(d));
289 },
290 std::forward<T1>(a), std::forward<T2>(b));
291}
292
293} // namespace math
294} // namespace stan
295#endif
#define unlikely(x)
require_any_t< container_type_check_base< is_matrix, scalar_type_t, TypeCheck, Check >... > require_any_matrix_st
Require any of the types satisfy is_matrix.
Definition is_matrix.hpp:64
require_all_t< container_type_check_base< is_matrix, scalar_type_t, TypeCheck, Check >... > require_all_matrix_st
Require all of the types does not satisfy is_matrix.
Definition is_matrix.hpp:73
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
require_all_not_t< is_stan_scalar< std::decay_t< Types > >... > require_all_not_stan_scalar_t
Require none of the types satisfy is_stan_scalar.
require_all_t< is_var_or_arithmetic< scalar_type_t< std::decay_t< Types > > >... > require_all_st_var_or_arithmetic
Require all of the scalar types satisfy is_var_or_arithmetic.
complex_return_t< U, V > complex_pow(const U &x, const V &y)
Return the first argument raised to the power of the second argument.
Definition pow.hpp:27
T as_array_or_scalar(T &&v)
Returns specified input value.
fvar< T > inv_square(const fvar< T > &x)
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
auto pow(const T1 &x1, const T2 &x2)
Definition pow.hpp:32
auto apply_scalar_binary(F &&f, T1 &&x, T2 &&y)
Base template function for vectorization of binary scalar functions defined by applying a functor to ...
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:18
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:23
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
fvar< T > inv_sqrt(const fvar< T > &x)
Definition inv_sqrt.hpp:14
fvar< T > inv(const fvar< T > &x)
Definition inv.hpp:13
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
If T is a complex type (that is, an instance of std::complex) or a cv-qualified version thereof,...
Check if a type is derived from Eigen::ArrayBase
Definition is_eigen.hpp:264
Extends std::false_type when instantiated with zero or more template parameters, all of which extend ...