// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // PrimeNtt.hpp // 素数モジュラー NTT + CRT による高速乗算 // 3 つの 64bit NTT-friendly 素数で NTT を実行し、 // 中国剰余定理で結果を復元する。 // Schönhage-Strassen (mod B^F+1) の代替で、 // pointwise が O(1) (64bit mod 乗算) になる。 #pragma once #ifndef CALX_FORCEINLINE #ifdef _MSC_VER #define CALX_FORCEINLINE __forceinline #else #define CALX_FORCEINLINE __attribute__((always_inline)) inline #endif #endif #include #include #include #include #include #include #ifdef _MSC_VER #include #endif #include namespace calx { namespace prime_ntt { // キャッシュブロッキングの閾値 (L1 32KB に収まるブロックサイズ) // 512 要素 = 4KB data + ~4KB roots = 8KB, L1 に十分な余裕 static constexpr size_t NTT_CACHE_BLOCK = 512; // ================================================================ // NTT-friendly 素数の定義 // ================================================================ // 条件: p = k × 2^s + 1 (s が大きいほど長い NTT が可能) // p < 2^63 (符号なし加算で 2p < 2^64) // FLINT/GMP で実績のある素数を採用 // p1 = 2^62 - 2^16 + 1 = 4611686018427387905 // = 1 × 2^62 + ... (素数性は別途検証) // → FLINT の n_mulmod_preinv 系で使われる素数を参考 // 実績ある NTT-friendly 素数 (FLINT/NTL 等で使用): // p = k * 2^s + 1 形式で s >= 47 // 検証済み NTT-friendly 素数 (p = k * 2^47 + 1) // NTT 最大長: 2^47 (= 140兆点、実用上無制限) // p1*p2*p3 ≈ 2^157 → n < 2^29 (≈500M limbs ≈ 10G 桁) まで CRT 正確復元 constexpr uint64_t P1 = 0x000D'8000'0000'0001ULL; // 27 * 2^47 + 1, g=5 constexpr uint64_t P2 = 0x0011'8000'0000'0001ULL; // 35 * 2^47 + 1, g=3 constexpr uint64_t P3 = 0x0020'8000'0000'0001ULL; // 65 * 2^47 + 1, g=3 constexpr uint64_t P4 = 0x0032'8000'0000'0001ULL; // 101 * 2^47 + 1, g=3 constexpr uint64_t P5 = 0x0037'8000'0000'0001ULL; // 111 * 2^47 + 1, g=11 constexpr uint64_t G1 = 5; // P1 の原始根 constexpr uint64_t G2 = 3; // P2 の原始根 constexpr uint64_t G3 = 3; // P3 の原始根 constexpr uint64_t G4 = 3; // P4 の原始根 constexpr uint64_t G5 = 11; // P5 の原始根 constexpr int NTT_MAX_S = 47; // 最大 NTT 長 = 2^47 // NTT ドメイン BS 用の素数数 // 64 素数 (M ≈ 2^3582) + 16-bit ワード (B=2^16) → NTT ドメインで深さ 6 まで安全 // 深さ d の係数上界: N^(2^d-1) · B^(2^d) < M // d=6, N=2^22, B=2^16: 22·63 + 16·64 = 2410 < 3582 ✓ // d=7: 22·127 + 16·128 = 4842 > 3582 ✗ constexpr int BS_NUM_PRIMES = 64; // BS 用素数テーブル: p = k * 2^47 + 1 形式、原始根付き struct BsPrimeEntry { uint64_t p; int g; }; static constexpr BsPrimeEntry BS_PRIME_TABLE[BS_NUM_PRIMES] = { { 0x000D800000000001ULL, 5 }, // k=27 { 0x000F000000000001ULL, 19 }, // k=30 { 0x0011800000000001ULL, 3 }, // k=35 { 0x001C000000000001ULL, 6 }, // k=56 { 0x0020800000000001ULL, 3 }, // k=65 { 0x002E000000000001ULL, 3 }, // k=92 { 0x0032800000000001ULL, 3 }, // k=101 { 0x0037800000000001ULL, 11 }, // k=111 { 0x0039000000000001ULL, 5 }, // k=114 { 0x003D000000000001ULL, 3 }, // k=122 { 0x0046000000000001ULL, 3 }, // k=140 { 0x0051000000000001ULL, 5 }, // k=162 { 0x006C000000000001ULL, 11 }, // k=216 { 0x0070000000000001ULL, 3 }, // k=224 { 0x0088000000000001ULL, 3 }, // k=272 { 0x00AC800000000001ULL, 13 }, // k=345 { 0x00B1000000000001ULL, 5 }, // k=354 { 0x00B3800000000001ULL, 3 }, // k=359 { 0x00BB000000000001ULL, 3 }, // k=374 { 0x00D2000000000001ULL, 17 }, // k=420 { 0x00D7800000000001ULL, 3 }, // k=431 { 0x00DC800000000001ULL, 13 }, // k=441 { 0x00EB800000000001ULL, 17 }, // k=471 { 0x00FD000000000001ULL, 3 }, // k=506 { 0x0104800000000001ULL, 3 }, // k=521 { 0x0107800000000001ULL, 3 }, // k=527 { 0x0114000000000001ULL, 5 }, // k=552 { 0x0149800000000001ULL, 3 }, // k=659 { 0x016F000000000001ULL, 3 }, // k=734 { 0x0170800000000001ULL, 3 }, // k=737 { 0x0176800000000001ULL, 3 }, // k=749 { 0x0179800000000001ULL, 6 }, // k=755 { 0x0190000000000001ULL, 3 }, // k=800 { 0x0194800000000001ULL, 3 }, // k=809 { 0x0197800000000001ULL, 3 }, // k=815 { 0x01A2800000000001ULL, 7 }, // k=837 { 0x01A7000000000001ULL, 26 }, // k=846 { 0x01AC800000000001ULL, 3 }, // k=857 { 0x01B5800000000001ULL, 15 }, // k=875 { 0x01C5000000000001ULL, 14 }, // k=906 { 0x01E6000000000001ULL, 5 }, // k=972 { 0x01ED800000000001ULL, 5 }, // k=987 { 0x01EE800000000001ULL, 3 }, // k=989 { 0x01F7800000000001ULL, 3 }, // k=1007 { 0x020C800000000001ULL, 3 }, // k=1049 { 0x021E800000000001ULL, 3 }, // k=1085 { 0x022B000000000001ULL, 14 }, // k=1110 { 0x0235000000000001ULL, 3 }, // k=1130 { 0x0247000000000001ULL, 3 }, // k=1166 { 0x024D800000000001ULL, 5 }, // k=1179 { 0x0258000000000001ULL, 11 }, // k=1200 { 0x025D800000000001ULL, 3 }, // k=1211 { 0x0266800000000001ULL, 3 }, // k=1229 { 0x026D000000000001ULL, 5 }, // k=1242 { 0x027B800000000001ULL, 3 }, // k=1271 { 0x0280000000000001ULL, 6 }, // k=1280 { 0x0289800000000001ULL, 5 }, // k=1299 { 0x0293800000000001ULL, 3 }, // k=1319 { 0x0297000000000001ULL, 7 }, // k=1326 { 0x02A5800000000001ULL, 3 }, // k=1355 { 0x02A9000000000001ULL, 5 }, // k=1362 { 0x02BF000000000001ULL, 3 }, // k=1406 { 0x02C6800000000001ULL, 3 }, // k=1421 { 0x02CD000000000001ULL, 5 }, // k=1434 }; // ================================================================ // mod p 算術プリミティブ // ================================================================ // (a + b) mod p — a, b < p を前提 inline uint64_t mod_add(uint64_t a, uint64_t b, uint64_t p) { uint64_t sum = a + b; // sum >= p (オーバーフロー含む) なら p を引く // a, b < p < 2^63 なので sum < 2^64 (オーバーフローなし) return (sum >= p) ? (sum - p) : sum; } // (a - b) mod p — a, b < p を前提 inline uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t p) { // a < b なら a - b + p (アンダーフロー補正) return (a >= b) ? (a - b) : (a - b + p); } // (a * b) mod p — a, b < p を前提 // 128bit 積 → Barrett reduction inline uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t p) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi; uint64_t lo = _umul128(a, b, &hi); uint64_t rem; _udiv128(hi, lo, p, &rem); return rem; #elif defined(__GNUC__) || defined(__clang__) __uint128_t prod = static_cast<__uint128_t>(a) * b; return static_cast(prod % p); #else // フォールバック: 遅い __uint128_t prod = static_cast<__uint128_t>(a) * b; return static_cast(prod % p); #endif } // base^exp mod p (高速冪剰余) inline uint64_t mod_pow(uint64_t base, uint64_t exp, uint64_t p) { uint64_t result = 1; base %= p; while (exp > 0) { if (exp & 1) result = mod_mul(result, base, p); base = mod_mul(base, base, p); exp >>= 1; } return result; } // a^(-1) mod p (Fermat の小定理: a^(p-2) mod p) inline uint64_t mod_inv(uint64_t a, uint64_t p) { return mod_pow(a, p - 2, p); } // ================================================================ // 素数性検証と原始根探索 (初期化時に使用) // ================================================================ // Miller-Rabin 素数判定 (確定的、64bit) inline bool is_prime_64(uint64_t n) { if (n < 2) return false; if (n < 4) return true; if (n % 2 == 0) return false; // n-1 = d * 2^r uint64_t d = n - 1; int r = 0; while ((d & 1) == 0) { d >>= 1; r++; } // 64bit で確定的な witness セット const uint64_t witnesses[] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}; for (uint64_t a : witnesses) { if (a >= n) continue; uint64_t x = mod_pow(a, d, n); if (x == 1 || x == n - 1) continue; bool composite = true; for (int i = 0; i < r - 1; i++) { x = mod_mul(x, x, n); if (x == n - 1) { composite = false; break; } } if (composite) return false; } return true; } // k * 2^s + 1 形式の素数を探索 // min_s: 最小の s (NTT 長 2^s を保証) // start_k: 探索開始の k (奇数) inline uint64_t find_ntt_prime(int min_s, uint64_t start_k = 1) { for (uint64_t k = start_k; k < (1ULL << (63 - min_s)); k += 2) { uint64_t p = k * (1ULL << min_s) + 1; if (p >= (1ULL << 63)) break; // 2p < 2^64 を保証 if (is_prime_64(p)) return p; } return 0; // 見つからなかった } // 原始根を探索 (p-1 の素因数分解が必要) // p-1 = 2^s * k の場合、p-1 の素因数は 2 と k の因数 inline uint64_t find_primitive_root(uint64_t p) { // p-1 の素因数を列挙 uint64_t pm1 = p - 1; std::vector factors; // 2 は必ず因数 factors.push_back(2); // 奇数部分 k の因数分解 uint64_t k = pm1; while ((k & 1) == 0) k >>= 1; uint64_t temp = k; for (uint64_t f = 3; f * f <= temp; f += 2) { if (temp % f == 0) { factors.push_back(f); while (temp % f == 0) temp /= f; } } if (temp > 1) factors.push_back(temp); // g^((p-1)/q) != 1 (mod p) for all prime factors q of p-1 for (uint64_t g = 2; g < p; ++g) { bool is_root = true; for (uint64_t q : factors) { if (mod_pow(g, pm1 / q, p) == 1) { is_root = false; break; } } if (is_root) return g; } return 0; } // ================================================================ // Montgomery 乗算 (NTT 高速化用) // ================================================================ // R = 2^64, p' = -p^(-1) mod R // REDC(T) = (T + (T_lo * p') * p) >> 64 // 2 × MULX + ADD で完了 (_udiv128 の 35-45 cycles → ~8 cycles) // p^(-1) mod 2^64 を Newton 法で計算 (p は奇数) inline uint64_t compute_p_inv_neg(uint64_t p) { // p * x ≡ 1 (mod 2^64) を Newton 法で求め、符号反転 uint64_t x = 1; for (int i = 0; i < 6; i++) // 6 回で 64bit 収束 x *= 2 - p * x; return ~x + 1; // -p^(-1) mod 2^64 } // Montgomery 乗算: REDC(a * b) = a*b*R^(-1) mod p // a, b は Montgomery 形式 (a*R mod p, b*R mod p) でも通常値でも可 inline uint64_t mont_mul(uint64_t a, uint64_t b, uint64_t p, uint64_t p_inv_neg) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t T_hi; uint64_t T_lo = _umul128(a, b, &T_hi); // m = T_lo * (-p^(-1)) mod 2^64 uint64_t m = T_lo * p_inv_neg; // m * p uint64_t mp_hi; uint64_t mp_lo = _umul128(m, p, &mp_hi); // T + m*p の上位 64bit (下位は構造的に 0) uint64_t carry = (T_lo + mp_lo < T_lo) ? 1 : 0; uint64_t t = T_hi + mp_hi + carry; return (t >= p) ? (t - p) : t; #elif defined(__GNUC__) || defined(__clang__) unsigned __int128 T = (unsigned __int128)a * b; uint64_t T_lo = (uint64_t)T; uint64_t m = T_lo * p_inv_neg; unsigned __int128 mp = (unsigned __int128)m * p; unsigned __int128 sum = T + mp; uint64_t t = (uint64_t)(sum >> 64); return (t >= p) ? (t - p) : t; #endif } // Montgomery 形式への変換: to_mont(a) = a * R mod p = REDC(a * R^2) inline uint64_t to_mont(uint64_t a, uint64_t p, uint64_t p_inv_neg, uint64_t r2_mod_p) { return mont_mul(a, r2_mod_p, p, p_inv_neg); } // Montgomery 形式からの逆変換: from_mont(a_mont) = REDC(a_mont * 1) inline uint64_t from_mont(uint64_t a_mont, uint64_t p, uint64_t p_inv_neg) { return mont_mul(a_mont, 1, p, p_inv_neg); } // ================================================================ // NTT 素数パラメータ構造体 // ================================================================ struct NttPrime { uint64_t p; // 素数 uint64_t g; // 原始根 int s; // p-1 = k * 2^s + ... の s (最大 NTT 長 = 2^s) uint64_t inv_2; // 2^(-1) mod p // Montgomery 定数 uint64_t p_inv_neg; // -p^(-1) mod 2^64 uint64_t r2_mod_p; // R^2 mod p (R = 2^64) // N 点 NTT の原始 N 乗根: omega = g^((p-1)/N) mod p uint64_t nth_root(size_t N) const { return mod_pow(g, (p - 1) / N, p); } // N^(-1) mod p uint64_t inv_n(size_t N) const { return mod_inv(N, p); } }; // ================================================================ // 根テーブル (twiddle factors) のキャッシュ // ================================================================ struct NttRoots { std::vector roots; // forward: omega^i std::vector inv_roots; // inverse: omega^(-i) uint64_t inv_n; // N^(-1) mod p size_t N; // NTT 長 uint64_t p; // 素数 void build(const NttPrime& prime, size_t ntt_len) { N = ntt_len; p = prime.p; inv_n = prime.inv_n(N); uint64_t omega = prime.nth_root(N); uint64_t omega_inv = mod_inv(omega, p); roots.resize(N / 2); inv_roots.resize(N / 2); roots[0] = 1; inv_roots[0] = 1; for (size_t i = 1; i < N / 2; ++i) { roots[i] = mod_mul(roots[i - 1], omega, p); inv_roots[i] = mod_mul(inv_roots[i - 1], omega_inv, p); } } }; // ================================================================ // NTT forward (DIF: Decimation In Frequency) // ================================================================ inline void forward_ntt(uint64_t* data, size_t N, uint64_t p, const uint64_t* roots) { for (size_t len = N; len >= 2; len >>= 1) { size_t half = len >> 1; size_t step = N / len; // root index step for (size_t start = 0; start < N; start += len) { for (size_t j = 0; j < half; ++j) { uint64_t u = data[start + j]; uint64_t v = data[start + j + half]; // DIF butterfly: (u, v) → (u+v, (u-v)*ω) data[start + j] = mod_add(u, v, p); uint64_t diff = mod_sub(u, v, p); data[start + j + half] = mod_mul(diff, roots[j * step], p); } } } } // ================================================================ // NTT inverse (DIT: Decimation In Time) // ================================================================ inline void inverse_ntt(uint64_t* data, size_t N, uint64_t p, const uint64_t* inv_roots, uint64_t inv_n_val) { for (size_t len = 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t step = N / len; for (size_t start = 0; start < N; start += len) { for (size_t j = 0; j < half; ++j) { uint64_t u = data[start + j]; uint64_t v = mod_mul(data[start + j + half], inv_roots[j * step], p); // DIT butterfly: (u, v·ω⁻¹) → (u+v, u-v) data[start + j] = mod_add(u, v, p); data[start + j + half] = mod_sub(u, v, p); } } } // N^(-1) で正規化 for (size_t i = 0; i < N; ++i) { data[i] = mod_mul(data[i], inv_n_val, p); } } // ================================================================ // Pointwise 演算 // ================================================================ inline void pointwise_mul(uint64_t* a, const uint64_t* b, size_t N, uint64_t p) { for (size_t i = 0; i < N; ++i) { a[i] = mod_mul(a[i], b[i], p); } } inline void pointwise_sqr(uint64_t* a, size_t N, uint64_t p) { for (size_t i = 0; i < N; ++i) { a[i] = mod_mul(a[i], a[i], p); } } // ================================================================ // Montgomery 形式の根テーブルと NTT // ================================================================ struct NttRootsMont { std::vector roots; // forward: omega^i (Montgomery 形式) std::vector inv_roots; // inverse: omega^(-i) (Montgomery 形式) uint64_t inv_n; // N^(-1) mod p (通常形式 — from_mont に兼用) size_t N; uint64_t p; uint64_t p_inv_neg; void build(const NttPrime& prime, size_t ntt_len) { N = ntt_len; p = prime.p; p_inv_neg = prime.p_inv_neg; inv_n = prime.inv_n(N); // 通常形式 (逆変換の最終スケーリング用) uint64_t omega = prime.nth_root(N); uint64_t omega_inv = mod_inv(omega, p); // Montgomery 形式に変換 uint64_t omega_mont = to_mont(omega, p, p_inv_neg, prime.r2_mod_p); uint64_t omega_inv_mont = to_mont(omega_inv, p, p_inv_neg, prime.r2_mod_p); roots.resize(N / 2); inv_roots.resize(N / 2); // roots[0] = 1 の Montgomery 形式 = R mod p uint64_t one_mont = to_mont(1, p, p_inv_neg, prime.r2_mod_p); roots[0] = one_mont; inv_roots[0] = one_mont; for (size_t i = 1; i < N / 2; ++i) { roots[i] = mont_mul(roots[i - 1], omega_mont, p, p_inv_neg); inv_roots[i] = mont_mul(inv_roots[i - 1], omega_inv_mont, p, p_inv_neg); } } }; // Montgomery 形式 NTT forward (DIF) // data は Montgomery 形式、roots も Montgomery 形式 inline void forward_ntt_mont(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* roots) { for (size_t len = N; len >= 2; len >>= 1) { size_t half = len >> 1; size_t step = N / len; for (size_t start = 0; start < N; start += len) { for (size_t j = 0; j < half; ++j) { uint64_t u = data[start + j]; uint64_t v = data[start + j + half]; data[start + j] = mod_add(u, v, p); uint64_t diff = mod_sub(u, v, p); data[start + j + half] = mont_mul(diff, roots[j * step], p, p_inv_neg); } } } } // Montgomery 形式 NTT inverse (DIT) // 最後に from_mont + inv_n スケーリングを同時に行う // inv_n_normal: N^(-1) mod p (通常形式) inline void inverse_ntt_mont(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* inv_roots, uint64_t inv_n_normal) { for (size_t len = 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t step = N / len; for (size_t start = 0; start < N; start += len) { for (size_t j = 0; j < half; ++j) { uint64_t u = data[start + j]; uint64_t v = mont_mul(data[start + j + half], inv_roots[j * step], p, p_inv_neg); data[start + j] = mod_add(u, v, p); data[start + j + half] = mod_sub(u, v, p); } } } // N^(-1) スケーリング + Montgomery → 通常形式の逆変換を同時実行 // mont_mul(data[i]_mont, inv_n_normal) = data_i * inv_n mod p for (size_t i = 0; i < N; ++i) { data[i] = mont_mul(data[i], inv_n_normal, p, p_inv_neg); } } // Montgomery 形式 pointwise 乗算 inline void pointwise_mul_mont(uint64_t* a, const uint64_t* b, size_t N, uint64_t p, uint64_t p_inv_neg) { for (size_t i = 0; i < N; ++i) { a[i] = mont_mul(a[i], b[i], p, p_inv_neg); } } inline void pointwise_sqr_mont(uint64_t* a, size_t N, uint64_t p, uint64_t p_inv_neg) { for (size_t i = 0; i < N; ++i) { a[i] = mont_mul(a[i], a[i], p, p_inv_neg); } } // YC-2: pointwise multiply-accumulate: r[i] += a[i] * b[i] mod p (スカラー) inline void pointwise_mac_mont(uint64_t* r, const uint64_t* a, const uint64_t* b, size_t N, uint64_t p, uint64_t p_inv_neg) { for (size_t i = 0; i < N; ++i) { r[i] = mod_add(r[i], mont_mul(a[i], b[i], p, p_inv_neg), p); } } // ================================================================ // AVX2 高速 NTT (Montgomery 形式) // ================================================================ // _mm256_mul_epu32 による 4-wide Montgomery 乗算で NTT バタフライを高速化。 // データは Montgomery 形式のまま。レイヤーごとの根テーブルにより // ストライドアクセスを排除。 #if defined(__AVX2__) } } // namespace prime_ntt, calx (immintrin.h はグローバルスコープで) #include namespace calx { namespace prime_ntt { // --- AVX2 64-bit 乗算ヘルパー --- // 4-lane 64×64→128 bit 乗算 (lo と hi を同時取得) // 4 個の _mm256_mul_epu32 で 32-bit 部分積を計算し合成。 // 入力は任意の uint64_t (2^64 未満) で正しい結果を返す。 inline void avx2_mul_full_epu64(__m256i a, __m256i b, __m256i& lo, __m256i& hi) { const __m256i mask32 = _mm256_set1_epi64x(0xFFFFFFFF); __m256i a_hi = _mm256_srli_epi64(a, 32); __m256i b_hi = _mm256_srli_epi64(b, 32); __m256i ll = _mm256_mul_epu32(a, b); // a_lo * b_lo __m256i lh = _mm256_mul_epu32(a, b_hi); // a_lo * b_hi __m256i hl = _mm256_mul_epu32(a_hi, b); // a_hi * b_lo __m256i hh = _mm256_mul_epu32(a_hi, b_hi); // a_hi * b_hi // lo = ll + (lh + hl) << 32 [mod 2^64] __m256i cross = _mm256_add_epi64(lh, hl); lo = _mm256_add_epi64(ll, _mm256_slli_epi64(cross, 32)); // hi: 桁上がりを正確に伝搬 // mid = lh + (ll >> 32), carry の伝搬なし (lh < 2^64, ll>>32 < 2^32) __m256i ll_hi = _mm256_srli_epi64(ll, 32); __m256i mid = _mm256_add_epi64(lh, ll_hi); __m256i mid_lo = _mm256_and_si256(mid, mask32); __m256i mid_hi = _mm256_srli_epi64(mid, 32); // mid_lo + hl: mid_lo < 2^32, hl < 2^64 → sum < 2^64 + 2^32 < 2^64 (∵ hl ≤ (2^32-1)^2 = 2^64-2^33+1, mid_lo ≤ 2^32-1) __m256i t = _mm256_add_epi64(mid_lo, hl); __m256i t_hi = _mm256_srli_epi64(t, 32); hi = _mm256_add_epi64(hh, _mm256_add_epi64(mid_hi, t_hi)); } // low 64 bits of a*b (3 mul_epu32) inline __m256i avx2_mullo_epu64(__m256i a, __m256i b) { __m256i a_hi = _mm256_srli_epi64(a, 32); __m256i b_hi = _mm256_srli_epi64(b, 32); __m256i ll = _mm256_mul_epu32(a, b); __m256i lh = _mm256_mul_epu32(a, b_hi); __m256i hl = _mm256_mul_epu32(a_hi, b); __m256i cross = _mm256_add_epi64(lh, hl); return _mm256_add_epi64(ll, _mm256_slli_epi64(cross, 32)); } // high 64 bits of a*b (4 mul_epu32) inline __m256i avx2_mulhi_epu64(__m256i a, __m256i b) { const __m256i mask32 = _mm256_set1_epi64x(0xFFFFFFFF); __m256i a_hi = _mm256_srli_epi64(a, 32); __m256i b_hi = _mm256_srli_epi64(b, 32); __m256i ll = _mm256_mul_epu32(a, b); __m256i lh = _mm256_mul_epu32(a, b_hi); __m256i hl = _mm256_mul_epu32(a_hi, b); __m256i hh = _mm256_mul_epu32(a_hi, b_hi); __m256i ll_hi = _mm256_srli_epi64(ll, 32); __m256i mid = _mm256_add_epi64(lh, ll_hi); __m256i mid_lo = _mm256_and_si256(mid, mask32); __m256i mid_hi = _mm256_srli_epi64(mid, 32); __m256i t = _mm256_add_epi64(mid_lo, hl); __m256i t_hi = _mm256_srli_epi64(t, 32); return _mm256_add_epi64(hh, _mm256_add_epi64(mid_hi, t_hi)); } // 4-wide Montgomery 乗算: REDC(a * b) mod p // 入力: a, b は Montgomery 形式 (< p) // p_vec = broadcast(p), pin_vec = broadcast(-p^{-1} mod 2^64) // 利用する性質: T_lo + m*p_lo ≡ 0 (mod 2^64) → carry = (T_lo != 0) inline __m256i avx2_mont_mul(__m256i a, __m256i b, __m256i p_vec, __m256i pin_vec) { // Step 1: T = a * b (128-bit), lo と hi を同時取得 __m256i T_lo, T_hi; avx2_mul_full_epu64(a, b, T_lo, T_hi); // 4 VPMULUDQ // Step 2: m = T_lo * (-p^{-1}) mod 2^64 __m256i m = avx2_mullo_epu64(T_lo, pin_vec); // 3 VPMULUDQ // Step 3: mp_hi = mulhi(m, p) __m256i mp_hi = avx2_mulhi_epu64(m, p_vec); // 4 VPMULUDQ // Step 4: carry = (T_lo != 0) ? 1 : 0 __m256i eq_zero = _mm256_cmpeq_epi64(T_lo, _mm256_setzero_si256()); __m256i carry = _mm256_andnot_si256(eq_zero, _mm256_set1_epi64x(1)); // Step 5: t = T_hi + mp_hi + carry __m256i t = _mm256_add_epi64(_mm256_add_epi64(T_hi, mp_hi), carry); // Step 6: if t >= p: t -= p (符号ビットで判定) __m256i t_sub_p = _mm256_sub_epi64(t, p_vec); // t < p → t_sub_p は巨大 (bit 63 = 1, 符号付きで負) __m256i neg_mask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), t_sub_p); return _mm256_blendv_epi8(t_sub_p, t, neg_mask); } // 合計: 11 VPMULUDQ + ~10 add/cmp/blend per 4 elements // --- レイヤーごとの根テーブル --- // 各 NTT レイヤーの根を連続配置し、AVX2 のシーケンシャルロードに対応。 struct NttRootsLayered { std::vector fwd; // forward roots (層ごと連続) std::vector inv; // inverse roots (層ごと連続) std::vector fwd_offset; // fwd 内の各層オフセット std::vector inv_offset; // inv 内の各層オフセット uint64_t inv_n; // N^{-1} mod p (通常形式) size_t N = 0; uint64_t p = 0; uint64_t p_inv_neg = 0; void build(const NttPrime& prime, size_t ntt_len) { N = ntt_len; p = prime.p; p_inv_neg = prime.p_inv_neg; inv_n = prime.inv_n(N); uint64_t omega = prime.nth_root(N); uint64_t omega_inv = mod_inv(omega, p); // Montgomery 形式に変換 uint64_t omega_mont = to_mont(omega, p, p_inv_neg, prime.r2_mod_p); uint64_t omega_inv_mont = to_mont(omega_inv, p, p_inv_neg, prime.r2_mod_p); // 通常の根テーブルを構築 std::vector all_roots(N / 2), all_inv(N / 2); uint64_t one_mont = to_mont(1, p, p_inv_neg, prime.r2_mod_p); all_roots[0] = one_mont; all_inv[0] = one_mont; for (size_t i = 1; i < N / 2; ++i) { all_roots[i] = mont_mul(all_roots[i - 1], omega_mont, p, p_inv_neg); all_inv[i] = mont_mul(all_inv[i - 1], omega_inv_mont, p, p_inv_neg); } // レイヤーごとに連続配置 // Forward DIF: layer l → len = N >> l, half = len/2, step = 1 << l // roots needed: all_roots[j * step] for j = 0..half-1 int num_layers = 0; for (size_t l = N; l >= 2; l >>= 1) ++num_layers; fwd.resize(N - 1); // 各層の合計: N/2 + N/4 + ... + 1 = N - 1 inv.resize(N - 1); fwd_offset.resize(num_layers); inv_offset.resize(num_layers); size_t off = 0; int layer = 0; // Forward (DIF): len = N, N/2, ..., 2 for (size_t len = N; len >= 2; len >>= 1) { size_t half = len >> 1; size_t step = N / len; fwd_offset[layer] = off; for (size_t j = 0; j < half; ++j) fwd[off + j] = all_roots[j * step]; off += half; ++layer; } off = 0; layer = 0; // Inverse (DIT): len = 2, 4, ..., N for (size_t len = 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t step = N / len; inv_offset[layer] = off; for (size_t j = 0; j < half; ++j) inv[off + j] = all_inv[j * step]; off += half; ++layer; } } }; // --- AVX2 DIF バタフライ (1 グループ分) --- CALX_FORCEINLINE void dif_butterfly_avx2(uint64_t* data, size_t start, size_t half, const uint64_t* roots, __m256i p_vec, __m256i pin_vec, uint64_t p, uint64_t p_inv_neg) { uint64_t* d0 = data + start; uint64_t* d1 = data + start + half; size_t j = 0; for (; j + 4 <= half; j += 4) { __m256i u = _mm256_loadu_si256((__m256i*)(d0 + j)); __m256i v = _mm256_loadu_si256((__m256i*)(d1 + j)); __m256i w = _mm256_loadu_si256((__m256i*)(roots + j)); // DIF: (u, v) → (u+v, (u-v)*ω) __m256i sum = _mm256_add_epi64(u, v); __m256i sum_sub = _mm256_sub_epi64(sum, p_vec); __m256i sm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), sum_sub); sum = _mm256_blendv_epi8(sum_sub, sum, sm); __m256i diff = _mm256_sub_epi64(u, v); __m256i diff_add = _mm256_add_epi64(diff, p_vec); __m256i dm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); diff = _mm256_blendv_epi8(diff, diff_add, dm); __m256i tw = avx2_mont_mul(diff, w, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(d0 + j), sum); _mm256_storeu_si256((__m256i*)(d1 + j), tw); } for (; j < half; ++j) { uint64_t u = d0[j]; uint64_t v = d1[j]; d0[j] = mod_add(u, v, p); uint64_t diff = mod_sub(u, v, p); d1[j] = mont_mul(diff, roots[j], p, p_inv_neg); } } // --- AVX2 NTT forward (DIF, Montgomery, cache-blocked) --- inline void forward_ntt_mont_avx2(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* layer_roots, const size_t* layer_offsets) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t block = (N <= NTT_CACHE_BLOCK) ? N : NTT_CACHE_BLOCK; // Phase 1: Top layers (len > block) — 標準 layer-by-layer int layer = 0; for (size_t len = N; len > block; len >>= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[layer]; for (size_t start = 0; start < N; start += len) dif_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++layer; } // Phase 2: Bottom layers (len <= block) — ブロック単位で L1 内処理 int bottom_start = layer; for (size_t blk = 0; blk < N; blk += block) { int cur_layer = bottom_start; for (size_t len = block; len >= 2; len >>= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[cur_layer]; for (size_t start = blk; start < blk + block; start += len) dif_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++cur_layer; } } } // --- AVX2 DIT バタフライ (1 グループ分) --- CALX_FORCEINLINE void dit_butterfly_avx2(uint64_t* data, size_t start, size_t half, const uint64_t* roots, __m256i p_vec, __m256i pin_vec, uint64_t p, uint64_t p_inv_neg) { uint64_t* d0 = data + start; uint64_t* d1 = data + start + half; size_t j = 0; for (; j + 4 <= half; j += 4) { __m256i u = _mm256_loadu_si256((__m256i*)(d0 + j)); __m256i v_raw = _mm256_loadu_si256((__m256i*)(d1 + j)); __m256i w = _mm256_loadu_si256((__m256i*)(roots + j)); // DIT: v = mont_mul(v_raw, ω⁻¹), then (u+v, u-v) __m256i v = avx2_mont_mul(v_raw, w, p_vec, pin_vec); __m256i sum = _mm256_add_epi64(u, v); __m256i sum_sub = _mm256_sub_epi64(sum, p_vec); __m256i sm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), sum_sub); sum = _mm256_blendv_epi8(sum_sub, sum, sm); __m256i diff = _mm256_sub_epi64(u, v); __m256i diff_add = _mm256_add_epi64(diff, p_vec); __m256i dm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); diff = _mm256_blendv_epi8(diff, diff_add, dm); _mm256_storeu_si256((__m256i*)(d0 + j), sum); _mm256_storeu_si256((__m256i*)(d1 + j), diff); } for (; j < half; ++j) { uint64_t u = d0[j]; uint64_t v = mont_mul(d1[j], roots[j], p, p_inv_neg); d0[j] = mod_add(u, v, p); d1[j] = mod_sub(u, v, p); } } // --- AVX2 NTT inverse (DIT, Montgomery, cache-blocked) --- inline void inverse_ntt_mont_avx2(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* layer_roots, const size_t* layer_offsets, uint64_t inv_n_normal) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t block = (N <= NTT_CACHE_BLOCK) ? N : NTT_CACHE_BLOCK; // Phase 1: Bottom layers (len <= block) — ブロック単位で L1 内処理 int num_bottom_layers = 0; for (size_t l = 2; l <= block; l <<= 1) ++num_bottom_layers; for (size_t blk = 0; blk < N; blk += block) { int cur_layer = 0; for (size_t len = 2; len <= block; len <<= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[cur_layer]; for (size_t start = blk; start < blk + block; start += len) dit_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++cur_layer; } } // Phase 2: Top layers (len > block) — 標準 layer-by-layer int layer = num_bottom_layers; for (size_t len = block * 2; len <= N; len <<= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[layer]; for (size_t start = 0; start < N; start += len) dit_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++layer; } // N^{-1} スケーリング + from_mont (AVX2) __m256i inv_n_vec = _mm256_set1_epi64x(inv_n_normal); size_t i = 0; for (; i + 4 <= N; i += 4) { __m256i d = _mm256_loadu_si256((__m256i*)(data + i)); d = avx2_mont_mul(d, inv_n_vec, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(data + i), d); } for (; i < N; ++i) { data[i] = mont_mul(data[i], inv_n_normal, p, p_inv_neg); } } // --- AVX2 pointwise 乗算 --- inline void pointwise_mul_mont_avx2(uint64_t* a, const uint64_t* b, size_t N, uint64_t p, uint64_t p_inv_neg) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t i = 0; for (; i + 4 <= N; i += 4) { __m256i va = _mm256_loadu_si256((__m256i*)(a + i)); __m256i vb = _mm256_loadu_si256((__m256i*)(b + i)); va = avx2_mont_mul(va, vb, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(a + i), va); } for (; i < N; ++i) { a[i] = mont_mul(a[i], b[i], p, p_inv_neg); } } inline void pointwise_sqr_mont_avx2(uint64_t* a, size_t N, uint64_t p, uint64_t p_inv_neg) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t i = 0; for (; i + 4 <= N; i += 4) { __m256i va = _mm256_loadu_si256((__m256i*)(a + i)); va = avx2_mont_mul(va, va, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(a + i), va); } for (; i < N; ++i) { a[i] = mont_mul(a[i], a[i], p, p_inv_neg); } } // YC-2: pointwise multiply-accumulate: r[i] += a[i] * b[i] mod p (AVX2) inline void pointwise_mac_mont_avx2(uint64_t* r, const uint64_t* a, const uint64_t* b, size_t N, uint64_t p, uint64_t p_inv_neg) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t i = 0; for (; i + 4 <= N; i += 4) { __m256i va = _mm256_loadu_si256((__m256i*)(a + i)); __m256i vb = _mm256_loadu_si256((__m256i*)(b + i)); __m256i vr = _mm256_loadu_si256((__m256i*)(r + i)); __m256i prod = avx2_mont_mul(va, vb, p_vec, pin_vec); // mod_add: sum = vr + prod, if sum >= p then sum -= p __m256i sum = _mm256_add_epi64(vr, prod); __m256i sub = _mm256_sub_epi64(sum, p_vec); __m256i mask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), sub); sum = _mm256_blendv_epi8(sub, sum, mask); _mm256_storeu_si256((__m256i*)(r + i), sum); } for (; i < N; ++i) { r[i] = mod_add(r[i], mont_mul(a[i], b[i], p, p_inv_neg), p); } } #endif // __AVX2__ // ================================================================ // CRT 復元 (Garner のアルゴリズム) // ================================================================ // 3 素数の NTT 結果から真の係数を復元し、 // limb 列 (uint64_t) に変換する。 // r1[i] mod p1, r2[i] mod p2, r3[i] mod p3 → result limbs // // Garner: // v1 = r1[i] // v2 = (r2[i] - v1) * inv_p1_mod_p2 mod p2 // v3 = (r3[i] - v1 - v2*p1) * inv_p1p2_mod_p3 mod p3 // c[i] = v1 + v2*p1 + v3*p1*p2 (最大 ~189 bit) struct CrtConstants { uint64_t p1, p2, p3; uint64_t inv_p1_mod_p2; // p1^(-1) mod p2 uint64_t inv_p1p2_mod_p3; // (p1*p2)^(-1) mod p3 uint64_t p1_mod_p3; // p1 mod p3 void init(uint64_t _p1, uint64_t _p2, uint64_t _p3) { p1 = _p1; p2 = _p2; p3 = _p3; inv_p1_mod_p2 = mod_inv(p1 % p2, p2); // p1*p2 mod p3: (p1 mod p3) * (p2 mod p3) mod p3 p1_mod_p3 = p1 % p3; uint64_t p1p2_mod_p3 = mod_mul(p1_mod_p3, p2 % p3, p3); inv_p1p2_mod_p3 = mod_inv(p1p2_mod_p3, p3); } }; // CRT で 1 係数を復元し、3-word (192bit) 値を返す // result[0] = low 64bit, result[1] = mid 64bit, result[2] = high 64bit inline void crt_single(uint64_t result[3], uint64_t r1, uint64_t r2, uint64_t r3, const CrtConstants& crt) { // v1 = r1 uint64_t v1 = r1; // v2 = (r2 - v1) * inv_p1_mod_p2 mod p2 uint64_t t = mod_sub(r2, v1 % crt.p2, crt.p2); uint64_t v2 = mod_mul(t, crt.inv_p1_mod_p2, crt.p2); // v3 = (r3 - v1 - v2*p1_mod_p3) * inv_p1p2_mod_p3 mod p3 uint64_t v2p1_mod_p3 = mod_mul(v2, crt.p1_mod_p3, crt.p3); uint64_t s = mod_sub(r3, v1 % crt.p3, crt.p3); s = mod_sub(s, v2p1_mod_p3, crt.p3); uint64_t v3 = mod_mul(s, crt.inv_p1p2_mod_p3, crt.p3); // c = v1 + v2*p1 + v3*p1*p2 // v2*p1: 最大 63bit × 63bit = 126bit → 2 words #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi1; uint64_t lo1 = _umul128(v2, crt.p1, &hi1); #elif defined(__GNUC__) || defined(__clang__) __uint128_t prod1 = static_cast<__uint128_t>(v2) * crt.p1; uint64_t lo1 = static_cast(prod1); uint64_t hi1 = static_cast(prod1 >> 64); #endif // v1 + v2*p1 (128bit) uint64_t c0 = v1 + lo1; uint64_t c1 = hi1 + (c0 < v1 ? 1 : 0); // v3*p1*p2: v3 (63bit) × p1 (63bit) = 126bit → × p2 = 189bit // ただし v3 < p3 < 2^63, p1 < 2^63, p2 < 2^63 // v3*p1: 126bit → 2 words #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi2; uint64_t lo2 = _umul128(v3, crt.p1, &hi2); // lo2:hi2 = v3*p1 (126bit) // × p2: (lo2:hi2) × p2 = 189bit → 3 words uint64_t hi3; uint64_t lo3 = _umul128(lo2, crt.p2, &hi3); uint64_t hi4; uint64_t lo4 = _umul128(hi2, crt.p2, &hi4); #elif defined(__GNUC__) || defined(__clang__) __uint128_t prod2 = static_cast<__uint128_t>(v3) * crt.p1; uint64_t lo2 = static_cast(prod2); uint64_t hi2 = static_cast(prod2 >> 64); __uint128_t prod3 = static_cast<__uint128_t>(lo2) * crt.p2; uint64_t lo3 = static_cast(prod3); uint64_t hi3 = static_cast(prod3 >> 64); __uint128_t prod4 = static_cast<__uint128_t>(hi2) * crt.p2; uint64_t lo4 = static_cast(prod4); uint64_t hi4 = static_cast(prod4 >> 64); #endif // v3*p1*p2 = lo3 + (hi3 + lo4) << 64 + (hi4) << 128 uint64_t mid = hi3 + lo4; uint64_t carry_mid = (mid < hi3) ? 1 : 0; uint64_t top = hi4 + carry_mid; // c += v3*p1*p2 c0 += lo3; uint64_t carry0 = (c0 < lo3) ? 1 : 0; c1 += mid + carry0; uint64_t carry1 = (c1 < mid || (carry0 && c1 == mid)) ? 1 : 0; uint64_t c2 = top + carry1; result[0] = c0; result[1] = c1; result[2] = c2; } // CRT 復元 + キャリー伝搬で最終結果を生成 inline void crt_recompose(uint64_t* rp, size_t result_limbs, const uint64_t* r1, const uint64_t* r2, const uint64_t* r3, size_t N, const CrtConstants& crt) { // 結果をゼロクリア std::memset(rp, 0, result_limbs * sizeof(uint64_t)); uint64_t carry_lo = 0, carry_hi = 0; for (size_t i = 0; i < N && i < result_limbs; ++i) { uint64_t coeff[3]; crt_single(coeff, r1[i], r2[i], r3[i], crt); // coeff[0:2] + carry を加算 uint64_t sum0 = coeff[0] + carry_lo; uint64_t c0 = (sum0 < coeff[0]) ? 1 : 0; uint64_t sum1 = coeff[1] + carry_hi + c0; uint64_t c1 = (sum1 < coeff[1] || (c0 && sum1 <= coeff[1])) ? 1 : 0; uint64_t sum2 = coeff[2] + c1; rp[i] = sum0; carry_lo = sum1; carry_hi = sum2; } // 残りのキャリーを書き出す if (carry_lo != 0 && N < result_limbs) { rp[N] = carry_lo; if (carry_hi != 0 && N + 1 < result_limbs) { rp[N + 1] = carry_hi; } } } // ================================================================ // グローバル NTT パラメータ (初期化は初回使用時) // ================================================================ struct PrimeNttContext { NttPrime primes[3]; CrtConstants crt; bool initialized = false; void init() { const uint64_t ps[] = { P1, P2, P3 }; const uint64_t gs[] = { G1, G2, G3 }; for (int k = 0; k < 3; ++k) { uint64_t p = ps[k]; uint64_t pin = compute_p_inv_neg(p); // R^2 mod p: R = 2^64, compute via mod_pow(2, 128, p) // 2^128 mod p = (2^64 mod p)^2 mod p uint64_t r_mod_p = mod_pow(2, 64, p); uint64_t r2 = mod_mul(r_mod_p, r_mod_p, p); primes[k] = { p, gs[k], NTT_MAX_S, mod_inv(2, p), pin, r2 }; } crt.init(P1, P2, P3); initialized = true; } }; inline PrimeNttContext& getPrimeNttContext() { static PrimeNttContext ctx; if (!ctx.initialized) { ctx.init(); } return ctx; } // ================================================================ // N 素数 CRT (Garner のアルゴリズム, NTT ドメイン BS 用) // ================================================================ struct CrtConstantsN { uint64_t p[BS_NUM_PRIMES]; // Garner 用逆元: inv[i][j] = p[j]^{-1} mod p[i] (j < i) uint64_t inv[BS_NUM_PRIMES][BS_NUM_PRIMES]; // prefix[i] = p[0]*p[1]*...*p[i-1] (多ワード, i ワード) // prefix[0] = {1}, prefix[1] = {p[0]}, ... uint64_t prefix[BS_NUM_PRIMES][BS_NUM_PRIMES]; int prefix_len[BS_NUM_PRIMES]; void init(const uint64_t ps[BS_NUM_PRIMES]) { for (int i = 0; i < BS_NUM_PRIMES; i++) p[i] = ps[i]; // 逆元テーブル for (int i = 0; i < BS_NUM_PRIMES; i++) for (int j = 0; j < i; j++) inv[i][j] = mod_inv(p[j] % p[i], p[i]); // 前置積 (多ワード) std::memset(prefix, 0, sizeof(prefix)); prefix[0][0] = 1; prefix_len[0] = 1; for (int i = 1; i < BS_NUM_PRIMES; i++) { uint64_t carry = 0; for (int w = 0; w < prefix_len[i - 1]; w++) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi; uint64_t lo = _umul128(prefix[i - 1][w], p[i - 1], &hi); #else __uint128_t prod = (__uint128_t)prefix[i - 1][w] * p[i - 1]; uint64_t lo = (uint64_t)prod; uint64_t hi = (uint64_t)(prod >> 64); #endif lo += carry; if (lo < carry) hi++; prefix[i][w] = lo; carry = hi; } if (carry) { prefix[i][prefix_len[i - 1]] = carry; prefix_len[i] = prefix_len[i - 1] + 1; } else { prefix_len[i] = prefix_len[i - 1]; } } } }; // N 素数 CRT: 1 係数を復元 → BS_NUM_PRIMES ワードの値を返す inline void crt_single_n(uint64_t result[BS_NUM_PRIMES], uint64_t r[BS_NUM_PRIMES], const CrtConstantsN& crt) { // Garner の混合基数表現 uint64_t v[BS_NUM_PRIMES]; v[0] = r[0]; for (int i = 1; i < BS_NUM_PRIMES; i++) { uint64_t temp = r[i]; for (int j = 0; j < i; j++) { temp = mod_sub(temp, v[j] % crt.p[i], crt.p[i]); temp = mod_mul(temp, crt.inv[i][j], crt.p[i]); } v[i] = temp; } // 復元: x = v[0] + v[1]*prefix[1] + v[2]*prefix[2] + ... std::memset(result, 0, BS_NUM_PRIMES * sizeof(uint64_t)); result[0] = v[0]; for (int i = 1; i < BS_NUM_PRIMES; i++) { // result += v[i] * prefix[i] uint64_t carry = 0; for (int w = 0; w < crt.prefix_len[i]; w++) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi; uint64_t lo = _umul128(v[i], crt.prefix[i][w], &hi); #else __uint128_t prod = (__uint128_t)v[i] * crt.prefix[i][w]; uint64_t lo = (uint64_t)prod; uint64_t hi = (uint64_t)(prod >> 64); #endif lo += carry; if (lo < carry) hi++; lo += result[w]; if (lo < result[w]) hi++; result[w] = lo; carry = hi; } for (int w = crt.prefix_len[i]; w < BS_NUM_PRIMES && carry; w++) { result[w] += carry; carry = (result[w] < carry) ? 1 : 0; } } } // N 素数 CRT 復元 + キャリー伝搬で最終結果を生成 inline void crt_recompose_n(uint64_t* rp, size_t result_limbs, uint64_t* const* residues, size_t N, const CrtConstantsN& crt) { std::memset(rp, 0, result_limbs * sizeof(uint64_t)); uint64_t carry[BS_NUM_PRIMES] = {}; for (size_t i = 0; i < N && i < result_limbs; ++i) { uint64_t res[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; k++) res[k] = residues[k][i]; uint64_t coeff[BS_NUM_PRIMES]; crt_single_n(coeff, res, crt); // coeff += carry uint64_t c = 0; for (int w = 0; w < BS_NUM_PRIMES; w++) { uint64_t sum = coeff[w] + carry[w]; uint64_t c1 = (sum < coeff[w]) ? 1 : 0; sum += c; c1 += (sum < c) ? 1 : 0; coeff[w] = sum; c = c1; } rp[i] = coeff[0]; for (int w = 0; w < BS_NUM_PRIMES - 1; w++) carry[w] = coeff[w + 1]; carry[BS_NUM_PRIMES - 1] = c; } // 残りキャリーを書き出す for (int w = 0; w < BS_NUM_PRIMES; w++) { size_t idx = N + w; if (idx < result_limbs && carry[w]) rp[idx] = carry[w]; } } // ================================================================ // NTT ドメイン BS 用コンテキスト (5 素数) // ================================================================ struct NttBsContext { NttPrime primes[BS_NUM_PRIMES]; CrtConstantsN crt; bool initialized = false; void init() { uint64_t ps[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = BS_PRIME_TABLE[k].p; int g = BS_PRIME_TABLE[k].g; uint64_t pin = compute_p_inv_neg(p); uint64_t r_mod_p = mod_pow(2, 64, p); uint64_t r2 = mod_mul(r_mod_p, r_mod_p, p); primes[k] = { p, static_cast(g), NTT_MAX_S, mod_inv(2, p), pin, r2 }; ps[k] = p; } crt.init(ps); initialized = true; } }; inline NttBsContext& getNttBsContext() { static NttBsContext ctx; if (!ctx.initialized) { ctx.init(); } return ctx; } // ================================================================ // NTT 長の決定 // ================================================================ inline size_t next_power_of_2(size_t n) { size_t p = 1; while (p < n) p <<= 1; return p; } // スレッドプール並列化の閾値 (NTT 長がこれ以上で 3 素数並列) static constexpr size_t PRIME_NTT_PARALLEL_THRESHOLD = 4096; // YC-10: 超並列 NTT の閾値 (NTT 長がこれ以上で intra-NTT 多スレッド化) // 128K limbs ≈ 2.5M 桁。これ以上で各素数の NTT を T スレッドに分割 static constexpr size_t MULTI_THREAD_NTT_THRESHOLD = 131072; // ================================================================ // 1 素数分の NTT パイプライン (Montgomery 版, スレッドから呼べる) // ================================================================ // 乗算: da[N], db[N] → 結果は da[N] に上書き (通常形式で返る) inline void ntt_mul_pipeline_mont(uint64_t* da, uint64_t* db, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* fwd_roots, const uint64_t* inv_roots, uint64_t inv_n_normal) { forward_ntt_mont(da, N, p, p_inv_neg, fwd_roots); forward_ntt_mont(db, N, p, p_inv_neg, fwd_roots); pointwise_mul_mont(da, db, N, p, p_inv_neg); inverse_ntt_mont(da, N, p, p_inv_neg, inv_roots, inv_n_normal); } // 自乗: da[N] → 結果は da[N] に上書き (通常形式で返る) inline void ntt_sqr_pipeline_mont(uint64_t* da, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* fwd_roots, const uint64_t* inv_roots, uint64_t inv_n_normal) { forward_ntt_mont(da, N, p, p_inv_neg, fwd_roots); pointwise_sqr_mont(da, N, p, p_inv_neg); inverse_ntt_mont(da, N, p, p_inv_neg, inv_roots, inv_n_normal); } #ifdef __AVX2__ // AVX2 版パイプライン (レイヤー根テーブル使用) inline void ntt_mul_pipeline_avx2(uint64_t* da, uint64_t* db, size_t N, uint64_t p, uint64_t p_inv_neg, const NttRootsLayered& roots) { forward_ntt_mont_avx2(da, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data()); forward_ntt_mont_avx2(db, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data()); pointwise_mul_mont_avx2(da, db, N, p, p_inv_neg); inverse_ntt_mont_avx2(da, N, p, p_inv_neg, roots.inv.data(), roots.inv_offset.data(), roots.inv_n); } inline void ntt_sqr_pipeline_avx2(uint64_t* da, size_t N, uint64_t p, uint64_t p_inv_neg, const NttRootsLayered& roots) { forward_ntt_mont_avx2(da, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data()); pointwise_sqr_mont_avx2(da, N, p, p_inv_neg); inverse_ntt_mont_avx2(da, N, p, p_inv_neg, roots.inv.data(), roots.inv_offset.data(), roots.inv_n); } #endif // ================================================================ // YC-10: 超並列 NTT (intra-NTT multi-threading) // ================================================================ // 各素数の NTT を T スレッドに分割し、多コア CPU を活用する。 // 構成: 3 素数 × T スレッド/素数 = 3T スレッド (64コアで T≈21) // Phase 1 (top layers): 各ステージのバタフライを T 分割 + バリア // Phase 2 (bottom layers): キャッシュブロック単位で T 分割 (バリア 1 回) // 並列 range 実行ヘルパー: body(start, end) を T スレッドで分割実行 template inline void ntt_parallel_range(size_t total, int T, F&& body) { if (T <= 1 || total == 0) { body(static_cast(0), total); return; } auto& pool = calx::threadPool(); size_t chunk = (total + T - 1) / T; std::future futures[128]; int nf = 0; for (int t = 0; t < T - 1; ++t) { size_t s = static_cast(t) * chunk; size_t e = std::min(s + chunk, total); if (s >= total) break; futures[nf++] = pool.submit([s, e, &body]{ body(s, e); }); } size_t last_s = static_cast(T - 1) * chunk; if (last_s < total) body(last_s, total); for (int i = 0; i < nf; ++i) futures[i].get(); } // --- Scalar multi-threaded NTT --- // Forward NTT (DIF) — T スレッド並列 inline void forward_ntt_mont_mt(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* roots, int T) { for (size_t len = N; len >= 2; len >>= 1) { size_t half = len >> 1; size_t step = N / len; size_t num_groups = N / len; if (num_groups >= static_cast(T)) { // グループ数 >= T: グループ単位で分割 ntt_parallel_range(num_groups, T, [=](size_t g_start, size_t g_end) { for (size_t g = g_start; g < g_end; ++g) { size_t base = g * len; for (size_t j = 0; j < half; ++j) { uint64_t u = data[base + j]; uint64_t v = data[base + j + half]; data[base + j] = mod_add(u, v, p); data[base + j + half] = mont_mul( mod_sub(u, v, p), roots[j * step], p, p_inv_neg); } } }); } else { // グループ数 < T: j ループを T 分割 (各グループ独立) ntt_parallel_range(half, T, [=](size_t j_start, size_t j_end) { for (size_t g = 0; g < num_groups; ++g) { size_t base = g * len; for (size_t j = j_start; j < j_end; ++j) { uint64_t u = data[base + j]; uint64_t v = data[base + j + half]; data[base + j] = mod_add(u, v, p); data[base + j + half] = mont_mul( mod_sub(u, v, p), roots[j * step], p, p_inv_neg); } } }); } } } // Inverse NTT (DIT) — T スレッド並列 inline void inverse_ntt_mont_mt(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* inv_roots, uint64_t inv_n_normal, int T) { for (size_t len = 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t step = N / len; size_t num_groups = N / len; if (num_groups >= static_cast(T)) { ntt_parallel_range(num_groups, T, [=](size_t g_start, size_t g_end) { for (size_t g = g_start; g < g_end; ++g) { size_t base = g * len; for (size_t j = 0; j < half; ++j) { uint64_t u = data[base + j]; uint64_t v = mont_mul(data[base + j + half], inv_roots[j * step], p, p_inv_neg); data[base + j] = mod_add(u, v, p); data[base + j + half] = mod_sub(u, v, p); } } }); } else { ntt_parallel_range(half, T, [=](size_t j_start, size_t j_end) { for (size_t g = 0; g < num_groups; ++g) { size_t base = g * len; for (size_t j = j_start; j < j_end; ++j) { uint64_t u = data[base + j]; uint64_t v = mont_mul(data[base + j + half], inv_roots[j * step], p, p_inv_neg); data[base + j] = mod_add(u, v, p); data[base + j + half] = mod_sub(u, v, p); } } }); } } // N^{-1} スケーリング (並列) ntt_parallel_range(N, T, [=](size_t start, size_t end) { for (size_t i = start; i < end; ++i) data[i] = mont_mul(data[i], inv_n_normal, p, p_inv_neg); }); } // Multi-threaded pointwise multiply (AVX2/scalar 自動選択) inline void pointwise_mul_mont_mt(uint64_t* a, const uint64_t* b, size_t N, uint64_t p, uint64_t p_inv_neg, int T) { ntt_parallel_range(N, T, [=](size_t start, size_t end) { #ifdef __AVX2__ const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t i = start; for (; i + 4 <= end; i += 4) { __m256i va = _mm256_loadu_si256((__m256i*)(a + i)); __m256i vb = _mm256_loadu_si256((__m256i*)(b + i)); va = avx2_mont_mul(va, vb, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(a + i), va); } for (; i < end; ++i) a[i] = mont_mul(a[i], b[i], p, p_inv_neg); #else for (size_t i = start; i < end; ++i) a[i] = mont_mul(a[i], b[i], p, p_inv_neg); #endif }); } // Multi-threaded pointwise square inline void pointwise_sqr_mont_mt(uint64_t* a, size_t N, uint64_t p, uint64_t p_inv_neg, int T) { ntt_parallel_range(N, T, [=](size_t start, size_t end) { #ifdef __AVX2__ const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); size_t i = start; for (; i + 4 <= end; i += 4) { __m256i va = _mm256_loadu_si256((__m256i*)(a + i)); va = avx2_mont_mul(va, va, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(a + i), va); } for (; i < end; ++i) a[i] = mont_mul(a[i], a[i], p, p_inv_neg); #else for (size_t i = start; i < end; ++i) a[i] = mont_mul(a[i], a[i], p, p_inv_neg); #endif }); } // Scalar multi-threaded pipeline: mul inline void ntt_mul_pipeline_mt(uint64_t* da, uint64_t* db, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* fwd_roots, const uint64_t* inv_roots, uint64_t inv_n_normal, int T) { forward_ntt_mont_mt(da, N, p, p_inv_neg, fwd_roots, T); forward_ntt_mont_mt(db, N, p, p_inv_neg, fwd_roots, T); pointwise_mul_mont_mt(da, db, N, p, p_inv_neg, T); inverse_ntt_mont_mt(da, N, p, p_inv_neg, inv_roots, inv_n_normal, T); } // Scalar multi-threaded pipeline: sqr inline void ntt_sqr_pipeline_mt(uint64_t* da, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* fwd_roots, const uint64_t* inv_roots, uint64_t inv_n_normal, int T) { forward_ntt_mont_mt(da, N, p, p_inv_neg, fwd_roots, T); pointwise_sqr_mont_mt(da, N, p, p_inv_neg, T); inverse_ntt_mont_mt(da, N, p, p_inv_neg, inv_roots, inv_n_normal, T); } #ifdef __AVX2__ // --- AVX2 multi-threaded NTT --- // AVX2 Forward NTT (DIF, cache-blocked) — T スレッド並列 inline void forward_ntt_mont_avx2_mt(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* layer_roots, const size_t* layer_offsets, int T) { size_t block = (N <= NTT_CACHE_BLOCK) ? N : NTT_CACHE_BLOCK; // Phase 1: Top layers (len > block) — ステージごとにバリア int layer = 0; for (size_t len = N; len > block; len >>= 1) { size_t half = len >> 1; size_t num_groups = N / len; const uint64_t* roots = layer_roots + layer_offsets[layer]; if (num_groups >= static_cast(T)) { // グループ単位で分割 ntt_parallel_range(num_groups, T, [=](size_t g_start, size_t g_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t g = g_start; g < g_end; ++g) dif_butterfly_avx2(data, g * len, half, roots, p_vec, pin_vec, p, p_inv_neg); }); } else { // j ループを分割 (各グループ内の half バタフライを T 分割) ntt_parallel_range(half, T, [=](size_t j_start, size_t j_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t g = 0; g < num_groups; ++g) { size_t start = g * len; uint64_t* d0 = data + start; uint64_t* d1 = data + start + half; size_t j = j_start; for (; j + 4 <= j_end; j += 4) { __m256i u = _mm256_loadu_si256((__m256i*)(d0 + j)); __m256i v = _mm256_loadu_si256((__m256i*)(d1 + j)); __m256i w = _mm256_loadu_si256((__m256i*)(roots + j)); __m256i sum = _mm256_add_epi64(u, v); __m256i sub = _mm256_sub_epi64(sum, p_vec); __m256i sm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), sub); sum = _mm256_blendv_epi8(sub, sum, sm); __m256i diff = _mm256_sub_epi64(u, v); __m256i da = _mm256_add_epi64(diff, p_vec); __m256i dm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); diff = _mm256_blendv_epi8(diff, da, dm); __m256i tw = avx2_mont_mul(diff, w, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(d0 + j), sum); _mm256_storeu_si256((__m256i*)(d1 + j), tw); } for (; j < j_end; ++j) { uint64_t uv = d0[j], vv = d1[j]; d0[j] = mod_add(uv, vv, p); d1[j] = mont_mul(mod_sub(uv, vv, p), roots[j], p, p_inv_neg); } } }); } ++layer; } // Phase 2: Bottom layers (len <= block) — ブロック並列 (バリア 1 回) int bottom_start = layer; size_t num_blocks = N / block; ntt_parallel_range(num_blocks, T, [=](size_t blk_start, size_t blk_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t bi = blk_start; bi < blk_end; ++bi) { size_t blk = bi * block; int cur_layer = bottom_start; for (size_t len = block; len >= 2; len >>= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[cur_layer]; for (size_t start = blk; start < blk + block; start += len) dif_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++cur_layer; } } }); } // AVX2 Inverse NTT (DIT, cache-blocked) — T スレッド並列 inline void inverse_ntt_mont_avx2_mt(uint64_t* data, size_t N, uint64_t p, uint64_t p_inv_neg, const uint64_t* layer_roots, const size_t* layer_offsets, uint64_t inv_n_normal, int T) { size_t block = (N <= NTT_CACHE_BLOCK) ? N : NTT_CACHE_BLOCK; // Phase 1: Bottom layers (len <= block) — ブロック並列 (バリア 1 回) int num_bottom_layers = 0; for (size_t l = 2; l <= block; l <<= 1) ++num_bottom_layers; size_t num_blocks = N / block; ntt_parallel_range(num_blocks, T, [=](size_t blk_start, size_t blk_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t bi = blk_start; bi < blk_end; ++bi) { size_t blk = bi * block; int cur_layer = 0; for (size_t len = 2; len <= block; len <<= 1) { size_t half = len >> 1; const uint64_t* roots = layer_roots + layer_offsets[cur_layer]; for (size_t start = blk; start < blk + block; start += len) dit_butterfly_avx2(data, start, half, roots, p_vec, pin_vec, p, p_inv_neg); ++cur_layer; } } }); // Phase 2: Top layers (len > block) — ステージごとにバリア int layer = num_bottom_layers; for (size_t len = block * 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t num_groups = N / len; const uint64_t* roots = layer_roots + layer_offsets[layer]; if (num_groups >= static_cast(T)) { ntt_parallel_range(num_groups, T, [=](size_t g_start, size_t g_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t g = g_start; g < g_end; ++g) dit_butterfly_avx2(data, g * len, half, roots, p_vec, pin_vec, p, p_inv_neg); }); } else { ntt_parallel_range(half, T, [=](size_t j_start, size_t j_end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); for (size_t g = 0; g < num_groups; ++g) { size_t start = g * len; uint64_t* d0 = data + start; uint64_t* d1 = data + start + half; size_t j = j_start; for (; j + 4 <= j_end; j += 4) { __m256i u = _mm256_loadu_si256((__m256i*)(d0 + j)); __m256i v_raw = _mm256_loadu_si256((__m256i*)(d1 + j)); __m256i w = _mm256_loadu_si256((__m256i*)(roots + j)); __m256i v = avx2_mont_mul(v_raw, w, p_vec, pin_vec); __m256i sum = _mm256_add_epi64(u, v); __m256i sub = _mm256_sub_epi64(sum, p_vec); __m256i sm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), sub); sum = _mm256_blendv_epi8(sub, sum, sm); __m256i diff = _mm256_sub_epi64(u, v); __m256i da = _mm256_add_epi64(diff, p_vec); __m256i dm = _mm256_cmpgt_epi64(_mm256_setzero_si256(), diff); diff = _mm256_blendv_epi8(diff, da, dm); _mm256_storeu_si256((__m256i*)(d0 + j), sum); _mm256_storeu_si256((__m256i*)(d1 + j), diff); } for (; j < j_end; ++j) { uint64_t uv = d0[j]; uint64_t vv = mont_mul(d1[j], roots[j], p, p_inv_neg); d0[j] = mod_add(uv, vv, p); d1[j] = mod_sub(uv, vv, p); } } }); } ++layer; } // N^{-1} スケーリング (並列, AVX2) ntt_parallel_range(N, T, [=](size_t start, size_t end) { const __m256i p_vec = _mm256_set1_epi64x(p); const __m256i pin_vec = _mm256_set1_epi64x(p_inv_neg); __m256i inv_n_vec = _mm256_set1_epi64x(inv_n_normal); size_t i = start; for (; i + 4 <= end; i += 4) { __m256i d = _mm256_loadu_si256((__m256i*)(data + i)); d = avx2_mont_mul(d, inv_n_vec, p_vec, pin_vec); _mm256_storeu_si256((__m256i*)(data + i), d); } for (; i < end; ++i) data[i] = mont_mul(data[i], inv_n_normal, p, p_inv_neg); }); } // AVX2 multi-threaded pipeline: mul inline void ntt_mul_pipeline_avx2_mt(uint64_t* da, uint64_t* db, size_t N, uint64_t p, uint64_t p_inv_neg, const NttRootsLayered& roots, int T) { forward_ntt_mont_avx2_mt(da, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data(), T); forward_ntt_mont_avx2_mt(db, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data(), T); pointwise_mul_mont_mt(da, db, N, p, p_inv_neg, T); inverse_ntt_mont_avx2_mt(da, N, p, p_inv_neg, roots.inv.data(), roots.inv_offset.data(), roots.inv_n, T); } // AVX2 multi-threaded pipeline: sqr inline void ntt_sqr_pipeline_avx2_mt(uint64_t* da, size_t N, uint64_t p, uint64_t p_inv_neg, const NttRootsLayered& roots, int T) { forward_ntt_mont_avx2_mt(da, N, p, p_inv_neg, roots.fwd.data(), roots.fwd_offset.data(), T); pointwise_sqr_mont_mt(da, N, p, p_inv_neg, T); inverse_ntt_mont_avx2_mt(da, N, p, p_inv_neg, roots.inv.data(), roots.inv_offset.data(), roots.inv_n, T); } #endif // ================================================================ // YC-2: Fused Multiply-Add (rp = a*b + c*d) // ================================================================ // 4 入力を forward NTT → pointwise MAC → 1 回の inverse NTT + CRT // 通常の 2 回乗算 + 加算と比べ、inverse NTT 1 回 + CRT 1 回 + 大数加算を節約 // rp[0..rn-1] = ap[0..an-1]*bp[0..bn-1] + cp[0..cn-1]*dp[0..dn-1] inline void mul_add_prime_ntt(uint64_t* rp, size_t rn, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn, const uint64_t* cp, size_t cn, const uint64_t* dp, size_t dn) { auto& ctx = getPrimeNttContext(); // NTT 長は両積の最大サイズに合わせる size_t prod1_n = an + bn; size_t prod2_n = cn + dn; size_t max_prod = std::max(prod1_n, prod2_n); size_t N = next_power_of_2(max_prod); // 4 入力 × 3 素数 = 12N ワード thread_local std::vector work; size_t total = 12 * N; if (work.size() < total) work.resize(total); uint64_t* da[3]; // a の NTT uint64_t* db[3]; // b の NTT uint64_t* dc[3]; // c の NTT uint64_t* dd[3]; // d の NTT for (int k = 0; k < 3; ++k) { da[k] = work.data() + k * 4 * N; db[k] = da[k] + N; dc[k] = db[k] + N; dd[k] = dc[k] + N; } // 入力分解: mod p → Montgomery 形式 for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; // a for (size_t i = 0; i < an; ++i) da[k][i] = to_mont(ap[i] % p, p, pin, r2); for (size_t i = an; i < N; ++i) da[k][i] = 0; // b for (size_t i = 0; i < bn; ++i) db[k][i] = to_mont(bp[i] % p, p, pin, r2); for (size_t i = bn; i < N; ++i) db[k][i] = 0; // c for (size_t i = 0; i < cn; ++i) dc[k][i] = to_mont(cp[i] % p, p, pin, r2); for (size_t i = cn; i < N; ++i) dc[k][i] = 0; // d for (size_t i = 0; i < dn; ++i) dd[k][i] = to_mont(dp[i] % p, p, pin, r2); for (size_t i = dn; i < N; ++i) dd[k][i] = 0; } #ifdef __AVX2__ thread_local NttRootsLayered roots_avx[3]; for (int k = 0; k < 3; ++k) { if (roots_avx[k].N != N || roots_avx[k].p != ctx.primes[k].p) { roots_avx[k].build(ctx.primes[k], N); } } bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { const NttRootsLayered* rptr[3] = {&roots_avx[0], &roots_avx[1], &roots_avx[2]}; // 素数 0, 1 をワーカーで、素数 2 をメインスレッドで処理 auto f0 = calx::threadPool().submit([&, N, rptr]{ uint64_t p = ctx.primes[0].p, pin = ctx.primes[0].p_inv_neg; const auto& r = *rptr[0]; forward_ntt_mont_avx2(da[0], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(db[0], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dc[0], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dd[0], N, p, pin, r.fwd.data(), r.fwd_offset.data()); pointwise_mul_mont_avx2(da[0], db[0], N, p, pin); pointwise_mac_mont_avx2(da[0], dc[0], dd[0], N, p, pin); inverse_ntt_mont_avx2(da[0], N, p, pin, r.inv.data(), r.inv_offset.data(), r.inv_n); }); auto f1 = calx::threadPool().submit([&, N, rptr]{ uint64_t p = ctx.primes[1].p, pin = ctx.primes[1].p_inv_neg; const auto& r = *rptr[1]; forward_ntt_mont_avx2(da[1], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(db[1], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dc[1], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dd[1], N, p, pin, r.fwd.data(), r.fwd_offset.data()); pointwise_mul_mont_avx2(da[1], db[1], N, p, pin); pointwise_mac_mont_avx2(da[1], dc[1], dd[1], N, p, pin); inverse_ntt_mont_avx2(da[1], N, p, pin, r.inv.data(), r.inv_offset.data(), r.inv_n); }); { uint64_t p = ctx.primes[2].p, pin = ctx.primes[2].p_inv_neg; const auto& r = *rptr[2]; forward_ntt_mont_avx2(da[2], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(db[2], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dc[2], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dd[2], N, p, pin, r.fwd.data(), r.fwd_offset.data()); pointwise_mul_mont_avx2(da[2], db[2], N, p, pin); pointwise_mac_mont_avx2(da[2], dc[2], dd[2], N, p, pin); inverse_ntt_mont_avx2(da[2], N, p, pin, r.inv.data(), r.inv_offset.data(), r.inv_n); } f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p, pin = ctx.primes[k].p_inv_neg; const auto& r = roots_avx[k]; forward_ntt_mont_avx2(da[k], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(db[k], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dc[k], N, p, pin, r.fwd.data(), r.fwd_offset.data()); forward_ntt_mont_avx2(dd[k], N, p, pin, r.fwd.data(), r.fwd_offset.data()); pointwise_mul_mont_avx2(da[k], db[k], N, p, pin); pointwise_mac_mont_avx2(da[k], dc[k], dd[k], N, p, pin); inverse_ntt_mont_avx2(da[k], N, p, pin, r.inv.data(), r.inv_offset.data(), r.inv_n); } } #else thread_local NttRootsMont roots[3]; for (int k = 0; k < 3; ++k) { if (roots[k].N != N || roots[k].p != ctx.primes[k].p) { roots[k].build(ctx.primes[k], N); } } bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { auto f0 = calx::threadPool().submit([&, N]{ uint64_t p = ctx.primes[0].p, pin = ctx.primes[0].p_inv_neg; forward_ntt_mont(da[0], N, p, pin, roots[0].roots.data()); forward_ntt_mont(db[0], N, p, pin, roots[0].roots.data()); forward_ntt_mont(dc[0], N, p, pin, roots[0].roots.data()); forward_ntt_mont(dd[0], N, p, pin, roots[0].roots.data()); pointwise_mul_mont(da[0], db[0], N, p, pin); pointwise_mac_mont(da[0], dc[0], dd[0], N, p, pin); inverse_ntt_mont(da[0], N, p, pin, roots[0].inv_roots.data(), roots[0].inv_n); }); auto f1 = calx::threadPool().submit([&, N]{ uint64_t p = ctx.primes[1].p, pin = ctx.primes[1].p_inv_neg; forward_ntt_mont(da[1], N, p, pin, roots[1].roots.data()); forward_ntt_mont(db[1], N, p, pin, roots[1].roots.data()); forward_ntt_mont(dc[1], N, p, pin, roots[1].roots.data()); forward_ntt_mont(dd[1], N, p, pin, roots[1].roots.data()); pointwise_mul_mont(da[1], db[1], N, p, pin); pointwise_mac_mont(da[1], dc[1], dd[1], N, p, pin); inverse_ntt_mont(da[1], N, p, pin, roots[1].inv_roots.data(), roots[1].inv_n); }); { uint64_t p = ctx.primes[2].p, pin = ctx.primes[2].p_inv_neg; forward_ntt_mont(da[2], N, p, pin, roots[2].roots.data()); forward_ntt_mont(db[2], N, p, pin, roots[2].roots.data()); forward_ntt_mont(dc[2], N, p, pin, roots[2].roots.data()); forward_ntt_mont(dd[2], N, p, pin, roots[2].roots.data()); pointwise_mul_mont(da[2], db[2], N, p, pin); pointwise_mac_mont(da[2], dc[2], dd[2], N, p, pin); inverse_ntt_mont(da[2], N, p, pin, roots[2].inv_roots.data(), roots[2].inv_n); } f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p, pin = ctx.primes[k].p_inv_neg; forward_ntt_mont(da[k], N, p, pin, roots[k].roots.data()); forward_ntt_mont(db[k], N, p, pin, roots[k].roots.data()); forward_ntt_mont(dc[k], N, p, pin, roots[k].roots.data()); forward_ntt_mont(dd[k], N, p, pin, roots[k].roots.data()); pointwise_mul_mont(da[k], db[k], N, p, pin); pointwise_mac_mont(da[k], dc[k], dd[k], N, p, pin); inverse_ntt_mont(da[k], N, p, pin, roots[k].inv_roots.data(), roots[k].inv_n); } } #endif // CRT 復元 crt_recompose(rp, rn, da[0], da[1], da[2], max_prod, ctx.crt); } // ================================================================ // メイン乗算関数 // ================================================================ // rp[0..an+bn-1] = ap[0..an-1] × bp[0..bn-1] inline void mul_prime_ntt(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn) { if (an < bn) { std::swap(ap, bp); std::swap(an, bn); } auto& ctx = getPrimeNttContext(); size_t rn = an + bn; size_t N = next_power_of_2(rn); // 3 素数分の NTT データ配列を確保 // 各素数: data_a[N] + data_b[N] = 2N ワード、計 6N ワード thread_local std::vector work; size_t total = 6 * N; if (work.size() < total) work.resize(total); uint64_t* da[3]; uint64_t* db[3]; for (int k = 0; k < 3; ++k) { da[k] = work.data() + k * 2 * N; db[k] = da[k] + N; } // 入力分解: a[i] mod p → Montgomery 形式に変換 for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < an; ++i) da[k][i] = to_mont(ap[i] % p, p, pin, r2); for (size_t i = an; i < N; ++i) da[k][i] = 0; for (size_t i = 0; i < bn; ++i) db[k][i] = to_mont(bp[i] % p, p, pin, r2); for (size_t i = bn; i < N; ++i) db[k][i] = 0; } #ifdef __AVX2__ // AVX2: レイヤー根テーブル (thread_local キャッシュ) thread_local NttRootsLayered roots_avx[3]; for (int k = 0; k < 3; ++k) { if (roots_avx[k].N != N || roots_avx[k].p != ctx.primes[k].p) { roots_avx[k].build(ctx.primes[k], N); } } // NTT パイプライン (AVX2): 3 素数は完全独立 → 並列実行可能 bool use_mt = (N >= MULTI_THREAD_NTT_THRESHOLD); bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (use_mt) { // YC-10: 素数間並列 + 素数内マルチスレッド unsigned hw = std::thread::hardware_concurrency(); int T = std::max(2, static_cast(hw > 3 ? (hw - 1) / 3 : 1)); const NttRootsLayered* rpp[3] = {&roots_avx[0], &roots_avx[1], &roots_avx[2]}; auto f0 = calx::threadPool().submit([&, N, T, rpp]{ ntt_mul_pipeline_avx2_mt(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *rpp[0], T); }); auto f1 = calx::threadPool().submit([&, N, T, rpp]{ ntt_mul_pipeline_avx2_mt(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *rpp[1], T); }); ntt_mul_pipeline_avx2_mt(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *rpp[2], T); f0.get(); f1.get(); } else if (parallel) { // ワーカーから thread_local を直接参照しないため構造体コピー // (NttRootsLayered は vector で内部バッファを持つ → ポインタで渡す) const NttRootsLayered* rpp[3] = {&roots_avx[0], &roots_avx[1], &roots_avx[2]}; auto f0 = calx::threadPool().submit([&, N, rpp]{ ntt_mul_pipeline_avx2(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *rpp[0]); }); auto f1 = calx::threadPool().submit([&, N, rpp]{ ntt_mul_pipeline_avx2(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *rpp[1]); }); ntt_mul_pipeline_avx2(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *rpp[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { ntt_mul_pipeline_avx2(da[k], db[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, roots_avx[k]); } } #else // スカラー Montgomery NTT thread_local NttRootsMont roots[3]; for (int k = 0; k < 3; ++k) { if (roots[k].N != N || roots[k].p != ctx.primes[k].p) { roots[k].build(ctx.primes[k], N); } } const uint64_t* fwd_r[3], *inv_r[3]; uint64_t inv_n_vals[3]; for (int k = 0; k < 3; ++k) { fwd_r[k] = roots[k].roots.data(); inv_r[k] = roots[k].inv_roots.data(); inv_n_vals[k] = roots[k].inv_n; } bool use_mt = (N >= MULTI_THREAD_NTT_THRESHOLD); bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (use_mt) { unsigned hw = std::thread::hardware_concurrency(); int T = std::max(2, static_cast(hw > 3 ? (hw - 1) / 3 : 1)); auto f0 = calx::threadPool().submit([&, N, T]{ ntt_mul_pipeline_mt(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0], inv_r[0], inv_n_vals[0], T); }); auto f1 = calx::threadPool().submit([&, N, T]{ ntt_mul_pipeline_mt(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1], inv_r[1], inv_n_vals[1], T); }); ntt_mul_pipeline_mt(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2], inv_r[2], inv_n_vals[2], T); f0.get(); f1.get(); } else if (parallel) { auto f0 = calx::threadPool().submit([&, N]{ ntt_mul_pipeline_mont(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0], inv_r[0], inv_n_vals[0]); }); auto f1 = calx::threadPool().submit([&, N]{ ntt_mul_pipeline_mont(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1], inv_r[1], inv_n_vals[1]); }); ntt_mul_pipeline_mont(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2], inv_r[2], inv_n_vals[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { ntt_mul_pipeline_mont(da[k], db[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k], inv_r[k], inv_n_vals[k]); } } #endif // CRT 復元 + キャリー伝搬 crt_recompose(rp, rn, da[0], da[1], da[2], rn, ctx.crt); } // ================================================================ // YC-1d: Middle Product (中間積) // ================================================================ // // ap[0..an-1] × bp[0..bn-1] のフル積 c[0..an+bn-2] のうち、 // 中間部分 c[bn-1..an-1] (計 an-bn+1 リム) を rp に書き出す。 // // 数学的根拠: // NTT 長 N = next_pow2(an) で循環畳み込みを行うと、位置 k での wrap-around は // i+j = k+N (i < an, j < bn) を満たす項。k >= bn-1 のとき k+N >= bn-1+an > an+bn-2 // より i+j <= an-1+bn-1 = an+bn-2 < k+N なので wrap-around 項は存在しない。 // 従って位置 bn-1..an-1 は正確に得られる。 // // 用途: Newton 除算の残差計算 — bx ≈ 1 の residual = 1 - bx を // フル積 (NTT長 an+bn) の代わりに短い NTT 長 (an) で計算。 // // 前提: an >= bn >= 1 // 出力: rp[0..an-bn] = c[bn-1..an-1] (an-bn+1 リム) inline void middle_product_prime_ntt(uint64_t* rp, size_t rn, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn) { // rn = an - bn + 1 (呼び出し元で保証) auto& ctx = getPrimeNttContext(); // NTT 長: next_pow2(an) — フル積なら next_pow2(an+bn) が必要だが、 // middle product では an で十分 (上記の数学的根拠) size_t N = next_power_of_2(an); // 3 素数分のワーク: da[N] + db[N] × 3 = 6N thread_local std::vector work_mp; size_t total = 6 * N; if (work_mp.size() < total) work_mp.resize(total); uint64_t* da[3]; uint64_t* db[3]; for (int k = 0; k < 3; ++k) { da[k] = work_mp.data() + k * 2 * N; db[k] = da[k] + N; } // 入力分解: a → Montgomery 形式 (そのまま) // b → 反転して Montgomery 形式に変換 // b_rev[j] = b[bn-1-j] により、循環畳み込みの位置 k が // フル積の位置 k + (bn-1) に対応する for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < an; ++i) da[k][i] = to_mont(ap[i] % p, p, pin, r2); for (size_t i = an; i < N; ++i) da[k][i] = 0; // b を反転して配置 for (size_t i = 0; i < bn; ++i) db[k][i] = to_mont(bp[bn - 1 - i] % p, p, pin, r2); for (size_t i = bn; i < N; ++i) db[k][i] = 0; } #ifdef __AVX2__ thread_local NttRootsLayered roots_avx_mp[3]; for (int k = 0; k < 3; ++k) { if (roots_avx_mp[k].N != N || roots_avx_mp[k].p != ctx.primes[k].p) { roots_avx_mp[k].build(ctx.primes[k], N); } } bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { const NttRootsLayered* rp_roots[3] = {&roots_avx_mp[0], &roots_avx_mp[1], &roots_avx_mp[2]}; auto f0 = calx::threadPool().submit([&, N, rp_roots]{ ntt_mul_pipeline_avx2(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *rp_roots[0]); }); auto f1 = calx::threadPool().submit([&, N, rp_roots]{ ntt_mul_pipeline_avx2(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *rp_roots[1]); }); ntt_mul_pipeline_avx2(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *rp_roots[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { ntt_mul_pipeline_avx2(da[k], db[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, roots_avx_mp[k]); } } #else thread_local NttRootsMont roots_mp[3]; for (int k = 0; k < 3; ++k) { if (roots_mp[k].N != N || roots_mp[k].p != ctx.primes[k].p) { roots_mp[k].build(ctx.primes[k], N); } } const uint64_t* fwd_r[3], *inv_r[3]; uint64_t inv_n_vals[3]; for (int k = 0; k < 3; ++k) { fwd_r[k] = roots_mp[k].roots.data(); inv_r[k] = roots_mp[k].inv_roots.data(); inv_n_vals[k] = roots_mp[k].inv_n; } bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { auto f0 = calx::threadPool().submit([&, N]{ ntt_mul_pipeline_mont(da[0], db[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0], inv_r[0], inv_n_vals[0]); }); auto f1 = calx::threadPool().submit([&, N]{ ntt_mul_pipeline_mont(da[1], db[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1], inv_r[1], inv_n_vals[1]); }); ntt_mul_pipeline_mont(da[2], db[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2], inv_r[2], inv_n_vals[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { ntt_mul_pipeline_mont(da[k], db[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k], inv_r[k], inv_n_vals[k]); } } #endif // CRT 復元: NTT 長 N 分を一時バッファに復元し、位置 0..rn-1 を抽出 // (循環畳み込みの位置 0 がフル積の位置 bn-1 に対応) // N が大きい場合でもフル積長 (an+bn) は不要で N 分で十分 thread_local std::vector crt_buf; if (crt_buf.size() < N) crt_buf.resize(N); crt_recompose(crt_buf.data(), N, da[0], da[1], da[2], N, ctx.crt); // 位置 0..rn-1 を出力にコピー (= フル積の c[bn-1..an-1]) std::memcpy(rp, crt_buf.data(), rn * sizeof(uint64_t)); } // ================================================================ // 自乗関数 // ================================================================ // rp[0..2n-1] = ap[0..n-1]² inline void sqr_prime_ntt(uint64_t* rp, const uint64_t* ap, size_t an) { auto& ctx = getPrimeNttContext(); size_t rn = 2 * an; size_t N = next_power_of_2(rn); thread_local std::vector work; size_t total = 3 * N; if (work.size() < total) work.resize(total); uint64_t* da[3]; for (int k = 0; k < 3; ++k) { da[k] = work.data() + k * N; } // 入力分解 → Montgomery 形式 for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < an; ++i) da[k][i] = to_mont(ap[i] % p, p, pin, r2); for (size_t i = an; i < N; ++i) da[k][i] = 0; } #ifdef __AVX2__ thread_local NttRootsLayered roots_avx[3]; for (int k = 0; k < 3; ++k) { if (roots_avx[k].N != N || roots_avx[k].p != ctx.primes[k].p) roots_avx[k].build(ctx.primes[k], N); } bool use_mt = (N >= MULTI_THREAD_NTT_THRESHOLD); bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (use_mt) { unsigned hw = std::thread::hardware_concurrency(); int T = std::max(2, static_cast(hw > 3 ? (hw - 1) / 3 : 1)); const NttRootsLayered* rpp[3] = {&roots_avx[0], &roots_avx[1], &roots_avx[2]}; auto f0 = calx::threadPool().submit([&, N, T, rpp]{ ntt_sqr_pipeline_avx2_mt(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *rpp[0], T); }); auto f1 = calx::threadPool().submit([&, N, T, rpp]{ ntt_sqr_pipeline_avx2_mt(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *rpp[1], T); }); ntt_sqr_pipeline_avx2_mt(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *rpp[2], T); f0.get(); f1.get(); } else if (parallel) { const NttRootsLayered* rpp[3] = {&roots_avx[0], &roots_avx[1], &roots_avx[2]}; auto f0 = calx::threadPool().submit([&, N, rpp]{ ntt_sqr_pipeline_avx2(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, *rpp[0]); }); auto f1 = calx::threadPool().submit([&, N, rpp]{ ntt_sqr_pipeline_avx2(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, *rpp[1]); }); ntt_sqr_pipeline_avx2(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, *rpp[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) ntt_sqr_pipeline_avx2(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, roots_avx[k]); } #else thread_local NttRootsMont roots[3]; for (int k = 0; k < 3; ++k) { if (roots[k].N != N || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], N); } const uint64_t* fwd_r[3], *inv_r[3]; uint64_t inv_n_vals[3]; for (int k = 0; k < 3; ++k) { fwd_r[k] = roots[k].roots.data(); inv_r[k] = roots[k].inv_roots.data(); inv_n_vals[k] = roots[k].inv_n; } bool use_mt = (N >= MULTI_THREAD_NTT_THRESHOLD); bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (use_mt) { unsigned hw = std::thread::hardware_concurrency(); int T = std::max(2, static_cast(hw > 3 ? (hw - 1) / 3 : 1)); auto f0 = calx::threadPool().submit([&, N, T]{ ntt_sqr_pipeline_mt(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0], inv_r[0], inv_n_vals[0], T); }); auto f1 = calx::threadPool().submit([&, N, T]{ ntt_sqr_pipeline_mt(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1], inv_r[1], inv_n_vals[1], T); }); ntt_sqr_pipeline_mt(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2], inv_r[2], inv_n_vals[2], T); f0.get(); f1.get(); } else if (parallel) { auto f0 = calx::threadPool().submit([&, N]{ ntt_sqr_pipeline_mont(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0], inv_r[0], inv_n_vals[0]); }); auto f1 = calx::threadPool().submit([&, N]{ ntt_sqr_pipeline_mont(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1], inv_r[1], inv_n_vals[1]); }); ntt_sqr_pipeline_mont(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2], inv_r[2], inv_n_vals[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) ntt_sqr_pipeline_mont(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k], inv_r[k], inv_n_vals[k]); } #endif // CRT 復元 + キャリー伝搬 crt_recompose(rp, rn, da[0], da[1], da[2], rn, ctx.crt); } // ================================================================ // NTT キャッシュ: 定数オペランドの forward NTT を保存 // ================================================================ // 同じ値で繰り返し乗算する場合、forward NTT を 1 回だけ計算して // キャッシュし、以降は pointwise_mul + INTT のみで乗算する。 // 通常の乗算では NTT(A) + NTT(B) + pointwise + INTT の 4 ステップが必要だが、 // キャッシュ使用時は NTT(A) + pointwise + INTT の 3 ステップで済む。 struct NttCache { std::vector storage_; uint64_t* data[3] = {}; // 3 素数分の forward NTT 結果 (Montgomery 形式) size_t ntt_len = 0; // キャッシュされた NTT 長 size_t orig_limbs = 0; // 元のリム数 void invalidate() { ntt_len = 0; } bool valid(size_t N) const { return ntt_len == N && N > 0; } // bp[0..bn-1] の forward NTT を計算してキャッシュ // roots は呼び出し元の thread_local を共有 (roots 再構築を回避) #ifdef __AVX2__ void build(const uint64_t* bp, size_t bn, size_t N, NttRootsLayered (&roots_ext)[3]) { #else void build(const uint64_t* bp, size_t bn, size_t N, NttRootsMont (&roots_ext)[3]) { #endif auto& ctx = getPrimeNttContext(); ntt_len = N; orig_limbs = bn; storage_.resize(3 * N); for (int k = 0; k < 3; ++k) data[k] = storage_.data() + k * N; // 入力を Montgomery 形式に変換 for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < bn; ++i) data[k][i] = to_mont(bp[i] % p, p, pin, r2); for (size_t i = bn; i < N; ++i) data[k][i] = 0; } // forward NTT (3 素数並列) — 呼び出し元の roots を使用 bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); #ifdef __AVX2__ if (parallel) { const NttRootsLayered* rp[3] = {&roots_ext[0], &roots_ext[1], &roots_ext[2]}; auto f0 = calx::threadPool().submit([&, N, rp]{ forward_ntt_mont_avx2(data[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, rp[0]->fwd.data(), rp[0]->fwd_offset.data()); }); auto f1 = calx::threadPool().submit([&, N, rp]{ forward_ntt_mont_avx2(data[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, rp[1]->fwd.data(), rp[1]->fwd_offset.data()); }); forward_ntt_mont_avx2(data[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, roots_ext[2].fwd.data(), roots_ext[2].fwd_offset.data()); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) forward_ntt_mont_avx2(data[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, roots_ext[k].fwd.data(), roots_ext[k].fwd_offset.data()); } #else const uint64_t* fwd_r[3]; for (int k = 0; k < 3; ++k) fwd_r[k] = roots_ext[k].roots.data(); if (parallel) { auto f0 = calx::threadPool().submit([&, N]{ forward_ntt_mont(data[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0]); }); auto f1 = calx::threadPool().submit([&, N]{ forward_ntt_mont(data[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1]); }); forward_ntt_mont(data[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) forward_ntt_mont(data[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); } #endif } }; // キャッシュ付き乗算: rp = ap × bp (bp の forward NTT はキャッシュから取得) // cache が無効または NTT 長が合わない場合は自動で build する inline void mul_prime_ntt_cached(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn, NttCache& cache) { auto& ctx = getPrimeNttContext(); size_t rn = an + bn; size_t N = next_power_of_2(rn); // A の NTT データ配列を確保 (3 素数分) thread_local std::vector cached_work; size_t total = 3 * N; if (cached_work.size() < total) cached_work.resize(total); uint64_t* da[3]; for (int k = 0; k < 3; ++k) da[k] = cached_work.data() + k * N; // A を Montgomery 形式に変換 for (int k = 0; k < 3; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < an; ++i) da[k][i] = to_mont(ap[i] % p, p, pin, r2); for (size_t i = an; i < N; ++i) da[k][i] = 0; } // パイプライン: NTT(A) + pointwise_mul(A, cached_B) + INTT(A) bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); #ifdef __AVX2__ thread_local NttRootsLayered cached_roots_avx[3]; for (int k = 0; k < 3; ++k) { if (cached_roots_avx[k].N != N || cached_roots_avx[k].p != ctx.primes[k].p) cached_roots_avx[k].build(ctx.primes[k], N); } // キャッシュ構築 (初回または NTT 長変更時) — roots を共有 if (!cache.valid(N)) { cache.build(bp, bn, N, cached_roots_avx); } // thread_local → raw pointer 抽出 (ワーカースレッド安全) const uint64_t* fwd_d[3], *inv_d[3]; const size_t* fwd_o[3], *inv_o[3]; uint64_t inv_n_vals[3]; for (int k = 0; k < 3; ++k) { fwd_d[k] = cached_roots_avx[k].fwd.data(); fwd_o[k] = cached_roots_avx[k].fwd_offset.data(); inv_d[k] = cached_roots_avx[k].inv.data(); inv_o[k] = cached_roots_avx[k].inv_offset.data(); inv_n_vals[k] = cached_roots_avx[k].inv_n; } if (parallel) { auto f0 = calx::threadPool().submit([&, N]{ forward_ntt_mont_avx2(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_d[0], fwd_o[0]); pointwise_mul_mont_avx2(da[0], cache.data[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg); inverse_ntt_mont_avx2(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, inv_d[0], inv_o[0], inv_n_vals[0]); }); auto f1 = calx::threadPool().submit([&, N]{ forward_ntt_mont_avx2(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_d[1], fwd_o[1]); pointwise_mul_mont_avx2(da[1], cache.data[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg); inverse_ntt_mont_avx2(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, inv_d[1], inv_o[1], inv_n_vals[1]); }); forward_ntt_mont_avx2(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_d[2], fwd_o[2]); pointwise_mul_mont_avx2(da[2], cache.data[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg); inverse_ntt_mont_avx2(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, inv_d[2], inv_o[2], inv_n_vals[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { forward_ntt_mont_avx2(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_d[k], fwd_o[k]); pointwise_mul_mont_avx2(da[k], cache.data[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg); inverse_ntt_mont_avx2(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, inv_d[k], inv_o[k], inv_n_vals[k]); } } #else thread_local NttRootsMont cached_roots[3]; for (int k = 0; k < 3; ++k) { if (cached_roots[k].N != N || cached_roots[k].p != ctx.primes[k].p) cached_roots[k].build(ctx.primes[k], N); } // キャッシュ構築 (初回または NTT 長変更時) — roots を共有 if (!cache.valid(N)) { cache.build(bp, bn, N, cached_roots); } const uint64_t* fwd_r[3], *inv_r[3]; uint64_t inv_n_vals[3]; for (int k = 0; k < 3; ++k) { fwd_r[k] = cached_roots[k].roots.data(); inv_r[k] = cached_roots[k].inv_roots.data(); inv_n_vals[k] = cached_roots[k].inv_n; } if (parallel) { auto f0 = calx::threadPool().submit([&, N]{ forward_ntt_mont(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, fwd_r[0]); pointwise_mul_mont(da[0], cache.data[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg); inverse_ntt_mont(da[0], N, ctx.primes[0].p, ctx.primes[0].p_inv_neg, inv_r[0], inv_n_vals[0]); }); auto f1 = calx::threadPool().submit([&, N]{ forward_ntt_mont(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, fwd_r[1]); pointwise_mul_mont(da[1], cache.data[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg); inverse_ntt_mont(da[1], N, ctx.primes[1].p, ctx.primes[1].p_inv_neg, inv_r[1], inv_n_vals[1]); }); forward_ntt_mont(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, fwd_r[2]); pointwise_mul_mont(da[2], cache.data[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg); inverse_ntt_mont(da[2], N, ctx.primes[2].p, ctx.primes[2].p_inv_neg, inv_r[2], inv_n_vals[2]); f0.get(); f1.get(); } else { for (int k = 0; k < 3; ++k) { forward_ntt_mont(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); pointwise_mul_mont(da[k], cache.data[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg); inverse_ntt_mont(da[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, inv_r[k], inv_n_vals[k]); } } #endif // CRT 復元 crt_recompose(rp, rn, da[0], da[1], da[2], rn, ctx.crt); } // ================================================================ // NttInt: NTT ドメインでの整数表現 // ================================================================ // BS_NUM_PRIMES 素数 × N 係数の評価値を保持。 // BS merge ステップを NTT ドメインで実行し、 // 中間段階の NTT/INTT 変換を省略して高速化する。 // NTT ドメイン整数: BS_NUM_PRIMES 素数 × ntt_len の評価値 (Montgomery 形式) struct NttInt { std::vector storage_; uint64_t* data[BS_NUM_PRIMES] = {}; size_t ntt_len = 0; size_t orig_limbs = 0; // 元のリム数 (CRT 復元時の出力サイズ) NttInt() = default; void alloc(size_t N) { if (ntt_len == N && !storage_.empty()) return; ntt_len = N; storage_.resize(BS_NUM_PRIMES * N); for (int k = 0; k < BS_NUM_PRIMES; ++k) data[k] = storage_.data() + k * N; } bool empty() const { return ntt_len == 0; } }; // ================================================================ // NttInt の変換関数 (BS_NUM_PRIMES 素数) // ================================================================ // リム列 → NTT ドメイン (forward NTT × BS_NUM_PRIMES 素数, Montgomery 形式) inline void ntt_forward(NttInt& r, const uint64_t* src, size_t n, size_t ntt_len) { auto& ctx = getNttBsContext(); r.alloc(ntt_len); r.orig_limbs = n; // Montgomery 形式の根テーブル (thread_local キャッシュ) thread_local NttRootsMont roots[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { if (roots[k].N != ntt_len || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], ntt_len); } // 入力を mod p → Montgomery 変換 + ゼロパディング for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < n; ++i) r.data[k][i] = to_mont(src[i] % p, p, pin, r2); std::memset(r.data[k] + n, 0, (ntt_len - n) * sizeof(uint64_t)); } // 根テーブルのポインタを事前取得 (ワーカーから thread_local を参照しない) const uint64_t* fwd_r[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) fwd_r[k] = roots[k].roots.data(); // Forward NTT (BS_NUM_PRIMES 素数 — 並列実行) bool parallel = (ntt_len >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { std::future futures[BS_NUM_PRIMES - 1]; for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) { futures[k] = calx::threadPool().submit([&, k, ntt_len]{ forward_ntt_mont(r.data[k], ntt_len, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); }); } forward_ntt_mont(r.data[BS_NUM_PRIMES - 1], ntt_len, ctx.primes[BS_NUM_PRIMES - 1].p, ctx.primes[BS_NUM_PRIMES - 1].p_inv_neg, fwd_r[BS_NUM_PRIMES - 1]); for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) futures[k].get(); } else { for (int k = 0; k < BS_NUM_PRIMES; ++k) forward_ntt_mont(r.data[k], ntt_len, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); } } // NTT ドメイン → リム列 (inverse NTT × BS_NUM_PRIMES 素数 + CRT) inline void ntt_inverse(uint64_t* dst, size_t dst_limbs, const NttInt& src) { auto& ctx = getNttBsContext(); size_t N = src.ntt_len; // 根テーブル thread_local NttRootsMont roots[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { if (roots[k].N != N || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], N); } // INTT 用の作業コピー (src を破壊しないため) std::vector work_buf(BS_NUM_PRIMES * N); uint64_t* work[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { work[k] = work_buf.data() + k * N; std::memcpy(work[k], src.data[k], N * sizeof(uint64_t)); } // 根テーブルのポインタを事前取得 const uint64_t* inv_r[BS_NUM_PRIMES]; uint64_t inv_n_vals[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { inv_r[k] = roots[k].inv_roots.data(); inv_n_vals[k] = roots[k].inv_n; } // Inverse NTT (BS_NUM_PRIMES 素数 — 並列実行) bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { std::future futures[BS_NUM_PRIMES - 1]; for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) { futures[k] = calx::threadPool().submit([&, k, N]{ inverse_ntt_mont(work[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, inv_r[k], inv_n_vals[k]); }); } inverse_ntt_mont(work[BS_NUM_PRIMES - 1], N, ctx.primes[BS_NUM_PRIMES - 1].p, ctx.primes[BS_NUM_PRIMES - 1].p_inv_neg, inv_r[BS_NUM_PRIMES - 1], inv_n_vals[BS_NUM_PRIMES - 1]); for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) futures[k].get(); } else { for (int k = 0; k < BS_NUM_PRIMES; ++k) inverse_ntt_mont(work[k], N, ctx.primes[k].p, ctx.primes[k].p_inv_neg, inv_r[k], inv_n_vals[k]); } // CRT 復元 (5 素数) crt_recompose_n(dst, dst_limbs, work, N, ctx.crt); } // ================================================================ // 16-bit 分割版 NTT forward / inverse (保持、ただし現在未使用) // 5 素数化により 64-bit リムのまま深さ 2 まで安全。 // ================================================================ // 16-bit 分割版 forward NTT inline void ntt_forward_16(NttInt& r, const uint64_t* src, size_t n, size_t ntt_len) { auto& ctx = getNttBsContext(); r.alloc(ntt_len); r.orig_limbs = n; size_t n16 = n * 4; thread_local NttRootsMont roots[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) { if (roots[k].N != ntt_len || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], ntt_len); } for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; uint64_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < n; ++i) { uint64_t w = src[i]; r.data[k][4*i ] = to_mont( w & 0xFFFF, p, pin, r2); r.data[k][4*i + 1] = to_mont((w >> 16) & 0xFFFF, p, pin, r2); r.data[k][4*i + 2] = to_mont((w >> 32) & 0xFFFF, p, pin, r2); r.data[k][4*i + 3] = to_mont((w >> 48) & 0xFFFF, p, pin, r2); } std::memset(r.data[k] + n16, 0, (ntt_len - n16) * sizeof(uint64_t)); } const uint64_t* fwd_r[BS_NUM_PRIMES]; for (int k = 0; k < BS_NUM_PRIMES; ++k) fwd_r[k] = roots[k].roots.data(); bool parallel = (ntt_len >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { std::future futures[BS_NUM_PRIMES - 1]; for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) { futures[k] = calx::threadPool().submit([&, k, ntt_len]{ forward_ntt_mont(r.data[k], ntt_len, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); }); } forward_ntt_mont(r.data[BS_NUM_PRIMES - 1], ntt_len, ctx.primes[BS_NUM_PRIMES - 1].p, ctx.primes[BS_NUM_PRIMES - 1].p_inv_neg, fwd_r[BS_NUM_PRIMES - 1]); for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) futures[k].get(); } else { for (int k = 0; k < BS_NUM_PRIMES; ++k) forward_ntt_mont(r.data[k], ntt_len, ctx.primes[k].p, ctx.primes[k].p_inv_neg, fwd_r[k]); } } // 16-bit 分割版 inverse NTT inline void ntt_inverse_16(uint64_t* dst, size_t dst_limbs, const NttInt& src) { // 5 素数版: 通常の ntt_inverse を使用して CRT 復元後、 // base-2^16 パック処理が必要な場合のみこちらを使う。 // 現在は未使用 (64-bit リムのまま深さ 2 まで安全)。 ntt_inverse(dst, dst_limbs, src); } // ================================================================ // NttInt の pointwise 演算 (BS_NUM_PRIMES 素数) // ================================================================ // r = a ⊗ b (pointwise Montgomery 乗算) inline void ntt_pmul(NttInt& r, const NttInt& a, const NttInt& b) { auto& ctx = getNttBsContext(); size_t N = a.ntt_len; bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { std::future futures[BS_NUM_PRIMES - 1]; for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) { futures[k] = calx::threadPool().submit([&, k]{ uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mont_mul(a.data[k][i], b.data[k][i], p, pin); }); } { constexpr int k = BS_NUM_PRIMES - 1; uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mont_mul(a.data[k][i], b.data[k][i], p, pin); } for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) futures[k].get(); } else { for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mont_mul(a.data[k][i], b.data[k][i], p, pin); } } } // r = a + b (pointwise mod add) inline void ntt_padd(NttInt& r, const NttInt& a, const NttInt& b) { auto& ctx = getNttBsContext(); size_t N = a.ntt_len; for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; for (size_t i = 0; i < N; ++i) r.data[k][i] = mod_add(a.data[k][i], b.data[k][i], p); } } // r = a - b (pointwise mod sub) inline void ntt_psub(NttInt& r, const NttInt& a, const NttInt& b) { auto& ctx = getNttBsContext(); size_t N = a.ntt_len; for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; for (size_t i = 0; i < N; ++i) r.data[k][i] = mod_sub(a.data[k][i], b.data[k][i], p); } } // r += a ⊗ b (pointwise multiply-accumulate) inline void ntt_pmac(NttInt& r, const NttInt& a, const NttInt& b) { auto& ctx = getNttBsContext(); size_t N = a.ntt_len; bool parallel = (N >= PRIME_NTT_PARALLEL_THRESHOLD); if (parallel) { std::future futures[BS_NUM_PRIMES - 1]; for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) { futures[k] = calx::threadPool().submit([&, k]{ uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mod_add(r.data[k][i], mont_mul(a.data[k][i], b.data[k][i], p, pin), p); }); } { constexpr int k = BS_NUM_PRIMES - 1; uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mod_add(r.data[k][i], mont_mul(a.data[k][i], b.data[k][i], p, pin), p); } for (int k = 0; k < BS_NUM_PRIMES - 1; ++k) futures[k].get(); } else { for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; uint64_t pin = ctx.primes[k].p_inv_neg; for (size_t i = 0; i < N; ++i) r.data[k][i] = mod_add(r.data[k][i], mont_mul(a.data[k][i], b.data[k][i], p, pin), p); } } } // r の符号反転 (pointwise negate: p - r[i]) inline void ntt_pneg(NttInt& r) { auto& ctx = getNttBsContext(); size_t N = r.ntt_len; for (int k = 0; k < BS_NUM_PRIMES; ++k) { uint64_t p = ctx.primes[k].p; for (size_t i = 0; i < N; ++i) r.data[k][i] = (r.data[k][i] == 0) ? 0 : (p - r.data[k][i]); } } // ================================================================ // NTT ドメイン BS merge 関数 // ================================================================ // Chudnovsky / atanh 用: // T = TL ⊗ QR + PL ⊗ TR // Q = QL ⊗ QR // P = PL ⊗ PR (need_P のとき) inline void bs_merge2_ntt(NttInt& P, NttInt& Q, NttInt& T, NttInt& PL, NttInt& QL, NttInt& TL, NttInt& PR, NttInt& QR, NttInt& TR, bool need_P, size_t ntt_len) { // T = TL ⊗ QR T.alloc(ntt_len); ntt_pmul(T, TL, QR); // T += PL ⊗ TR ntt_pmac(T, PL, TR); // Q = QL ⊗ QR Q.alloc(ntt_len); ntt_pmul(Q, QL, QR); // P = PL ⊗ PR if (need_P) { P.alloc(ntt_len); ntt_pmul(P, PL, PR); } } // factorial 用: // T = TL ⊗ QR + TR (P=1 なので PL*TR = TR) // Q = QL ⊗ QR inline void fac_merge2_ntt(NttInt& Q, NttInt& T, NttInt& QL, NttInt& TL, NttInt& QR, NttInt& TR, size_t ntt_len) { // T = TL ⊗ QR + TR T.alloc(ntt_len); ntt_pmul(T, TL, QR); ntt_padd(T, T, TR); // Q = QL ⊗ QR Q.alloc(ntt_len); ntt_pmul(Q, QL, QR); } // Euler-Mascheroni 用: // P = PL ⊗ PR, Q = QL ⊗ QR, D = DL ⊗ DR // B = BL ⊗ DR + DL ⊗ BR // T = TL ⊗ QR + PL ⊗ TR // S = SL ⊗ (QR ⊗ DR) + PL ⊗ (SR ⊗ DL + BL ⊗ (TR ⊗ DR)) inline void euler_merge2_ntt(NttInt& P, NttInt& Q, NttInt& D, NttInt& B, NttInt& T, NttInt& S, NttInt& PL, NttInt& QL, NttInt& DL, NttInt& BL, NttInt& TL, NttInt& SL, NttInt& PR, NttInt& QR, NttInt& DR, NttInt& BR, NttInt& TR, NttInt& SR, size_t ntt_len) { P.alloc(ntt_len); ntt_pmul(P, PL, PR); Q.alloc(ntt_len); ntt_pmul(Q, QL, QR); D.alloc(ntt_len); ntt_pmul(D, DL, DR); // B = BL ⊗ DR + DL ⊗ BR B.alloc(ntt_len); ntt_pmul(B, BL, DR); ntt_pmac(B, DL, BR); // T = TL ⊗ QR + PL ⊗ TR T.alloc(ntt_len); ntt_pmul(T, TL, QR); ntt_pmac(T, PL, TR); // S = SL ⊗ QR_DR + PL ⊗ (SR ⊗ DL + BL ⊗ TR_DR) // QR_DR = QR ⊗ DR (一時変数) NttInt QR_DR; QR_DR.alloc(ntt_len); ntt_pmul(QR_DR, QR, DR); // TR_DR = TR ⊗ DR (一時変数) NttInt TR_DR; TR_DR.alloc(ntt_len); ntt_pmul(TR_DR, TR, DR); // inner = SR ⊗ DL + BL ⊗ TR_DR NttInt inner; inner.alloc(ntt_len); ntt_pmul(inner, SR, DL); ntt_pmac(inner, BL, TR_DR); // S = SL ⊗ QR_DR + PL ⊗ inner S.alloc(ntt_len); ntt_pmul(S, SL, QR_DR); ntt_pmac(S, PL, inner); } } // namespace prime_ntt } // namespace calx