// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // fft_batch.hpp // // バッチ FFT — 同一サイズの複数信号を一括処理 // // 最適化ポイント: // 1. real-pair packing: 2 実数信号 → 1 複素 FFT (stereoFFT) // 2. SIMD バッチ: 4 複素 FFT を AVX2 で同時処理 (batchFFT) // 3. 回転子テーブル共有: N 回の FFT で 1 回のみ計算 // 4. L1 キャッシュ最適化: 1024 点 × 4 信号 = 32KB // // 使い方: // // ステレオ WAV のインターリーブを直接 FFT (転置不要) // calx::simd_fft::stereoFFT(interleaved, 1024, hannWindow, // specL, specR); // // // N 個の複素信号を一括 FFT // std::complex* signals[8] = { ... }; // calx::simd_fft::batchFFT(signals, 8, 1024); #ifndef CALX_FFT_BATCH_HPP #define CALX_FFT_BATCH_HPP #include #include #include #include #include // ASM カーネル宣言 (MASM, MSVC x64) #if defined(_MSC_VER) && defined(_M_X64) && defined(CALX_FFT_HAS_ASM) extern "C" void batch4_butterfly_avx2_float( float* data, int n, int half_len, const float* tw); #define CALX_BATCH_FFT_HAS_ASM 1 #endif namespace calx { namespace simd_fft { // ===================================================================== // 内部: バッチバタフライ (C++ intrinsics fallback) // ===================================================================== namespace detail { // 4 信号インターリーブデータに対する 1 ステージバタフライ (AVX2) // data レイアウト: position k → data[k*8 .. k*8+7] // = {s0_re, s0_im, s1_re, s1_im, s2_re, s2_im, s3_re, s3_im} inline void batch4_butterfly_float_cpp( float* data, int n, int half_len, const float* tw) { #if defined(CALX_FFT_HAS_AVX2) const __m256 sign_mask = _mm256_set_ps( 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f); int len = half_len * 2; for (int i = 0; i < n; i += len) { for (int j = 0; j < half_len; ++j) { int k_upper = (i + j) * 8; int k_lower = (i + j + half_len) * 8; __m256 u = _mm256_loadu_ps(&data[k_upper]); __m256 d = _mm256_loadu_ps(&data[k_lower]); // broadcast tw[j] = {re, im} → {re,im,re,im,re,im,re,im} __m256 w = _mm256_castpd_ps( _mm256_broadcast_sd(reinterpret_cast(&tw[j * 2]))); __m256 d_re = _mm256_shuffle_ps(d, d, 0xA0); __m256 d_im = _mm256_shuffle_ps(d, d, 0xF5); __m256 w_flip = _mm256_shuffle_ps(w, w, 0xB1); __m256 p1 = _mm256_mul_ps(d_re, w); __m256 p2 = _mm256_mul_ps(d_im, w_flip); __m256 v = _mm256_add_ps(p1, _mm256_mul_ps(p2, sign_mask)); _mm256_storeu_ps(&data[k_upper], _mm256_add_ps(u, v)); _mm256_storeu_ps(&data[k_lower], _mm256_sub_ps(u, v)); } } #else // スカラーフォールバック int len = half_len * 2; for (int i = 0; i < n; i += len) { for (int j = 0; j < half_len; ++j) { float w_re = tw[j * 2]; float w_im = tw[j * 2 + 1]; for (int s = 0; s < 4; ++s) { int u_idx = (i + j) * 8 + s * 2; int d_idx = (i + j + half_len) * 8 + s * 2; float d_re = data[d_idx]; float d_im = data[d_idx + 1]; float v_re = d_re * w_re - d_im * w_im; float v_im = d_re * w_im + d_im * w_re; float u_re = data[u_idx]; float u_im = data[u_idx + 1]; data[u_idx] = u_re + v_re; data[u_idx + 1] = u_im + v_im; data[d_idx] = u_re - v_re; data[d_idx + 1] = u_im - v_im; } } } #endif } // AoS → インターリーブ転置 // 入力: signals[0..3], 各 n complex // 出力: interleaved[pos*8 + sig*2 + {re,im}] inline void transpose_to_interleaved( const std::complex* const* signals, int count, float* interleaved, int n) { for (int k = 0; k < n; ++k) { int base = k * 8; for (int s = 0; s < count; ++s) { interleaved[base + s * 2] = signals[s][k].real(); interleaved[base + s * 2 + 1] = signals[s][k].imag(); } // ゼロパディング (count < 4 の場合) for (int s = count; s < 4; ++s) { interleaved[base + s * 2] = 0.0f; interleaved[base + s * 2 + 1] = 0.0f; } } } // インターリーブ → AoS 逆転置 inline void transpose_from_interleaved( const float* interleaved, int n, std::complex** signals, int count) { for (int k = 0; k < n; ++k) { int base = k * 8; for (int s = 0; s < count; ++s) { signals[s][k] = std::complex( interleaved[base + s * 2], interleaved[base + s * 2 + 1]); } } } // インターリーブデータに対するビット反転置換 // 4 信号分の position を一括で入れ替え (各 position = 32 bytes) inline void batch_bit_reversal(float* data, int n) { for (int i = 1, j = 0; i < n; ++i) { int bit = n >> 1; while (j & bit) { j ^= bit; bit >>= 1; } j ^= bit; if (i < j) { // swap position i and j (32 bytes each) float* pi = &data[i * 8]; float* pj = &data[j * 8]; #if defined(CALX_FFT_HAS_AVX2) __m256 a = _mm256_loadu_ps(pi); __m256 b = _mm256_loadu_ps(pj); _mm256_storeu_ps(pi, b); _mm256_storeu_ps(pj, a); #else for (int k = 0; k < 8; ++k) std::swap(pi[k], pj[k]); #endif } } } } // namespace detail // ===================================================================== // batchFFT: N 個の複素 FFT を一括処理 (in-place) // ===================================================================== /// 4 信号単位でバッチ処理。端数は通常の simd_fft::fft() で処理。 /// signals: count 個の complex* (各 fftSize 個, in-place で上書き) inline void batchFFT(std::complex** signals, int count, int fftSize) { if (fftSize <= 1 || count <= 0) return; // twiddle テーブル (通常 FFT と共有) thread_local TwiddleTable table; thread_local int cached_n = 0; if (cached_n != fftSize) { table.build(fftSize); cached_n = fftSize; } // インターリーブバッファを1回だけ確保 (ループ外) // N=1024 → 32KB, N=4096 → 128KB thread_local std::vector interleaved; if ((int)interleaved.size() < fftSize * 8) interleaved.resize(fftSize * 8); // 4 信号ずつバッチ処理 int batchStart = 0; for (; batchStart + 3 < count; batchStart += 4) { const std::complex* src[4] = { signals[batchStart], signals[batchStart + 1], signals[batchStart + 2], signals[batchStart + 3] }; detail::transpose_to_interleaved(src, 4, interleaved.data(), fftSize); // ビット反転置換 detail::batch_bit_reversal(interleaved.data(), fftSize); // 全ステージのバタフライ int stage = 0; for (int len = 2; len <= fftSize; len <<= 1) { int half = len / 2; const float* tw = reinterpret_cast( table.twiddles[stage].data()); #if defined(CALX_BATCH_FFT_HAS_ASM) batch4_butterfly_avx2_float(interleaved.data(), fftSize, half, tw); #else detail::batch4_butterfly_float_cpp(interleaved.data(), fftSize, half, tw); #endif ++stage; } // 逆転置 std::complex* dst[4] = { signals[batchStart], signals[batchStart + 1], signals[batchStart + 2], signals[batchStart + 3] }; detail::transpose_from_interleaved(interleaved.data(), fftSize, dst, 4); } // 端数は通常 FFT for (int i = batchStart; i < count; ++i) { fft(signals[i], fftSize); } } // ===================================================================== // batchIFFT: N 個の複素逆 FFT を一括処理 (in-place) // ===================================================================== inline void batchIFFT(std::complex** signals, int count, int fftSize) { // 共役 → batchFFT → 共役 + 1/N for (int i = 0; i < count; ++i) { for (int k = 0; k < fftSize; ++k) signals[i][k] = std::conj(signals[i][k]); } batchFFT(signals, count, fftSize); float inv = 1.0f / static_cast(fftSize); for (int i = 0; i < count; ++i) { for (int k = 0; k < fftSize; ++k) signals[i][k] = std::conj(signals[i][k]) * inv; } } // ===================================================================== // stereoFFT: ステレオインターリーブ WAV → L/R スペクトル // ===================================================================== /// interleaved: {L0,R0,L1,R1,...} の fftSize*2 個の float /// window: Hann 窓等 (fftSize 個の float), nullptr なら窓なし /// outL, outR: 各 fftSize/2+1 個の complex (Hermitian 半分) inline void stereoFFT( const float* interleaved, int fftSize, const float* window, std::complex* outL, std::complex* outR) { // interleaved を complex として窓適用 std::vector> z(fftSize); if (window) { for (int k = 0; k < fftSize; ++k) { z[k] = std::complex( interleaved[k * 2] * window[k], // L * window interleaved[k * 2 + 1] * window[k]); // R * window } } else { for (int k = 0; k < fftSize; ++k) { z[k] = std::complex( interleaved[k * 2], interleaved[k * 2 + 1]); } } // 1 回の複素 FFT fft(z.data(), fftSize); // L/R 分離 (Hermitian 対称性を利用) int half = fftSize / 2; // DC bin (k=0): Z[0] のみ, Z[N] は Z[0] と同じ outL[0] = std::complex(z[0].real(), 0.0f); outR[0] = std::complex(z[0].imag(), 0.0f); // Nyquist bin (k=N/2) outL[half] = std::complex(z[half].real(), 0.0f); outR[half] = std::complex(z[half].imag(), 0.0f); // k = 1..N/2-1 for (int k = 1; k < half; ++k) { auto zk = z[k]; auto znk = std::conj(z[fftSize - k]); // L[k] = (Z[k] + conj(Z[N-k])) / 2 outL[k] = (zk + znk) * 0.5f; // R[k] = (Z[k] - conj(Z[N-k])) / (2i) // = (Z[k] - conj(Z[N-k])) * (-i/2) auto diff = zk - znk; outR[k] = std::complex(diff.imag() * 0.5f, -diff.real() * 0.5f); } } // ===================================================================== // batchRealFFT: N 個の実数信号を一括 FFT (ペアパッキング + バッチ) // ===================================================================== /// realSignals: count 個の float* (各 fftSize 個) /// outSpectra: count 個の complex* (各 fftSize/2+1 個) /// window: 窓関数 (fftSize 個), nullptr なら窓なし inline void batchRealFFT( const float* const* realSignals, int count, int fftSize, std::complex** outSpectra, const float* window = nullptr) { int half = fftSize / 2; int pairs = count / 2; // 2 信号ずつペアにして複素 FFT std::vector> z(fftSize); std::vector*> complexSignals; for (int p = 0; p < pairs; ++p) { int i0 = p * 2; int i1 = p * 2 + 1; // 2 つの実数信号を 1 つの複素信号にパック if (window) { for (int k = 0; k < fftSize; ++k) { z[k] = std::complex( realSignals[i0][k] * window[k], realSignals[i1][k] * window[k]); } } else { for (int k = 0; k < fftSize; ++k) { z[k] = std::complex(realSignals[i0][k], realSignals[i1][k]); } } fft(z.data(), fftSize); // 分離 outSpectra[i0][0] = std::complex(z[0].real(), 0.0f); outSpectra[i1][0] = std::complex(z[0].imag(), 0.0f); outSpectra[i0][half] = std::complex(z[half].real(), 0.0f); outSpectra[i1][half] = std::complex(z[half].imag(), 0.0f); for (int k = 1; k < half; ++k) { auto zk = z[k]; auto znk = std::conj(z[fftSize - k]); outSpectra[i0][k] = (zk + znk) * 0.5f; auto diff = zk - znk; outSpectra[i1][k] = std::complex(diff.imag() * 0.5f, -diff.real() * 0.5f); } } // 奇数の場合、最後の 1 つは通常の実数 FFT if (count & 1) { int last = count - 1; if (window) { for (int k = 0; k < fftSize; ++k) z[k] = std::complex(realSignals[last][k] * window[k], 0.0f); } else { for (int k = 0; k < fftSize; ++k) z[k] = std::complex(realSignals[last][k], 0.0f); } fft(z.data(), fftSize); for (int k = 0; k <= half; ++k) outSpectra[last][k] = z[k]; } } } // namespace simd_fft } // namespace calx #endif // CALX_FFT_BATCH_HPP