Automatic Differentiation
 
Loading...
Searching...
No Matches
squared_distance.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_SQUARED_DISTANCE_HPP
2#define STAN_MATH_REV_FUN_SQUARED_DISTANCE_HPP
3
13#include <vector>
14
15namespace stan {
16namespace math {
17
21inline var squared_distance(const var& a, const var& b) {
22 check_finite("squared_distance", "a", a);
23 check_finite("squared_distance", "b", b);
24 return make_callback_vari(std::pow(a.val() - b.val(), 2),
25 [a, b](const auto& vi) mutable {
26 const double diff = 2.0 * (a.val() - b.val());
27 a.adj() += vi.adj_ * diff;
28 b.adj() -= vi.adj_ * diff;
29 });
30}
31
35inline var squared_distance(const var& a, double b) {
36 check_finite("squared_distance", "a", a);
37 check_finite("squared_distance", "b", b);
38 return make_callback_vari(std::pow(a.val() - b, 2),
39 [a, b](const auto& vi) mutable {
40 a.adj() += vi.adj_ * 2.0 * (a.val() - b);
41 });
42}
43
47inline var squared_distance(double a, const var& b) {
48 return squared_distance(b, a);
49}
50
51namespace internal {
52
54 protected:
57 size_t length_;
58
59 public:
60 template <
61 typename EigVecVar1, typename EigVecVar2,
63 squared_distance_vv_vari(const EigVecVar1& v1, const EigVecVar2& v2)
66 .squaredNorm()),
67 length_(v1.size()) {
68 v1_ = reinterpret_cast<vari**>(
70 v2_ = reinterpret_cast<vari**>(
72 Eigen::Map<vector_vi>(v1_, length_) = v1.vi();
73 Eigen::Map<vector_vi>(v2_, length_) = v2.vi();
74 }
75
76 virtual void chain() {
77 Eigen::Map<vector_vi> v1_map(v1_, length_);
78 Eigen::Map<vector_vi> v2_map(v2_, length_);
79 vector_d di = 2 * adj_ * (v1_map.val() - v2_map.val());
80 v1_map.adj() += di;
81 v2_map.adj() -= di;
82 }
83};
84
86 protected:
88 double* v2_;
89 size_t length_;
90
91 public:
92 template <typename EigVecVar, typename EigVecArith,
95 squared_distance_vd_vari(const EigVecVar& v1, const EigVecArith& v2)
98 .squaredNorm()),
99 length_(v1.size()) {
100 v1_ = reinterpret_cast<vari**>(
102 v2_ = reinterpret_cast<double*>(
104 Eigen::Map<vector_vi>(v1_, length_) = v1.vi();
105 Eigen::Map<vector_d>(v2_, length_) = v2;
106 }
107
108 virtual void chain() {
109 Eigen::Map<vector_vi> v1_map(v1_, length_);
110 v1_map.adj()
111 += 2 * adj_ * (v1_map.val() - Eigen::Map<vector_d>(v2_, length_));
112 }
113};
114} // namespace internal
115
116template <
117 typename EigVecVar1, typename EigVecVar2,
119inline var squared_distance(const EigVecVar1& v1, const EigVecVar2& v2) {
120 check_matching_sizes("squared_distance", "v1", v1, "v2", v2);
121 return {new internal::squared_distance_vv_vari(to_ref(v1), to_ref(v2))};
122}
123
124template <typename EigVecVar, typename EigVecArith,
127inline var squared_distance(const EigVecVar& v1, const EigVecArith& v2) {
128 check_matching_sizes("squared_distance", "v1", v1, "v2", v2);
129 return {new internal::squared_distance_vd_vari(to_ref(v1), to_ref(v2))};
130}
131
132template <typename EigVecArith, typename EigVecVar,
135inline var squared_distance(const EigVecArith& v1, const EigVecVar& v2) {
136 check_matching_sizes("squared_distance", "v1", v1, "v2", v2);
137 return {new internal::squared_distance_vd_vari(to_ref(v2), to_ref(v1))};
138}
139
155template <typename T1, typename T2, require_all_vector_t<T1, T2>* = nullptr,
156 require_any_var_vector_t<T1, T2>* = nullptr>
157inline var squared_distance(const T1& A, const T2& B) {
158 check_matching_sizes("squared_distance", "A", A.val(), "B", B.val());
159 if (unlikely(A.size() == 0)) {
160 return var(0.0);
164 arena_t<Eigen::VectorXd> res_diff(arena_A.size());
165 double res_val = 0.0;
166 for (size_t i = 0; i < arena_A.size(); ++i) {
167 const double diff = arena_A.val().coeff(i) - arena_B.val().coeff(i);
168 res_diff.coeffRef(i) = diff;
169 res_val += diff * diff;
170 }
171 return var(make_callback_vari(
172 res_val, [arena_A, arena_B, res_diff](const auto& res) mutable {
173 const double res_adj = 2.0 * res.adj();
174 for (size_t i = 0; i < arena_A.size(); ++i) {
175 const double diff = res_adj * res_diff.coeff(i);
176 arena_A.adj().coeffRef(i) += diff;
177 arena_B.adj().coeffRef(i) -= diff;
178 }
179 }));
180 } else if (!is_constant<T1>::value) {
183 arena_t<Eigen::VectorXd> res_diff(arena_A.size());
184 double res_val = 0.0;
185 for (size_t i = 0; i < arena_A.size(); ++i) {
186 const double diff = arena_A.val().coeff(i) - arena_B.coeff(i);
187 res_diff.coeffRef(i) = diff;
188 res_val += diff * diff;
189 }
190 return var(make_callback_vari(
191 res_val, [arena_A, arena_B, res_diff](const auto& res) mutable {
192 arena_A.adj() += 2.0 * res.adj() * res_diff;
193 }));
194 } else {
197 arena_t<Eigen::VectorXd> res_diff(arena_A.size());
198 double res_val = 0.0;
199 for (size_t i = 0; i < arena_A.size(); ++i) {
200 const double diff = arena_A.coeff(i) - arena_B.val().coeff(i);
201 res_diff.coeffRef(i) = diff;
202 res_val += diff * diff;
203 }
204 return var(make_callback_vari(
205 res_val, [arena_A, arena_B, res_diff](const auto& res) mutable {
206 arena_B.adj() -= 2.0 * res.adj() * res_diff;
207 }));
208 }
209}
210
211} // namespace math
212} // namespace stan
213#endif
squared_distance_vd_vari(const EigVecVar &v1, const EigVecArith &v2)
squared_distance_vv_vari(const EigVecVar1 &v1, const EigVecVar2 &v2)
void * alloc(size_t len)
Return a newly allocated block of memory of the appropriate size managed by the stack allocator.
#define unlikely(x)
require_all_t< container_type_check_base< is_eigen_vector, value_type_t, TypeCheck, Check >... > require_all_eigen_vector_vt
Require all of the types satisfy is_eigen_vector.
require_t< container_type_check_base< is_eigen_vector, value_type_t, TypeCheck, Check... > > require_eigen_vector_vt
Require type satisfies is_eigen_vector.
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
size_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:18
Eigen::Matrix< double, Eigen::Dynamic, 1 > vector_d
Type for (column) vector of double values.
Definition typedefs.hpp:24
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_matching_sizes(const char *function, const char *name1, const T_y1 &y1, const char *name2, const T_y2 &y2)
Check if two structures at the same size.
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
var_value< double > var
Definition var.hpp:1187
auto squared_distance(const T_a &a, const T_b &b)
Returns the squared distance.
internal::callback_vari< plain_type_t< T >, F > * make_callback_vari(T &&value, F &&functor)
Creates a new vari with given value and a callback that implements the reverse pass (chain).
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
static thread_local AutodiffStackStorage * instance_