// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // GoldilocksNtt.hpp // Goldilocks prime (p = 2^64 - 2^32 + 1) NTT // CRT 不要の単一素数 NTT。剰余演算がシフト+加減算のみで完結する。 // 入力は 16-bit ワード分割 (4 倍展開) で係数オーバーフローを回避。 // // 制約: NTT サイズ <= 2^32 (p-1 の 2 冪因子が 2^32) // 係数の最大値: N * (2^16 - 1)^2 < p (N < 2^32 で常に成立) // // ---------------------------------------------------------------- // 性能比較 (2026-03-16, AMD Ryzen Threadripper PRO 5995WX, Release x64) // // Goldilocks NTT は Prime NTT (3素数+CRT, AVX2) に対して 4-10x 遅い。 // // n(limbs) Goldilocks PrimeNTT NTT-size ratio // 500 399 us 97 us 4K / 1K 4.1x // 1000 872 us 193 us 8K / 2K 4.5x // 3000 3.93 ms 381 us 32K / 8K 10.3x // 10000 18.54 ms 1.83 ms 128K / 32K 10.2x // 50000 99.31 ms 21.16 ms 512K / 128K 4.7x // // 要因: // - 16-bit 分割の 4 倍展開が支配的。NTT サイズが 4 倍になり // O(N log N) の計算量が 4-5 倍に増大する。 // - CRT 不要 (3 NTT → 1 NTT) の恩恵では 4 倍展開を相殺できない。 // - Goldilocks 剰余演算 (shift+add, 1xMULX) は Montgomery (2xMULX) // より軽いが、Prime NTT が AVX2 で 4 要素並列化済みなのに対し // Goldilocks はスカラー実装のみで、実効差が小さい。 // // 改善の余地: // - 21-bit 分割 (3 倍展開, N < 2^22 で安全) で NTT サイズを 25% 削減 // - AVX2 Goldilocks バタフライ (4 VPMULUDQ vs Montgomery の 11 VPMULUDQ) // - ただし既存 Prime NTT の高度な最適化 (AVX2+MT+キャッシュブロッキング) // に対して有利になるかは不明 // ---------------------------------------------------------------- #pragma once #include #include #include #include #include #ifdef _MSC_VER #include #endif namespace calx { namespace goldilocks_ntt { // ================================================================ // Goldilocks 素数定数 // ================================================================ // p = 2^64 - 2^32 + 1 constexpr uint64_t GOLD_P = 0xFFFFFFFF00000001ULL; // p-1 = 2^32 * (2^32 - 1) = 2^32 * 3 * 5 * 17 * 257 * 65537 constexpr int GOLD_MAX_S = 32; // NTT 最大長 = 2^32 // 原始根 g = 7 (g^(p-1) = 1 mod p, g^((p-1)/q) != 1 for all prime q | p-1) constexpr uint64_t GOLD_G = 7; // ================================================================ // Goldilocks 高速剰余演算 // ================================================================ // (a + b) mod p — a, b < p // a + b < 2p < 2^65。オーバーフロー時は 2^64 mod p = 2^32 - 1 を加算。 // 証明: carry=1 のとき wrapped sum < 2p - 2^64 = 2^64 - 2^33 + 2。 // sum + (2^32-1) < 2^64 - 2^32 + 1 = p → 再オーバーフローなし。 inline uint64_t gold_add(uint64_t a, uint64_t b) { uint64_t sum = a + b; if (sum < a) { // carry=1: 実値は sum + 2^64 ≡ sum + (2^32 - 1) (mod p) // sum + (2^32-1) < p なので最終正規化不要 return sum + 0xFFFFFFFFULL; } return (sum >= GOLD_P) ? (sum - GOLD_P) : sum; } // (a - b) mod p — a, b < p inline uint64_t gold_sub(uint64_t a, uint64_t b) { if (a >= b) return a - b; return a - b + GOLD_P; // wrap: a - b + p } // Goldilocks reduction: (hi : lo) mod p // p = 2^64 - 2^32 + 1, よって 2^64 ≡ 2^32 - 1 (mod p) // (hi * 2^64 + lo) mod p = (hi * (2^32 - 1) + lo) mod p // // 2段階リダクション: // Step 1: hi*(2^32-1) = (hi<<32) - hi を 96 bit で計算し lo を加算 → {r1, r0} // Step 2: r1 < 2^32+1 なので r1*(2^32-1) < 2^64。r0 + r1*(2^32-1) を計算。 // オーバーフロー時は +0xFFFFFFFF で補正 (再オーバーフローなし)。 inline uint64_t gold_reduce(uint64_t lo, uint64_t hi) { // Step 1: {h1, h0} = hi << 32 (96bit value) uint64_t h0 = hi << 32; // 下位 64 bit uint64_t h1 = hi >> 32; // 上位 32 bit (< 2^32) // {t1, t0} = {h1, h0} - hi uint64_t borrow = (h0 < hi) ? 1ULL : 0ULL; uint64_t t0 = h0 - hi; uint64_t t1 = h1 - borrow; // h1 >= borrow (証明: h1=0 → hi<2^32 → h0=hi<<32>=hi → borrow=0) // {r1, r0} = {t1, t0} + lo uint64_t r0 = t0 + lo; uint64_t carry = (r0 < t0) ? 1ULL : 0ULL; uint64_t r1 = t1 + carry; // r1 <= 2^32 (t1 < 2^32, carry <= 1) // Step 2: result = r0 + r1 * (2^32 - 1) (mod p) // r1 * (2^32-1) <= 2^32 * (2^32-1) = 2^64 - 2^32, fits in uint64 uint64_t adj = r1 * 0xFFFFFFFFULL; uint64_t result = r0 + adj; if (result < r0) { // carry: result += 2^32 - 1 (再オーバーフローしない証明: // result_wrapped <= r0 + adj - 2^64 <= (2^64-1) + (2^64-2^32) - 2^64 = 2^64-2^32-1 // result_wrapped + (2^32-1) <= 2^64 - 2 < 2^64) result += 0xFFFFFFFFULL; } // 最終正規化 if (result >= GOLD_P) result -= GOLD_P; return result; } // (a * b) mod p — Goldilocks 高速リダクション // Montgomery の 2xMULX に対し、1xMULX + シフト/加減算のみ inline uint64_t gold_mul(uint64_t a, uint64_t b) { #if defined(_MSC_VER) && defined(_M_X64) uint64_t hi; uint64_t lo = _umul128(a, b, &hi); return gold_reduce(lo, hi); #elif defined(__GNUC__) || defined(__clang__) unsigned __int128 prod = (unsigned __int128)a * b; return gold_reduce((uint64_t)prod, (uint64_t)(prod >> 64)); #endif } // base^exp mod p inline uint64_t gold_pow(uint64_t base, uint64_t exp) { uint64_t result = 1; base %= GOLD_P; while (exp > 0) { if (exp & 1) result = gold_mul(result, base); base = gold_mul(base, base); exp >>= 1; } return result; } // a^(-1) mod p (Fermat) inline uint64_t gold_inv(uint64_t a) { return gold_pow(a, GOLD_P - 2); } // ================================================================ // NTT ルートテーブル // ================================================================ struct GoldRoots { size_t N = 0; std::vector fwd_roots; // omega^0, omega^1, ..., omega^(N-1) std::vector inv_roots; // omega_inv^0, ..., omega_inv^(N-1) uint64_t inv_n = 0; // N^(-1) mod p void build(size_t ntt_size) { N = ntt_size; fwd_roots.resize(N); inv_roots.resize(N); // omega = g^((p-1)/N) mod p uint64_t omega = gold_pow(GOLD_G, (GOLD_P - 1) / N); uint64_t omega_inv = gold_inv(omega); inv_n = gold_inv(N); fwd_roots[0] = 1; for (size_t i = 1; i < N; ++i) fwd_roots[i] = gold_mul(fwd_roots[i - 1], omega); inv_roots[0] = 1; for (size_t i = 1; i < N; ++i) inv_roots[i] = gold_mul(inv_roots[i - 1], omega_inv); } }; // ================================================================ // NTT 変換 (DIF / DIT) // ================================================================ // Forward NTT (Decimation In Frequency) inline void forward_ntt_gold(uint64_t* data, size_t N, const uint64_t* roots) { for (size_t len = N; len >= 2; len >>= 1) { size_t half = len >> 1; size_t step = N / len; for (size_t i = 0; i < N; i += len) { for (size_t j = 0; j < half; ++j) { uint64_t u = data[i + j]; uint64_t v = data[i + j + half]; data[i + j] = gold_add(u, v); data[i + j + half] = gold_mul(gold_sub(u, v), roots[j * step]); } } } } // Inverse NTT (Decimation In Time) inline void inverse_ntt_gold(uint64_t* data, size_t N, const uint64_t* inv_roots, uint64_t inv_n) { for (size_t len = 2; len <= N; len <<= 1) { size_t half = len >> 1; size_t step = N / len; for (size_t i = 0; i < N; i += len) { for (size_t j = 0; j < half; ++j) { uint64_t u = data[i + j]; uint64_t v = gold_mul(data[i + j + half], inv_roots[j * step]); data[i + j] = gold_add(u, v); data[i + j + half] = gold_sub(u, v); } } } // 1/N scaling for (size_t i = 0; i < N; ++i) data[i] = gold_mul(data[i], inv_n); } // Pointwise multiplication inline void pointwise_mul_gold(uint64_t* a, const uint64_t* b, size_t N) { for (size_t i = 0; i < N; ++i) a[i] = gold_mul(a[i], b[i]); } // ================================================================ // 16-bit ワード分割 / 再構成 // ================================================================ // 64-bit ワード → 16-bit チャンクに分割、残りゼロ埋め inline void pack_16bit(const uint64_t* src, size_t n, uint64_t* dst, size_t N) { size_t out = 0; for (size_t i = 0; i < n; ++i) { dst[out++] = src[i] & 0xFFFF; dst[out++] = (src[i] >> 16) & 0xFFFF; dst[out++] = (src[i] >> 32) & 0xFFFF; dst[out++] = (src[i] >> 48) & 0xFFFF; } for (size_t i = out; i < N; ++i) dst[i] = 0; } // 畳み込み結果 (base 2^16) を 64-bit ワード列にキャリー伝搬で再構成 inline void unpack_16bit(const uint64_t* conv, size_t conv_len, uint64_t* rp, size_t rn) { std::memset(rp, 0, rn * sizeof(uint64_t)); uint64_t carry = 0; for (size_t k = 0; k < conv_len; ++k) { // conv[k] + carry (最大 ~2^50, 128bit 不要だが安全のため対応) #if defined(_MSC_VER) && defined(_M_X64) uint64_t sum_lo; unsigned char c = _addcarry_u64(0, conv[k], carry, &sum_lo); uint64_t sum_hi = c; #elif defined(__GNUC__) || defined(__clang__) unsigned __int128 sum128 = (unsigned __int128)conv[k] + carry; uint64_t sum_lo = (uint64_t)sum128; uint64_t sum_hi = (uint64_t)(sum128 >> 64); #endif // 下位 16 bit を出力 size_t word = k / 4; size_t shift = (k % 4) * 16; if (word < rn) { rp[word] |= (sum_lo & 0xFFFF) << shift; } carry = (sum_lo >> 16) | (sum_hi << 48); } // 残りキャリーを出力 size_t word = conv_len / 4; size_t shift = (conv_len % 4) * 16; while (carry > 0 && word < rn) { if (shift == 0 && word >= conv_len / 4) { // conv_len 以降のワードはまだ 0 のはず } rp[word] |= (carry & 0xFFFF) << shift; carry >>= 16; shift += 16; if (shift >= 64) { shift = 0; word++; } } } // ================================================================ // ユーティリティ // ================================================================ inline size_t next_pow2_gold(size_t n) { size_t p = 1; while (p < n) p <<= 1; return p; } // ================================================================ // メイン乗算関数 // ================================================================ // rp[0..an+bn-1] = ap[0..an-1] * bp[0..bn-1] inline void mul_goldilocks_ntt(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn) { if (an < bn) { std::swap(ap, bp); std::swap(an, bn); } size_t rn = an + bn; size_t pack_a = 4 * an; size_t pack_b = 4 * bn; size_t conv_len = pack_a + pack_b - 1; size_t N = next_pow2_gold(conv_len + 1); if (N > (1ULL << GOLD_MAX_S)) return; // ルートテーブル (thread_local キャッシュ) thread_local GoldRoots roots; if (roots.N != N) { roots.build(N); } // ワーキングバッファ (thread_local) thread_local std::vector work; size_t total = 2 * N; if (work.size() < total) work.resize(total); uint64_t* da = work.data(); uint64_t* db = da + N; pack_16bit(ap, an, da, N); pack_16bit(bp, bn, db, N); forward_ntt_gold(da, N, roots.fwd_roots.data()); forward_ntt_gold(db, N, roots.fwd_roots.data()); pointwise_mul_gold(da, db, N); inverse_ntt_gold(da, N, roots.inv_roots.data(), roots.inv_n); unpack_16bit(da, conv_len, rp, rn); } // 自乗: rp[0..2*an-1] = ap[0..an-1]^2 inline void sqr_goldilocks_ntt(uint64_t* rp, const uint64_t* ap, size_t an) { size_t rn = 2 * an; size_t pack_a = 4 * an; size_t conv_len = 2 * pack_a - 1; size_t N = next_pow2_gold(conv_len + 1); if (N > (1ULL << GOLD_MAX_S)) return; thread_local GoldRoots roots; if (roots.N != N) { roots.build(N); } thread_local std::vector work; if (work.size() < N) work.resize(N); uint64_t* da = work.data(); pack_16bit(ap, an, da, N); forward_ntt_gold(da, N, roots.fwd_roots.data()); for (size_t i = 0; i < N; ++i) da[i] = gold_mul(da[i], da[i]); inverse_ntt_gold(da, N, roots.inv_roots.data(), roots.inv_n); unpack_16bit(da, conv_len, rp, rn); } } // namespace goldilocks_ntt } // namespace calx