// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // PrimeNtt.hpp // 素数モジュラー NTT + CRT による高速乗算 // 3 つの 64bit NTT-friendly 素数で NTT を実行し、 // 中国剰余定理で結果を復元する。 // Schönhage-Strassen (mod B^F+1) の代替で、 // pointwise が O(1) (64bit mod 乗算) になる。 #pragma once #ifndef SANGI_FORCEINLINE #ifdef _MSC_VER #define SANGI_FORCEINLINE __forceinline #else #define SANGI_FORCEINLINE __attribute__((always_inline)) inline #endif #endif #include #include #include #include #include #include #ifdef _MSC_VER #include #endif #include // AVX2 intrinsics はグローバルスコープで include する必要がある // (namespace 内で include すると GCC でエラーになる) #if defined(__AVX2__) #include #endif namespace sangi { namespace prime_ntt { // キャッシュブロッキングの閾値 (L1 32KB に収まるブロックサイズ) // 512 要素 = 4KB data + ~4KB roots = 8KB, L1 に十分な余裕 static constexpr size_t NTT_CACHE_BLOCK = 512; // ================================================================ // NTT-friendly 素数の定義 // ================================================================ // 条件: p = k × 2^s + 1 (s が大きいほど長い NTT が可能) // p < 2^63 (符号なし加算で 2p < 2^64) // FLINT/GMP で実績のある素数を採用 // p1 = 2^62 - 2^16 + 1 = 4611686018427387905 // = 1 × 2^62 + ... (素数性は別途検証) // → FLINT の n_mulmod_preinv 系で使われる素数を参考 // 実績ある NTT-friendly 素数 (FLINT/NTL 等で使用): // p = k * 2^s + 1 形式で s >= 47 // 検証済み NTT-friendly 素数 (p = k * 2^47 + 1) // NTT 最大長: 2^47 (= 140兆点、実用上無制限) // p1*p2*p3 ≈ 2^163 → n < 2^29 (≈500M limbs ≈ 10G 桁) まで CRT 正確復元 // 全素数で 15|(p-1) → 混合基数 NTT (N = {3,5}×2^k) をサポート // BS テーブルの検証済み素数から選択 constexpr uint64_t P1 = 0x000F'0000'0000'0001ULL; // 30 * 2^47 + 1, g=19 (30=2×3×5) constexpr uint64_t P2 = 0x00AC'8000'0000'0001ULL; // 345 * 2^47 + 1, g=13 (345=3×5×23) constexpr uint64_t P3 = 0x00D2'0000'0000'0001ULL; // 420 * 2^47 + 1, g=17 (420=4×3×5×7) constexpr uint64_t P4 = 0x0032'8000'0000'0001ULL; // 101 * 2^47 + 1, g=3 constexpr uint64_t P5 = 0x0037'8000'0000'0001ULL; // 111 * 2^47 + 1, g=11 constexpr uint64_t G1 = 19; // P1 の原始根 constexpr uint64_t G2 = 13; // P2 の原始根 constexpr uint64_t G3 = 17; // P3 の原始根 constexpr uint64_t G4 = 3; // P4 の原始根 constexpr uint64_t G5 = 11; // P5 の原始根 constexpr int NTT_MAX_S = 47; // 最大 NTT 長 = 2^47 // NTT ドメイン BS 用の素数数 // 64 素数 (M ≈ 2^3582) + 16-bit ワード (B=2^16) → NTT ドメインで深さ 6 まで安全 // 深さ d の係数上界: N^(2^d-1) · B^(2^d) < M // d=6, N=2^22, B=2^16: 22·63 + 16·64 = 2410 < 3582 ✓ // d=7: 22·127 + 16·128 = 4842 > 3582 ✗ constexpr int BS_NUM_PRIMES = 64; // BS 用素数テーブル: p = k * 2^47 + 1 形式、原始根付き struct BsPrimeEntry { uint64_t p; int g; }; static constexpr BsPrimeEntry BS_PRIME_TABLE[BS_NUM_PRIMES] = { { 0x000D800000000001ULL, 5 }, // k=27 { 0x000F000000000001ULL, 19 }, // k=30 { 0x0011800000000001ULL, 3 }, // k=35 { 0x001C000000000001ULL, 6 }, // k=56 { 0x0020800000000001ULL, 3 }, // k=65 { 0x002E000000000001ULL, 3 }, // k=92 { 0x0032800000000001ULL, 3 }, // k=101 { 0x0037800000000001ULL, 11 }, // k=111 { 0x0039000000000001ULL, 5 }, // k=114 { 0x003D000000000001ULL, 3 }, // k=122 { 0x0046000000000001ULL, 3 }, // k=140 { 0x0051000000000001ULL, 5 }, // k=162 { 0x006C000000000001ULL, 11 }, // k=216 { 0x0070000000000001ULL, 3 }, // k=224 { 0x0088000000000001ULL, 3 }, // k=272 { 0x00AC800000000001ULL, 13 }, // k=345 { 0x00B1000000000001ULL, 5 }, // k=354 { 0x00B3800000000001ULL, 3 }, // k=359 { 0x00BB000000000001ULL, 3 }, // k=374 { 0x00D2000000000001ULL, 17 }, // k=420 { 0x00D7800000000001ULL, 3 }, // k=431 { 0x00DC800000000001ULL, 13 }, // k=441 { 0x00EB800000000001ULL, 17 }, // k=471 { 0x00FD000000000001ULL, 3 }, // k=506 { 0x0104800000000001ULL, 3 }, // k=521 { 0x0107800000000001ULL, 3 }, // k=527 { 0x0114000000000001ULL, 5 }, // k=552 { 0x0149800000000001ULL, 3 }, // k=659 { 0x016F000000000001ULL, 3 }, // k=734 { 0x0170800000000001ULL, 3 }, // k=737 { 0x0176800000000001ULL, 3 }, // k=749 { 0x0179800000000001ULL, 6 }, // k=755 { 0x0190000000000001ULL, 3 }, // k=800 { 0x0194800000000001ULL, 3 }, // k=809 { 0x0197800000000001ULL, 3 }, // k=815 { 0x01A2800000000001ULL, 7 }, // k=837 { 0x01A7000000000001ULL, 26 }, // k=846 { 0x01AC800000000001ULL, 3 }, // k=857 { 0x01B5800000000001ULL, 15 }, // k=875 { 0x01C5000000000001ULL, 14 }, // k=906 { 0x01E6000000000001ULL, 5 }, // k=972 { 0x01ED800000000001ULL, 5 }, // k=987 { 0x01EE800000000001ULL, 3 }, // k=989 { 0x01F7800000000001ULL, 3 }, // k=1007 { 0x020C800000000001ULL, 3 }, // k=1049 { 0x021E800000000001ULL, 3 }, // k=1085 { 0x022B000000000001ULL, 14 }, // k=1110 { 0x0235000000000001ULL, 3 }, // k=1130 { 0x0247000000000001ULL, 3 }, // k=1166 { 0x024D800000000001ULL, 5 }, // k=1179 { 0x0258000000000001ULL, 11 }, // k=1200 { 0x025D800000000001ULL, 3 }, // k=1211 { 0x0266800000000001ULL, 3 }, // k=1229 { 0x026D000000000001ULL, 5 }, // k=1242 { 0x027B800000000001ULL, 3 }, // k=1271 { 0x0280000000000001ULL, 6 }, // k=1280 { 0x0289800000000001ULL, 5 }, // k=1299 { 0x0293800000000001ULL, 3 }, // k=1319 { 0x0297000000000001ULL, 7 }, // k=1326 { 0x02A5800000000001ULL, 3 }, // k=1355 { 0x02A9000000000001ULL, 5 }, // k=1362 { 0x02BF000000000001ULL, 3 }, // k=1406 { 0x02C6800000000001ULL, 3 }, // k=1421 { 0x02CD000000000001ULL, 5 }, // k=1434 }; // ================================================================ // mod p 算術プリミティブ // ================================================================ // (a + b) mod p — a, b < p を前提 inline uint64_t mod_add(uint64_t a, uint64_t b, uint64_t p) { uint64_t sum = a + b; // sum >= p (オーバーフロー含む) なら p を引く // a, b < p < 2^63 なので sum < 2^64 (オーバーフローなし) return (sum >= p) ? (sum - p) : sum; } // (a - b) mod p — a, b < p を前提 inline uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t p) { // a < b なら a - b + p (アンダーフロー補正) return (a >= b) ? (a - b) : (a - b + p); } // (a * b) mod p — a, b < p を前提 // 128bit 積 → Barrett reduction inline uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t p) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi; uint64_t lo = _umul128(a, b, &hi); uint64_t rem; _udiv128(hi, lo, p, &rem); return rem; #elif defined(__GNUC__) || defined(__clang__) __uint128_t prod = static_cast<__uint128_t>(a) * b; return static_cast(prod % p); #else // フォールバック: 遅い __uint128_t prod = static_cast<__uint128_t>(a) * b; return static_cast(prod % p); #endif } // base^exp mod p (高速冪剰余) inline uint64_t mod_pow(uint64_t base, uint64_t exp, uint64_t p) { uint64_t result = 1; base %= p; while (exp > 0) { if (exp & 1) result = mod_mul(result, base, p); base = mod_mul(base, base, p); exp >>= 1; } return result; } // a^(-1) mod p (Fermat の小定理: a^(p-2) mod p) inline uint64_t mod_inv(uint64_t a, uint64_t p) { return mod_pow(a, p - 2, p); } // ================================================================ // 素数性検証と原始根探索 (初期化時に使用) // ================================================================ // Miller-Rabin 素数判定 (確定的、64bit) inline bool is_prime_64(uint64_t n) { if (n < 2) return false; if (n < 4) return true; if (n % 2 == 0) return false; // n-1 = d * 2^r uint64_t d = n - 1; int r = 0; while ((d & 1) == 0) { d >>= 1; r++; } // 64bit で確定的な witness セット const uint64_t witnesses[] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}; for (uint64_t a : witnesses) { if (a >= n) continue; uint64_t x = mod_pow(a, d, n); if (x == 1 || x == n - 1) continue; bool composite = true; for (int i = 0; i < r - 1; i++) { x = mod_mul(x, x, n); if (x == n - 1) { composite = false; break; } } if (composite) return false; } return true; } // k * 2^s + 1 形式の素数を探索 // min_s: 最小の s (NTT 長 2^s を保証) // start_k: 探索開始の k (奇数) inline uint64_t find_ntt_prime(int min_s, uint64_t start_k = 1) { for (uint64_t k = start_k; k < (1ULL << (63 - min_s)); k += 2) { uint64_t p = k * (1ULL << min_s) + 1; if (p >= (1ULL << 63)) break; // 2p < 2^64 を保証 if (is_prime_64(p)) return p; } return 0; // 見つからなかった } // 原始根を探索 (p-1 の素因数分解が必要) // p-1 = 2^s * k の場合、p-1 の素因数は 2 と k の因数 inline uint64_t find_primitive_root(uint64_t p) { // p-1 の素因数を列挙 uint64_t pm1 = p - 1; std::vector factors; // 2 は必ず因数 factors.push_back(2); // 奇数部分 k の因数分解 uint64_t k = pm1; while ((k & 1) == 0) k >>= 1; uint64_t temp = k; for (uint64_t f = 3; f * f <= temp; f += 2) { if (temp % f == 0) { factors.push_back(f); while (temp % f == 0) temp /= f; } } if (temp > 1) factors.push_back(temp); // g^((p-1)/q) != 1 (mod p) for all prime factors q of p-1 for (uint64_t g = 2; g < p; ++g) { bool is_root = true; for (uint64_t q : factors) { if (mod_pow(g, pm1 / q, p) == 1) { is_root = false; break; } } if (is_root) return g; } return 0; } // ================================================================ // Montgomery 乗算 (NTT 高速化用) // ================================================================ // R = 2^64, p' = -p^(-1) mod R // REDC(T) = (T + (T_lo * p') * p) >> 64 // 2 × MULX + ADD で完了 (_udiv128 の 35-45 cycles → ~8 cycles) // p^(-1) mod 2^64 を Newton 法で計算 (p は奇数) inline uint64_t compute_p_inv_neg(uint64_t p) { // p * x ≡ 1 (mod 2^64) を Newton 法で求め、符号反転 uint64_t x = 1; for (int i = 0; i < 6; i++) // 6 回で 64bit 収束 x *= 2 - p * x; return ~x + 1; // -p^(-1) mod 2^64 } // Montgomery 乗算: REDC(a * b) = a*b*R^(-1) mod p // a, b は Montgomery 形式 (a*R mod p, b*R mod p) でも通常値でも可 inline uint64_t mont_mul(uint64_t a, uint64_t b, uint64_t p, uint64_t p_inv_neg) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t T_hi; uint64_t T_lo = _umul128(a, b, &T_hi); // m = T_lo * (-p^(-1)) mod 2^64 uint64_t m = T_lo * p_inv_neg; // m * p uint64_t mp_hi; uint64_t mp_lo = _umul128(m, p, &mp_hi); // T + m*p の上位 64bit (下位は構造的に 0) uint64_t carry = (T_lo + mp_lo < T_lo) ? 1 : 0; uint64_t t = T_hi + mp_hi + carry; return (t >= p) ? (t - p) : t; #elif defined(__GNUC__) || defined(__clang__) unsigned __int128 T = (unsigned __int128)a * b; uint64_t T_lo = (uint64_t)T; uint64_t m = T_lo * p_inv_neg; unsigned __int128 mp = (unsigned __int128)m * p; unsigned __int128 sum = T + mp; uint64_t t = (uint64_t)(sum >> 64); return (t >= p) ? (t - p) : t; #endif } // Montgomery 形式への変換: to_mont(a) = a * R mod p = REDC(a * R^2) inline uint64_t to_mont(uint64_t a, uint64_t p, uint64_t p_inv_neg, uint64_t r2_mod_p) { return mont_mul(a, r2_mod_p, p, p_inv_neg); } // Montgomery 形式からの逆変換: from_mont(a_mont) = REDC(a_mont * 1) inline uint64_t from_mont(uint64_t a_mont, uint64_t p, uint64_t p_inv_neg) { return mont_mul(a_mont, 1, p, p_inv_neg); } // ================================================================ // NTT 素数パラメータ構造体 // ================================================================ struct NttPrime { uint64_t p; // 素数 uint64_t g; // 原始根 int s; // p-1 = k * 2^s + ... の s (最大 NTT 長 = 2^s) uint64_t inv_2; // 2^(-1) mod p // Montgomery 定数 uint64_t p_inv_neg; // -p^(-1) mod 2^64 uint64_t r2_mod_p; // R^2 mod p (R = 2^64) // N 点 NTT の原始 N 乗根: omega = g^((p-1)/N) mod p uint64_t nth_root(size_t N) const { return mod_pow(g, (p - 1) / N, p); } // N^(-1) mod p uint64_t inv_n(size_t N) const { return mod_inv(N, p); } }; // ================================================================ // 根テーブル (twiddle factors) のキャッシュ // ================================================================ struct NttRoots { std::vector roots; // forward: omega^i std::vector inv_roots; // inverse: omega^(-i) uint64_t inv_n; // N^(-1) mod p size_t N; // NTT 長 uint64_t p; // 素数 void build(const NttPrime& prime, size_t ntt_len) { N = ntt_len; p = prime.p; inv_n = prime.inv_n(N); uint64_t omega = prime.nth_root(N); uint64_t omega_inv = mod_inv(omega, p); roots.resize(N / 2); inv_roots.resize(N / 2); roots[0] = 1; inv_roots[0] = 1; for (size_t i = 1; i < N / 2; ++i) { roots[i] = mod_mul(roots[i - 1], omega, p); inv_roots[i] = mod_mul(inv_roots[i - 1], omega_inv, p); } } }; // ================================================================ // NTT forward (DIF: Decimation In Frequency) // ================================================================ inline void forward_ntt(uint64_t* data, size_t N, uint64_t p, const uint64_t* roots) { for (size_t len = N; len >= 2; len >>= 1) { size_t half = len >> 1; size_t step = N / len; // root index step for (size_t start = 0; start < N; start += len) { for (size_t j = 0; j < half; ++j) { uint64_t u = data[start + j]; uint64_t v = data[start + j + half]; // DIF butterfly: (u, v) → (u+v, (u-v)*ω) data[start + j] = mod_add(u, v, p); uint64_t diff = mod_sub(u, v, p); data[start + j + half] = mod_mul(diff, roots[j * step], p); } } } } // ================================================================ // NTT inverse (DIT: Decimation In Time) // ================================================================ inline void inverse_ntt(uint64_t* data, size_t N, uint64_t p, const uint64_t* inv_roots, uint64_t inv_n_val) { for (size_t len = 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t step = N / len; for (size_t start = 0; start < N; start += len) { for (size_t j = 0; j < half; ++j) { uint64_t u = data[start + j]; uint64_t v = mod_mul(data[start + j + half], inv_roots[j * step], p); // DIT butterfly: (u, v·ω⁻¹) → (u+v, u-v) data[start + j] = mod_add(u, v, p); data[start + j + half] = mod_sub(u, v, p); } } } // N^(-1) で正規化 for (size_t i = 0; i < N; ++i) { data[i] = mod_mul(data[i], inv_n_val, p); } } // ================================================================ // Pointwise 演算 // ================================================================ inline void pointwise_mul(uint64_t* a, const uint64_t* b, size_t N, uint64_t p) { for (size_t i = 0; i < N; ++i) { a[i] = mod_mul(a[i], b[i], p); } } inline void pointwise_sqr(uint64_t* a, size_t N, uint64_t p) { for (size_t i = 0; i < N; ++i) { a[i] = mod_mul(a[i], a[i], p); } } // ================================================================ // Montgomery 形式の根テーブルと NTT // ================================================================ struct NttRootsMont { std::vector roots; // forward: omega^i (Montgomery 形式) std::vector inv_roots; // inverse: omega^(-i) (Montgomery 形式) uint64_t inv_n; // N^(-1) mod p (通常形式 — from_mont に兼用) size_t N; uint64_t p; uint64_t p_inv_neg; void build(const NttPrime& prime, size_t ntt_len) { N = ntt_len; p = prime.p; p_inv_neg = prime.p_inv_neg; inv_n = prime.inv_n(N); // 通常形式 (逆変換の最終スケーリング用) uint64_t omega = prime.nth_root(N); uint64_t omega_inv = mod_inv(omega, p); // Montgomery 形式に変換 uint64_t omega_mont = to_mont(omega, p, p_inv_neg, prime.r2_mod_p); uint64_t omega_inv_mont = to_mont(omega_inv, p, p_inv_neg, prime.r2_mod_p); roots.resize(N / 2); inv_roots.resize(N / 2); // roots[0] = 1 の Montgomery 形式 = R mod p uint64_t one_mont = to_mont(1, p, p_inv_neg, prime.r2_mod_p); roots[0] = one_mont; inv_roots[0] = one_mont; for (size_t i = 1; i < N / 2; ++i) { roots[i] = mont_mul(roots[i - 1], omega_mont, p, p_inv_neg); inv_roots[i] = mont_mul(inv_roots[i - 1], omega_inv_mont, p, p_inv_neg); } } }; // Montgomery 形式 NTT forward (DIF) // data は Montgomery 形式、roots も Montgomery 形式 inline void forward_ntt_mont(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* roots) { for (size_t len = N; len >= 2; len >>= 1) { size_t half = len >> 1; size_t step = N / len; for (size_t start = 0; start < N; start += len) { for (size_t j = 0; j < half; ++j) { uint64_t u = data[start + j]; uint64_t v = data[start + j + half]; data[start + j] = mod_add(u, v, p); uint64_t diff = mod_sub(u, v, p); data[start + j + half] = mont_mul(diff, roots[j * step], p, p_inv_neg); } } } } // Montgomery 形式 NTT inverse (DIT) // 最後に from_mont + inv_n スケーリングを同時に行う // inv_n_normal: N^(-1) mod p (通常形式) inline void inverse_ntt_mont(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* inv_roots, uint64_t inv_n_normal) { for (size_t len = 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t step = N / len; for (size_t start = 0; start < N; start += len) { for (size_t j = 0; j < half; ++j) { uint64_t u = data[start + j]; uint64_t v = mont_mul(data[start + j + half], inv_roots[j * step], p, p_inv_neg); data[start + j] = mod_add(u, v, p); data[start + j + half] = mod_sub(u, v, p); } } } // N^(-1) スケーリング + Montgomery → 通常形式の逆変換を同時実行 // mont_mul(data[i]_mont, inv_n_normal) = data_i * inv_n mod p for (size_t i = 0; i < N; ++i) { data[i] = mont_mul(data[i], inv_n_normal, p, p_inv_neg); } } // Montgomery 形式 pointwise 乗算 inline void pointwise_mul_mont(uint64_t* a, const uint64_t* b, size_t N, uint64_t p, uint64_t p_inv_neg) { for (size_t i = 0; i < N; ++i) { a[i] = mont_mul(a[i], b[i], p, p_inv_neg); } } inline void pointwise_sqr_mont(uint64_t* a, size_t N, uint64_t p, uint64_t p_inv_neg) { for (size_t i = 0; i < N; ++i) { a[i] = mont_mul(a[i], a[i], p, p_inv_neg); } } // YC-2: pointwise multiply-accumulate: r[i] += a[i] * b[i] mod p (スカラー) inline void pointwise_mac_mont(uint64_t* r, const uint64_t* a, const uint64_t* b, size_t N, uint64_t p, uint64_t p_inv_neg) { for (size_t i = 0; i < N; ++i) { r[i] = mod_add(r[i], mont_mul(a[i], b[i], p, p_inv_neg), p); } } // ================================================================ // AVX2 高速 NTT (Montgomery 形式) // ================================================================ // _mm256_mul_epu32 による 4-wide Montgomery 乗算で NTT バタフライを高速化。 // データは Montgomery 形式のまま。レイヤーごとの根テーブルにより // ストライドアクセスを排除。 #if defined(__AVX2__) // --- AVX2 64-bit 乗算ヘルパー --- // 4-lane 64×64→128 bit 乗算 (lo と hi を同時取得) // 4 個の _mm256_mul_epu32 で 32-bit 部分積を計算し合成。 // 入力は任意の uint64_t (2^64 未満) で正しい結果を返す。 inline void avx2_mul_full_epu64(__m256i a, __m256i b, __m256i& lo, __m256i& hi) { const __m256i mask32 = _mm256_set1_epi64x(0xFFFFFFFF); __m256i a_hi = _mm256_srli_epi64(a, 32); __m256i b_hi = _mm256_srli_epi64(b, 32); __m256i ll = _mm256_mul_epu32(a, b); // a_lo * b_lo __m256i lh = _mm256_mul_epu32(a, b_hi); // a_lo * b_hi __m256i hl = _mm256_mul_epu32(a_hi, b); // a_hi * b_lo __m256i hh = _mm256_mul_epu32(a_hi, b_hi); // a_hi * b_hi // lo = ll + (lh + hl) << 32 [mod 2^64] __m256i cross = _mm256_add_epi64(lh, hl); lo = _mm256_add_epi64(ll, _mm256_slli_epi64(cross, 32)); // hi: 桁上がりを正確に伝搬 // mid = lh + (ll >> 32), carry の伝搬なし (lh < 2^64, ll>>32 < 2^32) __m256i ll_hi = _mm256_srli_epi64(ll, 32); __m256i mid = _mm256_add_epi64(lh, ll_hi); __m256i mid_lo = _mm256_and_si256(mid, mask32); __m256i mid_hi = _mm256_srli_epi64(mid, 32); // mid_lo + hl: mid_lo < 2^32, hl < 2^64 → sum < 2^64 + 2^32 < 2^64 (∵ hl ≤ (2^32-1)^2 = 2^64-2^33+1, mid_lo ≤ 2^32-1) __m256i t = _mm256_add_epi64(mid_lo, hl); __m256i t_hi = _mm256_srli_epi64(t, 32); hi = _mm256_add_epi64(hh, _mm256_add_epi64(mid_hi, t_hi)); } // low 64 bits of a*b (3 mul_epu32) inline __m256i avx2_mullo_epu64(__m256i a, __m256i b) { __m256i a_hi = _mm256_srli_epi64(a, 32); __m256i b_hi = _mm256_srli_epi64(b, 32); __m256i ll = _mm256_mul_epu32(a, b); __m256i lh = _mm256_mul_epu32(a, b_hi); __m256i hl = _mm256_mul_epu32(a_hi, b); __m256i cross = _mm256_add_epi64(lh, hl); return _mm256_add_epi64(ll, _mm256_slli_epi64(cross, 32)); } // high 64 bits of a*b (4 mul_epu32) inline __m256i avx2_mulhi_epu64(__m256i a, __m256i b) { const __m256i mask32 = _mm256_set1_epi64x(0xFFFFFFFF); __m256i a_hi = _mm256_srli_epi64(a, 32); __m256i b_hi = _mm256_srli_epi64(b, 32); __m256i ll = _mm256_mul_epu32(a, b); __m256i lh = _mm256_mul_epu32(a, b_hi); __m256i hl = _mm256_mul_epu32(a_hi, b); __m256i hh = _mm256_mul_epu32(a_hi, b_hi); __m256i ll_hi = _mm256_srli_epi64(ll, 32); __m256i mid = _mm256_add_epi64(lh, ll_hi); __m256i mid_lo = _mm256_and_si256(mid, mask32); __m256i mid_hi = _mm256_srli_epi64(mid, 32); __m256i t = _mm256_add_epi64(mid_lo, hl); __m256i t_hi = _mm256_srli_epi64(t, 32); return _mm256_add_epi64(hh, _mm256_add_epi64(mid_hi, t_hi)); } // 4-wide Montgomery 乗算: REDC(a * b) mod p // 入力: a, b は Montgomery 形式 (< p) // p_vec = broadcast(p), pin_vec = broadcast(-p^{-1} mod 2^64) // 利用する性質: T_lo + m*p_lo ≡ 0 (mod 2^64) → carry = (T_lo != 0) inline __m256i avx2_mont_mul(__m256i a, __m256i b, __m256i p_vec, __m256i pin_vec) { // Step 1: T = a * b (128-bit), lo と hi を同時取得 __m256i T_lo, T_hi; avx2_mul_full_epu64(a, b, T_lo, T_hi); // 4 VPMULUDQ // Step 2: m = T_lo * (-p^{-1}) mod 2^64 __m256i m = avx2_mullo_epu64(T_lo, pin_vec); // 3 VPMULUDQ // Step 3: mp_hi = mulhi(m, p) __m256i mp_hi = avx2_mulhi_epu64(m, p_vec); // 4 VPMULUDQ // Step 4: carry = (T_lo != 0) ? 1 : 0 __m256i eq_zero = _mm256_cmpeq_epi64(T_lo, _mm256_setzero_si256()); __m256i carry = _mm256_andnot_si256(eq_zero, _mm256_set1_epi64x(1)); // Step 5: t = T_hi + mp_hi + carry __m256i t = _mm256_add_epi64(_mm256_add_epi64(T_hi, mp_hi), carry); // Step 6: if t >= p: t -= p (符号ビットで判定) __m256i t_sub_p = _mm256_sub_epi64(t, p_vec); // t < p → t_sub_p は巨大 (bit 63 = 1, 符号付きで負) __m256i neg_mask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), t_sub_p); return _mm256_blendv_epi8(t_sub_p, t, neg_mask); } // 合計: 11 VPMULUDQ + ~10 add/cmp/blend per 4 elements // --- レイヤーごとの根テーブル --- // 各 NTT レイヤーの根を連続配置し、AVX2 のシーケンシャルロードに対応。 struct NttRootsLayered { std::vector fwd; // forward roots (層ごと連続) std::vector inv; // inverse roots (層ごと連続) std::vector fwd_offset; // fwd 内の各層オフセット std::vector inv_offset; // inv 内の各層オフセット uint64_t inv_n; // N^{-1} mod p (通常形式) size_t N = 0; uint64_t p = 0; uint64_t p_inv_neg = 0; void build(const NttPrime& prime, size_t ntt_len) { N = ntt_len; p = prime.p; p_inv_neg = prime.p_inv_neg; inv_n = prime.inv_n(N); uint64_t omega = prime.nth_root(N); uint64_t omega_inv = mod_inv(omega, p); // Montgomery 形式に変換 uint64_t omega_mont = to_mont(omega, p, p_inv_neg, prime.r2_mod_p); uint64_t omega_inv_mont = to_mont(omega_inv, p, p_inv_neg, prime.r2_mod_p); // 通常の根テーブルを構築 std::vector all_roots(N / 2), all_inv(N / 2); uint64_t one_mont = to_mont(1, p, p_inv_neg, prime.r2_mod_p); all_roots[0] = one_mont; all_inv[0] = one_mont; for (size_t i = 1; i < N / 2; ++i) { all_roots[i] = mont_mul(all_roots[i - 1], omega_mont, p, p_inv_neg); all_inv[i] = mont_mul(all_inv[i - 1], omega_inv_mont, p, p_inv_neg); } // レイヤーごとに連続配置 // Forward DIF: layer l → len = N >> l, half = len/2, step = 1 << l // roots needed: all_roots[j * step] for j = 0..half-1 int num_layers = 0; for (size_t l = N; l >= 2; l >>= 1) ++num_layers; fwd.resize(N - 1); // 各層の合計: N/2 + N/4 + ... + 1 = N - 1 inv.resize(N - 1); fwd_offset.resize(num_layers); inv_offset.resize(num_layers); size_t off = 0; int layer = 0; // Forward (DIF): len = N, N/2, ..., 2 for (size_t len = N; len >= 2; len >>= 1) { size_t half = len >> 1; size_t step = N / len; fwd_offset[layer] = off; for (size_t j = 0; j < half; ++j) fwd[off + j] = all_roots[j * step]; off += half; ++layer; } off = 0; layer = 0; // Inverse (DIT): len = 2, 4, ..., N for (size_t len = 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t step = N / len; inv_offset[layer] = off; for (size_t j = 0; j < half; ++j) inv[off + j] = all_inv[j * step]; off += half; ++layer; } } }; // ================================================================ // 混合基数 NTT (N = 3 × M, M = 2^k) サポート // ================================================================ // 全素数で 3|(p-1) → ω₃ (原始 3 乗根) が存在 // Radix-3 DFT を 2-mul 最適化: // c₊ = (ω₃+ω₃²)·inv2, c₋ = (ω₃-ω₃²)·inv2 // y₀ = x₀ + s, y₁ = x₀ + c₊·s + c₋·d, y₂ = x₀ + c₊·s - c₋·d // where s = x₁+x₂, d = x₁-x₂ struct MixedRadixConstants { uint64_t c_plus; // (ω₃+ω₃²)·inv2 mod p (Montgomery) uint64_t c_minus; // (ω₃-ω₃²)·inv2 mod p (Montgomery) uint64_t inv3; // 3⁻¹ mod p (Montgomery) std::vector twiddle1; // ω_N^j (j=0..M-1, Montgomery) std::vector twiddle2; // ω_N^{2j} (Montgomery) std::vector inv_twiddle1; // ω_N^{-j} (Montgomery) std::vector inv_twiddle2; // ω_N^{-2j} (Montgomery) size_t N = 0; uint64_t p = 0; void build(const NttPrime& prime, size_t ntt_len) { N = ntt_len; p = prime.p; size_t M = N / 3; uint64_t pinv = prime.p_inv_neg; uint64_t r2 = prime.r2_mod_p; // ω₃ = g^((p-1)/3) — primitive cube root of unity uint64_t w3 = mod_pow(prime.g, (prime.p - 1) / 3, prime.p); uint64_t w3sq = mod_mul(w3, w3, prime.p); // c₊, c₋ の計算 (通常形式で計算 → Montgomery 変換) uint64_t sum_w = mod_add(w3, w3sq, prime.p); uint64_t diff_w = mod_sub(w3, w3sq, prime.p); uint64_t cp = mod_mul(sum_w, prime.inv_2, prime.p); uint64_t cm = mod_mul(diff_w, prime.inv_2, prime.p); c_plus = to_mont(cp, p, pinv, r2); c_minus = to_mont(cm, p, pinv, r2); inv3 = to_mont(mod_inv(3, p), p, pinv, r2); // Twiddle factors: ω_N^j (Montgomery) uint64_t wN = mod_pow(prime.g, (prime.p - 1) / N, prime.p); uint64_t wN_inv = mod_inv(wN, prime.p); uint64_t wN_mont = to_mont(wN, p, pinv, r2); uint64_t wN_inv_mont = to_mont(wN_inv, p, pinv, r2); uint64_t wN2_mont = mont_mul(wN_mont, wN_mont, p, pinv); uint64_t wN_inv2_mont = mont_mul(wN_inv_mont, wN_inv_mont, p, pinv); uint64_t one_mont = to_mont(1, p, pinv, r2); twiddle1.resize(M); twiddle2.resize(M); inv_twiddle1.resize(M); inv_twiddle2.resize(M); twiddle1[0] = twiddle2[0] = inv_twiddle1[0] = inv_twiddle2[0] = one_mont; for (size_t j = 1; j < M; ++j) { twiddle1[j] = mont_mul(twiddle1[j-1], wN_mont, p, pinv); twiddle2[j] = mont_mul(twiddle2[j-1], wN2_mont, p, pinv); inv_twiddle1[j] = mont_mul(inv_twiddle1[j-1], wN_inv_mont, p, pinv); inv_twiddle2[j] = mont_mul(inv_twiddle2[j-1], wN_inv2_mont, p, pinv); } } }; // --- Radix-3 forward DFT (M 列に対して 3 点 DFT, 2 mont_mul/列) --- inline void radix3_dft_forward(uint64_t* data, size_t M, uint64_t c_plus, uint64_t c_minus, uint64_t p, uint64_t pinv) { for (size_t j = 0; j < M; ++j) { uint64_t x0 = data[j], x1 = data[j + M], x2 = data[j + 2*M]; uint64_t s = mod_add(x1, x2, p); uint64_t d = mod_sub(x1, x2, p); uint64_t ps = mont_mul(s, c_plus, p, pinv); uint64_t pd = mont_mul(d, c_minus, p, pinv); data[j] = mod_add(x0, s, p); data[j + M] = mod_add(x0, mod_add(ps, pd, p), p); data[j + 2*M] = mod_add(x0, mod_sub(ps, pd, p), p); } } // --- Radix-3 inverse DFT with 1/3 scaling --- // 逆 DFT: c₋ を符号反転 (ω₃ と ω₃² の入替) inline void radix3_dft_inverse_scaled(uint64_t* data, size_t M, uint64_t c_plus, uint64_t c_minus, uint64_t inv3, uint64_t p, uint64_t pinv) { uint64_t c_minus_neg = (c_minus == 0) ? 0 : (p - c_minus); // -c₋ mod p for (size_t j = 0; j < M; ++j) { uint64_t y0 = data[j], y1 = data[j + M], y2 = data[j + 2*M]; uint64_t s = mod_add(y1, y2, p); uint64_t d = mod_sub(y1, y2, p); uint64_t ps = mont_mul(s, c_plus, p, pinv); uint64_t pd = mont_mul(d, c_minus_neg, p, pinv); data[j] = mont_mul(mod_add(y0, s, p), inv3, p, pinv); data[j + M] = mont_mul(mod_add(y0, mod_add(ps, pd, p), p), inv3, p, pinv); data[j + 2*M] = mont_mul(mod_add(y0, mod_sub(ps, pd, p), p), inv3, p, pinv); } } // --- Twiddle factor 適用 --- inline void apply_twiddles(uint64_t* data, size_t M, const uint64_t* tw1, const uint64_t* tw2, uint64_t p, uint64_t pinv) { // Row 0: no twiddle. Row 1: ×ω_N^j. Row 2: ×ω_N^{2j}. for (size_t j = 0; j < M; ++j) { data[M + j] = mont_mul(data[M + j], tw1[j], p, pinv); data[2*M + j] = mont_mul(data[2*M + j], tw2[j], p, pinv); } } #ifdef __AVX2__ // --- AVX2 Radix-3 forward DFT --- inline void radix3_dft_forward_avx2(uint64_t* data, size_t M, uint64_t c_plus_val, uint64_t c_minus_val, uint64_t p, uint64_t pinv) { const __m256i pv = _mm256_set1_epi64x(p); const __m256i pnv = _mm256_set1_epi64x(pinv); const __m256i cpv = _mm256_set1_epi64x(c_plus_val); const __m256i cmv = _mm256_set1_epi64x(c_minus_val); const __m256i zero = _mm256_setzero_si256(); size_t j = 0; for (; j + 4 <= M; j += 4) { __m256i x0 = _mm256_loadu_si256((__m256i*)(data + j)); __m256i x1 = _mm256_loadu_si256((__m256i*)(data + j + M)); __m256i x2 = _mm256_loadu_si256((__m256i*)(data + j + 2*M)); // s = x1 + x2 __m256i s = _mm256_add_epi64(x1, x2); __m256i tmp = _mm256_sub_epi64(s, pv); s = _mm256_blendv_epi8(tmp, s, _mm256_cmpgt_epi64(zero, tmp)); // d = x1 - x2 __m256i d = _mm256_sub_epi64(x1, x2); d = _mm256_blendv_epi8(d, _mm256_add_epi64(d, pv), _mm256_cmpgt_epi64(zero, d)); // ps, pd __m256i ps = avx2_mont_mul(s, cpv, pv, pnv); __m256i pd = avx2_mont_mul(d, cmv, pv, pnv); // y0 = x0 + s __m256i y0 = _mm256_add_epi64(x0, s); tmp = _mm256_sub_epi64(y0, pv); y0 = _mm256_blendv_epi8(tmp, y0, _mm256_cmpgt_epi64(zero, tmp)); // t1 = ps + pd, t2 = ps - pd __m256i t1 = _mm256_add_epi64(ps, pd); tmp = _mm256_sub_epi64(t1, pv); t1 = _mm256_blendv_epi8(tmp, t1, _mm256_cmpgt_epi64(zero, tmp)); __m256i t2 = _mm256_sub_epi64(ps, pd); t2 = _mm256_blendv_epi8(t2, _mm256_add_epi64(t2, pv), _mm256_cmpgt_epi64(zero, t2)); // y1 = x0 + t1, y2 = x0 + t2 __m256i y1 = _mm256_add_epi64(x0, t1); tmp = _mm256_sub_epi64(y1, pv); y1 = _mm256_blendv_epi8(tmp, y1, _mm256_cmpgt_epi64(zero, tmp)); __m256i y2 = _mm256_add_epi64(x0, t2); tmp = _mm256_sub_epi64(y2, pv); y2 = _mm256_blendv_epi8(tmp, y2, _mm256_cmpgt_epi64(zero, tmp)); _mm256_storeu_si256((__m256i*)(data + j), y0); _mm256_storeu_si256((__m256i*)(data + j + M), y1); _mm256_storeu_si256((__m256i*)(data + j + 2*M), y2); } for (; j < M; ++j) { uint64_t x0 = data[j], x1 = data[j+M], x2 = data[j+2*M]; uint64_t sv = mod_add(x1, x2, p), dv = mod_sub(x1, x2, p); uint64_t psv = mont_mul(sv, c_plus_val, p, pinv); uint64_t pdv = mont_mul(dv, c_minus_val, p, pinv); data[j] = mod_add(x0, sv, p); data[j+M] = mod_add(x0, mod_add(psv, pdv, p), p); data[j+2*M] = mod_add(x0, mod_sub(psv, pdv, p), p); } } // --- AVX2 Radix-3 inverse DFT with 1/3 scaling --- inline void radix3_dft_inverse_scaled_avx2(uint64_t* data, size_t M, uint64_t c_plus_val, uint64_t c_minus_val, uint64_t inv3_val, uint64_t p, uint64_t pinv) { uint64_t cm_neg = (c_minus_val == 0) ? 0 : (p - c_minus_val); const __m256i pv = _mm256_set1_epi64x(p); const __m256i pnv = _mm256_set1_epi64x(pinv); const __m256i cpv = _mm256_set1_epi64x(c_plus_val); const __m256i cmv = _mm256_set1_epi64x(cm_neg); const __m256i i3v = _mm256_set1_epi64x(inv3_val); const __m256i zero = _mm256_setzero_si256(); size_t j = 0; for (; j + 4 <= M; j += 4) { __m256i y0 = _mm256_loadu_si256((__m256i*)(data + j)); __m256i y1 = _mm256_loadu_si256((__m256i*)(data + j + M)); __m256i y2 = _mm256_loadu_si256((__m256i*)(data + j + 2*M)); __m256i s = _mm256_add_epi64(y1, y2); __m256i tmp = _mm256_sub_epi64(s, pv); s = _mm256_blendv_epi8(tmp, s, _mm256_cmpgt_epi64(zero, tmp)); __m256i d = _mm256_sub_epi64(y1, y2); d = _mm256_blendv_epi8(d, _mm256_add_epi64(d, pv), _mm256_cmpgt_epi64(zero, d)); __m256i ps = avx2_mont_mul(s, cpv, pv, pnv); __m256i pd = avx2_mont_mul(d, cmv, pv, pnv); // z0 = y0+s, z1 = y0+ps+pd, z2 = y0+ps-pd __m256i z0 = _mm256_add_epi64(y0, s); tmp = _mm256_sub_epi64(z0, pv); z0 = _mm256_blendv_epi8(tmp, z0, _mm256_cmpgt_epi64(zero, tmp)); __m256i t1 = _mm256_add_epi64(ps, pd); tmp = _mm256_sub_epi64(t1, pv); t1 = _mm256_blendv_epi8(tmp, t1, _mm256_cmpgt_epi64(zero, tmp)); __m256i t2 = _mm256_sub_epi64(ps, pd); t2 = _mm256_blendv_epi8(t2, _mm256_add_epi64(t2, pv), _mm256_cmpgt_epi64(zero, t2)); __m256i z1 = _mm256_add_epi64(y0, t1); tmp = _mm256_sub_epi64(z1, pv); z1 = _mm256_blendv_epi8(tmp, z1, _mm256_cmpgt_epi64(zero, tmp)); __m256i z2 = _mm256_add_epi64(y0, t2); tmp = _mm256_sub_epi64(z2, pv); z2 = _mm256_blendv_epi8(tmp, z2, _mm256_cmpgt_epi64(zero, tmp)); // scale by inv3 _mm256_storeu_si256((__m256i*)(data + j), avx2_mont_mul(z0, i3v, pv, pnv)); _mm256_storeu_si256((__m256i*)(data + j + M), avx2_mont_mul(z1, i3v, pv, pnv)); _mm256_storeu_si256((__m256i*)(data + j + 2*M), avx2_mont_mul(z2, i3v, pv, pnv)); } uint64_t cm_neg_s = cm_neg; for (; j < M; ++j) { uint64_t y0v = data[j], y1v = data[j+M], y2v = data[j+2*M]; uint64_t sv = mod_add(y1v, y2v, p), dv = mod_sub(y1v, y2v, p); uint64_t psv = mont_mul(sv, c_plus_val, p, pinv); uint64_t pdv = mont_mul(dv, cm_neg_s, p, pinv); data[j] = mont_mul(mod_add(y0v, sv, p), inv3_val, p, pinv); data[j+M] = mont_mul(mod_add(y0v, mod_add(psv, pdv, p), p), inv3_val, p, pinv); data[j+2*M] = mont_mul(mod_add(y0v, mod_sub(psv, pdv, p), p), inv3_val, p, pinv); } } // --- AVX2 Twiddle 適用 --- inline void apply_twiddles_avx2(uint64_t* data, size_t M, const uint64_t* tw1, const uint64_t* tw2, uint64_t p, uint64_t pinv) { const __m256i pv = _mm256_set1_epi64x(p); const __m256i pnv = _mm256_set1_epi64x(pinv); size_t j = 0; for (; j + 4 <= M; j += 4) { __m256i d1 = _mm256_loadu_si256((__m256i*)(data + M + j)); __m256i w1 = _mm256_loadu_si256((__m256i*)(tw1 + j)); _mm256_storeu_si256((__m256i*)(data + M + j), avx2_mont_mul(d1, w1, pv, pnv)); __m256i d2 = _mm256_loadu_si256((__m256i*)(data + 2*M + j)); __m256i w2 = _mm256_loadu_si256((__m256i*)(tw2 + j)); _mm256_storeu_si256((__m256i*)(data + 2*M + j), avx2_mont_mul(d2, w2, pv, pnv)); } for (; j < M; ++j) { data[M + j] = mont_mul(data[M + j], tw1[j], p, pinv); data[2*M + j] = mont_mul(data[2*M + j], tw2[j], p, pinv); } } // ================================================================ // Radix-5 混合基数 NTT (N = 5M, M = 2^k) // ================================================================ // 5 点 DFT を 8 mont_mul で計算: // a' = (w5+w5⁴)/2, b' = (w5²+w5³)/2 // c' = (w5-w5⁴)/2, e' = (w5²-w5³)/2 struct MixedRadix5Constants { uint64_t a_prime, b_prime, c_prime, e_prime; // Montgomery uint64_t a_plus_b; // (a'+b') mod p — Karatsuba 用 uint64_t c_plus_e; // (c'+e') mod p — Karatsuba 用 uint64_t inv5; // 5⁻¹ mod p (Montgomery) std::vector twiddle[4]; // ω_N^{rj} r=1..4 (Montgomery) std::vector inv_twiddle[4]; // ω_N^{-rj} size_t N = 0; uint64_t p = 0; void build(const NttPrime& prime, size_t ntt_len) { N = ntt_len; p = prime.p; size_t M = N / 5; uint64_t pinv = prime.p_inv_neg; uint64_t r2 = prime.r2_mod_p; uint64_t w5 = mod_pow(prime.g, (prime.p - 1) / 5, prime.p); uint64_t w5_2 = mod_mul(w5, w5, prime.p); uint64_t w5_3 = mod_mul(w5_2, w5, prime.p); uint64_t w5_4 = mod_mul(w5_3, w5, prime.p); uint64_t inv2 = prime.inv_2; a_prime = to_mont(mod_mul(mod_add(w5, w5_4, prime.p), inv2, prime.p), p, pinv, r2); b_prime = to_mont(mod_mul(mod_add(w5_2, w5_3, prime.p), inv2, prime.p), p, pinv, r2); c_prime = to_mont(mod_mul(mod_sub(w5, w5_4, prime.p), inv2, prime.p), p, pinv, r2); e_prime = to_mont(mod_mul(mod_sub(w5_2, w5_3, prime.p), inv2, prime.p), p, pinv, r2); // Karatsuba 定数 (Montgomery 形式の mod_add は正しい — 線形) a_plus_b = mod_add(a_prime, b_prime, p); c_plus_e = mod_add(c_prime, e_prime, p); inv5 = to_mont(mod_inv(5, p), p, pinv, r2); uint64_t wN = mod_pow(prime.g, (prime.p - 1) / N, prime.p); uint64_t wN_inv = mod_inv(wN, prime.p); uint64_t one_mont = to_mont(1, p, pinv, r2); // twiddle[r-1][j] = ω_N^{rj} for r=1..4, j=0..M-1 uint64_t wN_r_mont[4], wN_inv_r_mont[4]; uint64_t wr = wN, wri = wN_inv; for (int r = 0; r < 4; ++r) { wN_r_mont[r] = to_mont(wr, p, pinv, r2); wN_inv_r_mont[r] = to_mont(wri, p, pinv, r2); wr = mod_mul(wr, wN, prime.p); wri = mod_mul(wri, wN_inv, prime.p); twiddle[r].resize(M); inv_twiddle[r].resize(M); twiddle[r][0] = inv_twiddle[r][0] = one_mont; } for (size_t j = 1; j < M; ++j) { for (int r = 0; r < 4; ++r) { twiddle[r][j] = mont_mul(twiddle[r][j-1], wN_r_mont[r], p, pinv); inv_twiddle[r][j] = mont_mul(inv_twiddle[r][j-1], wN_inv_r_mont[r], p, pinv); } } } }; // --- Radix-5 forward DFT (M 列, 6 mont_mul/列 — Karatsuba) --- inline void radix5_dft_forward(uint64_t* data, size_t M, uint64_t ap, uint64_t bp, uint64_t cp, uint64_t ep, uint64_t a_plus_b, uint64_t c_plus_e, uint64_t p, uint64_t pinv) { for (size_t j = 0; j < M; ++j) { uint64_t x0=data[j], x1=data[j+M], x2=data[j+2*M], x3=data[j+3*M], x4=data[j+4*M]; uint64_t s1 = mod_add(x1, x4, p), d1 = mod_sub(x1, x4, p); uint64_t s2 = mod_add(x2, x3, p), d2 = mod_sub(x2, x3, p); // Karatsuba: 6 mul instead of 8 uint64_t P1 = mont_mul(s1, ap, p, pinv); uint64_t P2 = mont_mul(s2, bp, p, pinv); uint64_t P3 = mont_mul(mod_add(s1, s2, p), a_plus_b, p, pinv); uint64_t t14 = mod_add(P1, P2, p); uint64_t t23 = mod_sub(P3, t14, p); // (a+b)(s1+s2) - a·s1 - b·s2 = b·s1 + a·s2 uint64_t Q1 = mont_mul(d1, cp, p, pinv); uint64_t Q2 = mont_mul(d2, ep, p, pinv); uint64_t Q3 = mont_mul(mod_sub(d1, d2, p), c_plus_e, p, pinv); uint64_t u14 = mod_add(Q1, Q2, p); uint64_t u23 = mod_add(mod_sub(Q3, Q1, p), Q2, p); // (c+e)(d1-d2) - c·d1 + e·d2 = e·d1 - c·d2 data[j] = mod_add(x0, mod_add(s1, s2, p), p); data[j+M] = mod_add(x0, mod_add(t14, u14, p), p); data[j+4*M] = mod_add(x0, mod_sub(t14, u14, p), p); data[j+2*M] = mod_add(x0, mod_add(t23, u23, p), p); data[j+3*M] = mod_add(x0, mod_sub(t23, u23, p), p); } } // --- Radix-5 inverse DFT with 1/5 scaling (6 mul + 5 inv5 scaling) --- inline void radix5_dft_inverse_scaled(uint64_t* data, size_t M, uint64_t ap, uint64_t bp, uint64_t cp, uint64_t ep, uint64_t a_plus_b, uint64_t c_plus_e, uint64_t inv5, uint64_t p, uint64_t pinv) { uint64_t cp_neg = cp ? (p - cp) : 0; uint64_t ep_neg = ep ? (p - ep) : 0; uint64_t ce_neg = c_plus_e ? (p - c_plus_e) : 0; // -(c+e) mod p for (size_t j = 0; j < M; ++j) { uint64_t y0=data[j], y1=data[j+M], y2=data[j+2*M], y3=data[j+3*M], y4=data[j+4*M]; uint64_t s1 = mod_add(y1, y4, p), d1 = mod_sub(y1, y4, p); uint64_t s2 = mod_add(y2, y3, p), d2 = mod_sub(y2, y3, p); // t terms: same Karatsuba (a,b unchanged in inverse) uint64_t P1 = mont_mul(s1, ap, p, pinv); uint64_t P2 = mont_mul(s2, bp, p, pinv); uint64_t P3 = mont_mul(mod_add(s1, s2, p), a_plus_b, p, pinv); uint64_t t14 = mod_add(P1, P2, p); uint64_t t23 = mod_sub(P3, t14, p); // u terms: c→-c, e→-e, so (c+e)→-(c+e) uint64_t Q1 = mont_mul(d1, cp_neg, p, pinv); uint64_t Q2 = mont_mul(d2, ep_neg, p, pinv); uint64_t Q3 = mont_mul(mod_sub(d1, d2, p), ce_neg, p, pinv); uint64_t u14 = mod_add(Q1, Q2, p); uint64_t u23 = mod_add(mod_sub(Q3, Q1, p), Q2, p); data[j] = mont_mul(mod_add(y0, mod_add(s1, s2, p), p), inv5, p, pinv); data[j+M] = mont_mul(mod_add(y0, mod_add(t14, u14, p), p), inv5, p, pinv); data[j+4*M] = mont_mul(mod_add(y0, mod_sub(t14, u14, p), p), inv5, p, pinv); data[j+2*M] = mont_mul(mod_add(y0, mod_add(t23, u23, p), p), inv5, p, pinv); data[j+3*M] = mont_mul(mod_add(y0, mod_sub(t23, u23, p), p), inv5, p, pinv); } } // --- Radix-5 twiddle (rows 1-4) --- inline void apply_twiddles5(uint64_t* data, size_t M, const uint64_t* tw[4], uint64_t p, uint64_t pinv) { for (size_t j = 0; j < M; ++j) { for (int r = 0; r < 4; ++r) data[(r+1)*M + j] = mont_mul(data[(r+1)*M + j], tw[r][j], p, pinv); } } // --- AVX2 Radix-5 forward DFT (6 mul/列 — Karatsuba) --- inline void radix5_dft_forward_avx2(uint64_t* data, size_t M, uint64_t ap, uint64_t bp, uint64_t cp, uint64_t ep, uint64_t a_plus_b, uint64_t c_plus_e, uint64_t p, uint64_t pinv) { const __m256i pv = _mm256_set1_epi64x(p), pnv = _mm256_set1_epi64x(pinv); const __m256i av = _mm256_set1_epi64x(ap), bv = _mm256_set1_epi64x(bp); const __m256i cv = _mm256_set1_epi64x(cp), ev = _mm256_set1_epi64x(ep); const __m256i abv = _mm256_set1_epi64x(a_plus_b); const __m256i cev = _mm256_set1_epi64x(c_plus_e); const __m256i zero = _mm256_setzero_si256(); auto madd = [&](__m256i a, __m256i b) -> __m256i { __m256i s = _mm256_add_epi64(a, b); __m256i t = _mm256_sub_epi64(s, pv); return _mm256_blendv_epi8(t, s, _mm256_cmpgt_epi64(zero, t)); }; auto msub = [&](__m256i a, __m256i b) -> __m256i { __m256i d = _mm256_sub_epi64(a, b); return _mm256_blendv_epi8(d, _mm256_add_epi64(d, pv), _mm256_cmpgt_epi64(zero, d)); }; auto mmul = [&](__m256i a, __m256i b) -> __m256i { return avx2_mont_mul(a, b, pv, pnv); }; size_t j = 0; for (; j + 4 <= M; j += 4) { __m256i x0=_mm256_loadu_si256((__m256i*)(data+j)); __m256i x1=_mm256_loadu_si256((__m256i*)(data+j+M)); __m256i x2=_mm256_loadu_si256((__m256i*)(data+j+2*M)); __m256i x3=_mm256_loadu_si256((__m256i*)(data+j+3*M)); __m256i x4=_mm256_loadu_si256((__m256i*)(data+j+4*M)); __m256i s1=madd(x1,x4), d1=msub(x1,x4), s2=madd(x2,x3), d2=msub(x2,x3); // Karatsuba: 6 mmul instead of 8 __m256i P1=mmul(s1,av), P2=mmul(s2,bv), P3=mmul(madd(s1,s2),abv); __m256i t14=madd(P1,P2), t23=msub(P3,t14); __m256i Q1=mmul(d1,cv), Q2=mmul(d2,ev), Q3=mmul(msub(d1,d2),cev); __m256i u14=madd(Q1,Q2), u23=madd(msub(Q3,Q1),Q2); _mm256_storeu_si256((__m256i*)(data+j), madd(x0, madd(s1,s2))); _mm256_storeu_si256((__m256i*)(data+j+M), madd(x0, madd(t14,u14))); _mm256_storeu_si256((__m256i*)(data+j+4*M), madd(x0, msub(t14,u14))); _mm256_storeu_si256((__m256i*)(data+j+2*M), madd(x0, madd(t23,u23))); _mm256_storeu_si256((__m256i*)(data+j+3*M), madd(x0, msub(t23,u23))); } for (; j < M; ++j) { uint64_t x0=data[j],x1=data[j+M],x2=data[j+2*M],x3=data[j+3*M],x4=data[j+4*M]; uint64_t s1=mod_add(x1,x4,p),d1=mod_sub(x1,x4,p),s2=mod_add(x2,x3,p),d2=mod_sub(x2,x3,p); uint64_t P1v=mont_mul(s1,ap,p,pinv), P2v=mont_mul(s2,bp,p,pinv); uint64_t P3v=mont_mul(mod_add(s1,s2,p),a_plus_b,p,pinv); uint64_t t14v=mod_add(P1v,P2v,p), t23v=mod_sub(P3v,t14v,p); uint64_t Q1v=mont_mul(d1,cp,p,pinv), Q2v=mont_mul(d2,ep,p,pinv); uint64_t Q3v=mont_mul(mod_sub(d1,d2,p),c_plus_e,p,pinv); uint64_t u14v=mod_add(Q1v,Q2v,p), u23v=mod_add(mod_sub(Q3v,Q1v,p),Q2v,p); data[j]=mod_add(x0,mod_add(s1,s2,p),p); data[j+M]=mod_add(x0,mod_add(t14v,u14v,p),p); data[j+4*M]=mod_add(x0,mod_sub(t14v,u14v,p),p); data[j+2*M]=mod_add(x0,mod_add(t23v,u23v,p),p); data[j+3*M]=mod_add(x0,mod_sub(t23v,u23v,p),p); } } // --- AVX2 Radix-5 inverse DFT with 1/5 scaling (6 mul + 5 inv5 — Karatsuba) --- inline void radix5_dft_inverse_scaled_avx2(uint64_t* data, size_t M, uint64_t ap, uint64_t bp, uint64_t cp, uint64_t ep, uint64_t a_plus_b, uint64_t c_plus_e, uint64_t inv5v, uint64_t p, uint64_t pinv) { uint64_t cpn = cp ? (p-cp) : 0, epn = ep ? (p-ep) : 0; uint64_t cen = c_plus_e ? (p-c_plus_e) : 0; const __m256i pv=_mm256_set1_epi64x(p), pnv=_mm256_set1_epi64x(pinv); const __m256i av=_mm256_set1_epi64x(ap),bvv=_mm256_set1_epi64x(bp); const __m256i cvv=_mm256_set1_epi64x(cpn),evv=_mm256_set1_epi64x(epn); const __m256i abv=_mm256_set1_epi64x(a_plus_b); const __m256i cenv=_mm256_set1_epi64x(cen); const __m256i i5=_mm256_set1_epi64x(inv5v); const __m256i zero=_mm256_setzero_si256(); auto madd=[&](__m256i a,__m256i b){__m256i s=_mm256_add_epi64(a,b);__m256i t=_mm256_sub_epi64(s,pv);return _mm256_blendv_epi8(t,s,_mm256_cmpgt_epi64(zero,t));}; auto msub=[&](__m256i a,__m256i b){__m256i d=_mm256_sub_epi64(a,b);return _mm256_blendv_epi8(d,_mm256_add_epi64(d,pv),_mm256_cmpgt_epi64(zero,d));}; auto mmul=[&](__m256i a,__m256i b){return avx2_mont_mul(a,b,pv,pnv);}; size_t j=0; for(;j+4<=M;j+=4){ __m256i y0=_mm256_loadu_si256((__m256i*)(data+j)); __m256i y1=_mm256_loadu_si256((__m256i*)(data+j+M)); __m256i y2=_mm256_loadu_si256((__m256i*)(data+j+2*M)); __m256i y3=_mm256_loadu_si256((__m256i*)(data+j+3*M)); __m256i y4=_mm256_loadu_si256((__m256i*)(data+j+4*M)); __m256i s1=madd(y1,y4),d1=msub(y1,y4),s2=madd(y2,y3),d2=msub(y2,y3); // Karatsuba: t terms (a,b unchanged) __m256i P1=mmul(s1,av), P2=mmul(s2,bvv), P3=mmul(madd(s1,s2),abv); __m256i t14=madd(P1,P2), t23=msub(P3,t14); // Karatsuba: u terms (c_neg, e_neg, ce_neg) __m256i Q1=mmul(d1,cvv), Q2=mmul(d2,evv), Q3=mmul(msub(d1,d2),cenv); __m256i u14=madd(Q1,Q2), u23=madd(msub(Q3,Q1),Q2); _mm256_storeu_si256((__m256i*)(data+j),mmul(madd(y0,madd(s1,s2)),i5)); _mm256_storeu_si256((__m256i*)(data+j+M), mmul(madd(y0,madd(t14,u14)),i5)); _mm256_storeu_si256((__m256i*)(data+j+4*M),mmul(madd(y0,msub(t14,u14)),i5)); _mm256_storeu_si256((__m256i*)(data+j+2*M),mmul(madd(y0,madd(t23,u23)),i5)); _mm256_storeu_si256((__m256i*)(data+j+3*M),mmul(madd(y0,msub(t23,u23)),i5)); } for(;j block) — 標準 layer-by-layer int layer = 0; for (size_t len = N; len > block; len >>= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[layer]; for (size_t start = 0; start < N; start += len) dif_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++layer; } // Phase 2: Bottom layers (len <= block) — ブロック単位で L1 内処理 int bottom_start = layer; for (size_t blk = 0; blk < N; blk += block) { int cur_layer = bottom_start; for (size_t len = block; len >= 2; len >>= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[cur_layer]; for (size_t start = blk; start < blk + block; start += len) dif_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++cur_layer; } } } // --- AVX2 DIT バタフライ (1 グループ分) --- SANGI_FORCEINLINE void dit_butterfly_avx2(uint64_t* data, size_t start, size_t half, const uint64_t* roots, __m256i p_vec, __m256i pin_vec, uint64_t p, uint64_t p_inv_neg) { uint64_t* d0 = data + start; uint64_t* d1 = data + start + half; size_t j = 0; for (; j + 4 <= half; j += 4) { __m256i u = _mm256_loadu_si256((__m256i*)(d0 + j)); __m256i v_raw = _mm256_loadu_si256((__m256i*)(d1 + j)); __m256i w = _mm256_loadu_si256((__m256i*)(roots + j)); // DIT: v = mont_mul(v_raw, ω⁻¹), then (u+v, u-v) __m256i v = avx2_mont_mul(v_raw, w, p_vec, pin_vec); __m256i sum = _mm256_add_epi64(u, v); __m256i sum_sub = _mm256_sub_epi64(sum, p_vec); __m256i sm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), sum_sub); sum = _mm256_blendv_epi8(sum_sub, sum, sm); __m256i diff = _mm256_sub_epi64(u, v); __m256i diff_add = _mm256_add_epi64(diff, p_vec); __m256i dm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); diff = _mm256_blendv_epi8(diff, diff_add, dm); _mm256_storeu_si256((__m256i*)(d0 + j), sum); _mm256_storeu_si256((__m256i*)(d1 + j), diff); } for (; j < half; ++j) { uint64_t u = d0[j]; uint64_t v = mont_mul(d1[j], roots[j], p, p_inv_neg); d0[j] = mod_add(u, v, p); d1[j] = mod_sub(u, v, p); } } // --- AVX2 NTT inverse (DIT, Montgomery, cache-blocked) --- inline void inverse_ntt_mont_avx2(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* layer_roots, const size_t* layer_offsets, uint64_t inv_n_normal) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t block = (N <= NTT_CACHE_BLOCK) ? N : NTT_CACHE_BLOCK; // Phase 1: Bottom layers (len <= block) — ブロック単位で L1 内処理 int num_bottom_layers = 0; for (size_t l = 2; l <= block; l <<= 1) ++num_bottom_layers; for (size_t blk = 0; blk < N; blk += block) { int cur_layer = 0; for (size_t len = 2; len <= block; len <<= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[cur_layer]; for (size_t start = blk; start < blk + block; start += len) dit_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++cur_layer; } } // Phase 2: Top layers (len > block) — 標準 layer-by-layer int layer = num_bottom_layers; for (size_t len = block * 2; len <= N; len <<= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[layer]; for (size_t start = 0; start < N; start += len) dit_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++layer; } // N^{-1} スケーリング + from_mont (AVX2) __m256i inv_n_vec = _mm256_set1_epi64x(inv_n_normal); size_t i = 0; for (; i + 4 <= N; i += 4) { __m256i d = _mm256_loadu_si256((__m256i*)(data + i)); d = avx2_mont_mul(d, inv_n_vec, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(data + i), d); } for (; i < N; ++i) { data[i] = mont_mul(data[i], inv_n_normal, p, p_inv_neg); } } // --- AVX2 pointwise 乗算 --- inline void pointwise_mul_mont_avx2(uint64_t* a, const uint64_t* b, size_t N, uint64_t p, uint64_t p_inv_neg) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t i = 0; for (; i + 4 <= N; i += 4) { __m256i va = _mm256_loadu_si256((__m256i*)(a + i)); __m256i vb = _mm256_loadu_si256((__m256i*)(b + i)); va = avx2_mont_mul(va, vb, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(a + i), va); } for (; i < N; ++i) { a[i] = mont_mul(a[i], b[i], p, p_inv_neg); } } inline void pointwise_sqr_mont_avx2(uint64_t* a, size_t N, uint64_t p, uint64_t p_inv_neg) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t i = 0; for (; i + 4 <= N; i += 4) { __m256i va = _mm256_loadu_si256((__m256i*)(a + i)); va = avx2_mont_mul(va, va, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(a + i), va); } for (; i < N; ++i) { a[i] = mont_mul(a[i], a[i], p, p_inv_neg); } } // YC-2: pointwise multiply-accumulate: r[i] += a[i] * b[i] mod p (AVX2) inline void pointwise_mac_mont_avx2(uint64_t* r, const uint64_t* a, const uint64_t* b, size_t N, uint64_t p, uint64_t p_inv_neg) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t i = 0; for (; i + 4 <= N; i += 4) { __m256i va = _mm256_loadu_si256((__m256i*)(a + i)); __m256i vb = _mm256_loadu_si256((__m256i*)(b + i)); __m256i vr = _mm256_loadu_si256((__m256i*)(r + i)); __m256i prod = avx2_mont_mul(va, vb, p_vec, pin_vec); // mod_add: sum = vr + prod, if sum >= p then sum -= p __m256i sum = _mm256_add_epi64(vr, prod); __m256i sub = _mm256_sub_epi64(sum, p_vec); __m256i mask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), sub); sum = _mm256_blendv_epi8(sub, sum, mask); _mm256_storeu_si256((__m256i*)(r + i), sum); } for (; i < N; ++i) { r[i] = mod_add(r[i], mont_mul(a[i], b[i], p, p_inv_neg), p); } } #endif // __AVX2__ // ================================================================ // CRT 復元 (Garner のアルゴリズム) // ================================================================ // 3 素数の NTT 結果から真の係数を復元し、 // limb 列 (uint64_t) に変換する。 // r1[i] mod p1, r2[i] mod p2, r3[i] mod p3 → result limbs // // Garner: // v1 = r1[i] // v2 = (r2[i] - v1) * inv_p1_mod_p2 mod p2 // v3 = (r3[i] - v1 - v2*p1) * inv_p1p2_mod_p3 mod p3 // c[i] = v1 + v2*p1 + v3*p1*p2 (最大 ~189 bit) struct CrtConstants { uint64_t p1, p2, p3; uint64_t inv_p1_mod_p2; // p1^(-1) mod p2 uint64_t inv_p1p2_mod_p3; // (p1*p2)^(-1) mod p3 uint64_t p1_mod_p3; // p1 mod p3 void init(uint64_t _p1, uint64_t _p2, uint64_t _p3) { p1 = _p1; p2 = _p2; p3 = _p3; inv_p1_mod_p2 = mod_inv(p1 % p2, p2); // p1*p2 mod p3: (p1 mod p3) * (p2 mod p3) mod p3 p1_mod_p3 = p1 % p3; uint64_t p1p2_mod_p3 = mod_mul(p1_mod_p3, p2 % p3, p3); inv_p1p2_mod_p3 = mod_inv(p1p2_mod_p3, p3); } }; // CRT で 1 係数を復元し、3-word (192bit) 値を返す // result[0] = low 64bit, result[1] = mid 64bit, result[2] = high 64bit inline void crt_single(uint64_t result[3], uint64_t r1, uint64_t r2, uint64_t r3, const CrtConstants& crt) { // v1 = r1 uint64_t v1 = r1; // v2 = (r2 - v1) * inv_p1_mod_p2 mod p2 uint64_t t = mod_sub(r2, v1 % crt.p2, crt.p2); uint64_t v2 = mod_mul(t, crt.inv_p1_mod_p2, crt.p2); // v3 = (r3 - v1 - v2*p1_mod_p3) * inv_p1p2_mod_p3 mod p3 uint64_t v2p1_mod_p3 = mod_mul(v2, crt.p1_mod_p3, crt.p3); uint64_t s = mod_sub(r3, v1 % crt.p3, crt.p3); s = mod_sub(s, v2p1_mod_p3, crt.p3); uint64_t v3 = mod_mul(s, crt.inv_p1p2_mod_p3, crt.p3); // c = v1 + v2*p1 + v3*p1*p2 // v2*p1: 最大 63bit × 63bit = 126bit → 2 words #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi1; uint64_t lo1 = _umul128(v2, crt.p1, &hi1); #elif defined(__GNUC__) || defined(__clang__) __uint128_t prod1 = static_cast<__uint128_t>(v2) * crt.p1; uint64_t lo1 = static_cast(prod1); uint64_t hi1 = static_cast(prod1 >> 64); #endif // v1 + v2*p1 (128bit) uint64_t c0 = v1 + lo1; uint64_t c1 = hi1 + (c0 < v1 ? 1 : 0); // v3*p1*p2: v3 (63bit) × p1 (63bit) = 126bit → × p2 = 189bit // ただし v3 < p3 < 2^63, p1 < 2^63, p2 < 2^63 // v3*p1: 126bit → 2 words #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi2; uint64_t lo2 = _umul128(v3, crt.p1, &hi2); // lo2:hi2 = v3*p1 (126bit) // × p2: (lo2:hi2) × p2 = 189bit → 3 words uint64_t hi3; uint64_t lo3 = _umul128(lo2, crt.p2, &hi3); uint64_t hi4; uint64_t lo4 = _umul128(hi2, crt.p2, &hi4); #elif defined(__GNUC__) || defined(__clang__) __uint128_t prod2 = static_cast<__uint128_t>(v3) * crt.p1; uint64_t lo2 = static_cast(prod2); uint64_t hi2 = static_cast(prod2 >> 64); __uint128_t prod3 = static_cast<__uint128_t>(lo2) * crt.p2; uint64_t lo3 = static_cast(prod3); uint64_t hi3 = static_cast(prod3 >> 64); __uint128_t prod4 = static_cast<__uint128_t>(hi2) * crt.p2; uint64_t lo4 = static_cast(prod4); uint64_t hi4 = static_cast(prod4 >> 64); #endif // v3*p1*p2 = lo3 + (hi3 + lo4) << 64 + (hi4) << 128 uint64_t mid = hi3 + lo4; uint64_t carry_mid = (mid < hi3) ? 1 : 0; uint64_t top = hi4 + carry_mid; // c += v3*p1*p2 c0 += lo3; uint64_t carry0 = (c0 < lo3) ? 1 : 0; c1 += mid + carry0; uint64_t carry1 = (c1 < mid || (carry0 && c1 == mid)) ? 1 : 0; uint64_t c2 = top + carry1; result[0] = c0; result[1] = c1; result[2] = c2; } // CRT 復元 + キャリー伝搬で最終結果を生成 inline void crt_recompose(uint64_t* rp, size_t result_limbs, const uint64_t* r1, const uint64_t* r2, const uint64_t* r3, size_t N, const CrtConstants& crt) { // 結果をゼロクリア std::memset(rp, 0, result_limbs * sizeof(uint64_t)); uint64_t carry_lo = 0, carry_hi = 0; for (size_t i = 0; i < N && i < result_limbs; ++i) { uint64_t coeff[3]; crt_single(coeff, r1[i], r2[i], r3[i], crt); // coeff[0:2] + carry を加算 uint64_t sum0 = coeff[0] + carry_lo; uint64_t c0 = (sum0 < coeff[0]) ? 1 : 0; uint64_t sum1 = coeff[1] + carry_hi + c0; uint64_t c1 = (sum1 < coeff[1] || (c0 && sum1 <= coeff[1])) ? 1 : 0; uint64_t sum2 = coeff[2] + c1; rp[i] = sum0; carry_lo = sum1; carry_hi = sum2; } // 残りのキャリーを書き出す if (carry_lo != 0 && N < result_limbs) { rp[N] = carry_lo; if (carry_hi != 0 && N + 1 < result_limbs) { rp[N + 1] = carry_hi; } } } // ================================================================ // グローバル NTT パラメータ (初期化は初回使用時) // ================================================================ struct PrimeNttContext { NttPrime primes[3]; CrtConstants crt; bool initialized = false; void init() { const uint64_t ps[] = { P1, P2, P3 }; const uint64_t gs[] = { G1, G2, G3 }; for (int k = 0; k < 3; ++k) { uint64_t p = ps[k]; uint64_t pin = compute_p_inv_neg(p); // R^2 mod p: R = 2^64, compute via mod_pow(2, 128, p) // 2^128 mod p = (2^64 mod p)^2 mod p uint64_t r_mod_p = mod_pow(2, 64, p); uint64_t r2 = mod_mul(r_mod_p, r_mod_p, p); primes[k] = { p, gs[k], NTT_MAX_S, mod_inv(2, p), pin, r2 }; } crt.init(P1, P2, P3); initialized = true; } }; inline PrimeNttContext& getPrimeNttContext() { static PrimeNttContext ctx; if (!ctx.initialized) { ctx.init(); } return ctx; } // ================================================================ // N 素数 CRT (Garner のアルゴリズム, NTT ドメイン BS 用) // ================================================================ struct CrtConstantsN { uint64_t p[BS_NUM_PRIMES]; // Garner 用逆元: inv[i][j] = p[j]^{-1} mod p[i] (j < i) uint64_t inv[BS_NUM_PRIMES][BS_NUM_PRIMES]; // prefix[i] = p[0]*p[1]*...*p[i-1] (多ワード, i ワード) // prefix[0] = {1}, prefix[1] = {p[0]}, ... uint64_t prefix[BS_NUM_PRIMES][BS_NUM_PRIMES]; int prefix_len[BS_NUM_PRIMES]; void init(const uint64_t ps[BS_NUM_PRIMES]) { for (int i = 0; i < BS_NUM_PRIMES; i++) p[i] = ps[i]; // 逆元テーブル for (int i = 0; i < BS_NUM_PRIMES; i++) for (int j = 0; j < i; j++) inv[i][j] = mod_inv(p[j] % p[i], p[i]); // 前置積 (多ワード) std::memset(prefix, 0, sizeof(prefix)); prefix[0][0] = 1; prefix_len[0] = 1; for (int i = 1; i < BS_NUM_PRIMES; i++) { uint64_t carry = 0; for (int w = 0; w < prefix_len[i - 1]; w++) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi; uint64_t lo = _umul128(prefix[i - 1][w], p[i - 1], &hi); #else __uint128_t prod = (__uint128_t)prefix[i - 1][w] * p[i - 1]; uint64_t lo = (uint64_t)prod; uint64_t hi = (uint64_t)(prod >> 64); #endif lo += carry; if (lo < carry) hi++; prefix[i][w] = lo; carry = hi; } if (carry) { prefix[i][prefix_len[i - 1]] = carry; prefix_len[i] = prefix_len[i - 1] + 1; } else { prefix_len[i] = prefix_len[i - 1]; } } } }; // N 素数 CRT: 1 係数を復元 → BS_NUM_PRIMES ワードの値を返す inline void crt_single_n(uint64_t result[BS_NUM_PRIMES], uint64_t r[BS_NUM_PRIMES], const CrtConstantsN& crt) { // Garner の混合基数表現 uint64_t v[BS_NUM_PRIMES]; v[0] = r[0]; for (int i = 1; i < BS_NUM_PRIMES; i++) { uint64_t temp = r[i]; for (int j = 0; j < i; j++) { temp = mod_sub(temp, v[j] % crt.p[i], crt.p[i]); temp = mod_mul(temp, crt.inv[i][j], crt.p[i]); } v[i] = temp; } // 復元: x = v[0] + v[1]*prefix[1] + v[2]*prefix[2] + ... std::memset(result, 0, BS_NUM_PRIMES * sizeof(uint64_t)); result[0] = v[0]; for (int i = 1; i < BS_NUM_PRIMES; i++) { // result += v[i] * prefix[i] uint64_t carry = 0; for (int w = 0; w < crt.prefix_len[i]; w++) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi; uint64_t lo = _umul128(v[i], crt.prefix[i][w], &hi); #else __uint128_t prod = (__uint128_t)v[i] * crt.prefix[i][w]; uint64_t lo = (uint64_t)prod; uint64_t hi = (uint64_t)(prod >> 64); #endif lo += carry; if (lo < carry) hi++; lo += result[w]; if (lo < result[w]) hi++; result[w] = lo; carry = hi; } for (int w = crt.prefix_len[i]; w < BS_NUM_PRIMES && carry; w++) { result[w] += carry; carry = (result[w] < carry) ? 1 : 0; } } } // N 素数 CRT 復元 + キャリー伝搬で最終結果を生成 inline void crt_recompose_n(uint64_t* rp, size_t result_limbs, uint64_t* const* residues, size_t N, const CrtConstantsN& crt) { std::memset(rp, 0, result_limbs * sizeof(uint64_t)); uint64_t carry[BS_NUM_PRIMES] = {}; for (size_t i = 0; i < N && i < result_limbs; ++i) { uint64_t res[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; k++) res[k] = residues[k][i]; uint64_t coeff[BS_NUM_PRIMES]; crt_single_n(coeff, res, crt); // coeff += carry uint64_t c = 0; for (int w = 0; w < BS_NUM_PRIMES; w++) { uint64_t sum = coeff[w] + carry[w]; uint64_t c1 = (sum < coeff[w]) ? 1 : 0; sum += c; c1 += (sum < c) ? 1 : 0; coeff[w] = sum; c = c1; } rp[i] = coeff[0]; for (int w = 0; w < BS_NUM_PRIMES - 1; w++) carry[w] = coeff[w + 1]; carry[BS_NUM_PRIMES - 1] = c; } // 残りキャリーを書き出す for (int w = 0; w < BS_NUM_PRIMES; w++) { size_t idx = N + w; if (idx < result_limbs && carry[w]) rp[idx] = carry[w]; } } // ================================================================ // NTT ドメイン BS 用コンテキスト (5 素数) // ================================================================ struct NttBsContext { NttPrime primes[BS_NUM_PRIMES]; CrtConstantsN crt; bool initialized = false; void init() { uint64_t ps[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = BS_PRIME_TABLE[k].p; int g = BS_PRIME_TABLE[k].g; uint64_t pin = compute_p_inv_neg(p); uint64_t r_mod_p = mod_pow(2, 64, p); uint64_t r2 = mod_mul(r_mod_p, r_mod_p, p); primes[k] = { p, static_cast(g), NTT_MAX_S, mod_inv(2, p), pin, r2 }; ps[k] = p; } crt.init(ps); initialized = true; } }; inline NttBsContext& getNttBsContext() { static NttBsContext ctx; if (!ctx.initialized) { ctx.init(); } return ctx; } // ================================================================ // NTT 長の決定 // ================================================================ inline size_t next_power_of_2(size_t n) { size_t p = 1; while (p < n) p <<= 1; return p; } inline bool is_power_of_2(size_t n) { return n > 0 && (n & (n - 1)) == 0; } // {2^a, 3×2^a, 5×2^a} の最小サイズを返す (5-smooth NTT 用) // 例: rn=2050 → 2560 (5×512), rn=1026 → 1280 (5×256) inline size_t next_smooth_size(size_t n) { if (n <= 1) return 1; size_t best = next_power_of_2(n); // 2^a size_t c3 = 3 * next_power_of_2((n + 2) / 3); // 3 × 2^a if (c3 >= n && c3 < best) best = c3; size_t c5 = 5 * next_power_of_2((n + 4) / 5); // 5 × 2^a if (c5 >= n && c5 < best) best = c5; return best; } // スレッドプール並列化の閾値 (NTT 長がこれ以上で 3 素数並列) static constexpr size_t PRIME_NTT_PARALLEL_THRESHOLD = 4096; // YC-10: 超並列 NTT の閾値 (NTT 長がこれ以上で intra-NTT 多スレッド化) // 128K limbs ≈ 2.5M 桁。これ以上で各素数の NTT を T スレッドに分割 static constexpr size_t MULTI_THREAD_NTT_THRESHOLD = 131072; // ================================================================ // 1 素数分の NTT パイプライン (Montgomery 版, スレッドから呼べる) // ================================================================ // 乗算: da[N], db[N] → 結果は da[N] に上書き (通常形式で返る) inline void ntt_mul_pipeline_mont(uint64_t* da, uint64_t* db, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* fwd_roots, const uint64_t* inv_roots, uint64_t inv_n_normal) { forward_ntt_mont(da, N, p, p_inv_neg, fwd_roots); forward_ntt_mont(db, N, p, p_inv_neg, fwd_roots); pointwise_mul_mont(da, db, N, p, p_inv_neg); inverse_ntt_mont(da, N, p, p_inv_neg, inv_roots, inv_n_normal); } // 自乗: da[N] → 結果は da[N] に上書き (通常形式で返る) inline void ntt_sqr_pipeline_mont(uint64_t* da, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* fwd_roots, const uint64_t* inv_roots, uint64_t inv_n_normal) { forward_ntt_mont(da, N, p, p_inv_neg, fwd_roots); pointwise_sqr_mont(da, N, p, p_inv_neg); inverse_ntt_mont(da, N, p, p_inv_neg, inv_roots, inv_n_normal); } #ifdef __AVX2__ // AVX2 版パイプライン (レイヤー根テーブル使用) inline void ntt_mul_pipeline_avx2(uint64_t* da, uint64_t* db, size_t N, uint64_t p, uint64_t p_inv_neg, const NttRootsLayered& roots) { forward_ntt_mont_avx2(da, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data()); forward_ntt_mont_avx2(db, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data()); pointwise_mul_mont_avx2(da, db, N, p, p_inv_neg); inverse_ntt_mont_avx2(da, N, p, p_inv_neg, roots.inv.data(), roots.inv_offset.data(), roots.inv_n); } inline void ntt_sqr_pipeline_avx2(uint64_t* da, size_t N, uint64_t p, uint64_t p_inv_neg, const NttRootsLayered& roots) { forward_ntt_mont_avx2(da, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data()); pointwise_sqr_mont_avx2(da, N, p, p_inv_neg); inverse_ntt_mont_avx2(da, N, p, p_inv_neg, roots.inv.data(), roots.inv_offset.data(), roots.inv_n); } // ================================================================ // 混合基数 NTT パイプライン (N = 3M, M = 2^k) // ================================================================ // Forward: radix3_dft → twiddle → 3×M-point NTT // Inverse: 3×M-point INTT → inv_twiddle → radix3_idft(×1/3) // M-point NTT は inv_n で 1/M スケーリング → 全体で 1/(3M) = 1/N inline void mr_ntt_mul_pipeline_avx2(uint64_t* da, uint64_t* db, size_t N, uint64_t p, uint64_t pinv, const NttRootsLayered& sub_roots, const MixedRadixConstants& mr) { size_t M = N / 3; const uint64_t* fwd = sub_roots.fwd.data(); const size_t* fwd_off = sub_roots.fwd_offset.data(); const uint64_t* inv = sub_roots.inv.data(); const size_t* inv_off = sub_roots.inv_offset.data(); uint64_t inv_n_M = sub_roots.inv_n; // Forward da: radix3 → twiddle → 3 sub-NTTs radix3_dft_forward_avx2(da, M, mr.c_plus, mr.c_minus, p, pinv); apply_twiddles_avx2(da, M, mr.twiddle1.data(), mr.twiddle2.data(), p, pinv); forward_ntt_mont_avx2(da, M, p, pinv, fwd, fwd_off); forward_ntt_mont_avx2(da + M, M, p, pinv, fwd, fwd_off); forward_ntt_mont_avx2(da + 2*M, M, p, pinv, fwd, fwd_off); // Forward db radix3_dft_forward_avx2(db, M, mr.c_plus, mr.c_minus, p, pinv); apply_twiddles_avx2(db, M, mr.twiddle1.data(), mr.twiddle2.data(), p, pinv); forward_ntt_mont_avx2(db, M, p, pinv, fwd, fwd_off); forward_ntt_mont_avx2(db + M, M, p, pinv, fwd, fwd_off); forward_ntt_mont_avx2(db + 2*M, M, p, pinv, fwd, fwd_off); // Pointwise multiply (N 要素) pointwise_mul_mont_avx2(da, db, N, p, pinv); // Inverse: 3 sub-INTTs → inv_twiddle → radix3_idft(×1/3) inverse_ntt_mont_avx2(da, M, p, pinv, inv, inv_off, inv_n_M); inverse_ntt_mont_avx2(da + M, M, p, pinv, inv, inv_off, inv_n_M); inverse_ntt_mont_avx2(da + 2*M, M, p, pinv, inv, inv_off, inv_n_M); apply_twiddles_avx2(da, M, mr.inv_twiddle1.data(), mr.inv_twiddle2.data(), p, pinv); radix3_dft_inverse_scaled_avx2(da, M, mr.c_plus, mr.c_minus, mr.inv3, p, pinv); } inline void mr_ntt_sqr_pipeline_avx2(uint64_t* da, size_t N, uint64_t p, uint64_t pinv, const NttRootsLayered& sub_roots, const MixedRadixConstants& mr) { size_t M = N / 3; const uint64_t* fwd = sub_roots.fwd.data(); const size_t* fwd_off = sub_roots.fwd_offset.data(); const uint64_t* inv = sub_roots.inv.data(); const size_t* inv_off = sub_roots.inv_offset.data(); uint64_t inv_n_M = sub_roots.inv_n; radix3_dft_forward_avx2(da, M, mr.c_plus, mr.c_minus, p, pinv); apply_twiddles_avx2(da, M, mr.twiddle1.data(), mr.twiddle2.data(), p, pinv); forward_ntt_mont_avx2(da, M, p, pinv, fwd, fwd_off); forward_ntt_mont_avx2(da + M, M, p, pinv, fwd, fwd_off); forward_ntt_mont_avx2(da + 2*M, M, p, pinv, fwd, fwd_off); pointwise_sqr_mont_avx2(da, N, p, pinv); inverse_ntt_mont_avx2(da, M, p, pinv, inv, inv_off, inv_n_M); inverse_ntt_mont_avx2(da + M, M, p, pinv, inv, inv_off, inv_n_M); inverse_ntt_mont_avx2(da + 2*M, M, p, pinv, inv, inv_off, inv_n_M); apply_twiddles_avx2(da, M, mr.inv_twiddle1.data(), mr.inv_twiddle2.data(), p, pinv); radix3_dft_inverse_scaled_avx2(da, M, mr.c_plus, mr.c_minus, mr.inv3, p, pinv); } // ================================================================ // Radix-5 パイプライン (N = 5M, M = 2^k) // ================================================================ inline void mr5_ntt_mul_pipeline_avx2(uint64_t* da, uint64_t* db, size_t N, uint64_t p, uint64_t pinv, const NttRootsLayered& sub_roots, const MixedRadix5Constants& mr) { size_t M = N / 5; const uint64_t* fwd = sub_roots.fwd.data(); const size_t* fwd_off = sub_roots.fwd_offset.data(); const uint64_t* inv = sub_roots.inv.data(); const size_t* inv_off = sub_roots.inv_offset.data(); uint64_t inv_n_M = sub_roots.inv_n; const uint64_t* tw[4] = {mr.twiddle[0].data(), mr.twiddle[1].data(), mr.twiddle[2].data(), mr.twiddle[3].data()}; const uint64_t* itw[4] = {mr.inv_twiddle[0].data(), mr.inv_twiddle[1].data(), mr.inv_twiddle[2].data(), mr.inv_twiddle[3].data()}; // Forward da radix5_dft_forward_avx2(da, M, mr.a_prime, mr.b_prime, mr.c_prime, mr.e_prime, mr.a_plus_b, mr.c_plus_e, p, pinv); apply_twiddles5_avx2(da, M, tw, p, pinv); for (int i = 0; i < 5; ++i) forward_ntt_mont_avx2(da + i*M, M, p, pinv, fwd, fwd_off); // Forward db radix5_dft_forward_avx2(db, M, mr.a_prime, mr.b_prime, mr.c_prime, mr.e_prime, mr.a_plus_b, mr.c_plus_e, p, pinv); apply_twiddles5_avx2(db, M, tw, p, pinv); for (int i = 0; i < 5; ++i) forward_ntt_mont_avx2(db + i*M, M, p, pinv, fwd, fwd_off); // Pointwise pointwise_mul_mont_avx2(da, db, N, p, pinv); // Inverse for (int i = 0; i < 5; ++i) inverse_ntt_mont_avx2(da + i*M, M, p, pinv, inv, inv_off, inv_n_M); apply_twiddles5_avx2(da, M, itw, p, pinv); radix5_dft_inverse_scaled_avx2(da, M, mr.a_prime, mr.b_prime, mr.c_prime, mr.e_prime, mr.a_plus_b, mr.c_plus_e, mr.inv5, p, pinv); } inline void mr5_ntt_sqr_pipeline_avx2(uint64_t* da, size_t N, uint64_t p, uint64_t pinv, const NttRootsLayered& sub_roots, const MixedRadix5Constants& mr) { size_t M = N / 5; const uint64_t* fwd = sub_roots.fwd.data(); const size_t* fwd_off = sub_roots.fwd_offset.data(); const uint64_t* inv = sub_roots.inv.data(); const size_t* inv_off = sub_roots.inv_offset.data(); uint64_t inv_n_M = sub_roots.inv_n; const uint64_t* tw[4] = {mr.twiddle[0].data(), mr.twiddle[1].data(), mr.twiddle[2].data(), mr.twiddle[3].data()}; const uint64_t* itw[4] = {mr.inv_twiddle[0].data(), mr.inv_twiddle[1].data(), mr.inv_twiddle[2].data(), mr.inv_twiddle[3].data()}; radix5_dft_forward_avx2(da, M, mr.a_prime, mr.b_prime, mr.c_prime, mr.e_prime, mr.a_plus_b, mr.c_plus_e, p, pinv); apply_twiddles5_avx2(da, M, tw, p, pinv); for (int i = 0; i < 5; ++i) forward_ntt_mont_avx2(da + i*M, M, p, pinv, fwd, fwd_off); pointwise_sqr_mont_avx2(da, N, p, pinv); for (int i = 0; i < 5; ++i) inverse_ntt_mont_avx2(da + i*M, M, p, pinv, inv, inv_off, inv_n_M); apply_twiddles5_avx2(da, M, itw, p, pinv); radix5_dft_inverse_scaled_avx2(da, M, mr.a_prime, mr.b_prime, mr.c_prime, mr.e_prime, mr.a_plus_b, mr.c_plus_e, mr.inv5, p, pinv); } #endif // ================================================================ // YC-10: 超並列 NTT (intra-NTT multi-threading) // ================================================================ // 各素数の NTT を T スレッドに分割し、多コア CPU を活用する。 // 構成: 3 素数 × T スレッド/素数 = 3T スレッド (64コアで T≈21) // Phase 1 (top layers): 各ステージのバタフライを T 分割 + バリア // Phase 2 (bottom layers): キャッシュブロック単位で T 分割 (バリア 1 回) // 並列 range 実行ヘルパー: body(start, end) を T スレッドで分割実行 template inline void ntt_parallel_range(size_t total, int T, F&& body) { if (T <= 1 || total == 0) { body(static_cast(0), total); return; } auto& pool = sangi::threadPool(); size_t chunk = (total + T - 1) / T; std::future futures[128]; int nf = 0; for (int t = 0; t < T - 1; ++t) { size_t s = static_cast(t) * chunk; size_t e = std::min(s + chunk, total); if (s >= total) break; futures[nf++] = pool.submit([s, e, &body]{ body(s, e); }); } size_t last_s = static_cast(T - 1) * chunk; if (last_s < total) body(last_s, total); for (int i = 0; i < nf; ++i) futures[i].get(); } // --- Scalar multi-threaded NTT --- // Forward NTT (DIF) — T スレッド並列 inline void forward_ntt_mont_mt(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* roots, int T) { for (size_t len = N; len >= 2; len >>= 1) { size_t half = len >> 1; size_t step = N / len; size_t num_groups = N / len; if (num_groups >= static_cast(T)) { // グループ数 >= T: グループ単位で分割 ntt_parallel_range(num_groups, T, [=](size_t g_start, size_t g_end) { for (size_t g = g_start; g < g_end; ++g) { size_t base = g * len; for (size_t j = 0; j < half; ++j) { uint64_t u = data[base + j]; uint64_t v = data[base + j + half]; data[base + j] = mod_add(u, v, p); data[base + j + half] = mont_mul( mod_sub(u, v, p), roots[j * step], p, p_inv_neg); } } }); } else { // グループ数 < T: j ループを T 分割 (各グループ独立) ntt_parallel_range(half, T, [=](size_t j_start, size_t j_end) { for (size_t g = 0; g < num_groups; ++g) { size_t base = g * len; for (size_t j = j_start; j < j_end; ++j) { uint64_t u = data[base + j]; uint64_t v = data[base + j + half]; data[base + j] = mod_add(u, v, p); data[base + j + half] = mont_mul( mod_sub(u, v, p), roots[j * step], p, p_inv_neg); } } }); } } } // Inverse NTT (DIT) — T スレッド並列 inline void inverse_ntt_mont_mt(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* inv_roots, uint64_t inv_n_normal, int T) { for (size_t len = 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t step = N / len; size_t num_groups = N / len; if (num_groups >= static_cast(T)) { ntt_parallel_range(num_groups, T, [=](size_t g_start, size_t g_end) { for (size_t g = g_start; g < g_end; ++g) { size_t base = g * len; for (size_t j = 0; j < half; ++j) { uint64_t u = data[base + j]; uint64_t v = mont_mul(data[base + j + half], inv_roots[j * step], p, p_inv_neg); data[base + j] = mod_add(u, v, p); data[base + j + half] = mod_sub(u, v, p); } } }); } else { ntt_parallel_range(half, T, [=](size_t j_start, size_t j_end) { for (size_t g = 0; g < num_groups; ++g) { size_t base = g * len; for (size_t j = j_start; j < j_end; ++j) { uint64_t u = data[base + j]; uint64_t v = mont_mul(data[base + j + half], inv_roots[j * step], p, p_inv_neg); data[base + j] = mod_add(u, v, p); data[base + j + half] = mod_sub(u, v, p); } } }); } } // N^{-1} スケーリング (並列) ntt_parallel_range(N, T, [=](size_t start, size_t end) { for (size_t i = start; i < end; ++i) data[i] = mont_mul(data[i], inv_n_normal, p, p_inv_neg); }); } // Multi-threaded pointwise multiply (AVX2/scalar 自動選択) inline void pointwise_mul_mont_mt(uint64_t* a, const uint64_t* b, size_t N, uint64_t p, uint64_t p_inv_neg, int T) { ntt_parallel_range(N, T, [=](size_t start, size_t end) { #ifdef __AVX2__ const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t i = start; for (; i + 4 <= end; i += 4) { __m256i va = _mm256_loadu_si256((__m256i*)(a + i)); __m256i vb = _mm256_loadu_si256((__m256i*)(b + i)); va = avx2_mont_mul(va, vb, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(a + i), va); } for (; i < end; ++i) a[i] = mont_mul(a[i], b[i], p, p_inv_neg); #else for (size_t i = start; i < end; ++i) a[i] = mont_mul(a[i], b[i], p, p_inv_neg); #endif }); } // Multi-threaded pointwise square inline void pointwise_sqr_mont_mt(uint64_t* a, size_t N, uint64_t p, uint64_t p_inv_neg, int T) { ntt_parallel_range(N, T, [=](size_t start, size_t end) { #ifdef __AVX2__ const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t i = start; for (; i + 4 <= end; i += 4) { __m256i va = _mm256_loadu_si256((__m256i*)(a + i)); va = avx2_mont_mul(va, va, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(a + i), va); } for (; i < end; ++i) a[i] = mont_mul(a[i], a[i], p, p_inv_neg); #else for (size_t i = start; i < end; ++i) a[i] = mont_mul(a[i], a[i], p, p_inv_neg); #endif }); } // Scalar multi-threaded pipeline: mul inline void ntt_mul_pipeline_mt(uint64_t* da, uint64_t* db, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* fwd_roots, const uint64_t* inv_roots, uint64_t inv_n_normal, int T) { forward_ntt_mont_mt(da, N, p, p_inv_neg, fwd_roots, T); forward_ntt_mont_mt(db, N, p, p_inv_neg, fwd_roots, T); pointwise_mul_mont_mt(da, db, N, p, p_inv_neg, T); inverse_ntt_mont_mt(da, N, p, p_inv_neg, inv_roots, inv_n_normal, T); } // Scalar multi-threaded pipeline: sqr inline void ntt_sqr_pipeline_mt(uint64_t* da, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* fwd_roots, const uint64_t* inv_roots, uint64_t inv_n_normal, int T) { forward_ntt_mont_mt(da, N, p, p_inv_neg, fwd_roots, T); pointwise_sqr_mont_mt(da, N, p, p_inv_neg, T); inverse_ntt_mont_mt(da, N, p, p_inv_neg, inv_roots, inv_n_normal, T); } #ifdef __AVX2__ // --- AVX2 multi-threaded NTT --- // AVX2 Forward NTT (DIF, cache-blocked) — T スレッド並列 inline void forward_ntt_mont_avx2_mt(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* layer_roots, const size_t* layer_offsets, int T) { size_t block = (N <= NTT_CACHE_BLOCK) ? N : NTT_CACHE_BLOCK; // Phase 1: Top layers (len > block) — ステージごとにバリア int layer = 0; for (size_t len = N; len > block; len >>= 1) { size_t half = len >> 1; size_t num_groups = N / len; const uint64_t* roots = layer_roots + layer_offsets[layer]; if (num_groups >= static_cast(T)) { // グループ単位で分割 ntt_parallel_range(num_groups, T, [=](size_t g_start, size_t g_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t g = g_start; g < g_end; ++g) dif_butterfly_avx2(data, g * len, half, roots, p_vec, pin_vec, p, p_inv_neg); }); } else { // j ループを分割 (各グループ内の half バタフライを T 分割) ntt_parallel_range(half, T, [=](size_t j_start, size_t j_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t g = 0; g < num_groups; ++g) { size_t start = g * len; uint64_t* d0 = data + start; uint64_t* d1 = data + start + half; size_t j = j_start; for (; j + 4 <= j_end; j += 4) { __m256i u = _mm256_loadu_si256((__m256i*)(d0 + j)); __m256i v = _mm256_loadu_si256((__m256i*)(d1 + j)); __m256i w = _mm256_loadu_si256((__m256i*)(roots + j)); __m256i sum = _mm256_add_epi64(u, v); __m256i sub = _mm256_sub_epi64(sum, p_vec); __m256i sm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), sub); sum = _mm256_blendv_epi8(sub, sum, sm); __m256i diff = _mm256_sub_epi64(u, v); __m256i da = _mm256_add_epi64(diff, p_vec); __m256i dm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); diff = _mm256_blendv_epi8(diff, da, dm); __m256i tw = avx2_mont_mul(diff, w, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(d0 + j), sum); _mm256_storeu_si256((__m256i*)(d1 + j), tw); } for (; j < j_end; ++j) { uint64_t uv = d0[j], vv = d1[j]; d0[j] = mod_add(uv, vv, p); d1[j] = mont_mul(mod_sub(uv, vv, p), roots[j], p, p_inv_neg); } } }); } ++layer; } // Phase 2: Bottom layers (len <= block) — ブロック並列 (バリア 1 回) int bottom_start = layer; size_t num_blocks = N / block; ntt_parallel_range(num_blocks, T, [=](size_t blk_start, size_t blk_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t bi = blk_start; bi < blk_end; ++bi) { size_t blk = bi * block; int cur_layer = bottom_start; for (size_t len = block; len >= 2; len >>= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[cur_layer]; for (size_t start = blk; start < blk + block; start += len) dif_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++cur_layer; } } }); } // AVX2 Inverse NTT (DIT, cache-blocked) — T スレッド並列 inline void inverse_ntt_mont_avx2_mt(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* layer_roots, const size_t* layer_offsets, uint64_t inv_n_normal, int T) { size_t block = (N <= NTT_CACHE_BLOCK) ? N : NTT_CACHE_BLOCK; // Phase 1: Bottom layers (len <= block) — ブロック並列 (バリア 1 回) int num_bottom_layers = 0; for (size_t l = 2; l <= block; l <<= 1) ++num_bottom_layers; size_t num_blocks = N / block; ntt_parallel_range(num_blocks, T, [=](size_t blk_start, size_t blk_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t bi = blk_start; bi < blk_end; ++bi) { size_t blk = bi * block; int cur_layer = 0; for (size_t len = 2; len <= block; len <<= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[cur_layer]; for (size_t start = blk; start < blk + block; start += len) dit_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++cur_layer; } } }); // Phase 2: Top layers (len > block) — ステージごとにバリア int layer = num_bottom_layers; for (size_t len = block * 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t num_groups = N / len; const uint64_t* roots = layer_roots + layer_offsets[layer]; if (num_groups >= static_cast(T)) { ntt_parallel_range(num_groups, T, [=](size_t g_start, size_t g_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t g = g_start; g < g_end; ++g) dit_butterfly_avx2(data, g * len, half, roots, p_vec, pin_vec, p, p_inv_neg); }); } else { ntt_parallel_range(half, T, [=](size_t j_start, size_t j_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t g = 0; g < num_groups; ++g) { size_t start = g * len; uint64_t* d0 = data + start; uint64_t* d1 = data + start + half; size_t j = j_start; for (; j + 4 <= j_end; j += 4) { __m256i u = _mm256_loadu_si256((__m256i*)(d0 + j)); __m256i v_raw = _mm256_loadu_si256((__m256i*)(d1 + j)); __m256i w = _mm256_loadu_si256((__m256i*)(roots + j)); __m256i v = avx2_mont_mul(v_raw, w, p_vec, pin_vec); __m256i sum = _mm256_add_epi64(u, v); __m256i sub = _mm256_sub_epi64(sum, p_vec); __m256i sm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), sub); sum = _mm256_blendv_epi8(sub, sum, sm); __m256i diff = _mm256_sub_epi64(u, v); __m256i da = _mm256_add_epi64(diff, p_vec); __m256i dm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); diff = _mm256_blendv_epi8(diff, da, dm); _mm256_storeu_si256((__m256i*)(d0 + j), sum); _mm256_storeu_si256((__m256i*)(d1 + j), diff); } for (; j < j_end; ++j) { uint64_t uv = d0[j]; uint64_t vv = mont_mul(d1[j], roots[j], p, p_inv_neg); d0[j] = mod_add(uv, vv, p); d1[j] = mod_sub(uv, vv, p); } } }); } ++layer; } // N^{-1} スケーリング (並列, AVX2) ntt_parallel_range(N, T, [=](size_t start, size_t end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); __m256i inv_n_vec = _mm256_set1_epi64x(inv_n_normal); size_t i = start; for (; i + 4 <= end; i += 4) { __m256i d = _mm256_loadu_si256((__m256i*)(data + i)); d = avx2_mont_mul(d, inv_n_vec, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(data + i), d); } for (; i < end; ++i) data[i] = mont_mul(data[i], inv_n_normal, p, p_inv_neg); }); } // AVX2 multi-threaded pipeline: mul inline void ntt_mul_pipeline_avx2_mt(uint64_t* da, uint64_t* db, size_t N, uint64_t p, uint64_t p_inv_neg, const NttRootsLayered& roots, int T) { forward_ntt_mont_avx2_mt(da, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data(), T); forward_ntt_mont_avx2_mt(db, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data(), T); pointwise_mul_mont_mt(da, db, N, p, p_inv_neg, T); inverse_ntt_mont_avx2_mt(da, N, p, p_inv_neg, roots.inv.data(), roots.inv_offset.data(), roots.inv_n, T); } // AVX2 multi-threaded pipeline: sqr inline void ntt_sqr_pipeline_avx2_mt(uint64_t* da, size_t N, uint64_t p, uint64_t p_inv_neg, const NttRootsLayered& roots, int T) { forward_ntt_mont_avx2_mt(da, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data(), T); pointwise_sqr_mont_mt(da, N, p, p_inv_neg, T); inverse_ntt_mont_avx2_mt(da, N, p, p_inv_neg, roots.inv.data(), roots.inv_offset.data(), roots.inv_n, T); } #endif // ================================================================ // YC-2: Fused Multiply-Add (rp = a*b + c*d) // ================================================================ // 4 入力を forward NTT → pointwise MAC → 1 回の inverse NTT + CRT // 通常の 2 回乗算 + 加算と比べ、inverse NTT 1 回 + CRT 1 回 + 大数加算を節約 // rp[0..rn-1] = ap[0..an-1]*bp[0..bn-1] + cp[0..cn-1]*dp[0..dn-1] inline void mul_add_prime_ntt(uint64_t* rp, size_t rn, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn, const uint64_t* cp, size_t cn, const uint64_t* dp, size_t dn) { auto& ctx = getPrimeNttContext(); // NTT 長は両積の最大サイズに合わせる size_t prod1_n = an + bn; size_t prod2_n = cn + dn; size_t max_prod = std::max(prod1_n, prod2_n); size_t N = next_power_of_2(max_prod); // 4 入力 × 3 素数 = 12N ワード thread_local std::vector work; size_t total = 12 * N; if (work.size() < total) work.resize(total); uint64_t* da[3]; // a の NTT uint64_t* db[3]; // b の NTT uint64_t* dc[3]; // c の NTT uint64_t* dd[3]; // d の NTT for (int k = 0; k < 3; ++k) { da[k] = work.data() + k * 4 * N; db[k] = da[k] + N; dc[k] = db[k] + N; dd[k] = dc[k] + N; } // 入力分解: mod p → Montgomery 形式 for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; // a for (size_t i = 0; i < an; ++i) da[k][i] = to_mont(ap[i] % p, p, pin, r2); for (size_t i = an; i < N; ++i) da[k][i] = 0; // b for (size_t i = 0; i < bn; ++i) db[k][i] = to_mont(bp[i] % p, p, pin, r2); for (size_t i = bn; i < N; ++i) db[k][i] = 0; // c for (size_t i = 0; i < cn; ++i) dc[k][i] = to_mont(cp[i] % p, p, pin, r2); for (size_t i = cn; i < N; ++i) dc[k][i] = 0; // d for (size_t i = 0; i < dn; ++i) dd[k][i] = to_mont(dp[i] % p, p, pin, r2); for (size_t i = dn; i < N; ++i) dd[k][i] = 0; } #ifdef __AVX2__ thread_local NttRootsLayered roots_avx[3]; for (int k = 0; k < 3; ++k) { if (roots_avx[k].N != N || roots_avx[k].p != ctx.primes[k].p) { roots_avx[k].build(ctx.primes[k], N); } } bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { const NttRootsLayered* rptr[3] = {&roots_avx[0], &roots_avx[1], &roots_avx[2]}; // 素数 0, 1 をワーカーで、素数 2 をメインスレッドで処理 auto f0 = sangi::threadPool().submit([&, N, rptr]{ uint64_t p = ctx.primes[0].p, pin = ctx.primes[0].p_inv_neg; const auto& r = *rptr[0]; forward_ntt_mont_avx2(da[0], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(db[0], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dc[0], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dd[0], N, p, pin, r.fwd.data(), r.fwd_offset.data()); pointwise_mul_mont_avx2(da[0], db[0], N, p, pin); pointwise_mac_mont_avx2(da[0], dc[0], dd[0], N, p, pin); inverse_ntt_mont_avx2(da[0], N, p, pin, r.inv.data(), r.inv_offset.data(), r.inv_n); }); auto f1 = sangi::threadPool().submit([&, N, rptr]{ uint64_t p = ctx.primes[1].p, pin = ctx.primes[1].p_inv_neg; const auto& r = *rptr[1]; forward_ntt_mont_avx2(da[1], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(db[1], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dc[1], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dd[1], N, p, pin, r.fwd.data(), r.fwd_offset.data()); pointwise_mul_mont_avx2(da[1], db[1], N, p, pin); pointwise_mac_mont_avx2(da[1], dc[1], dd[1], N, p, pin); inverse_ntt_mont_avx2(da[1], N, p, pin, r.inv.data(), r.inv_offset.data(), r.inv_n); }); { uint64_t p = ctx.primes[2].p, pin = ctx.primes[2].p_inv_neg; const auto& r = *rptr[2]; forward_ntt_mont_avx2(da[2], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(db[2], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dc[2], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dd[2], N, p, pin, r.fwd.data(), r.fwd_offset.data()); pointwise_mul_mont_avx2(da[2], db[2], N, p, pin); pointwise_mac_mont_avx2(da[2], dc[2], dd[2], N, p, pin); inverse_ntt_mont_avx2(da[2], N, p, pin, r.inv.data(), r.inv_offset.data(), r.inv_n); } f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p, pin = ctx.primes[k].p_inv_neg; const auto& r = roots_avx[k]; forward_ntt_mont_avx2(da[k], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(db[k], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dc[k], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dd[k], N, p, pin, r.fwd.data(), r.fwd_offset.data()); pointwise_mul_mont_avx2(da[k], db[k], N, p, pin); pointwise_mac_mont_avx2(da[k], dc[k], dd[k], N, p, pin); inverse_ntt_mont_avx2(da[k], N, p, pin, r.inv.data(), r.inv_offset.data(), r.inv_n); } } #else thread_local NttRootsMont roots[3]; for (int k = 0; k < 3; ++k) { if (roots[k].N != N || roots[k].p != ctx.primes[k].p) { roots[k].build(ctx.primes[k], N); } } bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { auto f0 = sangi::threadPool().submit([&, N]{ uint64_t p = ctx.primes[0].p, pin = ctx.primes[0].p_inv_neg; forward_ntt_mont(da[0], N, p, pin, roots[0].roots.data()); forward_ntt_mont(db[0], N, p, pin, roots[0].roots.data()); forward_ntt_mont(dc[0], N, p, pin, roots[0].roots.data()); forward_ntt_mont(dd[0], N, p, pin, roots[0].roots.data()); pointwise_mul_mont(da[0], db[0], N, p, pin); pointwise_mac_mont(da[0], dc[0], dd[0], N, p, pin); inverse_ntt_mont(da[0], N, p, pin, roots[0].inv_roots.data(), roots[0].inv_n); }); auto f1 = sangi::threadPool().submit([&, N]{ uint64_t p = ctx.primes[1].p, pin = ctx.primes[1].p_inv_neg; forward_ntt_mont(da[1], N, p, pin, roots[1].roots.data()); forward_ntt_mont(db[1], N, p, pin, roots[1].roots.data()); forward_ntt_mont(dc[1], N, p, pin, roots[1].roots.data()); forward_ntt_mont(dd[1], N, p, pin, roots[1].roots.data()); pointwise_mul_mont(da[1], db[1], N, p, pin); pointwise_mac_mont(da[1], dc[1], dd[1], N, p, pin); inverse_ntt_mont(da[1], N, p, pin, roots[1].inv_roots.data(), roots[1].inv_n); }); { uint64_t p = ctx.primes[2].p, pin = ctx.primes[2].p_inv_neg; forward_ntt_mont(da[2], N, p, pin, roots[2].roots.data()); forward_ntt_mont(db[2], N, p, pin, roots[2].roots.data()); forward_ntt_mont(dc[2], N, p, pin, roots[2].roots.data()); forward_ntt_mont(dd[2], N, p, pin, roots[2].roots.data()); pointwise_mul_mont(da[2], db[2], N, p, pin); pointwise_mac_mont(da[2], dc[2], dd[2], N, p, pin); inverse_ntt_mont(da[2], N, p, pin, roots[2].inv_roots.data(), roots[2].inv_n); } f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p, pin = ctx.primes[k].p_inv_neg; forward_ntt_mont(da[k], N, p, pin, roots[k].roots.data()); forward_ntt_mont(db[k], N, p, pin, roots[k].roots.data()); forward_ntt_mont(dc[k], N, p, pin, roots[k].roots.data()); forward_ntt_mont(dd[k], N, p, pin, roots[k].roots.data()); pointwise_mul_mont(da[k], db[k], N, p, pin); pointwise_mac_mont(da[k], dc[k], dd[k], N, p, pin); inverse_ntt_mont(da[k], N, p, pin, roots[k].inv_roots.data(), roots[k].inv_n); } } #endif // CRT 復元 crt_recompose(rp, rn, da[0], da[1], da[2], max_prod, ctx.crt); } // ================================================================ // メイン乗算関数 // ================================================================ // rp[0..an+bn-1] = ap[0..an-1] × bp[0..bn-1] inline void mul_prime_ntt(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn) { if (an < bn) { std::swap(ap, bp); std::swap(an, bn); } auto& ctx = getPrimeNttContext(); size_t rn = an + bn; size_t N = next_smooth_size(rn); // 2^a, 3×2^a, or 5×2^a // radix 判定: 0=power-of-2, 3=radix-3, 5=radix-5 int radix = 0; if (N % 5 == 0 && is_power_of_2(N / 5)) radix = 5; else if (N % 3 == 0 && is_power_of_2(N / 3)) radix = 3; // 3 素数分の NTT データ配列を確保 thread_local std::vector work; size_t total = 6 * N; if (work.size() < total) work.resize(total); uint64_t* da[3]; uint64_t* db[3]; for (int k = 0; k < 3; ++k) { da[k] = work.data() + k * 2 * N; db[k] = da[k] + N; } // 入力分解: a[i] mod p → Montgomery 形式に変換 for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < an; ++i) da[k][i] = to_mont(ap[i] % p, p, pin, r2); for (size_t i = an; i < N; ++i) da[k][i] = 0; for (size_t i = 0; i < bn; ++i) db[k][i] = to_mont(bp[i] % p, p, pin, r2); for (size_t i = bn; i < N; ++i) db[k][i] = 0; } #ifdef __AVX2__ if (radix == 5) { // Radix-5 パス (N = 5M) size_t M = N / 5; thread_local NttRootsLayered sub_roots_avx[3]; thread_local MixedRadix5Constants mr5_const[3]; for (int k = 0; k < 3; ++k) { if (sub_roots_avx[k].N != M || sub_roots_avx[k].p != ctx.primes[k].p) sub_roots_avx[k].build(ctx.primes[k], M); if (mr5_const[k].N != N || mr5_const[k].p != ctx.primes[k].p) mr5_const[k].build(ctx.primes[k], N); } // 常にパラレル化 { const NttRootsLayered* srp[3] = {&sub_roots_avx[0], &sub_roots_avx[1], &sub_roots_avx[2]}; const MixedRadix5Constants* mrp[3] = {&mr5_const[0], &mr5_const[1], &mr5_const[2]}; auto f0 = sangi::threadPool().submit([&, N, srp, mrp]{ mr5_ntt_mul_pipeline_avx2(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *srp[0], *mrp[0]); }); auto f1 = sangi::threadPool().submit([&, N, srp, mrp]{ mr5_ntt_mul_pipeline_avx2(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *srp[1], *mrp[1]); }); mr5_ntt_mul_pipeline_avx2(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *srp[2], *mrp[2]); f0.get(); f1.get(); } } else if (radix == 3) { // Radix-3 パス (N = 3M) size_t M = N / 3; thread_local NttRootsLayered sub_roots_avx[3]; thread_local MixedRadixConstants mr_const[3]; for (int k = 0; k < 3; ++k) { if (sub_roots_avx[k].N != M || sub_roots_avx[k].p != ctx.primes[k].p) sub_roots_avx[k].build(ctx.primes[k], M); if (mr_const[k].N != N || mr_const[k].p != ctx.primes[k].p) mr_const[k].build(ctx.primes[k], N); } // 混合基数: 9 sub-NTTs/prime → 常にパラレル化 (M >= 256 で十分な仕事量) { const NttRootsLayered* srp[3] = {&sub_roots_avx[0], &sub_roots_avx[1], &sub_roots_avx[2]}; const MixedRadixConstants* mrp[3] = {&mr_const[0], &mr_const[1], &mr_const[2]}; auto f0 = sangi::threadPool().submit([&, N, srp, mrp]{ mr_ntt_mul_pipeline_avx2(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *srp[0], *mrp[0]); }); auto f1 = sangi::threadPool().submit([&, N, srp, mrp]{ mr_ntt_mul_pipeline_avx2(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *srp[1], *mrp[1]); }); mr_ntt_mul_pipeline_avx2(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *srp[2], *mrp[2]); f0.get(); f1.get(); } } else { // 既存の power-of-2 パス thread_local NttRootsLayered roots_avx[3]; for (int k = 0; k < 3; ++k) { if (roots_avx[k].N != N || roots_avx[k].p != ctx.primes[k].p) roots_avx[k].build(ctx.primes[k], N); } bool use_mt = (N >= MULTI_THREAD_NTT_THRESHOLD); bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (use_mt) { unsigned hw = std::thread::hardware_concurrency(); int T = std::max(2, static_cast(hw > 3 ? (hw - 1) / 3 : 1)); const NttRootsLayered* rpp[3] = {&roots_avx[0], &roots_avx[1], &roots_avx[2]}; auto f0 = sangi::threadPool().submit([&, N, T, rpp]{ ntt_mul_pipeline_avx2_mt(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *rpp[0], T); }); auto f1 = sangi::threadPool().submit([&, N, T, rpp]{ ntt_mul_pipeline_avx2_mt(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *rpp[1], T); }); ntt_mul_pipeline_avx2_mt(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *rpp[2], T); f0.get(); f1.get(); } else if (parallel) { const NttRootsLayered* rpp[3] = {&roots_avx[0], &roots_avx[1], &roots_avx[2]}; auto f0 = sangi::threadPool().submit([&, N, rpp]{ ntt_mul_pipeline_avx2(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *rpp[0]); }); auto f1 = sangi::threadPool().submit([&, N, rpp]{ ntt_mul_pipeline_avx2(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *rpp[1]); }); ntt_mul_pipeline_avx2(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *rpp[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { ntt_mul_pipeline_avx2(da[k], db[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, roots_avx[k]); } } } #else // スカラー Montgomery NTT (混合基数は AVX2 のみ → power-of-2 にフォールバック) if (radix != 0) N = next_power_of_2(rn); thread_local NttRootsMont roots[3]; for (int k = 0; k < 3; ++k) { if (roots[k].N != N || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], N); } const uint64_t* fwd_r[3], *inv_r[3]; uint64_t inv_n_vals[3]; for (int k = 0; k < 3; ++k) { fwd_r[k] = roots[k].roots.data(); inv_r[k] = roots[k].inv_roots.data(); inv_n_vals[k] = roots[k].inv_n; } bool use_mt = (N >= MULTI_THREAD_NTT_THRESHOLD); bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (use_mt) { unsigned hw = std::thread::hardware_concurrency(); int T = std::max(2, static_cast(hw > 3 ? (hw - 1) / 3 : 1)); auto f0 = sangi::threadPool().submit([&, N, T]{ ntt_mul_pipeline_mt(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0], inv_r[0], inv_n_vals[0], T); }); auto f1 = sangi::threadPool().submit([&, N, T]{ ntt_mul_pipeline_mt(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1], inv_r[1], inv_n_vals[1], T); }); ntt_mul_pipeline_mt(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2], inv_r[2], inv_n_vals[2], T); f0.get(); f1.get(); } else if (parallel) { auto f0 = sangi::threadPool().submit([&, N]{ ntt_mul_pipeline_mont(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0], inv_r[0], inv_n_vals[0]); }); auto f1 = sangi::threadPool().submit([&, N]{ ntt_mul_pipeline_mont(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1], inv_r[1], inv_n_vals[1]); }); ntt_mul_pipeline_mont(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2], inv_r[2], inv_n_vals[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { ntt_mul_pipeline_mont(da[k], db[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k], inv_r[k], inv_n_vals[k]); } } #endif // CRT 復元 + キャリー伝搬 crt_recompose(rp, rn, da[0], da[1], da[2], rn, ctx.crt); } // ================================================================ // YC-1d: Middle Product (中間積) // ================================================================ // // ap[0..an-1] × bp[0..bn-1] のフル積 c[0..an+bn-2] のうち、 // 中間部分 c[bn-1..an-1] (計 an-bn+1 リム) を rp に書き出す。 // // 数学的根拠: // NTT 長 N = next_pow2(an) で循環畳み込みを行うと、位置 k での wrap-around は // i+j = k+N (i < an, j < bn) を満たす項。k >= bn-1 のとき k+N >= bn-1+an > an+bn-2 // より i+j <= an-1+bn-1 = an+bn-2 < k+N なので wrap-around 項は存在しない。 // 従って位置 bn-1..an-1 は正確に得られる。 // // 用途: Newton 除算の残差計算 — bx ≈ 1 の residual = 1 - bx を // フル積 (NTT長 an+bn) の代わりに短い NTT 長 (an) で計算。 // // 前提: an >= bn >= 1 // 出力: rp[0..an-bn] = c[bn-1..an-1] (an-bn+1 リム) inline void middle_product_prime_ntt(uint64_t* rp, size_t rn, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn) { // rn = an - bn + 1 (呼び出し元で保証) auto& ctx = getPrimeNttContext(); // NTT 長: next_pow2(an) — フル積なら next_pow2(an+bn) が必要だが、 // middle product では an で十分 (上記の数学的根拠) size_t N = next_power_of_2(an); // 3 素数分のワーク: da[N] + db[N] × 3 = 6N thread_local std::vector work_mp; size_t total = 6 * N; if (work_mp.size() < total) work_mp.resize(total); uint64_t* da[3]; uint64_t* db[3]; for (int k = 0; k < 3; ++k) { da[k] = work_mp.data() + k * 2 * N; db[k] = da[k] + N; } // 入力分解: a → Montgomery 形式 (そのまま) // b → 反転して Montgomery 形式に変換 // b_rev[j] = b[bn-1-j] により、循環畳み込みの位置 k が // フル積の位置 k + (bn-1) に対応する for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < an; ++i) da[k][i] = to_mont(ap[i] % p, p, pin, r2); for (size_t i = an; i < N; ++i) da[k][i] = 0; // b を反転して配置 for (size_t i = 0; i < bn; ++i) db[k][i] = to_mont(bp[bn - 1 - i] % p, p, pin, r2); for (size_t i = bn; i < N; ++i) db[k][i] = 0; } #ifdef __AVX2__ thread_local NttRootsLayered roots_avx_mp[3]; for (int k = 0; k < 3; ++k) { if (roots_avx_mp[k].N != N || roots_avx_mp[k].p != ctx.primes[k].p) { roots_avx_mp[k].build(ctx.primes[k], N); } } bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { const NttRootsLayered* rp_roots[3] = {&roots_avx_mp[0], &roots_avx_mp[1], &roots_avx_mp[2]}; auto f0 = sangi::threadPool().submit([&, N, rp_roots]{ ntt_mul_pipeline_avx2(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *rp_roots[0]); }); auto f1 = sangi::threadPool().submit([&, N, rp_roots]{ ntt_mul_pipeline_avx2(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *rp_roots[1]); }); ntt_mul_pipeline_avx2(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *rp_roots[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { ntt_mul_pipeline_avx2(da[k], db[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, roots_avx_mp[k]); } } #else thread_local NttRootsMont roots_mp[3]; for (int k = 0; k < 3; ++k) { if (roots_mp[k].N != N || roots_mp[k].p != ctx.primes[k].p) { roots_mp[k].build(ctx.primes[k], N); } } const uint64_t* fwd_r[3], *inv_r[3]; uint64_t inv_n_vals[3]; for (int k = 0; k < 3; ++k) { fwd_r[k] = roots_mp[k].roots.data(); inv_r[k] = roots_mp[k].inv_roots.data(); inv_n_vals[k] = roots_mp[k].inv_n; } bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { auto f0 = sangi::threadPool().submit([&, N]{ ntt_mul_pipeline_mont(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0], inv_r[0], inv_n_vals[0]); }); auto f1 = sangi::threadPool().submit([&, N]{ ntt_mul_pipeline_mont(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1], inv_r[1], inv_n_vals[1]); }); ntt_mul_pipeline_mont(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2], inv_r[2], inv_n_vals[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { ntt_mul_pipeline_mont(da[k], db[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k], inv_r[k], inv_n_vals[k]); } } #endif // CRT 復元: NTT 長 N 分を一時バッファに復元し、位置 0..rn-1 を抽出 // (循環畳み込みの位置 0 がフル積の位置 bn-1 に対応) // N が大きい場合でもフル積長 (an+bn) は不要で N 分で十分 thread_local std::vector crt_buf; if (crt_buf.size() < N) crt_buf.resize(N); crt_recompose(crt_buf.data(), N, da[0], da[1], da[2], N, ctx.crt); // 位置 0..rn-1 を出力にコピー (= フル積の c[bn-1..an-1]) std::memcpy(rp, crt_buf.data(), rn * sizeof(uint64_t)); } // ================================================================ // 自乗関数 // ================================================================ // rp[0..2n-1] = ap[0..n-1]² inline void sqr_prime_ntt(uint64_t* rp, const uint64_t* ap, size_t an) { auto& ctx = getPrimeNttContext(); size_t rn = 2 * an; size_t N = next_smooth_size(rn); int radix = 0; if (N % 5 == 0 && is_power_of_2(N / 5)) radix = 5; else if (N % 3 == 0 && is_power_of_2(N / 3)) radix = 3; thread_local std::vector work; size_t total = 3 * N; if (work.size() < total) work.resize(total); uint64_t* da[3]; for (int k = 0; k < 3; ++k) { da[k] = work.data() + k * N; } // 入力分解 → Montgomery 形式 for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < an; ++i) da[k][i] = to_mont(ap[i] % p, p, pin, r2); for (size_t i = an; i < N; ++i) da[k][i] = 0; } #ifdef __AVX2__ if (radix == 5) { size_t M = N / 5; thread_local NttRootsLayered sub_roots_avx[3]; thread_local MixedRadix5Constants mr5_const[3]; for (int k = 0; k < 3; ++k) { if (sub_roots_avx[k].N != M || sub_roots_avx[k].p != ctx.primes[k].p) sub_roots_avx[k].build(ctx.primes[k], M); if (mr5_const[k].N != N || mr5_const[k].p != ctx.primes[k].p) mr5_const[k].build(ctx.primes[k], N); } { const NttRootsLayered* srp[3] = {&sub_roots_avx[0], &sub_roots_avx[1], &sub_roots_avx[2]}; const MixedRadix5Constants* mrp[3] = {&mr5_const[0], &mr5_const[1], &mr5_const[2]}; auto f0 = sangi::threadPool().submit([&, N, srp, mrp]{ mr5_ntt_sqr_pipeline_avx2(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *srp[0], *mrp[0]); }); auto f1 = sangi::threadPool().submit([&, N, srp, mrp]{ mr5_ntt_sqr_pipeline_avx2(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *srp[1], *mrp[1]); }); mr5_ntt_sqr_pipeline_avx2(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *srp[2], *mrp[2]); f0.get(); f1.get(); } } else if (radix == 3) { size_t M = N / 3; thread_local NttRootsLayered sub_roots_avx[3]; thread_local MixedRadixConstants mr_const[3]; for (int k = 0; k < 3; ++k) { if (sub_roots_avx[k].N != M || sub_roots_avx[k].p != ctx.primes[k].p) sub_roots_avx[k].build(ctx.primes[k], M); if (mr_const[k].N != N || mr_const[k].p != ctx.primes[k].p) mr_const[k].build(ctx.primes[k], N); } { const NttRootsLayered* srp[3] = {&sub_roots_avx[0], &sub_roots_avx[1], &sub_roots_avx[2]}; const MixedRadixConstants* mrp[3] = {&mr_const[0], &mr_const[1], &mr_const[2]}; auto f0 = sangi::threadPool().submit([&, N, srp, mrp]{ mr_ntt_sqr_pipeline_avx2(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *srp[0], *mrp[0]); }); auto f1 = sangi::threadPool().submit([&, N, srp, mrp]{ mr_ntt_sqr_pipeline_avx2(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *srp[1], *mrp[1]); }); mr_ntt_sqr_pipeline_avx2(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *srp[2], *mrp[2]); f0.get(); f1.get(); } } else { thread_local NttRootsLayered roots_avx[3]; for (int k = 0; k < 3; ++k) { if (roots_avx[k].N != N || roots_avx[k].p != ctx.primes[k].p) roots_avx[k].build(ctx.primes[k], N); } bool use_mt = (N >= MULTI_THREAD_NTT_THRESHOLD); bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (use_mt) { unsigned hw = std::thread::hardware_concurrency(); int T = std::max(2, static_cast(hw > 3 ? (hw - 1) / 3 : 1)); const NttRootsLayered* rpp[3] = {&roots_avx[0], &roots_avx[1], &roots_avx[2]}; auto f0 = sangi::threadPool().submit([&, N, T, rpp]{ ntt_sqr_pipeline_avx2_mt(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *rpp[0], T); }); auto f1 = sangi::threadPool().submit([&, N, T, rpp]{ ntt_sqr_pipeline_avx2_mt(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *rpp[1], T); }); ntt_sqr_pipeline_avx2_mt(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *rpp[2], T); f0.get(); f1.get(); } else if (parallel) { const NttRootsLayered* rpp[3] = {&roots_avx[0], &roots_avx[1], &roots_avx[2]}; auto f0 = sangi::threadPool().submit([&, N, rpp]{ ntt_sqr_pipeline_avx2(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *rpp[0]); }); auto f1 = sangi::threadPool().submit([&, N, rpp]{ ntt_sqr_pipeline_avx2(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *rpp[1]); }); ntt_sqr_pipeline_avx2(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *rpp[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) ntt_sqr_pipeline_avx2(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, roots_avx[k]); } } #else // スカラー: 混合基数未サポート → power-of-2 にフォールバック if (radix != 0) N = next_power_of_2(rn); thread_local NttRootsMont roots[3]; for (int k = 0; k < 3; ++k) { if (roots[k].N != N || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], N); } const uint64_t* fwd_r[3], *inv_r[3]; uint64_t inv_n_vals[3]; for (int k = 0; k < 3; ++k) { fwd_r[k] = roots[k].roots.data(); inv_r[k] = roots[k].inv_roots.data(); inv_n_vals[k] = roots[k].inv_n; } bool use_mt = (N >= MULTI_THREAD_NTT_THRESHOLD); bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (use_mt) { unsigned hw = std::thread::hardware_concurrency(); int T = std::max(2, static_cast(hw > 3 ? (hw - 1) / 3 : 1)); auto f0 = sangi::threadPool().submit([&, N, T]{ ntt_sqr_pipeline_mt(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0], inv_r[0], inv_n_vals[0], T); }); auto f1 = sangi::threadPool().submit([&, N, T]{ ntt_sqr_pipeline_mt(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1], inv_r[1], inv_n_vals[1], T); }); ntt_sqr_pipeline_mt(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2], inv_r[2], inv_n_vals[2], T); f0.get(); f1.get(); } else if (parallel) { auto f0 = sangi::threadPool().submit([&, N]{ ntt_sqr_pipeline_mont(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0], inv_r[0], inv_n_vals[0]); }); auto f1 = sangi::threadPool().submit([&, N]{ ntt_sqr_pipeline_mont(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1], inv_r[1], inv_n_vals[1]); }); ntt_sqr_pipeline_mont(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2], inv_r[2], inv_n_vals[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) ntt_sqr_pipeline_mont(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k], inv_r[k], inv_n_vals[k]); } #endif // CRT 復元 + キャリー伝搬 crt_recompose(rp, rn, da[0], da[1], da[2], rn, ctx.crt); } // ================================================================ // NTT キャッシュ: 定数オペランドの forward NTT を保存 // ================================================================ // 同じ値で繰り返し乗算する場合、forward NTT を 1 回だけ計算して // キャッシュし、以降は pointwise_mul + INTT のみで乗算する。 // 通常の乗算では NTT(A) + NTT(B) + pointwise + INTT の 4 ステップが必要だが、 // キャッシュ使用時は NTT(A) + pointwise + INTT の 3 ステップで済む。 struct NttCache { std::vector storage_; uint64_t* data[3] = {}; // 3 素数分の forward NTT 結果 (Montgomery 形式) size_t ntt_len = 0; // キャッシュされた NTT 長 size_t orig_limbs = 0; // 元のリム数 void invalidate() { ntt_len = 0; } bool valid(size_t N) const { return ntt_len == N && N > 0; } // bp[0..bn-1] の forward NTT を計算してキャッシュ // roots は呼び出し元の thread_local を共有 (roots 再構築を回避) #ifdef __AVX2__ void build(const uint64_t* bp, size_t bn, size_t N, NttRootsLayered (&roots_ext)[3]) { #else void build(const uint64_t* bp, size_t bn, size_t N, NttRootsMont (&roots_ext)[3]) { #endif auto& ctx = getPrimeNttContext(); ntt_len = N; orig_limbs = bn; storage_.resize(3 * N); for (int k = 0; k < 3; ++k) data[k] = storage_.data() + k * N; // 入力を Montgomery 形式に変換 for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < bn; ++i) data[k][i] = to_mont(bp[i] % p, p, pin, r2); for (size_t i = bn; i < N; ++i) data[k][i] = 0; } // forward NTT (3 素数並列) — 呼び出し元の roots を使用 bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); #ifdef __AVX2__ if (parallel) { const NttRootsLayered* rp[3] = {&roots_ext[0], &roots_ext[1], &roots_ext[2]}; auto f0 = sangi::threadPool().submit([&, N, rp]{ forward_ntt_mont_avx2(data[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, rp[0]->fwd.data(), rp[0]->fwd_offset.data()); }); auto f1 = sangi::threadPool().submit([&, N, rp]{ forward_ntt_mont_avx2(data[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, rp[1]->fwd.data(), rp[1]->fwd_offset.data()); }); forward_ntt_mont_avx2(data[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, roots_ext[2].fwd.data(), roots_ext[2].fwd_offset.data()); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) forward_ntt_mont_avx2(data[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, roots_ext[k].fwd.data(), roots_ext[k].fwd_offset.data()); } #else const uint64_t* fwd_r[3]; for (int k = 0; k < 3; ++k) fwd_r[k] = roots_ext[k].roots.data(); if (parallel) { auto f0 = sangi::threadPool().submit([&, N]{ forward_ntt_mont(data[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0]); }); auto f1 = sangi::threadPool().submit([&, N]{ forward_ntt_mont(data[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1]); }); forward_ntt_mont(data[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) forward_ntt_mont(data[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); } #endif } }; // キャッシュ付き乗算: rp = ap × bp (bp の forward NTT はキャッシュから取得) // cache が無効または NTT 長が合わない場合は自動で build する inline void mul_prime_ntt_cached(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn, NttCache& cache) { auto& ctx = getPrimeNttContext(); size_t rn = an + bn; size_t N = next_power_of_2(rn); // A の NTT データ配列を確保 (3 素数分) thread_local std::vector cached_work; size_t total = 3 * N; if (cached_work.size() < total) cached_work.resize(total); uint64_t* da[3]; for (int k = 0; k < 3; ++k) da[k] = cached_work.data() + k * N; // A を Montgomery 形式に変換 for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < an; ++i) da[k][i] = to_mont(ap[i] % p, p, pin, r2); for (size_t i = an; i < N; ++i) da[k][i] = 0; } // パイプライン: NTT(A) + pointwise_mul(A, cached_B) + INTT(A) bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); #ifdef __AVX2__ thread_local NttRootsLayered cached_roots_avx[3]; for (int k = 0; k < 3; ++k) { if (cached_roots_avx[k].N != N || cached_roots_avx[k].p != ctx.primes[k].p) cached_roots_avx[k].build(ctx.primes[k], N); } // キャッシュ構築 (初回または NTT 長変更時) — roots を共有 if (!cache.valid(N)) { cache.build(bp, bn, N, cached_roots_avx); } // thread_local → raw pointer 抽出 (ワーカースレッド安全) const uint64_t* fwd_d[3], *inv_d[3]; const size_t* fwd_o[3], *inv_o[3]; uint64_t inv_n_vals[3]; for (int k = 0; k < 3; ++k) { fwd_d[k] = cached_roots_avx[k].fwd.data(); fwd_o[k] = cached_roots_avx[k].fwd_offset.data(); inv_d[k] = cached_roots_avx[k].inv.data(); inv_o[k] = cached_roots_avx[k].inv_offset.data(); inv_n_vals[k] = cached_roots_avx[k].inv_n; } if (parallel) { auto f0 = sangi::threadPool().submit([&, N]{ forward_ntt_mont_avx2(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_d[0], fwd_o[0]); pointwise_mul_mont_avx2(da[0], cache.data[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg); inverse_ntt_mont_avx2(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, inv_d[0], inv_o[0], inv_n_vals[0]); }); auto f1 = sangi::threadPool().submit([&, N]{ forward_ntt_mont_avx2(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_d[1], fwd_o[1]); pointwise_mul_mont_avx2(da[1], cache.data[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg); inverse_ntt_mont_avx2(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, inv_d[1], inv_o[1], inv_n_vals[1]); }); forward_ntt_mont_avx2(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_d[2], fwd_o[2]); pointwise_mul_mont_avx2(da[2], cache.data[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg); inverse_ntt_mont_avx2(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, inv_d[2], inv_o[2], inv_n_vals[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { forward_ntt_mont_avx2(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_d[k], fwd_o[k]); pointwise_mul_mont_avx2(da[k], cache.data[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg); inverse_ntt_mont_avx2(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, inv_d[k], inv_o[k], inv_n_vals[k]); } } #else thread_local NttRootsMont cached_roots[3]; for (int k = 0; k < 3; ++k) { if (cached_roots[k].N != N || cached_roots[k].p != ctx.primes[k].p) cached_roots[k].build(ctx.primes[k], N); } // キャッシュ構築 (初回または NTT 長変更時) — roots を共有 if (!cache.valid(N)) { cache.build(bp, bn, N, cached_roots); } const uint64_t* fwd_r[3], *inv_r[3]; uint64_t inv_n_vals[3]; for (int k = 0; k < 3; ++k) { fwd_r[k] = cached_roots[k].roots.data(); inv_r[k] = cached_roots[k].inv_roots.data(); inv_n_vals[k] = cached_roots[k].inv_n; } if (parallel) { auto f0 = sangi::threadPool().submit([&, N]{ forward_ntt_mont(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0]); pointwise_mul_mont(da[0], cache.data[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg); inverse_ntt_mont(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, inv_r[0], inv_n_vals[0]); }); auto f1 = sangi::threadPool().submit([&, N]{ forward_ntt_mont(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1]); pointwise_mul_mont(da[1], cache.data[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg); inverse_ntt_mont(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, inv_r[1], inv_n_vals[1]); }); forward_ntt_mont(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2]); pointwise_mul_mont(da[2], cache.data[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg); inverse_ntt_mont(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, inv_r[2], inv_n_vals[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { forward_ntt_mont(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); pointwise_mul_mont(da[k], cache.data[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg); inverse_ntt_mont(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, inv_r[k], inv_n_vals[k]); } } #endif // CRT 復元 crt_recompose(rp, rn, da[0], da[1], da[2], rn, ctx.crt); } // ================================================================ // NttInt: NTT ドメインでの整数表現 // ================================================================ // BS_NUM_PRIMES 素数 × N 係数の評価値を保持。 // BS merge ステップを NTT ドメインで実行し、 // 中間段階の NTT/INTT 変換を省略して高速化する。 // NTT ドメイン整数: BS_NUM_PRIMES 素数 × ntt_len の評価値 (Montgomery 形式) struct NttInt { std::vector storage_; uint64_t* data[BS_NUM_PRIMES] = {}; size_t ntt_len = 0; size_t orig_limbs = 0; // 元のリム数 (CRT 復元時の出力サイズ) NttInt() = default; void alloc(size_t N) { if (ntt_len == N && !storage_.empty()) return; ntt_len = N; storage_.resize(BS_NUM_PRIMES * N); for (int k = 0; k < BS_NUM_PRIMES; ++k) data[k] = storage_.data() + k * N; } bool empty() const { return ntt_len == 0; } }; // ================================================================ // NttInt の変換関数 (BS_NUM_PRIMES 素数) // ================================================================ // リム列 → NTT ドメイン (forward NTT × BS_NUM_PRIMES 素数, Montgomery 形式) inline void ntt_forward(NttInt& r, const uint64_t* src, size_t n, size_t ntt_len) { auto& ctx = getNttBsContext(); r.alloc(ntt_len); r.orig_limbs = n; // Montgomery 形式の根テーブル (thread_local キャッシュ) thread_local NttRootsMont roots[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { if (roots[k].N != ntt_len || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], ntt_len); } // 入力を mod p → Montgomery 変換 + ゼロパディング for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < n; ++i) r.data[k][i] = to_mont(src[i] % p, p, pin, r2); std::memset(r.data[k] + n, 0, (ntt_len - n) * sizeof(uint64_t)); } // 根テーブルのポインタを事前取得 (ワーカーから thread_local を参照しない) const uint64_t* fwd_r[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) fwd_r[k] = roots[k].roots.data(); // Forward NTT (BS_NUM_PRIMES 素数 — 並列実行) bool parallel = (ntt_len >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { std::future futures[BS_NUM_PRIMES - 1]; for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) { futures[k] = sangi::threadPool().submit([&, k, ntt_len]{ forward_ntt_mont(r.data[k], ntt_len, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); }); } forward_ntt_mont(r.data[BS_NUM_PRIMES - 1], ntt_len, ctx.primes[BS_NUM_PRIMES - 1].p, ctx.primes[BS_NUM_PRIMES - 1].p_inv_neg, fwd_r[BS_NUM_PRIMES - 1]); for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) futures[k].get(); } else { for (int k = 0; k < BS_NUM_PRIMES; ++k) forward_ntt_mont(r.data[k], ntt_len, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); } } // NTT ドメイン → リム列 (inverse NTT × BS_NUM_PRIMES 素数 + CRT) inline void ntt_inverse(uint64_t* dst, size_t dst_limbs, const NttInt& src) { auto& ctx = getNttBsContext(); size_t N = src.ntt_len; // 根テーブル thread_local NttRootsMont roots[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { if (roots[k].N != N || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], N); } // INTT 用の作業コピー (src を破壊しないため) std::vector work_buf(BS_NUM_PRIMES * N); uint64_t* work[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { work[k] = work_buf.data() + k * N; std::memcpy(work[k], src.data[k], N * sizeof(uint64_t)); } // 根テーブルのポインタを事前取得 const uint64_t* inv_r[BS_NUM_PRIMES]; uint64_t inv_n_vals[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { inv_r[k] = roots[k].inv_roots.data(); inv_n_vals[k] = roots[k].inv_n; } // Inverse NTT (BS_NUM_PRIMES 素数 — 並列実行) bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { std::future futures[BS_NUM_PRIMES - 1]; for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) { futures[k] = sangi::threadPool().submit([&, k, N]{ inverse_ntt_mont(work[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, inv_r[k], inv_n_vals[k]); }); } inverse_ntt_mont(work[BS_NUM_PRIMES - 1], N, ctx.primes[BS_NUM_PRIMES - 1].p, ctx.primes[BS_NUM_PRIMES - 1].p_inv_neg, inv_r[BS_NUM_PRIMES - 1], inv_n_vals[BS_NUM_PRIMES - 1]); for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) futures[k].get(); } else { for (int k = 0; k < BS_NUM_PRIMES; ++k) inverse_ntt_mont(work[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, inv_r[k], inv_n_vals[k]); } // CRT 復元 (5 素数) crt_recompose_n(dst, dst_limbs, work, N, ctx.crt); } // ================================================================ // 16-bit 分割版 NTT forward / inverse (保持、ただし現在未使用) // 5 素数化により 64-bit リムのまま深さ 2 まで安全。 // ================================================================ // 16-bit 分割版 forward NTT inline void ntt_forward_16(NttInt& r, const uint64_t* src, size_t n, size_t ntt_len) { auto& ctx = getNttBsContext(); r.alloc(ntt_len); r.orig_limbs = n; size_t n16 = n * 4; thread_local NttRootsMont roots[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { if (roots[k].N != ntt_len || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], ntt_len); } for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < n; ++i) { uint64_t w = src[i]; r.data[k][4*i ] = to_mont( w & 0xFFFF, p, pin, r2); r.data[k][4*i + 1] = to_mont((w >> 16) & 0xFFFF, p, pin, r2); r.data[k][4*i + 2] = to_mont((w >> 32) & 0xFFFF, p, pin, r2); r.data[k][4*i + 3] = to_mont((w >> 48) & 0xFFFF, p, pin, r2); } std::memset(r.data[k] + n16, 0, (ntt_len - n16) * sizeof(uint64_t)); } const uint64_t* fwd_r[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) fwd_r[k] = roots[k].roots.data(); bool parallel = (ntt_len >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { std::future futures[BS_NUM_PRIMES - 1]; for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) { futures[k] = sangi::threadPool().submit([&, k, ntt_len]{ forward_ntt_mont(r.data[k], ntt_len, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); }); } forward_ntt_mont(r.data[BS_NUM_PRIMES - 1], ntt_len, ctx.primes[BS_NUM_PRIMES - 1].p, ctx.primes[BS_NUM_PRIMES - 1].p_inv_neg, fwd_r[BS_NUM_PRIMES - 1]); for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) futures[k].get(); } else { for (int k = 0; k < BS_NUM_PRIMES; ++k) forward_ntt_mont(r.data[k], ntt_len, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); } } // 16-bit 分割版 inverse NTT inline void ntt_inverse_16(uint64_t* dst, size_t dst_limbs, const NttInt& src) { // 5 素数版: 通常の ntt_inverse を使用して CRT 復元後、 // base-2^16 パック処理が必要な場合のみこちらを使う。 // 現在は未使用 (64-bit リムのまま深さ 2 まで安全)。 ntt_inverse(dst, dst_limbs, src); } // ================================================================ // NttInt の pointwise 演算 (BS_NUM_PRIMES 素数) // ================================================================ // r = a ⊗ b (pointwise Montgomery 乗算) inline void ntt_pmul(NttInt& r, const NttInt& a, const NttInt& b) { auto& ctx = getNttBsContext(); size_t N = a.ntt_len; bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { std::future futures[BS_NUM_PRIMES - 1]; for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) { futures[k] = sangi::threadPool().submit([&, k]{ uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mont_mul(a.data[k][i], b.data[k][i], p, pin); }); } { constexpr int k = BS_NUM_PRIMES - 1; uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mont_mul(a.data[k][i], b.data[k][i], p, pin); } for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) futures[k].get(); } else { for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mont_mul(a.data[k][i], b.data[k][i], p, pin); } } } // r = a + b (pointwise mod add) inline void ntt_padd(NttInt& r, const NttInt& a, const NttInt& b) { auto& ctx = getNttBsContext(); size_t N = a.ntt_len; for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; for (size_t i = 0; i < N; ++i) r.data[k][i] = mod_add(a.data[k][i], b.data[k][i], p); } } // r = a - b (pointwise mod sub) inline void ntt_psub(NttInt& r, const NttInt& a, const NttInt& b) { auto& ctx = getNttBsContext(); size_t N = a.ntt_len; for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; for (size_t i = 0; i < N; ++i) r.data[k][i] = mod_sub(a.data[k][i], b.data[k][i], p); } } // r += a ⊗ b (pointwise multiply-accumulate) inline void ntt_pmac(NttInt& r, const NttInt& a, const NttInt& b) { auto& ctx = getNttBsContext(); size_t N = a.ntt_len; bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { std::future futures[BS_NUM_PRIMES - 1]; for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) { futures[k] = sangi::threadPool().submit([&, k]{ uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mod_add(r.data[k][i], mont_mul(a.data[k][i], b.data[k][i], p, pin), p); }); } { constexpr int k = BS_NUM_PRIMES - 1; uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mod_add(r.data[k][i], mont_mul(a.data[k][i], b.data[k][i], p, pin), p); } for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) futures[k].get(); } else { for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mod_add(r.data[k][i], mont_mul(a.data[k][i], b.data[k][i], p, pin), p); } } } // r の符号反転 (pointwise negate: p - r[i]) inline void ntt_pneg(NttInt& r) { auto& ctx = getNttBsContext(); size_t N = r.ntt_len; for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; for (size_t i = 0; i < N; ++i) r.data[k][i] = (r.data[k][i] == 0) ? 0 : (p - r.data[k][i]); } } // ================================================================ // NTT ドメイン BS merge 関数 // ================================================================ // Chudnovsky / atanh 用: // T = TL ⊗ QR + PL ⊗ TR // Q = QL ⊗ QR // P = PL ⊗ PR (need_P のとき) inline void bs_merge2_ntt(NttInt& P, NttInt& Q, NttInt& T, NttInt& PL, NttInt& QL, NttInt& TL, NttInt& PR, NttInt& QR, NttInt& TR, bool need_P, size_t ntt_len) { // T = TL ⊗ QR T.alloc(ntt_len); ntt_pmul(T, TL, QR); // T += PL ⊗ TR ntt_pmac(T, PL, TR); // Q = QL ⊗ QR Q.alloc(ntt_len); ntt_pmul(Q, QL, QR); // P = PL ⊗ PR if (need_P) { P.alloc(ntt_len); ntt_pmul(P, PL, PR); } } // factorial 用: // T = TL ⊗ QR + TR (P=1 なので PL*TR = TR) // Q = QL ⊗ QR inline void fac_merge2_ntt(NttInt& Q, NttInt& T, NttInt& QL, NttInt& TL, NttInt& QR, NttInt& TR, size_t ntt_len) { // T = TL ⊗ QR + TR T.alloc(ntt_len); ntt_pmul(T, TL, QR); ntt_padd(T, T, TR); // Q = QL ⊗ QR Q.alloc(ntt_len); ntt_pmul(Q, QL, QR); } // Euler-Mascheroni 用: // P = PL ⊗ PR, Q = QL ⊗ QR, D = DL ⊗ DR // B = BL ⊗ DR + DL ⊗ BR // T = TL ⊗ QR + PL ⊗ TR // S = SL ⊗ (QR ⊗ DR) + PL ⊗ (SR ⊗ DL + BL ⊗ (TR ⊗ DR)) inline void euler_merge2_ntt(NttInt& P, NttInt& Q, NttInt& D, NttInt& B, NttInt& T, NttInt& S, NttInt& PL, NttInt& QL, NttInt& DL, NttInt& BL, NttInt& TL, NttInt& SL, NttInt& PR, NttInt& QR, NttInt& DR, NttInt& BR, NttInt& TR, NttInt& SR, size_t ntt_len) { P.alloc(ntt_len); ntt_pmul(P, PL, PR); Q.alloc(ntt_len); ntt_pmul(Q, QL, QR); D.alloc(ntt_len); ntt_pmul(D, DL, DR); // B = BL ⊗ DR + DL ⊗ BR B.alloc(ntt_len); ntt_pmul(B, BL, DR); ntt_pmac(B, DL, BR); // T = TL ⊗ QR + PL ⊗ TR T.alloc(ntt_len); ntt_pmul(T, TL, QR); ntt_pmac(T, PL, TR); // S = SL ⊗ QR_DR + PL ⊗ (SR ⊗ DL + BL ⊗ TR_DR) // QR_DR = QR ⊗ DR (一時変数) NttInt QR_DR; QR_DR.alloc(ntt_len); ntt_pmul(QR_DR, QR, DR); // TR_DR = TR ⊗ DR (一時変数) NttInt TR_DR; TR_DR.alloc(ntt_len); ntt_pmul(TR_DR, TR, DR); // inner = SR ⊗ DL + BL ⊗ TR_DR NttInt inner; inner.alloc(ntt_len); ntt_pmul(inner, SR, DL); ntt_pmac(inner, BL, TR_DR); // S = SL ⊗ QR_DR + PL ⊗ inner S.alloc(ntt_len); ntt_pmul(S, SL, QR_DR); ntt_pmac(S, PL, inner); } } // namespace prime_ntt } // namespace sangi