Automatic Differentiation
 
Loading...
Searching...
No Matches
fft.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_FFT_HPP
2#define STAN_MATH_REV_FUN_FFT_HPP
3
9#include <Eigen/Dense>
10#include <complex>
11#include <type_traits>
12#include <vector>
13
14namespace stan {
15namespace math {
16
40template <typename V, require_eigen_vector_vt<is_complex, V>* = nullptr,
41 require_var_t<base_type_t<value_type_t<V>>>* = nullptr>
42inline plain_type_t<V> fft(const V& x) {
43 if (unlikely(x.size() <= 1)) {
44 return plain_type_t<V>(x);
45 }
46
47 arena_t<V> arena_v = x;
48 arena_t<V> res = fft(to_complex(arena_v.real().val(), arena_v.imag().val()));
49
50 reverse_pass_callback([arena_v, res]() mutable {
51 auto adj_inv_fft = inv_fft(to_complex(res.real().adj(), res.imag().adj()));
52 adj_inv_fft *= res.size();
53 arena_v.real().adj() += adj_inv_fft.real();
54 arena_v.imag().adj() += adj_inv_fft.imag();
55 });
56
57 return plain_type_t<V>(res);
58}
59
85template <typename V, require_eigen_vector_vt<is_complex, V>* = nullptr,
86 require_var_t<base_type_t<value_type_t<V>>>* = nullptr>
87inline plain_type_t<V> inv_fft(const V& y) {
88 if (unlikely(y.size() <= 1)) {
89 return plain_type_t<V>(y);
90 }
91
92 arena_t<V> arena_v = y;
93 arena_t<V> res
94 = inv_fft(to_complex(arena_v.real().val(), arena_v.imag().val()));
95
96 reverse_pass_callback([arena_v, res]() mutable {
97 auto adj_fft = fft(to_complex(res.real().adj(), res.imag().adj()));
98 adj_fft /= res.size();
99
100 arena_v.real().adj() += adj_fft.real();
101 arena_v.imag().adj() += adj_fft.imag();
102 });
103 return plain_type_t<V>(res);
104}
105
121template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
122 require_var_t<base_type_t<value_type_t<M>>>* = nullptr>
123inline plain_type_t<M> fft2(const M& x) {
124 arena_t<M> arena_v = x;
125 arena_t<M> res = fft2(to_complex(arena_v.real().val(), arena_v.imag().val()));
126
127 reverse_pass_callback([arena_v, res]() mutable {
128 auto adj_inv_fft = inv_fft2(to_complex(res.real().adj(), res.imag().adj()));
129 adj_inv_fft *= res.size();
130 arena_v.real().adj() += adj_inv_fft.real();
131 arena_v.imag().adj() += adj_inv_fft.imag();
132 });
133
134 return plain_type_t<M>(res);
135}
136
153template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
154 require_var_t<base_type_t<value_type_t<M>>>* = nullptr>
155inline plain_type_t<M> inv_fft2(const M& y) {
156 arena_t<M> arena_v = y;
157 arena_t<M> res
158 = inv_fft2(to_complex(arena_v.real().val(), arena_v.imag().val()));
159
160 reverse_pass_callback([arena_v, res]() mutable {
161 auto adj_fft = fft2(to_complex(res.real().adj(), res.imag().adj()));
162 adj_fft /= res.size();
163
164 arena_v.real().adj() += adj_fft.real();
165 arena_v.imag().adj() += adj_fft.imag();
166 });
167 return plain_type_t<M>(res);
168}
169
170} // namespace math
171} // namespace stan
172#endif
#define unlikely(x)
Eigen::Matrix< scalar_type_t< M >, -1, -1 > fft2(const M &x)
Return the two-dimensional discrete Fourier transform of the specified complex matrix.
Definition fft.hpp:87
constexpr std::complex< stan::real_return_t< T, S > > to_complex(const T &re=0, const S &im=0)
Return a complex value from a real component and an imaginary component.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
Eigen::Matrix< scalar_type_t< V >, -1, 1 > fft(const V &x)
Return the discrete Fourier transform of the specified complex vector.
Definition fft.hpp:35
Eigen::Matrix< scalar_type_t< V >, -1, 1 > inv_fft(const V &y)
Return the inverse discrete Fourier transform of the specified complex vector.
Definition fft.hpp:66
Eigen::Matrix< scalar_type_t< M >, -1, -1 > inv_fft2(const M &y)
Return the two-dimensional inverse discrete Fourier transform of the specified complex matrix.
Definition fft.hpp:109
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...