Automatic Differentiation
 
Loading...
Searching...
No Matches
fft.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_FFT_HPP
2#define STAN_MATH_REV_FUN_FFT_HPP
3
9#include <complex>
10#include <type_traits>
11#include <vector>
12
13namespace stan {
14namespace math {
15
39template <typename V, require_eigen_vector_vt<is_complex, V>* = nullptr,
40 require_var_t<base_type_t<value_type_t<V>>>* = nullptr>
41inline plain_type_t<V> fft(const V& x) {
42 if (unlikely(x.size() <= 1)) {
43 return plain_type_t<V>(x);
44 }
45
46 arena_t<V> arena_v = x;
47 arena_t<V> res = fft(to_complex(arena_v.real().val(), arena_v.imag().val()));
48
49 reverse_pass_callback([arena_v, res]() mutable {
50 auto adj_inv_fft = inv_fft(to_complex(res.real().adj(), res.imag().adj()));
51 adj_inv_fft *= res.size();
52 arena_v.real().adj() += adj_inv_fft.real();
53 arena_v.imag().adj() += adj_inv_fft.imag();
54 });
55
56 return plain_type_t<V>(res);
57}
58
84template <typename V, require_eigen_vector_vt<is_complex, V>* = nullptr,
85 require_var_t<base_type_t<value_type_t<V>>>* = nullptr>
86inline plain_type_t<V> inv_fft(const V& y) {
87 if (unlikely(y.size() <= 1)) {
88 return plain_type_t<V>(y);
89 }
90
91 arena_t<V> arena_v = y;
92 arena_t<V> res
93 = inv_fft(to_complex(arena_v.real().val(), arena_v.imag().val()));
94
95 reverse_pass_callback([arena_v, res]() mutable {
96 auto adj_fft = fft(to_complex(res.real().adj(), res.imag().adj()));
97 adj_fft /= res.size();
98
99 arena_v.real().adj() += adj_fft.real();
100 arena_v.imag().adj() += adj_fft.imag();
101 });
102 return plain_type_t<V>(res);
103}
104
120template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
121 require_var_t<base_type_t<value_type_t<M>>>* = nullptr>
122inline plain_type_t<M> fft2(const M& x) {
123 arena_t<M> arena_v = x;
124 arena_t<M> res = fft2(to_complex(arena_v.real().val(), arena_v.imag().val()));
125
126 reverse_pass_callback([arena_v, res]() mutable {
127 auto adj_inv_fft = inv_fft2(to_complex(res.real().adj(), res.imag().adj()));
128 adj_inv_fft *= res.size();
129 arena_v.real().adj() += adj_inv_fft.real();
130 arena_v.imag().adj() += adj_inv_fft.imag();
131 });
132
133 return plain_type_t<M>(res);
134}
135
152template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
153 require_var_t<base_type_t<value_type_t<M>>>* = nullptr>
154inline plain_type_t<M> inv_fft2(const M& y) {
155 arena_t<M> arena_v = y;
156 arena_t<M> res
157 = inv_fft2(to_complex(arena_v.real().val(), arena_v.imag().val()));
158
159 reverse_pass_callback([arena_v, res]() mutable {
160 auto adj_fft = fft2(to_complex(res.real().adj(), res.imag().adj()));
161 adj_fft /= res.size();
162
163 arena_v.real().adj() += adj_fft.real();
164 arena_v.imag().adj() += adj_fft.imag();
165 });
166 return plain_type_t<M>(res);
167}
168
169} // namespace math
170} // namespace stan
171#endif
#define unlikely(x)
Eigen::Matrix< scalar_type_t< M >, -1, -1 > fft2(const M &x)
Return the two-dimensional discrete Fourier transform of the specified complex matrix.
Definition fft.hpp:87
constexpr std::complex< stan::real_return_t< T, S > > to_complex(const T &re=0, const S &im=0)
Return a complex value from a real component and an imaginary component.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
Eigen::Matrix< scalar_type_t< V >, -1, 1 > fft(const V &x)
Return the discrete Fourier transform of the specified complex vector.
Definition fft.hpp:35
Eigen::Matrix< scalar_type_t< V >, -1, 1 > inv_fft(const V &y)
Return the inverse discrete Fourier transform of the specified complex vector.
Definition fft.hpp:66
Eigen::Matrix< scalar_type_t< M >, -1, -1 > inv_fft2(const M &y)
Return the two-dimensional inverse discrete Fourier transform of the specified complex matrix.
Definition fft.hpp:109
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...