1#ifndef STAN_MATH_REV_FUN_POW_HPP
2#define STAN_MATH_REV_FUN_POW_HPP
72template <
typename Scal1,
typename Scal2,
73 require_any_var_t<base_type_t<Scal1>, base_type_t<Scal2>>* =
nullptr,
74 require_all_stan_scalar_t<Scal1, Scal2>* =
nullptr>
75inline auto pow(
const Scal1& base,
const Scal2& exponent) {
80 if (exponent == 0.5) {
82 }
else if (exponent == 1.0) {
84 }
else if (exponent == 2.0) {
86 }
else if (exponent == -2.0) {
88 }
else if (exponent == -1.0) {
90 }
else if (exponent == -0.5) {
95 [base, exponent](
auto&& vi)
mutable {
99 const double vi_mul = vi.adj() * vi.val();
102 forward_as<var>(base).adj()
107 forward_as<var>(exponent).adj()
108 += vi_mul * std::log(
value_of(base));
127template <
typename Mat1,
typename Mat2,
131inline auto pow(
const Mat1& base,
const Mat2& exponent) {
135 using val_type = std::conditional_t<
137 decltype(std::declval<expr_type>().eval()),
138 decltype(std::declval<expr_type>().matrix().eval())>;
152 const auto& ret_mul =
to_ref(ret.adj().array() * ret.val().array());
155 forward_as<base_var_arena_t>(arena_base).adj()
163 forward_as<exp_var_arena_t>(arena_exponent).adj()
164 += (are_vals_zero).
select(ret_mul *
value_of(arena_base).log(), 0);
167 return ret_type(ret);
181template <
typename Mat1,
typename Scal1,
185inline auto pow(
const Mat1& base,
const Scal1& exponent) {
189 if (exponent == 0.5) {
190 return ret_type(
sqrt(base));
191 }
else if (exponent == 1.0) {
192 return ret_type(base);
193 }
else if (exponent == 2.0) {
194 return ret_type(
square(base));
195 }
else if (exponent == -2.0) {
197 }
else if (exponent == -1.0) {
198 return ret_type(
inv(base));
199 }
else if (exponent == -0.5) {
209 const auto& are_vals_zero =
to_ref(
value_of(arena_base).array() != 0.0);
210 const auto& ret_mul =
to_ref(ret.adj().array() * ret.val().array());
212 forward_as<ret_type>(arena_base).adj().array()
219 forward_as<var>(exponent).adj()
226 return ret_type(ret);
246template <
typename Scal1,
typename Mat1,
250inline auto pow(Scal1 base,
const Mat1& exponent) {
260 const auto& ret_mul =
to_ref(ret.adj().array() * ret.val().array());
262 forward_as<var>(base).adj()
267 forward_as<ret_type>(arena_exponent).adj().array()
268 += ret_mul * std::log(
value_of(base));
271 return ret_type(ret);
285template <
typename T1,
typename T2, require_any_container_t<T1, T2>* =
nullptr,
286 require_all_not_matrix_st<is_var, T1, T2>* =
nullptr,
287 require_any_var_t<base_type_t<T1>, base_type_t<T2>>* =
nullptr>
288inline auto pow(
const T1& a,
const T2& b) {
290 a, b, [](
const auto& c,
const auto& d) {
return stan::math::pow(c, d); });
require_any_t< container_type_check_base< is_matrix, scalar_type_t, TypeCheck, Check >... > require_any_matrix_st
Require any of the types satisfy is_matrix.
require_all_t< container_type_check_base< is_matrix, scalar_type_t, TypeCheck, Check >... > require_all_matrix_st
Require all of the types does not satisfy is_matrix.
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
require_all_not_t< is_stan_scalar< std::decay_t< Types > >... > require_all_not_stan_scalar_t
Require none of the types satisfy is_stan_scalar.
require_all_t< is_var_or_arithmetic< scalar_type_t< std::decay_t< Types > > >... > require_all_st_var_or_arithmetic
Require all of the scalar types satisfy is_var_or_arithmetic.
complex_return_t< U, V > complex_pow(const U &x, const V &y)
Return the first argument raised to the power of the second argument.
T as_array_or_scalar(T &&v)
Returns specified input value.
fvar< T > inv_square(const fvar< T > &x)
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
auto pow(const T1 &x1, const T2 &x2)
T value_of(const fvar< T > &v)
Return the value of the specified variable.
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
fvar< T > sqrt(const fvar< T > &x)
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
auto apply_scalar_binary(const T1 &x, const T2 &y, const F &f)
Base template function for vectorization of binary scalar functions defined by applying a functor to ...
fvar< T > inv_sqrt(const fvar< T > &x)
fvar< T > inv(const fvar< T > &x)
fvar< T > square(const fvar< T > &x)
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
If T is an arithmetic type (that is, an instance of std::complex) or a cv-qualified version thereof,...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Check if a type is derived from Eigen::ArrayBase
Extends std::false_type when instantiated with zero or more template parameters, all of which extend ...