// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // fft_simd.hpp // // SIMD 最適化 Radix-2 FFT (float / double) // // 最適化ポイント: // 1. 事前計算 twiddle テーブル (cos/sin ルックアップ, 毎ステージの再計算を回避) // 2. AVX2 バタフライ: float 4複素 (8-wide), double 2複素 (4-wide) // 3. SSE2 バタフライ: float 2複素 (4-wide), double 1複素 (2-wide) // 4. コンパイル時 SIMD レベル選択 (#if ディスパッチ) // // 使い方: // std::vector> data(1024); // calx::simd_fft::fft(data.data(), 1024); // 順変換 // calx::simd_fft::ifft(data.data(), 1024); // 逆変換 // // std::vector> data_d(1024); // calx::simd_fft::fft(data_d.data(), 1024); #ifndef CALX_FFT_SIMD_HPP #define CALX_FFT_SIMD_HPP #include #include #include #include #include #include // SIMD 検出 (simd_backend.hpp と同じパターン) #if defined(__AVX2__) || (defined(_MSC_VER) && defined(__AVX2__)) #define CALX_FFT_HAS_AVX2 1 #endif #if defined(__AVX__) || (defined(_MSC_VER) && defined(__AVX__)) #define CALX_FFT_HAS_AVX 1 #endif #if defined(__SSE2__) || defined(_M_X64) || defined(_M_AMD64) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2) #define CALX_FFT_HAS_SSE2 1 #endif #if defined(CALX_FFT_HAS_AVX2) || defined(CALX_FFT_HAS_AVX) || defined(CALX_FFT_HAS_SSE2) #include #endif namespace calx { namespace simd_fft { // ============================================================================= // ビット反転並べ替え // ============================================================================= inline void bit_reversal(std::complex* data, int n) { for (int i = 0, j = 0; i < n; ++i) { if (i < j) std::swap(data[i], data[j]); int mask = n; do { mask >>= 1; j ^= mask; } while (mask && (j & mask) == 0); } } inline void bit_reversal(std::complex* data, int n) { for (int i = 0, j = 0; i < n; ++i) { if (i < j) std::swap(data[i], data[j]); int mask = n; do { mask >>= 1; j ^= mask; } while (mask && (j & mask) == 0); } } // ============================================================================= // Twiddle テーブル (事前計算 cos/sin) // ============================================================================= template struct TwiddleTable { // twiddles[s] = ステージ s (len = 2^(s+1)) の twiddle factors // twiddles[s][j] = exp(-2πi·j / 2^(s+1)), j = 0..2^s - 1 std::vector>> twiddles; void build(int n) { int stages = 0; for (int tmp = n; tmp > 1; tmp >>= 1) ++stages; twiddles.resize(stages); for (int s = 0; s < stages; ++s) { int len = 2 << s; // 2, 4, 8, ... int half = len / 2; // 1, 2, 4, ... twiddles[s].resize(half); Real angle = Real(-2) * std::numbers::pi_v / Real(len); for (int j = 0; j < half; ++j) { Real a = angle * Real(j); twiddles[s][j] = std::complex(std::cos(a), std::sin(a)); } } } // 逆変換用: 共役を返す std::complex get_inv(int stage, int j) const { return std::conj(twiddles[stage][j]); } }; // ============================================================================= // スカラー バタフライ (フォールバック) // ============================================================================= template inline void butterfly_stage_scalar( std::complex* data, int n, int half_len, const std::complex* tw) { for (int i = 0; i < n; i += half_len * 2) { for (int j = 0; j < half_len; ++j) { auto u = data[i + j]; auto v = data[i + j + half_len] * tw[j]; data[i + j] = u + v; data[i + j + half_len] = u - v; } } } // ============================================================================= // SSE2 バタフライ — double (1複素 = 2 double = 1 __m128d) // ============================================================================= #if defined(CALX_FFT_HAS_SSE2) inline void butterfly_stage_sse2_double( std::complex* data, int n, int half_len, const std::complex* tw) { // half_len < 1 はありえないが安全のため for (int i = 0; i < n; i += half_len * 2) { int j = 0; for (; j < half_len; ++j) { // data[k] と data[k+half] を読み込み int k = i + j; __m128d u = _mm_loadu_pd(reinterpret_cast(&data[k])); __m128d d = _mm_loadu_pd(reinterpret_cast(&data[k + half_len])); __m128d w = _mm_loadu_pd(reinterpret_cast(&tw[j])); // 複素乗算 v = d * w // d = [dre, dim], w = [wre, wim] __m128d d_re = _mm_shuffle_pd(d, d, 0); // [dre, dre] __m128d d_im = _mm_shuffle_pd(d, d, 3); // [dim, dim] __m128d w_flip = _mm_shuffle_pd(w, w, 1); // [wim, wre] __m128d p1 = _mm_mul_pd(d_re, w); // [dre*wre, dre*wim] __m128d p2 = _mm_mul_pd(d_im, w_flip); // [dim*wim, dim*wre] // addsub: [p1[0]-p2[0], p1[1]+p2[1]] = [dre*wre-dim*wim, dre*wim+dim*wre] // SSE3 has _mm_addsub_pd, but for SSE2 we do manually: __m128d sign = _mm_set_pd(1.0, -1.0); __m128d v = _mm_add_pd(p1, _mm_mul_pd(p2, sign)); _mm_storeu_pd(reinterpret_cast(&data[k]), _mm_add_pd(u, v)); _mm_storeu_pd(reinterpret_cast(&data[k + half_len]), _mm_sub_pd(u, v)); } } } // SSE2 バタフライ — float (2複素 = 4 float = 1 __m128) inline void butterfly_stage_sse2_float( std::complex* data, int n, int half_len, const std::complex* tw) { const __m128 sign_mask = _mm_set_ps(1.0f, -1.0f, 1.0f, -1.0f); for (int i = 0; i < n; i += half_len * 2) { int j = 0; // 2複素ずつ処理 for (; j + 1 < half_len; j += 2) { int k = i + j; // 2つの複素数をロード (4 float) __m128 u = _mm_loadu_ps(reinterpret_cast(&data[k])); __m128 d = _mm_loadu_ps(reinterpret_cast(&data[k + half_len])); __m128 w = _mm_loadu_ps(reinterpret_cast(&tw[j])); // d = [dre0, dim0, dre1, dim1], w = [wre0, wim0, wre1, wim1] __m128 d_re = _mm_shuffle_ps(d, d, 0xA0); // [dre0,dre0,dre1,dre1] __m128 d_im = _mm_shuffle_ps(d, d, 0xF5); // [dim0,dim0,dim1,dim1] __m128 w_flip = _mm_shuffle_ps(w, w, 0xB1); // [wim0,wre0,wim1,wre1] __m128 p1 = _mm_mul_ps(d_re, w); __m128 p2 = _mm_mul_ps(d_im, w_flip); __m128 v = _mm_add_ps(p1, _mm_mul_ps(p2, sign_mask)); _mm_storeu_ps(reinterpret_cast(&data[k]), _mm_add_ps(u, v)); _mm_storeu_ps(reinterpret_cast(&data[k + half_len]), _mm_sub_ps(u, v)); } // 残り (half_len が奇数の場合) for (; j < half_len; ++j) { int k = i + j; auto uu = data[k]; auto vv = data[k + half_len] * tw[j]; data[k] = uu + vv; data[k + half_len] = uu - vv; } } } #endif // CALX_FFT_HAS_SSE2 // ============================================================================= // AVX2 バタフライ — double (2複素 = 4 double = 1 __m256d) // ============================================================================= #if defined(CALX_FFT_HAS_AVX2) inline void butterfly_stage_avx2_double( std::complex* data, int n, int half_len, const std::complex* tw) { const __m256d sign_mask = _mm256_set_pd(1.0, -1.0, 1.0, -1.0); for (int i = 0; i < n; i += half_len * 2) { int j = 0; // 2複素ずつ for (; j + 1 < half_len; j += 2) { int k = i + j; __m256d u = _mm256_loadu_pd(reinterpret_cast(&data[k])); __m256d d = _mm256_loadu_pd(reinterpret_cast(&data[k + half_len])); __m256d w = _mm256_loadu_pd(reinterpret_cast(&tw[j])); // 複素乗算: v = d * w __m256d d_re = _mm256_shuffle_pd(d, d, 0x0); // [dre0,dre0,dre1,dre1] __m256d d_im = _mm256_shuffle_pd(d, d, 0xF); // [dim0,dim0,dim1,dim1] __m256d w_flip = _mm256_shuffle_pd(w, w, 0x5); // [wim0,wre0,wim1,wre1] __m256d p1 = _mm256_mul_pd(d_re, w); __m256d p2 = _mm256_mul_pd(d_im, w_flip); __m256d v = _mm256_add_pd(p1, _mm256_mul_pd(p2, sign_mask)); _mm256_storeu_pd(reinterpret_cast(&data[k]), _mm256_add_pd(u, v)); _mm256_storeu_pd(reinterpret_cast(&data[k + half_len]), _mm256_sub_pd(u, v)); } // 残り for (; j < half_len; ++j) { int k = i + j; auto uu = data[k]; auto vv = data[k + half_len] * tw[j]; data[k] = uu + vv; data[k + half_len] = uu - vv; } } } // AVX2 バタフライ — float (4複素 = 8 float = 1 __m256) inline void butterfly_stage_avx2_float( std::complex* data, int n, int half_len, const std::complex* tw) { const __m256 sign_mask = _mm256_set_ps( 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f); for (int i = 0; i < n; i += half_len * 2) { int j = 0; // 4複素ずつ for (; j + 3 < half_len; j += 4) { int k = i + j; __m256 u = _mm256_loadu_ps(reinterpret_cast(&data[k])); __m256 d = _mm256_loadu_ps(reinterpret_cast(&data[k + half_len])); __m256 w = _mm256_loadu_ps(reinterpret_cast(&tw[j])); // d = [re0,im0, re1,im1, re2,im2, re3,im3] __m256 d_re = _mm256_shuffle_ps(d, d, 0xA0); // [re,re, re,re, ...] __m256 d_im = _mm256_shuffle_ps(d, d, 0xF5); // [im,im, im,im, ...] __m256 w_flip = _mm256_shuffle_ps(w, w, 0xB1); // [im,re, im,re, ...] __m256 p1 = _mm256_mul_ps(d_re, w); __m256 p2 = _mm256_mul_ps(d_im, w_flip); __m256 v = _mm256_add_ps(p1, _mm256_mul_ps(p2, sign_mask)); _mm256_storeu_ps(reinterpret_cast(&data[k]), _mm256_add_ps(u, v)); _mm256_storeu_ps(reinterpret_cast(&data[k + half_len]), _mm256_sub_ps(u, v)); } // SSE2 で 2複素 #if defined(CALX_FFT_HAS_SSE2) { const __m128 sign4 = _mm_set_ps(1.0f, -1.0f, 1.0f, -1.0f); for (; j + 1 < half_len; j += 2) { int k = i + j; __m128 u = _mm_loadu_ps(reinterpret_cast(&data[k])); __m128 d = _mm_loadu_ps(reinterpret_cast(&data[k + half_len])); __m128 w = _mm_loadu_ps(reinterpret_cast(&tw[j])); __m128 d_re = _mm_shuffle_ps(d, d, 0xA0); __m128 d_im = _mm_shuffle_ps(d, d, 0xF5); __m128 w_flip = _mm_shuffle_ps(w, w, 0xB1); __m128 p1 = _mm_mul_ps(d_re, w); __m128 p2 = _mm_mul_ps(d_im, w_flip); __m128 v = _mm_add_ps(p1, _mm_mul_ps(p2, sign4)); _mm_storeu_ps(reinterpret_cast(&data[k]), _mm_add_ps(u, v)); _mm_storeu_ps(reinterpret_cast(&data[k + half_len]), _mm_sub_ps(u, v)); } } #endif // スカラー残り for (; j < half_len; ++j) { int k = i + j; auto uu = data[k]; auto vv = data[k + half_len] * tw[j]; data[k] = uu + vv; data[k + half_len] = uu - vv; } } } #endif // CALX_FFT_HAS_AVX2 // ============================================================================= // ディスパッチ: ステージごとに最適なカーネルを選択 // ============================================================================= template inline void butterfly_stage( std::complex* data, int n, int half_len, const std::complex* tw) { if constexpr (std::is_same_v) { #if defined(CALX_FFT_HAS_AVX2) if (half_len >= 2) { butterfly_stage_avx2_double(data, n, half_len, tw); return; } #elif defined(CALX_FFT_HAS_SSE2) if (half_len >= 1) { butterfly_stage_sse2_double(data, n, half_len, tw); return; } #endif butterfly_stage_scalar(data, n, half_len, tw); } else if constexpr (std::is_same_v) { #if defined(CALX_FFT_HAS_AVX2) if (half_len >= 4) { butterfly_stage_avx2_float(data, n, half_len, tw); return; } #endif #if defined(CALX_FFT_HAS_SSE2) if (half_len >= 2) { butterfly_stage_sse2_float(data, n, half_len, tw); return; } #endif butterfly_stage_scalar(data, n, half_len, tw); } else { butterfly_stage_scalar(data, n, half_len, tw); } } // ============================================================================= // メイン FFT / IFFT (Radix-2, 2のべき乗専用) // ============================================================================= /// 順変換 (in-place) template inline void fft(std::complex* data, int n) { if (n <= 1) return; // twiddle テーブルを構築 (thread_local キャッシュ) thread_local TwiddleTable table; thread_local int cached_n = 0; if (cached_n != n) { table.build(n); cached_n = n; } bit_reversal(data, n); int stage = 0; for (int len = 2; len <= n; len <<= 1) { int half = len / 2; butterfly_stage(data, n, half, table.twiddles[stage].data()); ++stage; } } /// 逆変換 (in-place, 1/N 正規化付き) template inline void ifft(std::complex* data, int n) { if (n <= 1) return; // 共役を取って FFT → 共役 → 1/N for (int i = 0; i < n; ++i) data[i] = std::conj(data[i]); fft(data, n); Real inv = Real(1) / Real(n); for (int i = 0; i < n; ++i) data[i] = std::conj(data[i]) * inv; } /// マグニチュードのみ (performFrequencyOnlyForwardTransform 相当) /// 入力: n 個の実数 (float* / double*) /// 出力: 先頭 n/2 個にマグニチュード (in-place, 残りは未定義) template inline void magnitude_spectrum(Real* realData, int n) { // 実数 → 複素数に変換 (2n バッファ必要) // realData は 2n 要素の領域を持つ前提 (JUCE 互換) auto* cdata = reinterpret_cast*>(realData); // 後ろから前に向かってコピー (in-place で安全) for (int i = n - 1; i >= 0; --i) { cdata[i] = std::complex(realData[i], Real(0)); } fft(cdata, n); // マグニチュード計算 int half = n / 2; for (int i = 0; i < half; ++i) { realData[i] = std::abs(cdata[i]); } } // ============================================================================= // コンパイル時の SIMD レベル情報 // ============================================================================= constexpr const char* simd_level_name() { #if defined(CALX_FFT_HAS_AVX2) return "AVX2"; #elif defined(CALX_FFT_HAS_AVX) return "AVX"; #elif defined(CALX_FFT_HAS_SSE2) return "SSE2"; #else return "Scalar"; #endif } } // namespace simd_fft } // namespace calx #endif // CALX_FFT_SIMD_HPP