// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // fft.hpp #ifndef CALX_FFT_HPP #define CALX_FFT_HPP #include #include #include #include #include #include #include #include #include #include namespace calx { // FFT 正規化モード enum class FFTDivMode { None, // 正規化なし Forward, // 順変換で正規化 (1/N) Inverse, // 逆変換で正規化 (1/N) — デフォルト Both // 両方で正規化 (1/sqrt(N)) }; /** * @brief 高速フーリエ変換 (FFT) のクラス * * このクラスは混合基数(radix 2, 3, 5)FFTを提供します。 * * @tparam T 値の型(通常は複素数型、あるいはModularInt) * @tparam Real 実数型(T=std::complexの場合のReal型) */ template class FFT { private: // 型特性 using complex_type = std::conditional_t>, T, std::complex>; // ダミーコンポーネント(コンパイル時にSFINAEを使用する場合) template static constexpr bool is_complex = false; template static constexpr bool is_complex> = true; // 単位円上の点を計算 static complex_type omega(int k, int n) { const Real theta = 2.0 * std::numbers::pi_v * static_cast(k) / static_cast(n); return complex_type(std::cos(theta), std::sin(theta)); } // 素因数分解 static std::vector factorize(int n) { std::vector factors; // 5の因数を抽出 while (n % 5 == 0) { factors.push_back(5); n /= 5; } // 3の因数を抽出 while (n % 3 == 0) { factors.push_back(3); n /= 3; } // 2の因数を抽出 while (n % 2 == 0) { factors.push_back(2); n /= 2; } // 残りの素因数(ある場合) if (n > 1) { factors.push_back(n); } // 因数を小さい順にソート std::sort(factors.begin(), factors.end()); return factors; } // Radix-2 バタフライ演算 static void butterfly_radix2(std::span data, int stride, int offset, int m) { for (int i = 0; i < m; ++i) { int idx1 = offset + i; int idx2 = idx1 + m; T temp = data[idx1]; data[idx1] = temp + data[idx2]; data[idx2] = temp - data[idx2]; } } // Radix-3 バタフライ演算 static void butterfly_radix3(std::span data, int stride, int offset, int m) { const complex_type w1 = omega(1, 3); const complex_type w2 = omega(2, 3); for (int i = 0; i < m; ++i) { int idx1 = offset + i; int idx2 = idx1 + m; int idx3 = idx2 + m; T x0 = data[idx1]; T x1 = data[idx2]; T x2 = data[idx3]; // DFT計算 data[idx1] = x0 + x1 + x2; data[idx2] = x0 + w1 * x1 + w2 * x2; data[idx3] = x0 + w2 * x1 + w1 * x2; } } // Radix-5 バタフライ演算 static void butterfly_radix5(std::span data, int stride, int offset, int m) { const complex_type w1 = omega(1, 5); const complex_type w2 = omega(2, 5); const complex_type w3 = omega(3, 5); const complex_type w4 = omega(4, 5); for (int i = 0; i < m; ++i) { int idx1 = offset + i; int idx2 = idx1 + m; int idx3 = idx2 + m; int idx4 = idx3 + m; int idx5 = idx4 + m; T x0 = data[idx1]; T x1 = data[idx2]; T x2 = data[idx3]; T x3 = data[idx4]; T x4 = data[idx5]; // DFT計算 data[idx1] = x0 + x1 + x2 + x3 + x4; data[idx2] = x0 + w1 * x1 + w2 * x2 + w3 * x3 + w4 * x4; data[idx3] = x0 + w2 * x1 + w4 * x2 + w1 * x3 + w3 * x4; data[idx4] = x0 + w3 * x1 + w1 * x2 + w4 * x3 + w2 * x4; data[idx5] = x0 + w4 * x1 + w3 * x2 + w2 * x3 + w1 * x4; } } // ビット反転順序を適用 template static void apply_bit_reversal(std::span data) { int n = static_cast(data.size()); for (int i = 0, j = 0; i < n; ++i) { if (i < j) { std::swap(data[i], data[j]); } // ビット反転のインデックス更新 int mask = n; do { mask >>= 1; j ^= mask; } while (mask && (j & mask) == 0); } } // Bluestein FFT: 任意長 N の DFT を radix-2 畳み込みで計算 (Chirp-Z 変換) // // nk = (n² + k² - (k-n)²) / 2 の恒等式を利用して // DFT を長さ M ≥ 2N-1 (2のべき乗) の巡回畳み込みに帰着する。 static void bluestein_fft(std::span data, bool inverse = false) { const int N = static_cast(data.size()); if (N <= 1) return; // M ≥ 2N-1 の最小 2 のべき乗 int M = 1; while (M < 2 * N - 1) M <<= 1; const Real pi = std::numbers::pi_v; // 順変換: exp(-πi·n²/N), 逆変換: exp(+πi·n²/N) const Real sign = inverse ? Real(1) : Real(-1); // チャープ列: chirp[n] = exp(sign·πi·n²/N) std::vector chirp(N); for (int n = 0; n < N; ++n) { Real angle = sign * pi * static_cast(static_cast(n) * n) / static_cast(N); chirp[n] = complex_type(std::cos(angle), std::sin(angle)); } // a[n] = x[n] · chirp[n], ゼロパディングして長さ M std::vector a(M, complex_type(0, 0)); for (int n = 0; n < N; ++n) { if constexpr (is_complex) { a[n] = static_cast(data[n]) * chirp[n]; } else { a[n] = complex_type(static_cast(data[n]), Real(0)) * chirp[n]; } } // b: conj(chirp) を巡回配置 // b[0] = conj(chirp[0]) // b[n] = conj(chirp[n]) (n = 1..N-1) // b[M-n] = conj(chirp[n]) (n = 1..N-1, 負インデックスのラップ) std::vector b(M, complex_type(0, 0)); b[0] = std::conj(chirp[0]); for (int n = 1; n < N; ++n) { b[n] = std::conj(chirp[n]); b[M - n] = std::conj(chirp[n]); } // M は 2 のべき乗 → radix-2 パスで FFT FFT::fft(std::span(a)); FFT::fft(std::span(b)); for (int i = 0; i < M; ++i) a[i] *= b[i]; FFT::ifft(std::span(a)); // X[k] = chirp[k] · (a ★ b)[k] for (int k = 0; k < N; ++k) { complex_type val = chirp[k] * a[k]; if constexpr (is_complex) { data[k] = val; } else { data[k] = static_cast(val.real()); } } // 逆変換: 1/N 正規化 if (inverse) { Real inv_n = Real(1) / static_cast(N); for (int k = 0; k < N; ++k) { if constexpr (is_complex) { data[k] *= inv_n; } else { data[k] = static_cast(static_cast(data[k]) * inv_n); } } } } // 混合基数FFTの実装 static void mixed_radix_fft_impl(std::span data, const std::vector& factors, bool inverse = false) { const int n = static_cast(data.size()); if (n <= 1) return; // べき乗として表現できるか確認 int product = 1; for (int factor : factors) { product *= factor; } if (product != n) { throw MathError("Input size must be a product of supported radices (2, 3, 5)"); } // ビット反転または同等の並べ替え // ... (効率的な実装のために複雑なので略) // 各段階でのFFT int m = 1; for (int factor : factors) { int stride = n / (m * factor); for (int i = 0; i < n; i += m * factor) { for (int j = 0; j < m; ++j) { int offset = i + j; // 基数に応じたバタフライ演算 if (factor == 2) { butterfly_radix2(data, stride, offset, m); } else if (factor == 3) { butterfly_radix3(data, stride, offset, m); } else if (factor == 5) { butterfly_radix5(data, stride, offset, m); } else { // 一般的なDFT(基数が実装されていない場合) throw MathError("Unsupported radix factor"); } } } m *= factor; } // 逆変換の場合は結果をNで割る if (inverse) { Real inv_n = 1.0 / static_cast(n); for (int i = 0; i < n; ++i) { data[i] = data[i] * inv_n; } } } public: /** * @brief FFT計算が可能なサイズかを確認 * @param n 入力サイズ * @return 計算可能な場合はtrue */ static bool is_valid_size(int n) { if (n <= 0) return false; // 2, 3, 5の累乗の積に分解できるか確認 while (n % 2 == 0) n /= 2; while (n % 3 == 0) n /= 3; while (n % 5 == 0) n /= 5; return n == 1; } /** * @brief 有効なFFTサイズに切り上げ * @param n 必要最小サイズ * @return 最小の有効なFFTサイズ */ static int next_valid_size(int n) { while (!is_valid_size(n)) { ++n; } return n; } /** * @brief 順変換 (FFT) * @param data 入力/出力データ配列 */ static void fft(std::span data) { const int n = static_cast(data.size()); if (n <= 1) return; // サイズが2の累乗の場合は特殊最適化 if ((n & (n - 1)) == 0) { // float/double の場合は SIMD 最適化パスにディスパッチ if constexpr (std::is_same_v> || std::is_same_v>) { simd_fft::fft(data.data(), n); } else { // 非 float/double (ModularInt 等): スカラー radix-2 apply_bit_reversal(data); for (int len = 2; len <= n; len <<= 1) { const int half_len = len / 2; const Real angle = 2.0 * std::numbers::pi_v / static_cast(len); for (int i = 0; i < n; i += len) { complex_type w(1.0, 0.0); complex_type wn(std::cos(angle), std::sin(angle)); for (int j = 0; j < half_len; ++j) { T u = data[i + j]; T v = data[i + j + half_len] * w; data[i + j] = u + v; data[i + j + half_len] = u - v; w *= wn; } } } } } else if (is_valid_size(n)) { // 混合基数FFT (2,3,5) std::vector factors = factorize(n); mixed_radix_fft_impl(data, factors, false); } else { // 任意長: Bluestein/Chirp-Z bluestein_fft(data, false); } } /** * @brief 逆変換 (IFFT) * @param data 入力/出力データ配列 */ static void ifft(std::span data) { const int n = static_cast(data.size()); if (n <= 1) return; // 一般的なアプローチ: 複素共役を取り、FFTを実行し、Nで割る if constexpr (is_complex) { // 複素数型の場合 for (int i = 0; i < n; ++i) { data[i] = std::conj(data[i]); } fft(data); Real inv_n = 1.0 / static_cast(n); for (int i = 0; i < n; ++i) { data[i] = std::conj(data[i]) * inv_n; } } else { // 非複素数型の場合(ModularIntなど) // サイズが2の累乗の場合は特殊最適化 if ((n & (n - 1)) == 0) { // ビット反転は通常のFFTと同じ apply_bit_reversal(data); for (int len = 2; len <= n; len <<= 1) { const int half_len = len / 2; const Real angle = -2.0 * std::numbers::pi_v / static_cast(len); // 負の角度 for (int i = 0; i < n; i += len) { complex_type w(1.0, 0.0); complex_type wn(std::cos(angle), std::sin(angle)); for (int j = 0; j < half_len; ++j) { T u = data[i + j]; T v = data[i + j + half_len] * w; data[i + j] = u + v; data[i + j + half_len] = u - v; w *= wn; } } } // Nで割る T inv_n = static_cast(1) / static_cast(n); for (int i = 0; i < n; ++i) { data[i] = data[i] * inv_n; } } else if (is_valid_size(n)) { // 混合基数FFT(逆変換フラグ付き) std::vector factors = factorize(n); mixed_radix_fft_impl(data, factors, true); } else { // 任意長: Bluestein/Chirp-Z (逆変換) bluestein_fft(data, true); } } } /** * @brief 多項式乗算(畳み込み) * @param a 入力多項式1 * @param b 入力多項式2 * @return 乗算結果の多項式 */ static std::vector convolve(std::span a, std::span b) { const int na = static_cast(a.size()); const int nb = static_cast(b.size()); const int n = na + nb - 1; // FFTに適したサイズを計算 int fft_size = 1; while (fft_size < n) { fft_size *= 2; } // 配列を拡張 std::vector fa(fft_size, static_cast(0)); std::vector fb(fft_size, static_cast(0)); for (int i = 0; i < na; ++i) fa[i] = a[i]; for (int i = 0; i < nb; ++i) fb[i] = b[i]; // FFTを実行 fft(std::span(fa)); fft(std::span(fb)); // スペクトル領域で乗算 for (int i = 0; i < fft_size; ++i) { fa[i] = fa[i] * fb[i]; } // 逆FFT ifft(std::span(fa)); // 結果を切り詰める fa.resize(n); return fa; } /** * @brief 高速数論変換 (NTT) - ModularInt用のFFT * @param data 入力/出力データ * @param primitive_root 原始根 */ template static void ntt(std::span> data, ModularInt

primitive_root) { const int n = static_cast(data.size()); if (n <= 1) return; if ((n & (n - 1)) != 0) { throw MathError("NTT only supports power of 2 sizes"); } // ビット反転順序に並べ替え apply_bit_reversal(data); for (int len = 2; len <= n; len <<= 1) { const int half_len = len / 2; const ModularInt

wn = primitive_root.pow((P - 1) / len); for (int i = 0; i < n; i += len) { ModularInt

w(1); for (int j = 0; j < half_len; ++j) { ModularInt

u = data[i + j]; ModularInt

v = data[i + j + half_len] * w; data[i + j] = u + v; data[i + j + half_len] = u - v; w = w * wn; } } } } /** * @brief 逆高速数論変換 (INTT) - ModularInt用の逆FFT * @param data 入力/出力データ * @param primitive_root 原始根 */ template static void intt(std::span> data, ModularInt

primitive_root) { const int n = static_cast(data.size()); if (n <= 1) return; if ((n & (n - 1)) != 0) { throw MathError("INTT only supports power of 2 sizes"); } // ビット反転順序に並べ替え apply_bit_reversal(data); for (int len = 2; len <= n; len <<= 1) { const int half_len = len / 2; const ModularInt

wn = primitive_root.pow((P - 1) / len).inverse(); for (int i = 0; i < n; i += len) { ModularInt

w(1); for (int j = 0; j < half_len; ++j) { ModularInt

u = data[i + j]; ModularInt

v = data[i + j + half_len] * w; data[i + j] = u + v; data[i + j + half_len] = u - v; w = w * wn; } } } // Nで割る ModularInt

inv_n = ModularInt

(n).inverse(); for (int i = 0; i < n; ++i) { data[i] = data[i] * inv_n; } } /** * @brief 多項式乗算をNTTで実行(ModularInt用) * @param a 入力多項式1 * @param b 入力多項式2 * @param primitive_root 原始根 * @return 乗算結果の多項式 */ template static std::vector> convolve_ntt( std::span> a, std::span> b, ModularInt

primitive_root) { const int na = static_cast(a.size()); const int nb = static_cast(b.size()); const int n = na + nb - 1; // NTTに適したサイズを計算(2の累乗) int ntt_size = 1; while (ntt_size < n) { ntt_size *= 2; } // 配列を拡張 std::vector> fa(ntt_size, ModularInt

(0)); std::vector> fb(ntt_size, ModularInt

(0)); for (int i = 0; i < na; ++i) fa[i] = a[i]; for (int i = 0; i < nb; ++i) fb[i] = b[i]; // NTTを実行 ntt(std::span>(fa), primitive_root); ntt(std::span>(fb), primitive_root); // スペクトル領域で乗算 for (int i = 0; i < ntt_size; ++i) { fa[i] = fa[i] * fb[i]; } // 逆NTT intt(std::span>(fa), primitive_root); // 結果を切り詰める fa.resize(n); return fa; } }; // 実数型向けの特殊化(実FFT) template class RealFFT { private: using complex_type = std::complex; public: /** * @brief 実数データに対する高速フーリエ変換 * @param data 実数データ * @return 複素数スペクトル(サイズはdata.size()/2+1) */ static std::vector> fft(std::span data) { const int n = static_cast(data.size()); // 複素数配列に変換 std::vector complex_data(n); for (int i = 0; i < n; ++i) { complex_data[i] = complex_type(data[i], 0); } // 通常のFFTを実行 FFT::fft(std::span(complex_data)); // 共役対称性を利用して結果を半分に圧縮 std::vector result(n / 2 + 1); for (int i = 0; i <= n / 2; ++i) { result[i] = complex_data[i]; } return result; } /** * @brief 複素スペクトルから実数データへの逆変換 * @param spectrum 複素スペクトル(サイズは出力サイズの半分+1) * @param output_size 出力実数データのサイズ * @return 実数データ */ static std::vector ifft(std::span> spectrum, int output_size) { const int half_n_plus_1 = static_cast(spectrum.size()); const int n = output_size; if (half_n_plus_1 != n / 2 + 1) { throw MathError("Spectrum size must be output_size/2+1"); } // 完全な複素スペクトルを再構築(共役対称性を利用) std::vector complex_data(n); for (int i = 0; i < half_n_plus_1; ++i) { complex_data[i] = spectrum[i]; } for (int i = 1; i < n - half_n_plus_1 + 1; ++i) { complex_data[n - i] = std::conj(spectrum[i]); } // 逆FFTを実行 FFT::ifft(std::span(complex_data)); // 実部のみを抽出 std::vector result(n); for (int i = 0; i < n; ++i) { result[i] = complex_data[i].real(); } return result; } /** * @brief 実数データに対する畳み込み * @param a 入力配列1 * @param b 入力配列2 * @return 畳み込み結果 */ static std::vector convolve(std::span a, std::span b) { const int na = static_cast(a.size()); const int nb = static_cast(b.size()); const int n = na + nb - 1; // FFTに適したサイズを計算 int fft_size = 1; while (fft_size < n) { fft_size *= 2; } // 実数配列を拡張 std::vector fa(fft_size, 0); std::vector fb(fft_size, 0); for (int i = 0; i < na; ++i) fa[i] = a[i]; for (int i = 0; i < nb; ++i) fb[i] = b[i]; // 実FFTを実行 auto fa_spectrum = fft(std::span(fa)); auto fb_spectrum = fft(std::span(fb)); // スペクトル領域で乗算 std::vector product_spectrum(fa_spectrum.size()); for (size_t i = 0; i < fa_spectrum.size(); ++i) { product_spectrum[i] = fa_spectrum[i] * fb_spectrum[i]; } // 逆FFT std::vector result = ifft(std::span(product_spectrum), fft_size); // 結果を切り詰める result.resize(n); return result; } }; /** * @brief インスタンスベースの FFT ラッパー * * 旧 FFT.hpp 互換の API を提供する。 * 内部で FFT および RealFFT の static API に委譲する。 * * @tparam Real 実数型 (double, float) */ template class FFTEngine { using complex_type = std::complex; int m_nfft; FFTDivMode m_divMode; // 正規化を適用 // 注: 下層の FFT::ifft は既に 1/N 正規化を含む。 // FFTEngine はその上で追加の補正を行う。 // // Forward (順変換): FFT::fft は非正規化 → 必要に応じてスケーリング // Inverse (逆変換): FFT::ifft は 1/N 正規化済み → DivMode に応じて補正 // // Forward補正 Inverse補正 // None 1 N (1/N を打ち消す) // Forward 1/N N (1/N を打ち消す) // Inverse 1 1 (組み込み 1/N をそのまま) // Both 1/sqrt(N) sqrt(N) (1/N → 1/sqrt(N) に補正) template void applyNormalization(Container& data, bool isForward) const { Real scale = Real(1); if (isForward) { if (m_divMode == FFTDivMode::Forward) { scale = Real(1) / static_cast(m_nfft); } else if (m_divMode == FFTDivMode::Both) { scale = Real(1) / std::sqrt(static_cast(m_nfft)); } } else { if (m_divMode == FFTDivMode::None || m_divMode == FFTDivMode::Forward) { scale = static_cast(m_nfft); } else if (m_divMode == FFTDivMode::Both) { scale = std::sqrt(static_cast(m_nfft)); } // Inverse: scale = 1 (組み込み 1/N がそのまま) } if (scale != Real(1)) { for (auto& x : data) { x *= scale; } } } public: /** * @brief コンストラクタ * @param nfft FFT サイズ (2のべき乗) * @param mode 正規化モード */ explicit FFTEngine(int nfft, FFTDivMode mode = FFTDivMode::Forward) : m_nfft(nfft), m_divMode(mode) { if (nfft <= 0 || (nfft & (nfft - 1)) != 0) { throw std::invalid_argument("FFT size must be a power of 2"); } } int size() const { return m_nfft; } /** @brief 複素 FFT 順変換 */ std::vector transform(const std::vector& input) { if (static_cast(input.size()) != m_nfft) { throw std::invalid_argument("Input size must match FFT size"); } std::vector output = input; FFT::fft(std::span(output)); applyNormalization(output, true); return output; } /** @brief 複素 FFT 逆変換 */ std::vector inverse(const std::vector& input) { if (static_cast(input.size()) != m_nfft) { throw std::invalid_argument("Input size must match FFT size"); } std::vector output = input; FFT::ifft(std::span(output)); applyNormalization(output, false); return output; } /** @brief 実数 FFT 順変換 (N → N/2+1) */ std::vector real_transform(const std::vector& input) { if (static_cast(input.size()) != m_nfft) { throw std::invalid_argument("Input size must match FFT size"); } auto output = RealFFT::fft(std::span(input)); applyNormalization(output, true); return output; } /** @brief 実数 FFT 逆変換 (N/2+1 → N) */ std::vector real_inverse(const std::vector& input) { if (static_cast(input.size()) != m_nfft / 2 + 1) { throw std::invalid_argument("Input size must be N/2+1 for real IFFT"); } auto output = RealFFT::ifft(std::span(input), m_nfft); applyNormalization(output, false); return output; } }; } // namespace calx #endif // CALX_FFT_HPP