Automatic Differentiation
 
Loading...
Searching...
No Matches
pow.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_POW_HPP
2#define STAN_MATH_REV_FUN_POW_HPP
3
23#include <cmath>
24#include <complex>
25#include <type_traits>
26
27namespace stan {
28namespace math {
29
68template <typename Scal1, typename Scal2,
69 require_any_st_var<Scal1, Scal2>* = nullptr,
70 require_all_stan_scalar_t<Scal1, Scal2>* = nullptr>
71inline var pow(const Scal1& base, const Scal2& exponent) {
73 if (exponent == 0.5) {
74 return sqrt(base);
75 } else if (exponent == 1.0) {
76 return base;
77 } else if (exponent == 2.0) {
78 return square(base);
79 } else if (exponent == -2.0) {
80 return inv_square(base);
81 } else if (exponent == -1.0) {
82 return inv(base);
83 } else if (exponent == -0.5) {
84 return inv_sqrt(base);
85 }
86 }
87 return make_callback_var(
88 std::pow(value_of(base), value_of(exponent)),
89 [base, exponent](auto&& vi) mutable {
90 if (value_of(base) == 0.0) {
91 return; // partials zero, avoids 0 & log(0)
92 }
93 const double vi_mul = vi.adj() * vi.val();
94
96 forward_as<var>(base).adj()
97 += vi_mul * value_of(exponent) / value_of(base);
98 }
100 forward_as<var>(exponent).adj() += vi_mul * std::log(value_of(base));
101 }
102 });
103}
104
118template <typename Mat1, typename Mat2,
122inline auto pow(const Mat1& base, const Mat2& exponent) {
123 check_consistent_sizes("pow", "base", base, "exponent", exponent);
124 using expr_type = decltype(as_array_or_scalar(value_of(base))
125 .pow(as_array_or_scalar(value_of(exponent))));
126 using val_type = std::conditional_t<
128 decltype(std::declval<expr_type>().eval()),
129 decltype(std::declval<expr_type>().matrix().eval())>;
131 using base_t = decltype(as_array_or_scalar(base));
132 using exp_t = decltype(as_array_or_scalar(exponent));
133 using base_arena_t = arena_t<base_t>;
134 using exp_arena_t = arena_t<exp_t>;
135
136 base_arena_t arena_base = as_array_or_scalar(base);
137 exp_arena_t arena_exponent = as_array_or_scalar(exponent);
139 = value_of(arena_base).pow(value_of(arena_exponent)).matrix();
140
141 reverse_pass_callback([arena_base, arena_exponent, ret]() mutable {
142 const auto& are_vals_zero = to_ref(value_of(arena_base) != 0.0);
143 const auto& ret_mul = to_ref(ret.adj().array() * ret.val().array());
145 using base_var_arena_t = arena_t<promote_scalar_t<var, base_arena_t>>;
146 forward_as<base_var_arena_t>(arena_base).adj()
147 += (are_vals_zero)
148 .select(
149 ret_mul * value_of(arena_exponent) / value_of(arena_base),
150 0);
151 }
153 using exp_var_arena_t = arena_t<promote_scalar_t<var, exp_arena_t>>;
154 forward_as<exp_var_arena_t>(arena_exponent).adj()
155 += (are_vals_zero).select(ret_mul * value_of(arena_base).log(), 0);
156 }
157 });
158 return ret_type(ret);
159}
160
171template <typename Mat1, typename Scal1,
175inline auto pow(const Mat1& base, const Scal1& exponent) {
177
179 if (exponent == 0.5) {
180 return ret_type(sqrt(base));
181 } else if (exponent == 1.0) {
182 return ret_type(base);
183 } else if (exponent == 2.0) {
184 return ret_type(square(base));
185 } else if (exponent == -2.0) {
186 return ret_type(inv_square(base));
187 } else if (exponent == -1.0) {
188 return ret_type(inv(base));
189 } else if (exponent == -0.5) {
190 return ret_type(inv_sqrt(base));
191 }
192 }
193
194 arena_t<plain_type_t<Mat1>> arena_base = base;
196 = value_of(arena_base).array().pow(value_of(exponent)).matrix();
197
198 reverse_pass_callback([arena_base, exponent, ret]() mutable {
199 const auto& are_vals_zero = to_ref(value_of(arena_base).array() != 0.0);
200 const auto& ret_mul = to_ref(ret.adj().array() * ret.val().array());
202 forward_as<ret_type>(arena_base).adj().array()
203 += (are_vals_zero)
204 .select(ret_mul * value_of(exponent)
205 / value_of(arena_base).array(),
206 0);
207 }
209 forward_as<var>(exponent).adj()
210 += (are_vals_zero)
211 .select(ret_mul * value_of(arena_base).array().log(), 0)
212 .sum();
213 }
214 });
215
216 return ret_type(ret);
217}
218
236template <typename Scal1, typename Mat1,
240inline auto pow(Scal1 base, const Mat1& exponent) {
242 arena_t<Mat1> arena_exponent = exponent;
244 = Eigen::pow(value_of(base), value_of(arena_exponent).array());
245
246 reverse_pass_callback([base, arena_exponent, ret]() mutable {
247 if (unlikely(value_of(base) == 0.0)) {
248 return; // partials zero, avoids 0 & log(0)
249 }
250 const auto& ret_mul = to_ref(ret.adj().array() * ret.val().array());
252 forward_as<var>(base).adj()
253 += (ret_mul * value_of(arena_exponent).array() / value_of(base))
254 .sum();
255 }
257 forward_as<ret_type>(arena_exponent).adj().array()
258 += ret_mul * std::log(value_of(base));
259 }
260 });
261 return ret_type(ret);
262}
263
264// must uniquely match all pairs of { complex<var>, complex<T>, var, T }
265// with at least one var and at least one complex, where T is arithmetic:
266// 1) complex<var>, complex<var>
267// 2) complex<var>, complex<T>
268// 3) complex<var>, var
269// 4) complex<var>, T
270// 5) complex<T>, complex<var>
271// 6) complex<T>, var
272// 7) var, complex<var>
273// 8) var, complex<T>
274// 9) T, complex<var>
275
283inline std::complex<var> pow(const std::complex<var>& x,
284 const std::complex<var>& y) {
285 return internal::complex_pow(x, y);
286}
287
296template <typename T, typename = require_arithmetic_t<T>>
297inline std::complex<var> pow(const std::complex<var>& x,
298 const std::complex<T> y) {
299 return internal::complex_pow(x, y);
300}
301
309inline std::complex<var> pow(const std::complex<var>& x, const var& y) {
310 return internal::complex_pow(x, y);
311}
312
321template <typename T, typename = require_arithmetic_t<T>>
322inline std::complex<var> pow(const std::complex<var>& x, T y) {
323 return internal::complex_pow(x, y);
324}
325
334template <typename T, typename = require_arithmetic_t<T>>
335inline std::complex<var> pow(std::complex<T> x, const std::complex<var>& y) {
336 return internal::complex_pow(x, y);
337}
338
347template <typename T, typename = require_arithmetic_t<T>>
348inline std::complex<var> pow(std::complex<T> x, const var& y) {
349 return internal::complex_pow(x, y);
350}
351
359inline std::complex<var> pow(const var& x, const std::complex<var>& y) {
360 return internal::complex_pow(x, y);
361}
362
371template <typename T, typename = require_arithmetic_t<T>>
372inline std::complex<var> pow(const var& x, std::complex<T> y) {
373 return internal::complex_pow(x, y);
374}
375
384template <typename T, typename = require_arithmetic_t<T>>
385inline std::complex<var> pow(T x, const std::complex<var>& y) {
386 return internal::complex_pow(x, y);
387}
388
400inline std::complex<var> pow(const std::complex<var>& x, int y) {
401 return internal::complex_pow(x, y);
402}
403
404} // namespace math
405} // namespace stan
406#endif
#define unlikely(x)
require_any_t< container_type_check_base< is_matrix, scalar_type_t, TypeCheck, Check >... > require_any_matrix_st
Require any of the types satisfy is_matrix.
Definition is_matrix.hpp:64
require_all_t< container_type_check_base< is_matrix, scalar_type_t, TypeCheck, Check >... > require_all_matrix_st
Require all of the types does not satisfy is_matrix.
Definition is_matrix.hpp:73
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
require_all_not_t< is_stan_scalar< std::decay_t< Types > >... > require_all_not_stan_scalar_t
Require none of the types satisfy is_stan_scalar.
require_all_t< is_var_or_arithmetic< scalar_type_t< std::decay_t< Types > > >... > require_all_st_var_or_arithmetic
Require all of the scalar types satisfy is_var_or_arithmetic.
complex_return_t< U, V > complex_pow(const U &x, const V &y)
Return the first argument raised to the power of the second argument.
Definition pow.hpp:26
T as_array_or_scalar(T &&v)
Returns specified input value.
fvar< T > inv_square(const fvar< T > &x)
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:17
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
fvar< T > pow(const fvar< T > &x1, const fvar< T > &x2)
Definition pow.hpp:19
fvar< T > inv_sqrt(const fvar< T > &x)
Definition inv_sqrt.hpp:12
fvar< T > inv(const fvar< T > &x)
Definition inv.hpp:12
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Check if a type is derived from Eigen::ArrayBase
Definition is_eigen.hpp:206
Extends std::false_type when instantiated with zero or more template parameters, all of which extend ...