// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // SmallPrimeNtt.hpp // 32-bit NTT-friendly 素数 5 個 + Garner CRT による高速乗算 // 各素数が 32bit に収まるため、Montgomery 乗算が 32x32→64 の // 軽い命令だけで完結する。64-bit Montgomery (2x MULX) より低コスト。 // AVX2 では VPMULUDQ で 4 要素並列、加減算は VPADDD/VPSUBD で // 8 要素並列が可能 (将来の最適化パス)。 // // 制約: 5 素数の積 ≈ 2^150 → 係数上限 n * 2^128 < 2^150 // → n < 2^22 ≈ 4M limbs (約 7600 万桁) まで正確 // // ================================================================ // 性能比較 (vs 64-bit 3-prime NTT with AVX2, Ryzen 9 5950X) // ================================================================ // 32-bit 5-prime NTT は PrimeNTT (64-bit 3-prime AVX2) に対して 4-11x 遅い。 // // balanced mul: // n(limbs) SmallPrime PrimeNTT ratio // 500 445 us 87 us 5.1x // 1000 989 us 181 us 5.5x // 2000 2.11 ms 200 us 10.6x // 5000 9.32 ms 886 us 10.5x // 10000 20.6 ms 1.88 ms 11.0x // 50000 89.1 ms 21.9 ms 4.1x // // sqr: // n(limbs) SmallPrime PrimeNTT ratio // 1000 660 us 147 us 4.5x // 5000 5.43 ms 613 us 8.9x // 20000 27.1 ms 3.01 ms 9.0x // // 主な性能差の要因: // 1. 素数数: 5 個 vs 3 個 → NTT 変換 5/3 = 1.67x 多い // 2. CRT 復元: 5 素数 Garner CRT は O(n·k²) で k=5 vs k=3 → 2.78x 重い // 3. 入力 mod 変換: 64-bit limb を 32-bit 素数で mod するコスト (5 回 vs 3 回) // 4. AVX2 未最適化: PrimeNTT は AVX2 Montgomery + layered root テーブルで // 高度に最適化済み。SmallPrime は scalar 実装のみ。 // 5. NTT サイズ同一: 32-bit 素数でも NTT サイズは 2n (同一) のため、 // 32-bit Montgomery の低コストが素数数増加を相殺できない。 // // 理論的に AVX2 VPMULLD (8×32) を使えば素数あたりの NTT は高速化できるが、 // VPMULLD は Zen 3 で 4 cycle/1 ポートと重く (VPMULUDQ は 3 cycle/2 ポート)、 // さらに 5 素数の CRT オーバーヘッドが支配的なため、64-bit 3-prime を // 上回ることは現実的に困難。 #pragma once #include #include #include #include #include #ifdef _MSC_VER #include #endif namespace calx { namespace small_prime_ntt { // ================================================================ // 32-bit NTT-friendly 素数 (p = k * 2^s + 1) // ================================================================ // 競プロ・数値計算で実績のある素数。全て < 2^31。 // min(s) = 23 → 最大 NTT サイズ 2^23 = 8M 点 static constexpr int NUM_PRIMES = 5; // p k s g(原始根) constexpr uint32_t SP1 = 469762049; // 7 26 3 constexpr uint32_t SP2 = 998244353; // 119 23 3 constexpr uint32_t SP3 = 167772161; // 5 25 3 constexpr uint32_t SP4 = 754974721; // 45 24 11 constexpr uint32_t SP5 = 2013265921; // 15 27 31 constexpr uint32_t SP_G[NUM_PRIMES] = { 3, 3, 3, 11, 31 }; constexpr uint32_t SP_ALL[NUM_PRIMES] = { SP1, SP2, SP3, SP4, SP5 }; constexpr int SP_MAX_S = 23; // min(26,23,25,24,27) = 23 // ================================================================ // 32-bit 剰余演算 // ================================================================ inline uint32_t mod_add_32(uint32_t a, uint32_t b, uint32_t p) { uint32_t sum = a + b; return (sum >= p) ? (sum - p) : sum; } inline uint32_t mod_sub_32(uint32_t a, uint32_t b, uint32_t p) { return (a >= b) ? (a - b) : (a - b + p); } // 32-bit Montgomery: R = 2^32 // REDC(T) = (T + (T_lo * p_inv_neg) * p) >> 32 // T = a*b < p^2 < 2^62, p_inv_neg = -p^(-1) mod 2^32 inline uint32_t mont_mul_32(uint32_t a, uint32_t b, uint32_t p, uint32_t p_inv_neg) { uint64_t T = (uint64_t)a * b; uint32_t m = (uint32_t)T * p_inv_neg; // mod 2^32 uint64_t mp = (uint64_t)m * p; uint64_t t = T + mp; // 下位 32bit は構造的に 0 uint32_t result = (uint32_t)(t >> 32); return (result >= p) ? (result - p) : result; } // p^(-1) mod 2^32 を Newton 法で計算し符号反転 inline uint32_t compute_p_inv_neg_32(uint32_t p) { uint32_t x = 1; for (int i = 0; i < 5; i++) // 5 回で 32bit 収束 x *= 2 - p * x; return ~x + 1; } // 32-bit mod_mul (非 Montgomery, 初期化用) inline uint32_t mod_mul_32(uint32_t a, uint32_t b, uint32_t p) { return (uint32_t)((uint64_t)a * b % p); } // base^exp mod p inline uint32_t mod_pow_32(uint32_t base, uint32_t exp, uint32_t p) { uint64_t result = 1; uint64_t b = base % p; while (exp > 0) { if (exp & 1) result = result * b % p; b = b * b % p; exp >>= 1; } return (uint32_t)result; } inline uint32_t mod_inv_32(uint32_t a, uint32_t p) { return mod_pow_32(a, p - 2, p); } // Montgomery 変換: to_mont(a) = a*R mod p inline uint32_t to_mont_32(uint32_t a, uint32_t p, uint32_t p_inv_neg, uint32_t r2_mod_p) { return mont_mul_32(a, r2_mod_p, p, p_inv_neg); } inline uint32_t from_mont_32(uint32_t a_mont, uint32_t p, uint32_t p_inv_neg) { return mont_mul_32(a_mont, 1, p, p_inv_neg); } // ================================================================ // NTT 素数パラメータ // ================================================================ struct SmallNttPrime { uint32_t p; uint32_t g; // 原始根 uint32_t p_inv_neg; // -p^(-1) mod 2^32 uint32_t r2_mod_p; // R^2 mod p (R = 2^32) uint32_t nth_root(size_t N) const { return mod_pow_32(g, (p - 1) / (uint32_t)N, p); } }; // ================================================================ // NTT ルートテーブル (Montgomery 形式) // ================================================================ struct SmallNttRoots { size_t N = 0; uint32_t p = 0; std::vector roots; // forward roots (Montgomery) std::vector inv_roots; // inverse roots (Montgomery) uint32_t inv_n = 0; // N^(-1) mod p (Montgomery) void build(const SmallNttPrime& prime, size_t ntt_size) { N = ntt_size; p = prime.p; roots.resize(N); inv_roots.resize(N); uint32_t omega = prime.nth_root(N); uint32_t omega_inv = mod_inv_32(omega, p); uint32_t inv_n_plain = mod_inv_32((uint32_t)N, p); inv_n = to_mont_32(inv_n_plain, p, prime.p_inv_neg, prime.r2_mod_p); // ルートを Montgomery 形式で格納 uint32_t w = 1; for (size_t i = 0; i < N; ++i) { roots[i] = to_mont_32(w, p, prime.p_inv_neg, prime.r2_mod_p); w = mod_mul_32(w, omega, p); } w = 1; for (size_t i = 0; i < N; ++i) { inv_roots[i] = to_mont_32(w, p, prime.p_inv_neg, prime.r2_mod_p); w = mod_mul_32(w, omega_inv, p); } } }; // ================================================================ // NTT 変換 (DIF / DIT, 32-bit Montgomery) // ================================================================ inline void forward_ntt_32(uint32_t* data, size_t N, uint32_t p, uint32_t p_inv_neg, const uint32_t* roots) { for (size_t len = N; len >= 2; len >>= 1) { size_t half = len >> 1; size_t step = N / len; for (size_t i = 0; i < N; i += len) { for (size_t j = 0; j < half; ++j) { uint32_t u = data[i + j]; uint32_t v = data[i + j + half]; data[i + j] = mod_add_32(u, v, p); data[i + j + half] = mont_mul_32( mod_sub_32(u, v, p), roots[j * step], p, p_inv_neg); } } } } inline void inverse_ntt_32(uint32_t* data, size_t N, uint32_t p, uint32_t p_inv_neg, const uint32_t* inv_roots, uint32_t inv_n) { for (size_t len = 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t step = N / len; for (size_t i = 0; i < N; i += len) { for (size_t j = 0; j < half; ++j) { uint32_t u = data[i + j]; uint32_t v = mont_mul_32(data[i + j + half], inv_roots[j * step], p, p_inv_neg); data[i + j] = mod_add_32(u, v, p); data[i + j + half] = mod_sub_32(u, v, p); } } } // 1/N scaling for (size_t i = 0; i < N; ++i) data[i] = mont_mul_32(data[i], inv_n, p, p_inv_neg); } inline void pointwise_mul_32(uint32_t* a, const uint32_t* b, size_t N, uint32_t p, uint32_t p_inv_neg) { for (size_t i = 0; i < N; ++i) a[i] = mont_mul_32(a[i], b[i], p, p_inv_neg); } // ================================================================ // 5 素数 Garner CRT // ================================================================ struct SmallCrtConstants { uint32_t p[NUM_PRIMES]; // Garner 逆元: inv[i][j] = p[j]^(-1) mod p[i] (j < i) uint32_t inv[NUM_PRIMES][NUM_PRIMES]; // prefix[i] = p[0]*p[1]*...*p[i-1] (多ワード) // 最大 5 素数 × ~31bit = ~155 bit → 3 ワード static constexpr int MAX_WORDS = 3; uint64_t prefix[NUM_PRIMES][MAX_WORDS]; int prefix_len[NUM_PRIMES]; void init() { for (int i = 0; i < NUM_PRIMES; i++) p[i] = SP_ALL[i]; for (int i = 0; i < NUM_PRIMES; i++) for (int j = 0; j < i; j++) inv[i][j] = mod_inv_32(p[j] % p[i], p[i]); // prefix product std::memset(prefix, 0, sizeof(prefix)); prefix[0][0] = 1; prefix_len[0] = 1; for (int i = 1; i < NUM_PRIMES; i++) { uint64_t carry = 0; for (int w = 0; w < prefix_len[i - 1]; w++) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi; uint64_t lo = _umul128(prefix[i - 1][w], (uint64_t)p[i - 1], &hi); lo += carry; if (lo < carry) hi++; prefix[i][w] = lo; carry = hi; #elif defined(__GNUC__) || defined(__clang__) unsigned __int128 p128 = (unsigned __int128)prefix[i - 1][w] * p[i - 1]; p128 += carry; prefix[i][w] = (uint64_t)p128; carry = (uint64_t)(p128 >> 64); #endif } if (carry) { prefix[i][prefix_len[i - 1]] = carry; prefix_len[i] = prefix_len[i - 1] + 1; } else { prefix_len[i] = prefix_len[i - 1]; } } } }; // Garner CRT: 5 素数の剰余 → 最大 3 ワードの値を復元 inline void crt_single_5(uint64_t result[3], const uint32_t r[NUM_PRIMES], const SmallCrtConstants& crt) { // Garner 混合基数表現 uint32_t v[NUM_PRIMES]; v[0] = r[0]; for (int i = 1; i < NUM_PRIMES; i++) { uint64_t temp = r[i]; for (int j = 0; j < i; j++) { temp = (temp + crt.p[i] - (v[j] % crt.p[i])) % crt.p[i]; temp = temp * crt.inv[i][j] % crt.p[i]; } v[i] = (uint32_t)temp; } // 復元: x = v[0] + v[1]*prefix[1] + v[2]*prefix[2] + ... result[0] = result[1] = result[2] = 0; result[0] = v[0]; for (int i = 1; i < NUM_PRIMES; i++) { uint64_t carry = 0; for (int w = 0; w < crt.prefix_len[i]; w++) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi; uint64_t lo = _umul128(v[i], crt.prefix[i][w], &hi); #elif defined(__GNUC__) || defined(__clang__) unsigned __int128 prod = (unsigned __int128)v[i] * crt.prefix[i][w]; uint64_t lo = (uint64_t)prod; uint64_t hi = (uint64_t)(prod >> 64); #endif lo += carry; if (lo < carry) hi++; lo += result[w]; if (lo < result[w]) hi++; result[w] = lo; carry = hi; } for (int w = crt.prefix_len[i]; w < 3 && carry; w++) { result[w] += carry; carry = (result[w] < carry) ? 1 : 0; } } } // CRT 復元 + キャリー伝搬 inline void crt_recompose_5(uint64_t* rp, size_t result_limbs, uint32_t* const da[NUM_PRIMES], size_t N, const SmallCrtConstants& crt) { std::memset(rp, 0, result_limbs * sizeof(uint64_t)); uint64_t carry_lo = 0, carry_hi = 0; for (size_t i = 0; i < N && i < result_limbs; ++i) { uint32_t residues[NUM_PRIMES]; for (int k = 0; k < NUM_PRIMES; k++) residues[k] = da[k][i]; uint64_t coeff[3]; crt_single_5(coeff, residues, crt); // coeff[0:2] + carry uint64_t sum0 = coeff[0] + carry_lo; uint64_t c0 = (sum0 < coeff[0]) ? 1 : 0; uint64_t sum1 = coeff[1] + carry_hi + c0; uint64_t c1 = (sum1 < coeff[1] || (c0 && sum1 <= coeff[1])) ? 1 : 0; uint64_t sum2 = coeff[2] + c1; rp[i] = sum0; carry_lo = sum1; carry_hi = sum2; } if (carry_lo != 0 && N < result_limbs) { rp[N] = carry_lo; if (carry_hi != 0 && N + 1 < result_limbs) rp[N + 1] = carry_hi; } } // ================================================================ // ユーティリティ // ================================================================ inline size_t next_pow2_sp(size_t n) { size_t p = 1; while (p < n) p <<= 1; return p; } // ================================================================ // コンテキスト (初回初期化) // ================================================================ struct SmallPrimeNttContext { SmallNttPrime primes[NUM_PRIMES]; SmallCrtConstants crt; bool initialized = false; void init() { for (int k = 0; k < NUM_PRIMES; k++) { uint32_t p = SP_ALL[k]; uint32_t pin = compute_p_inv_neg_32(p); // R^2 mod p: R = 2^32 uint32_t r_mod_p = mod_pow_32(2, 32, p); uint32_t r2 = mod_mul_32(r_mod_p, r_mod_p, p); primes[k] = { p, SP_G[k], pin, r2 }; } crt.init(); initialized = true; } }; inline SmallPrimeNttContext& getSmallPrimeNttContext() { static SmallPrimeNttContext ctx; if (!ctx.initialized) ctx.init(); return ctx; } // ================================================================ // メイン乗算関数 // ================================================================ // rp[0..an+bn-1] = ap[0..an-1] * bp[0..bn-1] inline void mul_small_prime_ntt(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn) { if (an < bn) { std::swap(ap, bp); std::swap(an, bn); } auto& ctx = getSmallPrimeNttContext(); size_t rn = an + bn; size_t N = next_pow2_sp(rn); // NTT サイズチェック if (N > (1ULL << SP_MAX_S)) return; // 5 素数分のバッファ (各 N 要素 × uint32_t) // data_a[k][N] + data_b[k][N] = 2N per prime, 計 10N uint32 thread_local std::vector work; size_t total = 10 * N; if (work.size() < total) work.resize(total); uint32_t* da[NUM_PRIMES]; uint32_t* db[NUM_PRIMES]; for (int k = 0; k < NUM_PRIMES; k++) { da[k] = work.data() + k * 2 * N; db[k] = da[k] + N; } // 入力を各素数で mod + Montgomery 変換 for (int k = 0; k < NUM_PRIMES; k++) { uint32_t p = ctx.primes[k].p; uint32_t pin = ctx.primes[k].p_inv_neg; uint32_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < an; i++) da[k][i] = to_mont_32((uint32_t)(ap[i] % p), p, pin, r2); for (size_t i = an; i < N; i++) da[k][i] = 0; for (size_t i = 0; i < bn; i++) db[k][i] = to_mont_32((uint32_t)(bp[i] % p), p, pin, r2); for (size_t i = bn; i < N; i++) db[k][i] = 0; } // NTT パイプライン (5 素数, 各独立) thread_local SmallNttRoots roots[NUM_PRIMES]; for (int k = 0; k < NUM_PRIMES; k++) { if (roots[k].N != N || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], N); } for (int k = 0; k < NUM_PRIMES; k++) { uint32_t p = ctx.primes[k].p; uint32_t pin = ctx.primes[k].p_inv_neg; forward_ntt_32(da[k], N, p, pin, roots[k].roots.data()); forward_ntt_32(db[k], N, p, pin, roots[k].roots.data()); pointwise_mul_32(da[k], db[k], N, p, pin); inverse_ntt_32(da[k], N, p, pin, roots[k].inv_roots.data(), roots[k].inv_n); // Montgomery → 通常値に戻す for (size_t i = 0; i < N; i++) da[k][i] = from_mont_32(da[k][i], p, pin); } // CRT 復元 crt_recompose_5(rp, rn, da, N, ctx.crt); } // 自乗: rp[0..2*an-1] = ap[0..an-1]^2 inline void sqr_small_prime_ntt(uint64_t* rp, const uint64_t* ap, size_t an) { size_t rn = 2 * an; size_t N = next_pow2_sp(rn); if (N > (1ULL << SP_MAX_S)) return; auto& ctx = getSmallPrimeNttContext(); thread_local std::vector work; size_t total = 5 * N; // da only if (work.size() < total) work.resize(total); uint32_t* da[NUM_PRIMES]; for (int k = 0; k < NUM_PRIMES; k++) da[k] = work.data() + k * N; for (int k = 0; k < NUM_PRIMES; k++) { uint32_t p = ctx.primes[k].p; uint32_t pin = ctx.primes[k].p_inv_neg; uint32_t r2 = ctx.primes[k].r2_mod_p; for (size_t i = 0; i < an; i++) da[k][i] = to_mont_32((uint32_t)(ap[i] % p), p, pin, r2); for (size_t i = an; i < N; i++) da[k][i] = 0; } thread_local SmallNttRoots roots[NUM_PRIMES]; for (int k = 0; k < NUM_PRIMES; k++) { if (roots[k].N != N || roots[k].p != ctx.primes[k].p) roots[k].build(ctx.primes[k], N); } for (int k = 0; k < NUM_PRIMES; k++) { uint32_t p = ctx.primes[k].p; uint32_t pin = ctx.primes[k].p_inv_neg; forward_ntt_32(da[k], N, p, pin, roots[k].roots.data()); for (size_t i = 0; i < N; i++) da[k][i] = mont_mul_32(da[k][i], da[k][i], p, pin); inverse_ntt_32(da[k], N, p, pin, roots[k].inv_roots.data(), roots[k].inv_n); for (size_t i = 0; i < N; i++) da[k][i] = from_mont_32(da[k][i], p, pin); } crt_recompose_5(rp, rn, da, N, ctx.crt); } } // namespace small_prime_ntt } // namespace calx