// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // DoubleFft.hpp // split-complex radix-4 FFT (double) // YC-8a: scalar baseline + right-angle convolution (negacyclic DWT) // YC-8d: エラー検出 (mod 2^61-1 ハッシュ + 係数上限チェック) + NTT フォールバック #pragma once #include #ifndef CALX_FORCEINLINE #ifdef _MSC_VER #define CALX_FORCEINLINE __forceinline #else #define CALX_FORCEINLINE __attribute__((always_inline)) inline #endif #endif #include #include #include #include #include #include namespace calx { namespace double_fft { static constexpr double FFT_PI = 3.141592653589793238462643383279502884; // --- split-complex multiply --- // (a_re + i*a_im) * (b_re + i*b_im) inline void cmul(double a_re, double a_im, double b_re, double b_im, double& out_re, double& out_im) { out_re = a_re * b_re - a_im * b_im; out_im = a_re * b_im + a_im * b_re; } // --- twiddle + weight table (layer-ordered) --- struct FftRootsDouble { // radix-4 twiddle (w1 only; w2, w3 are computed at runtime) std::vector fwd_cos; // forward: cos(-2*pi*j/len) std::vector fwd_sin; // forward: sin(-2*pi*j/len) std::vector inv_cos; // inverse: cos(+2*pi*j/len) std::vector inv_sin; // inverse: sin(+2*pi*j/len) std::vector r4_offset; // offset of each radix-4 stage size_t num_r4_stages = 0; bool has_r2_final = false; // log2(N) odd => final radix-2 size_t N = 0; // right-angle convolution weights: e^{i*pi*j/N} std::vector wt_re; // cos(pi*j/N) std::vector wt_im; // sin(pi*j/N) void build(size_t fft_len) { N = fft_len; assert(N >= 4 && (N & (N - 1)) == 0); int log2N = 0; for (size_t t = N; t > 1; t >>= 1) ++log2N; has_r2_final = (log2N & 1) != 0; num_r4_stages = static_cast(log2N / 2); // --- radix-4 roots --- size_t total = 0; r4_offset.resize(num_r4_stages); { size_t len = N; for (size_t s = 0; s < num_r4_stages; ++s) { r4_offset[s] = total; total += len / 4; len /= 4; } } fwd_cos.resize(total); fwd_sin.resize(total); inv_cos.resize(total); inv_sin.resize(total); { size_t len = N; for (size_t s = 0; s < num_r4_stages; ++s) { size_t quarter = len / 4; size_t off = r4_offset[s]; for (size_t j = 0; j < quarter; ++j) { double angle = -2.0 * FFT_PI * static_cast(j) / static_cast(len); fwd_cos[off + j] = std::cos(angle); fwd_sin[off + j] = std::sin(angle); inv_cos[off + j] = std::cos(-angle); inv_sin[off + j] = std::sin(-angle); } len /= 4; } } // --- right-angle weights --- wt_re.resize(N); wt_im.resize(N); for (size_t j = 0; j < N; ++j) { double angle = FFT_PI * static_cast(j) / static_cast(N); wt_re[j] = std::cos(angle); wt_im[j] = std::sin(angle); } } }; // ================================================================ // Forward FFT (radix-4 DIF, split-complex, in-place) // ================================================================ // // DIF stage at group size `len`: // a = x[j], b = x[j+Q], c = x[j+2Q], d = x[j+3Q] (Q = len/4) // u = a+c, v = a-c, s = b+d, t = b-d // w1 = e^{-2*pi*i*j/len}, w2 = w1^2, w3 = w1^3 // // x[j] = u + s // x[j+Q] = (v - i*t) * w1 // x[j+2Q] = (u - s) * w2 // x[j+3Q] = (v + i*t) * w3 // inline void forward_fft_double(double* re, double* im, size_t N, const FftRootsDouble& roots) { assert(N == roots.N); // --- radix-4 DIF stages (len = N, N/4, ...) --- size_t len = N; for (size_t stage = 0; stage < roots.num_r4_stages; ++stage) { const size_t quarter = len / 4; const double* wc = &roots.fwd_cos[roots.r4_offset[stage]]; const double* ws = &roots.fwd_sin[roots.r4_offset[stage]]; for (size_t grp = 0; grp < N; grp += len) { for (size_t j = 0; j < quarter; ++j) { const size_t i0 = grp + j; const size_t i1 = i0 + quarter; const size_t i2 = i0 + 2 * quarter; const size_t i3 = i0 + 3 * quarter; const double ar = re[i0], ai = im[i0]; const double br = re[i1], bi = im[i1]; const double cr = re[i2], ci = im[i2]; const double dr = re[i3], di = im[i3]; const double ur = ar + cr, ui = ai + ci; const double vr = ar - cr, vi = ai - ci; const double sr = br + dr, si = bi + di; const double tr = br - dr, ti = bi - di; // output 0: u + s re[i0] = ur + sr; im[i0] = ui + si; // output 1: (v - i*t) * w1 // -i*t = (t_im, -t_re) const double e1r = vr + ti; const double e1i = vi - tr; const double w1r = wc[j], w1i = ws[j]; cmul(e1r, e1i, w1r, w1i, re[i1], im[i1]); // w2 = w1^2 double w2r, w2i; cmul(w1r, w1i, w1r, w1i, w2r, w2i); // output 2: (u - s) * w2 const double e2r = ur - sr; const double e2i = ui - si; cmul(e2r, e2i, w2r, w2i, re[i2], im[i2]); // w3 = w1 * w2 double w3r, w3i; cmul(w1r, w1i, w2r, w2i, w3r, w3i); // output 3: (v + i*t) * w3 // i*t = (-t_im, t_re) const double e3r = vr - ti; const double e3i = vi + tr; cmul(e3r, e3i, w3r, w3i, re[i3], im[i3]); } } len /= 4; } // --- final radix-2 (len=2) if log2(N) is odd --- if (roots.has_r2_final) { for (size_t i = 0; i < N; i += 2) { const double ar = re[i], ai = im[i]; const double br = re[i+1], bi = im[i+1]; re[i] = ar + br; im[i] = ai + bi; re[i+1] = ar - br; im[i+1] = ai - bi; } } } // ================================================================ // Inverse FFT (radix-4 DIT, split-complex, in-place, includes 1/N) // ================================================================ // // DIT stage at group size `len`: // w1 = e^{+2*pi*i*j/len}, w2 = w1^2, w3 = w1^3 // b' = x[j+Q]*w1, c' = x[j+2Q]*w2, d' = x[j+3Q]*w3 // // u = a+c', v = a-c', s = b'+d', t = b'-d' // x[j] = u + s // x[j+Q] = v + i*t // x[j+2Q] = u - s // x[j+3Q] = v - i*t // inline void inverse_fft_double(double* re, double* im, size_t N, const FftRootsDouble& roots) { assert(N == roots.N); // --- initial radix-2 (len=2) if log2(N) is odd --- size_t start_len; if (roots.has_r2_final) { for (size_t i = 0; i < N; i += 2) { const double ar = re[i], ai = im[i]; const double br = re[i+1], bi = im[i+1]; re[i] = ar + br; im[i] = ai + bi; re[i+1] = ar - br; im[i+1] = ai - bi; } start_len = 8; } else { start_len = 4; } // --- radix-4 DIT stages (small to large) --- size_t len = start_len; for (int si = static_cast(roots.num_r4_stages) - 1; si >= 0; --si, len *= 4) { const size_t quarter = len / 4; const double* wc = &roots.inv_cos[roots.r4_offset[si]]; const double* ws = &roots.inv_sin[roots.r4_offset[si]]; for (size_t grp = 0; grp < N; grp += len) { for (size_t j = 0; j < quarter; ++j) { const size_t i0 = grp + j; const size_t i1 = i0 + quarter; const size_t i2 = i0 + 2 * quarter; const size_t i3 = i0 + 3 * quarter; // twiddle on inputs (DIT) const double ar = re[i0], ai = im[i0]; const double w1r = wc[j], w1i = ws[j]; double br, bi; cmul(re[i1], im[i1], w1r, w1i, br, bi); double w2r, w2i; cmul(w1r, w1i, w1r, w1i, w2r, w2i); double cr, ci; cmul(re[i2], im[i2], w2r, w2i, cr, ci); double w3r, w3i; cmul(w1r, w1i, w2r, w2i, w3r, w3i); double dr, di; cmul(re[i3], im[i3], w3r, w3i, dr, di); const double ur = ar + cr, ui = ai + ci; const double vr = ar - cr, vi = ai - ci; const double sr = br + dr, si = bi + di; const double tr = br - dr, ti = bi - di; re[i0] = ur + sr; im[i0] = ui + si; // v + i*t: i*t = (-t_im, t_re) re[i1] = vr - ti; im[i1] = vi + tr; re[i2] = ur - sr; im[i2] = ui - si; // v - i*t: -i*t = (t_im, -t_re) re[i3] = vr + ti; im[i3] = vi - tr; } } } // --- 1/N scaling --- const double inv_N = 1.0 / static_cast(N); for (size_t i = 0; i < N; ++i) { re[i] *= inv_N; im[i] *= inv_N; } } // ================================================================ // AVX2 FMA radix-4 FFT (YC-8c) // ================================================================ #ifdef __AVX2__ } } // namespace double_fft, calx #include namespace calx { namespace double_fft { static constexpr size_t FFT_CACHE_BLOCK = 1024; // --- 4-wide complex multiply via FMA --- CALX_FORCEINLINE void cmul_avx2(__m256d ar, __m256d ai, __m256d br, __m256d bi, __m256d& or_, __m256d& oi) { // or = ar*br - ai*bi, oi = ar*bi + ai*br or_ = _mm256_fmsub_pd(ar, br, _mm256_mul_pd(ai, bi)); oi = _mm256_fmadd_pd(ar, bi, _mm256_mul_pd(ai, br)); } // --- DIF radix-4 butterfly: 4 consecutive j at once --- CALX_FORCEINLINE void dif_r4_avx2(double* re, double* im, size_t base, size_t quarter, __m256d w1r, __m256d w1i) { const size_t p1 = base + quarter; const size_t p2 = base + 2 * quarter; const size_t p3 = base + 3 * quarter; __m256d ar = _mm256_loadu_pd(re + base); __m256d ai = _mm256_loadu_pd(im + base); __m256d br = _mm256_loadu_pd(re + p1); __m256d bi = _mm256_loadu_pd(im + p1); __m256d cr = _mm256_loadu_pd(re + p2); __m256d ci = _mm256_loadu_pd(im + p2); __m256d dr = _mm256_loadu_pd(re + p3); __m256d di = _mm256_loadu_pd(im + p3); __m256d ur = _mm256_add_pd(ar, cr), ui = _mm256_add_pd(ai, ci); __m256d vr = _mm256_sub_pd(ar, cr), vi = _mm256_sub_pd(ai, ci); __m256d sr = _mm256_add_pd(br, dr), si = _mm256_add_pd(bi, di); __m256d tr = _mm256_sub_pd(br, dr), ti = _mm256_sub_pd(bi, di); // out0 = u + s _mm256_storeu_pd(re + base, _mm256_add_pd(ur, sr)); _mm256_storeu_pd(im + base, _mm256_add_pd(ui, si)); // out1 = (v - i*t) * w1 [-i*t: re=ti, im=-tr] __m256d e1r = _mm256_add_pd(vr, ti); __m256d e1i = _mm256_sub_pd(vi, tr); __m256d o1r, o1i; cmul_avx2(e1r, e1i, w1r, w1i, o1r, o1i); _mm256_storeu_pd(re + p1, o1r); _mm256_storeu_pd(im + p1, o1i); // w2 = w1^2 __m256d w2r, w2i; cmul_avx2(w1r, w1i, w1r, w1i, w2r, w2i); // out2 = (u - s) * w2 __m256d e2r = _mm256_sub_pd(ur, sr); __m256d e2i = _mm256_sub_pd(ui, si); __m256d o2r, o2i; cmul_avx2(e2r, e2i, w2r, w2i, o2r, o2i); _mm256_storeu_pd(re + p2, o2r); _mm256_storeu_pd(im + p2, o2i); // w3 = w1 * w2 __m256d w3r, w3i; cmul_avx2(w1r, w1i, w2r, w2i, w3r, w3i); // out3 = (v + i*t) * w3 [i*t: re=-ti, im=tr] __m256d e3r = _mm256_sub_pd(vr, ti); __m256d e3i = _mm256_add_pd(vi, tr); __m256d o3r, o3i; cmul_avx2(e3r, e3i, w3r, w3i, o3r, o3i); _mm256_storeu_pd(re + p3, o3r); _mm256_storeu_pd(im + p3, o3i); } // --- DIT radix-4 butterfly: 4 consecutive j at once --- CALX_FORCEINLINE void dit_r4_avx2(double* re, double* im, size_t base, size_t quarter, __m256d w1r, __m256d w1i) { const size_t p1 = base + quarter; const size_t p2 = base + 2 * quarter; const size_t p3 = base + 3 * quarter; __m256d ar = _mm256_loadu_pd(re + base); __m256d ai = _mm256_loadu_pd(im + base); // twiddle on inputs __m256d br, bi; cmul_avx2(_mm256_loadu_pd(re + p1), _mm256_loadu_pd(im + p1), w1r, w1i, br, bi); __m256d w2r, w2i; cmul_avx2(w1r, w1i, w1r, w1i, w2r, w2i); __m256d cr, ci; cmul_avx2(_mm256_loadu_pd(re + p2), _mm256_loadu_pd(im + p2), w2r, w2i, cr, ci); __m256d w3r, w3i; cmul_avx2(w1r, w1i, w2r, w2i, w3r, w3i); __m256d dr, di; cmul_avx2(_mm256_loadu_pd(re + p3), _mm256_loadu_pd(im + p3), w3r, w3i, dr, di); __m256d ur = _mm256_add_pd(ar, cr), ui = _mm256_add_pd(ai, ci); __m256d vr = _mm256_sub_pd(ar, cr), vi = _mm256_sub_pd(ai, ci); __m256d sr = _mm256_add_pd(br, dr), si = _mm256_add_pd(bi, di); __m256d tr = _mm256_sub_pd(br, dr), ti = _mm256_sub_pd(bi, di); _mm256_storeu_pd(re + base, _mm256_add_pd(ur, sr)); _mm256_storeu_pd(im + base, _mm256_add_pd(ui, si)); // v + i*t: [i*t: re=-ti, im=tr] _mm256_storeu_pd(re + p1, _mm256_sub_pd(vr, ti)); _mm256_storeu_pd(im + p1, _mm256_add_pd(vi, tr)); _mm256_storeu_pd(re + p2, _mm256_sub_pd(ur, sr)); _mm256_storeu_pd(im + p2, _mm256_sub_pd(ui, si)); // v - i*t: [-i*t: re=ti, im=-tr] _mm256_storeu_pd(re + p3, _mm256_add_pd(vr, ti)); _mm256_storeu_pd(im + p3, _mm256_sub_pd(vi, tr)); } // --- one radix-4 DIF pass over [start, start+len) using AVX2 + scalar tail --- inline void dif_r4_pass_avx2(double* re, double* im, size_t N, size_t len, const double* wc, const double* ws) { const size_t quarter = len / 4; for (size_t grp = 0; grp < N; grp += len) { size_t j = 0; for (; j + 4 <= quarter; j += 4) { __m256d w1r = _mm256_loadu_pd(wc + j); __m256d w1i = _mm256_loadu_pd(ws + j); dif_r4_avx2(re, im, grp + j, quarter, w1r, w1i); } // scalar tail for (; j < quarter; ++j) { const size_t i0 = grp + j; const size_t i1 = i0 + quarter; const size_t i2 = i0 + 2 * quarter; const size_t i3 = i0 + 3 * quarter; const double ar = re[i0], ai = im[i0]; const double br = re[i1], bi = im[i1]; const double cr = re[i2], ci = im[i2]; const double dr = re[i3], di = im[i3]; const double ur = ar+cr, ui = ai+ci; const double vr = ar-cr, vi = ai-ci; const double sr = br+dr, si = bi+di; const double tr = br-dr, ti = bi-di; re[i0] = ur+sr; im[i0] = ui+si; double e1r = vr+ti, e1i = vi-tr; double w1r_ = wc[j], w1i_ = ws[j]; cmul(e1r, e1i, w1r_, w1i_, re[i1], im[i1]); double w2r_, w2i_; cmul(w1r_, w1i_, w1r_, w1i_, w2r_, w2i_); double e2r = ur-sr, e2i = ui-si; cmul(e2r, e2i, w2r_, w2i_, re[i2], im[i2]); double w3r_, w3i_; cmul(w1r_, w1i_, w2r_, w2i_, w3r_, w3i_); double e3r = vr-ti, e3i = vi+tr; cmul(e3r, e3i, w3r_, w3i_, re[i3], im[i3]); } } } // --- one radix-4 DIT pass over [start, start+len) using AVX2 + scalar tail --- inline void dit_r4_pass_avx2(double* re, double* im, size_t N, size_t len, const double* wc, const double* ws) { const size_t quarter = len / 4; for (size_t grp = 0; grp < N; grp += len) { size_t j = 0; for (; j + 4 <= quarter; j += 4) { __m256d w1r = _mm256_loadu_pd(wc + j); __m256d w1i = _mm256_loadu_pd(ws + j); dit_r4_avx2(re, im, grp + j, quarter, w1r, w1i); } // scalar tail for (; j < quarter; ++j) { const size_t i0 = grp + j; const size_t i1 = i0 + quarter; const size_t i2 = i0 + 2 * quarter; const size_t i3 = i0 + 3 * quarter; const double ar = re[i0], ai = im[i0]; double w1r_ = wc[j], w1i_ = ws[j]; double br, bi; cmul(re[i1], im[i1], w1r_, w1i_, br, bi); double w2r_, w2i_; cmul(w1r_, w1i_, w1r_, w1i_, w2r_, w2i_); double cr, ci; cmul(re[i2], im[i2], w2r_, w2i_, cr, ci); double w3r_, w3i_; cmul(w1r_, w1i_, w2r_, w2i_, w3r_, w3i_); double dr, di; cmul(re[i3], im[i3], w3r_, w3i_, dr, di); const double ur = ar+cr, ui = ai+ci; const double vr = ar-cr, vi = ai-ci; const double sr = br+dr, si = bi+di; const double tr = br-dr, ti = bi-di; re[i0] = ur+sr; im[i0] = ui+si; re[i1] = vr-ti; im[i1] = vi+tr; re[i2] = ur-sr; im[i2] = ui-si; re[i3] = vr+ti; im[i3] = vi-tr; } } } // --- Forward FFT (AVX2, cache-blocked) --- inline void forward_fft_double_avx2(double* re, double* im, size_t N, const FftRootsDouble& roots) { assert(N == roots.N); const size_t block = (N <= FFT_CACHE_BLOCK) ? N : FFT_CACHE_BLOCK; // Phase 1: large stages (len > block) — full passes size_t stage = 0; size_t len = N; for (; stage < roots.num_r4_stages && len > block; ++stage) { dif_r4_pass_avx2(re, im, N, len, &roots.fwd_cos[roots.r4_offset[stage]], &roots.fwd_sin[roots.r4_offset[stage]]); len /= 4; } // Phase 2: small stages (len <= block) — block by block const size_t bottom_start = stage; for (size_t blk = 0; blk < N; blk += block) { size_t s2 = bottom_start; size_t len2 = len; for (; s2 < roots.num_r4_stages; ++s2) { const size_t quarter = len2 / 4; const double* wc = &roots.fwd_cos[roots.r4_offset[s2]]; const double* ws = &roots.fwd_sin[roots.r4_offset[s2]]; for (size_t grp = blk; grp < blk + block; grp += len2) { size_t j = 0; for (; j + 4 <= quarter; j += 4) { __m256d w1r = _mm256_loadu_pd(wc + j); __m256d w1i = _mm256_loadu_pd(ws + j); dif_r4_avx2(re, im, grp + j, quarter, w1r, w1i); } for (; j < quarter; ++j) { const size_t i0 = grp + j; const size_t i1 = i0 + quarter; const size_t i2 = i0 + 2 * quarter; const size_t i3 = i0 + 3 * quarter; const double ar=re[i0],ai=im[i0],br=re[i1],bi=im[i1]; const double cr=re[i2],ci=im[i2],dr=re[i3],di=im[i3]; const double ur=ar+cr,ui=ai+ci,vr=ar-cr,vi=ai-ci; const double sr=br+dr,si=bi+di,tr=br-dr,ti=bi-di; re[i0]=ur+sr; im[i0]=ui+si; double w1r_=wc[j],w1i_=ws[j]; double e1r=vr+ti,e1i=vi-tr; cmul(e1r,e1i,w1r_,w1i_,re[i1],im[i1]); double w2r_,w2i_; cmul(w1r_,w1i_,w1r_,w1i_,w2r_,w2i_); cmul(ur-sr,ui-si,w2r_,w2i_,re[i2],im[i2]); double w3r_,w3i_; cmul(w1r_,w1i_,w2r_,w2i_,w3r_,w3i_); cmul(vr-ti,vi+tr,w3r_,w3i_,re[i3],im[i3]); } } len2 /= 4; } // optional radix-2 within block if (roots.has_r2_final) { for (size_t i = blk; i < blk + block; i += 2) { double ar = re[i], ai = im[i]; double br = re[i+1], bi = im[i+1]; re[i]=ar+br; im[i]=ai+bi; re[i+1]=ar-br; im[i+1]=ai-bi; } } } } // --- Inverse FFT (AVX2, cache-blocked) --- inline void inverse_fft_double_avx2(double* re, double* im, size_t N, const FftRootsDouble& roots) { assert(N == roots.N); const size_t block = (N <= FFT_CACHE_BLOCK) ? N : FFT_CACHE_BLOCK; // Determine which stages are "bottom" (len <= block) // Inverse DIT: stages go small → large // stage_idx in r4_offset: 0 = len N, 1 = len N/4, ... // DIT processes from num_r4_stages-1 down to 0 // Bottom = those with small len, i.e., large stage_idx // Compute start_len for DIT (smallest radix-4 group) size_t start_len = roots.has_r2_final ? 8 : 4; // Find the boundary: bottom stages have len <= block // stage_idx s corresponds to len = N / 4^s // len <= block ⟺ s >= log4(N/block) int bottom_start_idx = static_cast(roots.num_r4_stages) - 1; { size_t l = start_len; for (int s = static_cast(roots.num_r4_stages) - 1; s >= 0; --s) { if (l > block) { bottom_start_idx = s + 1; break; } if (s == 0) { bottom_start_idx = 0; break; } l *= 4; } } // Phase 1: bottom stages (small len, within blocks) for (size_t blk = 0; blk < N; blk += block) { // radix-2 within block first (if needed) if (roots.has_r2_final) { for (size_t i = blk; i < blk + block; i += 2) { double ar = re[i], ai = im[i]; double br = re[i+1], bi = im[i+1]; re[i]=ar+br; im[i]=ai+bi; re[i+1]=ar-br; im[i+1]=ai-bi; } } // radix-4 DIT within block size_t len2 = start_len; for (int s2 = static_cast(roots.num_r4_stages) - 1; s2 >= bottom_start_idx; --s2) { const size_t quarter = len2 / 4; const double* wc = &roots.inv_cos[roots.r4_offset[s2]]; const double* ws = &roots.inv_sin[roots.r4_offset[s2]]; for (size_t grp = blk; grp < blk + block; grp += len2) { size_t j = 0; for (; j + 4 <= quarter; j += 4) { __m256d w1r = _mm256_loadu_pd(wc + j); __m256d w1i = _mm256_loadu_pd(ws + j); dit_r4_avx2(re, im, grp + j, quarter, w1r, w1i); } for (; j < quarter; ++j) { const size_t i0=grp+j, i1=i0+quarter; const size_t i2=i0+2*quarter, i3=i0+3*quarter; double ar=re[i0],ai=im[i0]; double w1r_=wc[j],w1i_=ws[j]; double br,bi; cmul(re[i1],im[i1],w1r_,w1i_,br,bi); double w2r_,w2i_; cmul(w1r_,w1i_,w1r_,w1i_,w2r_,w2i_); double cr,ci; cmul(re[i2],im[i2],w2r_,w2i_,cr,ci); double w3r_,w3i_; cmul(w1r_,w1i_,w2r_,w2i_,w3r_,w3i_); double dr,di; cmul(re[i3],im[i3],w3r_,w3i_,dr,di); double ur=ar+cr,ui=ai+ci,vr=ar-cr,vi=ai-ci; double sr=br+dr,si=bi+di,tr=br-dr,ti=bi-di; re[i0]=ur+sr; im[i0]=ui+si; re[i1]=vr-ti; im[i1]=vi+tr; re[i2]=ur-sr; im[i2]=ui-si; re[i3]=vr+ti; im[i3]=vi-tr; } } len2 *= 4; } } // Phase 2: top stages (large len, full passes) size_t len_top = start_len; for (int s = static_cast(roots.num_r4_stages) - 1; s >= 0; --s) if (s < bottom_start_idx) { break; } else { len_top *= 4; } for (int s = bottom_start_idx - 1; s >= 0; --s) { dit_r4_pass_avx2(re, im, N, len_top, &roots.inv_cos[roots.r4_offset[s]], &roots.inv_sin[roots.r4_offset[s]]); len_top *= 4; } // 1/N scaling (AVX2) const __m256d inv_N = _mm256_set1_pd(1.0 / static_cast(N)); size_t i = 0; for (; i + 4 <= N; i += 4) { _mm256_storeu_pd(re + i, _mm256_mul_pd(_mm256_loadu_pd(re + i), inv_N)); _mm256_storeu_pd(im + i, _mm256_mul_pd(_mm256_loadu_pd(im + i), inv_N)); } for (; i < N; ++i) { re[i] /= static_cast(N); im[i] /= static_cast(N); } } #endif // __AVX2__ // ================================================================ // Right-angle convolution: weight / unweight // ================================================================ // weight[j] = e^{i*pi*j/N} // unweight[j] = conj(weight[j]) inline void apply_weight(double* re, double* im, size_t N, const FftRootsDouble& roots) { assert(N == roots.N); #ifdef __AVX2__ size_t j = 0; for (; j + 4 <= N; j += 4) { __m256d xr = _mm256_loadu_pd(re + j); __m256d xi = _mm256_loadu_pd(im + j); __m256d wr = _mm256_loadu_pd(roots.wt_re.data() + j); __m256d wi = _mm256_loadu_pd(roots.wt_im.data() + j); __m256d or_, oi; cmul_avx2(xr, xi, wr, wi, or_, oi); _mm256_storeu_pd(re + j, or_); _mm256_storeu_pd(im + j, oi); } for (; j < N; ++j) { double xr = re[j], xi = im[j]; cmul(xr, xi, roots.wt_re[j], roots.wt_im[j], re[j], im[j]); } #else for (size_t j = 0; j < N; ++j) { double xr = re[j], xi = im[j]; cmul(xr, xi, roots.wt_re[j], roots.wt_im[j], re[j], im[j]); } #endif } inline void apply_unweight(double* re, double* im, size_t N, const FftRootsDouble& roots) { assert(N == roots.N); #ifdef __AVX2__ size_t j = 0; for (; j + 4 <= N; j += 4) { __m256d xr = _mm256_loadu_pd(re + j); __m256d xi = _mm256_loadu_pd(im + j); __m256d wr = _mm256_loadu_pd(roots.wt_re.data() + j); // conj(w) = (wt_re, -wt_im) __m256d wi = _mm256_xor_pd(_mm256_loadu_pd(roots.wt_im.data() + j), _mm256_set1_pd(-0.0)); __m256d or_, oi; cmul_avx2(xr, xi, wr, wi, or_, oi); _mm256_storeu_pd(re + j, or_); _mm256_storeu_pd(im + j, oi); } for (; j < N; ++j) { double xr = re[j], xi = im[j]; cmul(xr, xi, roots.wt_re[j], -roots.wt_im[j], re[j], im[j]); } #else for (size_t j = 0; j < N; ++j) { double xr = re[j], xi = im[j]; // conj(w) = (wt_re, -wt_im) cmul(xr, xi, roots.wt_re[j], -roots.wt_im[j], re[j], im[j]); } #endif } // ================================================================ // Pointwise operations // ================================================================ inline void pointwise_mul(const double* a_re, const double* a_im, const double* b_re, const double* b_im, double* c_re, double* c_im, size_t N) { #ifdef __AVX2__ size_t i = 0; for (; i + 4 <= N; i += 4) { __m256d or_, oi; cmul_avx2(_mm256_loadu_pd(a_re + i), _mm256_loadu_pd(a_im + i), _mm256_loadu_pd(b_re + i), _mm256_loadu_pd(b_im + i), or_, oi); _mm256_storeu_pd(c_re + i, or_); _mm256_storeu_pd(c_im + i, oi); } for (; i < N; ++i) cmul(a_re[i], a_im[i], b_re[i], b_im[i], c_re[i], c_im[i]); #else for (size_t i = 0; i < N; ++i) cmul(a_re[i], a_im[i], b_re[i], b_im[i], c_re[i], c_im[i]); #endif } inline void pointwise_sqr(const double* a_re, const double* a_im, double* c_re, double* c_im, size_t N) { #ifdef __AVX2__ size_t i = 0; for (; i + 4 <= N; i += 4) { __m256d ar = _mm256_loadu_pd(a_re + i); __m256d ai = _mm256_loadu_pd(a_im + i); __m256d or_, oi; cmul_avx2(ar, ai, ar, ai, or_, oi); _mm256_storeu_pd(c_re + i, or_); _mm256_storeu_pd(c_im + i, oi); } for (; i < N; ++i) cmul(a_re[i], a_im[i], a_re[i], a_im[i], c_re[i], c_im[i]); #else for (size_t i = 0; i < N; ++i) cmul(a_re[i], a_im[i], a_re[i], a_im[i], c_re[i], c_im[i]); #endif } // ================================================================ // Negacyclic convolution (full pipeline) // ================================================================ // c = negacyclic_conv(a, b) of length N. // If a, b each have <= N/2 non-zero coefficients, the result equals // the linear convolution (no wrap-around). inline void negacyclic_convolution( const double* a_re, const double* a_im, const double* b_re, const double* b_im, double* c_re, double* c_im, size_t N, const FftRootsDouble& roots) { assert(N == roots.N); // scratch for weighted + transformed copies std::vector wa_re(N), wa_im(N), wb_re(N), wb_im(N); for (size_t i = 0; i < N; ++i) { wa_re[i] = a_re[i]; wa_im[i] = a_im[i]; wb_re[i] = b_re[i]; wb_im[i] = b_im[i]; } // 1. weight apply_weight(wa_re.data(), wa_im.data(), N, roots); apply_weight(wb_re.data(), wb_im.data(), N, roots); // 2. forward FFT forward_fft_double(wa_re.data(), wa_im.data(), N, roots); forward_fft_double(wb_re.data(), wb_im.data(), N, roots); // 3. pointwise multiply pointwise_mul(wa_re.data(), wa_im.data(), wb_re.data(), wb_im.data(), c_re, c_im, N); // 4. inverse FFT (includes 1/N) inverse_fft_double(c_re, c_im, N, roots); // 5. unweight apply_unweight(c_re, c_im, N, roots); } // ================================================================ // YC-8b: coefficient splitting / reconstruction / multiplication // ================================================================ // next power of 2 >= n inline size_t fft_next_pow2(size_t n) { size_t p = 1; while (p < n) p <<= 1; return (p >= 4) ? p : 4; // minimum FFT size = 4 } // B-bit coefficient count for n limbs (ceil(n*64 / B)) inline size_t coeff_count(size_t n_limbs, int B) { return (n_limbs * 64 + static_cast(B) - 1) / static_cast(B); } // Choose optimal B (bits per coefficient) for given input sizes. // Constraint: 2B + log2(N) + C_safety <= 53 (C_safety = 3) // Returns 0 if double FFT cannot handle the size. inline int choose_fft_bits(size_t an, size_t bn) { for (int B = 18; B >= 10; --B) { size_t ca = coeff_count(an, B); size_t cb = coeff_count(bn, B); size_t N = fft_next_pow2(ca + cb); int log2N = 0; for (size_t t = N; t > 1; t >>= 1) ++log2N; if (2 * B + log2N + 3 <= 53) return B; } return 0; } // Split 64-bit limb array into B-bit double coefficients (real only). // out_re[0..N-1]: coefficients (zero-padded), out_im[0..N-1]: all zeros. inline void limb_to_fft_coeffs(const uint64_t* limbs, size_t n, int B, double* out_re, double* out_im, size_t N) { const uint64_t mask = (1ULL << B) - 1; const size_t total_bits = n * 64; size_t bit_pos = 0; for (size_t i = 0; i < N; ++i) { if (bit_pos >= total_bits) { out_re[i] = 0.0; out_im[i] = 0.0; continue; } const size_t word = bit_pos / 64; const int shift = static_cast(bit_pos & 63); uint64_t val = limbs[word] >> shift; if (shift + B > 64 && word + 1 < n) { val |= limbs[word + 1] << (64 - shift); } out_re[i] = static_cast(val & mask); out_im[i] = 0.0; bit_pos += static_cast(B); } } // Reconstruct 64-bit limbs from rounded FFT coefficients. // re[0..num_coeffs-1]: convolution output (will be rounded to int64_t). // Carry propagation in base 2^B, then pack into out[0..out_n-1]. inline void fft_coeffs_to_limbs(const double* re, size_t num_coeffs, int B, uint64_t* out, size_t out_n) { std::memset(out, 0, out_n * sizeof(uint64_t)); const int64_t base = static_cast(1ULL << B); const uint64_t bmask = static_cast(base - 1); int64_t carry = 0; size_t bit_pos = 0; for (size_t i = 0; i < num_coeffs; ++i) { int64_t val = static_cast(std::round(re[i])) + carry; uint64_t digit; if (val >= 0) { digit = static_cast(val) & bmask; carry = val >> B; } else { // defensive: negacyclic wrap or rounding artifact int64_t m = val % base; if (m < 0) m += base; digit = static_cast(m); carry = (val - static_cast(digit)) / base; } const size_t word = bit_pos / 64; const int shift = static_cast(bit_pos & 63); if (word < out_n) { out[word] |= digit << shift; if (shift + B > 64 && word + 1 < out_n) out[word + 1] |= digit >> (64 - shift); } bit_pos += static_cast(B); } // flush remaining carry while (carry > 0 && bit_pos / 64 < out_n) { uint64_t digit = static_cast(carry) & bmask; carry >>= B; const size_t word = bit_pos / 64; const int shift = static_cast(bit_pos & 63); if (word < out_n) { out[word] |= digit << shift; if (shift + B > 64 && word + 1 < out_n) out[word + 1] |= digit >> (64 - shift); } bit_pos += static_cast(B); } } // ================================================================ // YC-8d: エラー検出 (mod 2^61-1 ハッシュ + 係数上限チェック) // ================================================================ // Mersenne prime P = 2^61 - 1 static constexpr uint64_t MERSENNE61 = (1ULL << 61) - 1; // 安全マージン: 2B + log2(N) + C_SAFETY <= 53 static constexpr int C_SAFETY_DEFAULT = 3; // 通常モード static constexpr int C_SAFETY_AGGRESSIVE = 2; // 積極的 (高速だが誤差リスク増) // 丸め誤差閾値: IFFT 後の各係数が整数から離れすぎていないか static constexpr double COEFF_ERROR_THRESHOLD = 0.25; // number mod (2^61-1) を limb 配列から計算 // A = sum(limbs[i] * 2^(64*i)), 2^64 ≡ 8 (mod P) inline uint64_t mod_mersenne61(const uint64_t* limbs, size_t n) { constexpr uint64_t P = MERSENNE61; uint64_t h = 0; for (size_t i = n; i-- > 0; ) { // h = (h * 8 + limbs[i]) mod P uint64_t h8 = h << 3; // h < P < 2^61 → h8 < 2^64 h8 = (h8 >> 61) + (h8 & P); if (h8 >= P) h8 -= P; uint64_t li = (limbs[i] >> 61) + (limbs[i] & P); if (li >= P) li -= P; h = h8 + li; if (h >= P) h -= P; } return h; } // (a * b) mod (2^61-1), a,b < P inline uint64_t mulmod_mersenne61(uint64_t a, uint64_t b) { constexpr uint64_t P = MERSENNE61; uint64_t hi; uint64_t lo = _umul128(a, b, &hi); // product = hi * 2^64 + lo, 2^64 ≡ 8 mod P // hi < 2^58 (since a,b < 2^61), hi*8 < 2^61 uint64_t h8 = hi << 3; uint64_t lo_mod = (lo >> 61) + (lo & P); if (lo_mod >= P) lo_mod -= P; uint64_t r = h8 + lo_mod; if (r >= P) r -= P; return r; } // IFFT 後の係数を検証: 丸め誤差が閾値内か確認 // im[] は線形畳み込み (weight なし) では 0 近傍、 // negacyclic では unweight 後に im 側も使うが c_re のみチェック inline bool check_coeff_error(const double* c_re, size_t N) { for (size_t i = 0; i < N; ++i) { double err = std::abs(c_re[i] - std::round(c_re[i])); if (err > COEFF_ERROR_THRESHOLD) return false; } return true; } // choose_fft_bits の safety margin 可変版 // B を広く探索: 大きい B → 少ない係数 → 小さい N → 高速 inline int choose_fft_bits_ex(size_t an, size_t bn, int c_safety) { for (int B = 24; B >= 10; --B) { size_t ca = coeff_count(an, B); size_t cb = coeff_count(bn, B); size_t N = fft_next_pow2(ca + cb); int log2N = 0; for (size_t t = N; t > 1; t >>= 1) ++log2N; if (2 * B + log2N + c_safety <= 53) return B; } return 0; } // ================================================================ // mul_double_fft: main entry point // ================================================================ // rp[0..an+bn-1] = ap[0..an-1] * bp[0..bn-1] // Returns true on success, false if error detected (caller should use NTT). // Interface matches mul_prime_ntt (except for return value). // mul_double_fft の内部実装 (指定の B, N で計算) inline bool mul_double_fft_core(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn, int B, size_t N, uint64_t hash_expect) { const size_t rn = an + bn; // thread_local caches thread_local FftRootsDouble cached_roots; if (cached_roots.N != N) cached_roots.build(N); thread_local std::vector buf; if (buf.size() < 6 * N) buf.resize(6 * N); double* a_re = buf.data(); double* a_im = a_re + N; double* b_re = a_im + N; double* b_im = b_re + N; double* c_re = b_im + N; double* c_im = c_re + N; // 1. split limb_to_fft_coeffs(ap, an, B, a_re, a_im, N); limb_to_fft_coeffs(bp, bn, B, b_re, b_im, N); // 2. forward FFT #ifdef __AVX2__ forward_fft_double_avx2(a_re, a_im, N, cached_roots); forward_fft_double_avx2(b_re, b_im, N, cached_roots); #else forward_fft_double(a_re, a_im, N, cached_roots); forward_fft_double(b_re, b_im, N, cached_roots); #endif // 3. pointwise multiply pointwise_mul(a_re, a_im, b_re, b_im, c_re, c_im, N); // 4. inverse FFT (includes 1/N scaling) #ifdef __AVX2__ inverse_fft_double_avx2(c_re, c_im, N, cached_roots); #else inverse_fft_double(c_re, c_im, N, cached_roots); #endif // YC-8d: 係数上限チェック (丸め誤差が閾値内か) if (!check_coeff_error(c_re, N)) return false; // 5. reconstruct limbs fft_coeffs_to_limbs(c_re, N, B, rp, rn); // YC-8d: 出力ハッシュ検証 const uint64_t hash_r = mod_mersenne61(rp, rn); if (hash_r != hash_expect) return false; return true; } inline bool mul_double_fft(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn) { if (an < bn) { std::swap(ap, bp); std::swap(an, bn); } // YC-8d: 入力ハッシュ (フォールバック判定用) const uint64_t hash_a = mod_mersenne61(ap, an); const uint64_t hash_b = mod_mersenne61(bp, bn); const uint64_t hash_expect = mulmod_mersenne61(hash_a, hash_b); // YC-8e: aggressive モード (C_SAFETY=2, wider B) を先に試行 { const int B_agg = choose_fft_bits_ex(an, bn, C_SAFETY_AGGRESSIVE); if (B_agg > 0) { const size_t ca = coeff_count(an, B_agg); const size_t cb = coeff_count(bn, B_agg); const size_t N = fft_next_pow2(ca + cb); if (mul_double_fft_core(rp, ap, an, bp, bn, B_agg, N, hash_expect)) return true; // aggressive で失敗 → conservative で再試行 } } // conservative モード (C_SAFETY=3) const int B = choose_fft_bits_ex(an, bn, C_SAFETY_DEFAULT); if (B == 0) return false; // double FFT では扱えないサイズ const size_t ca = coeff_count(an, B); const size_t cb = coeff_count(bn, B); const size_t N = fft_next_pow2(ca + cb); return mul_double_fft_core(rp, ap, an, bp, bn, B, N, hash_expect); } // ================================================================ // sqr_double_fft: squaring specialization (single FFT) // ================================================================ // sqr_double_fft の内部実装 inline bool sqr_double_fft_core(uint64_t* rp, const uint64_t* ap, size_t an, int B, size_t N, uint64_t hash_expect) { thread_local FftRootsDouble cached_roots; if (cached_roots.N != N) cached_roots.build(N); thread_local std::vector buf; if (buf.size() < 4 * N) buf.resize(4 * N); double* a_re = buf.data(); double* a_im = a_re + N; double* c_re = a_im + N; double* c_im = c_re + N; limb_to_fft_coeffs(ap, an, B, a_re, a_im, N); #ifdef __AVX2__ forward_fft_double_avx2(a_re, a_im, N, cached_roots); pointwise_sqr(a_re, a_im, c_re, c_im, N); inverse_fft_double_avx2(c_re, c_im, N, cached_roots); #else forward_fft_double(a_re, a_im, N, cached_roots); pointwise_sqr(a_re, a_im, c_re, c_im, N); inverse_fft_double(c_re, c_im, N, cached_roots); #endif // YC-8d: 係数上限チェック if (!check_coeff_error(c_re, N)) return false; fft_coeffs_to_limbs(c_re, N, B, rp, 2 * an); // YC-8d: 出力ハッシュ検証 const uint64_t hash_r = mod_mersenne61(rp, 2 * an); if (hash_r != hash_expect) return false; return true; } inline bool sqr_double_fft(uint64_t* rp, const uint64_t* ap, size_t an) { // YC-8d: 入力ハッシュ const uint64_t hash_a = mod_mersenne61(ap, an); const uint64_t hash_expect = mulmod_mersenne61(hash_a, hash_a); // YC-8e: aggressive モード (C_SAFETY=2) を先に試行 { const int B_agg = choose_fft_bits_ex(an, an, C_SAFETY_AGGRESSIVE); if (B_agg > 0) { const size_t ca = coeff_count(an, B_agg); const size_t N = fft_next_pow2(2 * ca); if (sqr_double_fft_core(rp, ap, an, B_agg, N, hash_expect)) return true; } } // conservative モード (C_SAFETY=3) const int B = choose_fft_bits_ex(an, an, C_SAFETY_DEFAULT); if (B == 0) return false; const size_t ca = coeff_count(an, B); const size_t N = fft_next_pow2(2 * ca); return sqr_double_fft_core(rp, ap, an, B, N, hash_expect); } } // namespace double_fft } // namespace calx