// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // MpnOps.hpp // 生 limb 配列上の低レベル演算関数群 // GMP の mpn_* に相当。Int オブジェクトを一切使わない。 #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace calx { namespace mpn { // ================================================================ // BMI2 + ADX アセンブリ最適化 (MSVC x64 のみ) // ================================================================ // CALX_INT_HAS_ASM が定義されている場合、MULX+ADCX/ADOX を使用した // 手書きアセンブリ版の addmul_1/mul_1 を呼び出す。 // CPU が BMI2+ADX をサポートしていない場合は intrinsics 版にフォールバック。 #ifdef CALX_INT_HAS_ASM extern "C" uint64_t mpn_addmul_1_mulx(uint64_t* rp, const uint64_t* ap, size_t n, uint64_t b); extern "C" uint64_t mpn_mul_1_mulx(uint64_t* rp, const uint64_t* ap, size_t n, uint64_t b); extern "C" uint64_t mpn_submul_1_mulx(uint64_t* rp, const uint64_t* ap, size_t n, uint64_t b); extern "C" void mpn_mul_basecase_mulx(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn); // sqr_basecase: 対称性利用, 全行インライン (push/pop 除去) extern "C" void mpn_sqr_basecase_mulx(uint64_t* rp, const uint64_t* ap, size_t n); // add_n / sub_n: BMI2/ADX 不要、基本 x86-64 の ADC/SBB のみ使用 extern "C" uint64_t mpn_add_n_asm(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, size_t n); extern "C" uint64_t mpn_sub_n_asm(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, size_t n); // add_n / sub_n 小サイズ特化 (n=1..4): ループなし、push/pop なし extern "C" uint64_t mpn_add_n_small_asm(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, size_t n); extern "C" uint64_t mpn_sub_n_small_asm(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, size_t n); // mul_basecase 小サイズ特化 (n×n, n=1..4): フルアンロール MULX extern "C" void mpn_mul_small_asm(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, size_t n); // addmul_1 小サイズ特化 (n=1..4): push/pop なし extern "C" uint64_t mpn_addmul_1_small_asm(uint64_t* rp, const uint64_t* ap, size_t n, uint64_t b); // div_basecase 3-by-2 ASM (submul インライン融合) extern "C" void mpn_sbpi1_div_qr_asm(uint64_t* qp, uint64_t* ap, size_t qn, const uint64_t* dp, size_t dn, uint64_t dinv3); // lshift / rshift: SHLD/SHRD 命令による高速シフト extern "C" uint64_t mpn_lshift_asm(uint64_t* rp, const uint64_t* ap, size_t n, unsigned shift); extern "C" uint64_t mpn_rshift_asm(uint64_t* rp, const uint64_t* ap, size_t n, unsigned shift); // Montgomery CIOS 乗算 / REDC (addmul_1 インライン化, push/pop 除去) extern "C" void mpn_mont_mul_mulx(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, size_t n, const uint64_t* mp, uint64_t m_inv, uint64_t* scratch); extern "C" void mpn_mont_redc_mulx(uint64_t* rp, uint64_t* tp, size_t tn, const uint64_t* mp, size_t n, uint64_t m_inv); namespace detail { // CPUID で BMI2 (bit 8) + ADX (bit 19) の両方をチェック inline bool detect_bmi2_adx() { int info[4]; __cpuidex(info, 7, 0); bool bmi2 = (info[1] >> 8) & 1; // EBX bit 8 bool adx = (info[1] >> 19) & 1; // EBX bit 19 return bmi2 && adx; } // スレッドセーフな一度きりの検出 (C++11 magic statics) inline bool has_bmi2_adx() { static const bool result = detect_bmi2_adx(); return result; } } // namespace detail #endif // CALX_INT_HAS_ASM // Karatsuba の閾値 (limb 数) // これ以下は basecase を使う constexpr size_t KARATSUBA_THRESHOLD = 22; // 自乗用 Karatsuba 閾値 (sqr_basecase → sqr_karatsuba) // sqr_basecase は対称性利用で MUL より有利なため、閾値が高い constexpr size_t SQR_KARATSUBA_THRESHOLD = 74; // ================================================================ // ユーティリティ // ================================================================ // 上位ゼロを除いた実際のサイズを返す inline size_t normalized_size(const uint64_t* a, size_t n) { while (n > 0 && a[n - 1] == 0) --n; return n; } // 比較: a > b なら正、a < b なら負、等しいなら 0 // 前提: a, b は正規化済み (先頭ゼロなし) inline int cmp(const uint64_t* a, size_t an, const uint64_t* b, size_t bn) { if (an != bn) return (an > bn) ? 1 : -1; for (size_t i = an; i > 0; --i) { if (a[i - 1] != b[i - 1]) return (a[i - 1] > b[i - 1]) ? 1 : -1; } return 0; } // ================================================================ // 加算・減算 // ================================================================ // 注: add/sub/add_1/sub_1 は意図的に手動キャリー検出を使用している。 // MSVC の _addcarry_u64/_subborrow_u64 は毎反復 CF↔汎用レジスタ変換を // 生成し、純粋な加減算ループでは手動コードより遅い。 // addmul_1/mul_1/submul_1 では MUL 命令の節約が変換コストを上回るため // intrinsics を使用している。(PERF-6 ベンチマーク 2026-02 で検証済み) // r = a + b, returns carry (0 or 1) // 前提: an >= bn, r のサイズ >= an // r は a, b と同じポインタでも可 (in-place) inline uint64_t add(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn) { #ifdef CALX_INT_HAS_ASM // ASM パス: 小サイズ特化 (n=1..4) + 汎用 8x アンロール (n>4) uint64_t carry; if (bn == 0) { carry = 0; } else if (bn <= 4) { carry = mpn_add_n_small_asm(r, a, b, bn); } else { carry = mpn_add_n_asm(r, a, b, bn); } // 残り (an > bn) のキャリー伝播 for (size_t i = bn; i < an; i++) { uint64_t sum = a[i] + carry; carry = (sum < a[i]) ? 1ULL : 0ULL; r[i] = sum; if (!carry) { if (r != a) { for (i++; i < an; i++) r[i] = a[i]; } return 0; } } return carry; #else // 非 ASM パス: C++ 手動キャリー検出 if (bn <= 2) { uint64_t carry = 0; for (size_t i = 0; i < bn; ++i) { uint64_t sum = a[i] + b[i]; uint64_t c1 = (sum < a[i]) ? 1ULL : 0ULL; uint64_t sum2 = sum + carry; uint64_t c2 = (sum2 < sum) ? 1ULL : 0ULL; r[i] = sum2; carry = c1 + c2; } for (size_t i = bn; i < an; ++i) { uint64_t sum = a[i] + carry; carry = (sum < a[i]) ? 1ULL : 0ULL; r[i] = sum; if (!carry) { if (r != a) { for (++i; i < an; ++i) r[i] = a[i]; } return 0; } } return carry; } #if defined(_MSC_VER) && defined(_M_X64) unsigned char carry = 0; size_t i = 0; for (; i < bn; i++) { carry = _addcarry_u64(carry, a[i], b[i], &r[i]); } for (; i < an; i++) { carry = _addcarry_u64(carry, a[i], 0, &r[i]); if (!carry) { if (r != a) { for (i++; i < an; i++) r[i] = a[i]; } return 0; } } return carry; #else uint64_t carry = 0; size_t i = 0; for (; i < bn; i++) { uint64_t sum = a[i] + b[i]; uint64_t c1 = (sum < a[i]) ? 1ULL : 0ULL; uint64_t sum2 = sum + carry; uint64_t c2 = (sum2 < sum) ? 1ULL : 0ULL; r[i] = sum2; carry = c1 + c2; } for (; i < an; i++) { uint64_t sum = a[i] + carry; carry = (sum < a[i]) ? 1ULL : 0ULL; r[i] = sum; } return carry; #endif // _MSC_VER #endif // CALX_INT_HAS_ASM } // r = a - b, returns borrow (0 or 1) // 前提: an >= bn, a >= b (絶対値), r のサイズ >= an inline uint64_t sub(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn) { #ifdef CALX_INT_HAS_ASM // ASM パス: 小サイズ特化 (n=1..4) + 汎用 8x アンロール (n>4) uint64_t borrow; if (bn == 0) { borrow = 0; } else if (bn <= 4) { borrow = mpn_sub_n_small_asm(r, a, b, bn); } else { borrow = mpn_sub_n_asm(r, a, b, bn); } // 残り (an > bn) のボロー伝播 for (size_t i = bn; i < an; i++) { uint64_t diff = a[i] - borrow; borrow = (a[i] < borrow) ? 1ULL : 0ULL; r[i] = diff; if (!borrow) { if (r != a) { for (i++; i < an; i++) r[i] = a[i]; } return 0; } } return borrow; #else // 非 ASM パス: C++ 手動ボロー検出 if (bn <= 2) { uint64_t borrow = 0; for (size_t i = 0; i < bn; ++i) { uint64_t diff = a[i] - b[i]; uint64_t b1 = (a[i] < b[i]) ? 1ULL : 0ULL; uint64_t diff2 = diff - borrow; uint64_t b2 = (diff < borrow) ? 1ULL : 0ULL; r[i] = diff2; borrow = b1 + b2; } for (size_t i = bn; i < an; ++i) { uint64_t diff = a[i] - borrow; borrow = (a[i] < borrow) ? 1ULL : 0ULL; r[i] = diff; if (!borrow) { if (r != a) { for (++i; i < an; ++i) r[i] = a[i]; } return 0; } } return borrow; } #if defined(_MSC_VER) && defined(_M_X64) unsigned char borrow = 0; size_t i = 0; for (; i < bn; i++) { borrow = _subborrow_u64(borrow, a[i], b[i], &r[i]); } for (; i < an; i++) { borrow = _subborrow_u64(borrow, a[i], 0, &r[i]); if (!borrow) { if (r != a) { for (i++; i < an; i++) r[i] = a[i]; } return 0; } } return borrow; #else uint64_t borrow = 0; size_t i = 0; for (; i < bn; i++) { uint64_t diff = a[i] - b[i]; uint64_t b1 = (a[i] < b[i]) ? 1ULL : 0ULL; uint64_t diff2 = diff - borrow; uint64_t b2 = (diff < borrow) ? 1ULL : 0ULL; r[i] = diff2; borrow = b1 + b2; } for (; i < an; i++) { uint64_t diff = a[i] - borrow; borrow = (a[i] < borrow) ? 1ULL : 0ULL; r[i] = diff; } return borrow; #endif // _MSC_VER #endif // CALX_INT_HAS_ASM } // r[0..n-1] += b (single limb), returns carry inline uint64_t add_1(uint64_t* r, size_t n, uint64_t b) { if (n == 0) return b; #if defined(_MSC_VER) && defined(_M_X64) unsigned char carry = _addcarry_u64(0, r[0], b, &r[0]); for (size_t i = 1; i < n && carry; i++) { carry = _addcarry_u64(carry, r[i], 0, &r[i]); } return carry; #else for (size_t i = 0; i < n; i++) { uint64_t old = r[i]; r[i] += b; b = (r[i] < old) ? 1ULL : 0ULL; if (b == 0) break; } return b; #endif } // r[0..n-1] -= b (single limb), returns borrow inline uint64_t sub_1(uint64_t* r, size_t n, uint64_t b) { if (n == 0) return b; #if defined(_MSC_VER) && defined(_M_X64) unsigned char borrow = _subborrow_u64(0, r[0], b, &r[0]); for (size_t i = 1; i < n && borrow; i++) { borrow = _subborrow_u64(borrow, r[i], 0, &r[i]); } return borrow; #else for (size_t i = 0; i < n; i++) { uint64_t old = r[i]; r[i] -= b; b = (old < b) ? 1ULL : 0ULL; if (b == 0) break; } return b; #endif } // r = a - 2*b, returns total borrow (0, 1, or 2) // 1パス: b[i] を左シフトしながら減算 inline uint64_t sublsh1_n(uint64_t* r, const uint64_t* a, const uint64_t* b, size_t n) { #if defined(_MSC_VER) && defined(_M_X64) // MSVC intrinsics: _subborrow_u64 で 1パス unsigned char borrow = 0; for (size_t i = 0; i < n; i++) { uint64_t b2 = b[i] << 1; // 2*b[i] の下位 uint64_t b2_hi = b[i] >> 63; // 2*b[i] のキャリー // r[i] = a[i] - b2 - borrow borrow = _subborrow_u64(borrow, a[i], b2, &r[i]); // b2_hi (b[i] の MSB からのキャリー) を追加ボロー borrow += static_cast(b2_hi); } return borrow; #else uint64_t borrow = 0; for (size_t i = 0; i < n; i++) { uint64_t b2 = b[i] << 1; uint64_t b2_hi = b[i] >> 63; uint64_t diff = a[i] - b2; uint64_t bw1 = (a[i] < b2) ? 1ULL : 0ULL; uint64_t diff2 = diff - borrow; uint64_t bw2 = (diff < borrow) ? 1ULL : 0ULL; r[i] = diff2; borrow = bw1 + bw2 + b2_hi; } return borrow; #endif } // r = 2*b - a, returns sign: -1 (negative), 0 (exact), +1 (carry) // 1パス: b[i] を左シフトしながら減算 inline int64_t rsblsh1_n(uint64_t* r, const uint64_t* a, const uint64_t* b, size_t n) { #if defined(_MSC_VER) && defined(_M_X64) // 1パス: 2*b[i] - a[i] を計算 unsigned char borrow = 0; uint64_t shift_carry = 0; for (size_t i = 0; i < n; i++) { uint64_t b2 = (b[i] << 1) | shift_carry; shift_carry = b[i] >> 63; borrow = _subborrow_u64(borrow, b2, a[i], &r[i]); } // shift_carry は 2*b の最上位キャリー (0 or 1) // borrow は減算のボロー (0 or 1) return static_cast(shift_carry) - static_cast(borrow); #else uint64_t borrow = 0; uint64_t shift_carry = 0; for (size_t i = 0; i < n; i++) { uint64_t b2 = (b[i] << 1) | shift_carry; shift_carry = b[i] >> 63; uint64_t diff = b2 - a[i]; uint64_t bw1 = (b2 < a[i]) ? 1ULL : 0ULL; uint64_t diff2 = diff - borrow; uint64_t bw2 = (diff < borrow) ? 1ULL : 0ULL; r[i] = diff2; borrow = bw1 + bw2; } return static_cast(shift_carry) - static_cast(borrow); #endif } // r += a * b (single limb), returns carry // addmul_1: basecase 乗算の内部ループ inline uint64_t addmul_1(uint64_t* r, const uint64_t* a, size_t an, uint64_t b) { #ifdef CALX_INT_HAS_ASM if (detail::has_bmi2_adx()) { if (an <= 4) return mpn_addmul_1_small_asm(r, a, an, b); return mpn_addmul_1_mulx(r, a, an, b); } #endif uint64_t carry = 0; #if defined(_MSC_VER) && defined(_M_X64) // MSVC x64: _umul128 + _addcarry_u64 intrinsics // 4x ループアンロールで MUL のパイプライン並列実行を促進 size_t i = 0; for (; i + 4 <= an; i += 4) { uint64_t hi0, hi1, hi2, hi3; uint64_t lo0 = _umul128(a[i + 0], b, &hi0); uint64_t lo1 = _umul128(a[i + 1], b, &hi1); uint64_t lo2 = _umul128(a[i + 2], b, &hi2); uint64_t lo3 = _umul128(a[i + 3], b, &hi3); // c1, c2 は独立。チェーンしてはいけない (c1 を r[i] に混入させない) unsigned char c1, c2; c1 = _addcarry_u64(0, lo0, carry, &lo0); c2 = _addcarry_u64(0, lo0, r[i + 0], &r[i + 0]); carry = hi0 + c1 + c2; c1 = _addcarry_u64(0, lo1, carry, &lo1); c2 = _addcarry_u64(0, lo1, r[i + 1], &r[i + 1]); carry = hi1 + c1 + c2; c1 = _addcarry_u64(0, lo2, carry, &lo2); c2 = _addcarry_u64(0, lo2, r[i + 2], &r[i + 2]); carry = hi2 + c1 + c2; c1 = _addcarry_u64(0, lo3, carry, &lo3); c2 = _addcarry_u64(0, lo3, r[i + 3], &r[i + 3]); carry = hi3 + c1 + c2; } for (; i < an; i++) { uint64_t hi; uint64_t lo = _umul128(a[i], b, &hi); unsigned char c1 = _addcarry_u64(0, lo, carry, &lo); unsigned char c2 = _addcarry_u64(0, lo, r[i], &r[i]); carry = hi + c1 + c2; } #elif defined(__SIZEOF_INT128__) // GCC/Clang: __uint128_t でコンパイラに最適命令を生成させる __uint128_t cy = carry; size_t i = 0; for (; i + 4 <= an; i += 4) { cy += (__uint128_t)a[i+0] * b + r[i+0]; r[i+0] = (uint64_t)cy; cy >>= 64; cy += (__uint128_t)a[i+1] * b + r[i+1]; r[i+1] = (uint64_t)cy; cy >>= 64; cy += (__uint128_t)a[i+2] * b + r[i+2]; r[i+2] = (uint64_t)cy; cy >>= 64; cy += (__uint128_t)a[i+3] * b + r[i+3]; r[i+3] = (uint64_t)cy; cy >>= 64; } for (; i < an; i++) { cy += (__uint128_t)a[i] * b + r[i]; r[i] = (uint64_t)cy; cy >>= 64; } return (uint64_t)cy; #else for (size_t i = 0; i < an; i++) { UInt128 prod = UInt128::multiply(a[i], b); uint64_t lo = prod.low + carry; uint64_t c1 = (lo < prod.low) ? 1ULL : 0ULL; uint64_t lo2 = lo + r[i]; uint64_t c2 = (lo2 < lo) ? 1ULL : 0ULL; r[i] = lo2; carry = prod.high + c1 + c2; } #endif return carry; } // r -= a * b (single limb), returns borrow // submul_1: schoolbook 除算の内部ループ // ADCX/ADOX の双キャリーチェーンは使えない (SUB が OF を破壊するため) // 4x ループアンロールで MUL のパイプライン並列実行を促進 inline uint64_t submul_1(uint64_t* r, const uint64_t* a, size_t n, uint64_t b) { #ifdef CALX_INT_HAS_ASM if (detail::has_bmi2_adx()) { return mpn_submul_1_mulx(r, a, n, b); } #endif uint64_t carry = 0; #if defined(_MSC_VER) && defined(_M_X64) size_t i = 0; for (; i + 4 <= n; i += 4) { uint64_t hi0, hi1, hi2, hi3; uint64_t lo0 = _umul128(a[i + 0], b, &hi0); uint64_t lo1 = _umul128(a[i + 1], b, &hi1); uint64_t lo2 = _umul128(a[i + 2], b, &hi2); uint64_t lo3 = _umul128(a[i + 3], b, &hi3); unsigned char c; c = _addcarry_u64(0, lo0, carry, &lo0); carry = hi0 + c; c = _subborrow_u64(0, r[i + 0], lo0, &r[i + 0]); carry += c; c = _addcarry_u64(0, lo1, carry, &lo1); carry = hi1 + c; c = _subborrow_u64(0, r[i + 1], lo1, &r[i + 1]); carry += c; c = _addcarry_u64(0, lo2, carry, &lo2); carry = hi2 + c; c = _subborrow_u64(0, r[i + 2], lo2, &r[i + 2]); carry += c; c = _addcarry_u64(0, lo3, carry, &lo3); carry = hi3 + c; c = _subborrow_u64(0, r[i + 3], lo3, &r[i + 3]); carry += c; } for (; i < n; i++) { uint64_t hi; uint64_t lo = _umul128(a[i], b, &hi); unsigned char c; c = _addcarry_u64(0, lo, carry, &lo); carry = hi + c; c = _subborrow_u64(0, r[i], lo, &r[i]); carry += c; } #elif defined(__SIZEOF_INT128__) __uint128_t cy = 0; for (size_t i = 0; i < n; i++) { cy += (__uint128_t)a[i] * b; uint64_t pl = (uint64_t)cy; uint64_t prev = r[i]; r[i] = prev - pl; cy >>= 64; cy += (prev < pl) ? 1ULL : 0ULL; } return (uint64_t)cy; #else for (size_t i = 0; i < n; i++) { UInt128 prod = UInt128::multiply(a[i], b); uint64_t pl = prod.low + carry; uint64_t c1 = (pl < prod.low) ? 1ULL : 0ULL; uint64_t prev = r[i]; r[i] = prev - pl; uint64_t c2 = (prev < pl) ? 1ULL : 0ULL; carry = prod.high + c1 + c2; } #endif return carry; } // ================================================================ // 乗算 // ================================================================ // 前方宣言 (mul_basecase が mul_1 を使用) inline uint64_t mul_1(uint64_t* r, const uint64_t* a, size_t n, uint64_t b); // Basecase 乗算: r[0..an+bn-1] = a[0..an-1] * b[0..bn-1] // r は a, b と重なってはならない inline void mul_basecase(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn) { if (an < bn) { std::swap(a, b); std::swap(an, bn); } #ifdef CALX_INT_HAS_ASM if (detail::has_bmi2_adx()) { // n×n 小サイズ特化: フルアンロール MULX (push/pop 最小) if (an == bn && bn <= 4) { mpn_mul_small_asm(r, a, b, bn); return; } // 汎用 asm 版: 全行を 1 回の push/pop で処理 mpn_mul_basecase_mulx(r, a, an, b, bn); return; } #endif // 非 ASM パス: 小サイズ特化 — GCC/MSVC 共通 // ★ Codex 指摘 #2: SBO 9 limb の実運用で最頻出領域 #if defined(__GNUC__) && defined(__SIZEOF_INT128__) // GCC/Clang: __uint128_t で 1×1, 2×2 を展開 typedef unsigned __int128 u128; if (bn == 1 && an == 1) { u128 p = (u128)a[0] * b[0]; r[0] = (uint64_t)p; r[1] = (uint64_t)(p >> 64); return; } if (bn == 2 && an == 2) { u128 p00 = (u128)a[0] * b[0]; u128 p01 = (u128)a[0] * b[1]; u128 p10 = (u128)a[1] * b[0]; u128 p11 = (u128)a[1] * b[1]; r[0] = (uint64_t)p00; u128 mid = (p00 >> 64) + (uint64_t)p01 + (uint64_t)p10; r[1] = (uint64_t)mid; u128 hi = (mid >> 64) + (p01 >> 64) + (p10 >> 64) + (uint64_t)p11; r[2] = (uint64_t)hi; r[3] = (uint64_t)(hi >> 64) + (uint64_t)(p11 >> 64); return; } #elif defined(_MSC_VER) && defined(_M_X64) if (bn == 1) { if (an == 1) { uint64_t hi; r[0] = _umul128(a[0], b[0], &hi); r[1] = hi; return; } if (an == 2) { uint64_t h0, h1; r[0] = _umul128(a[0], b[0], &h0); r[1] = _umul128(a[1], b[0], &h1); unsigned char c = _addcarry_u64(0, r[1], h0, &r[1]); r[2] = h1 + c; return; } } #endif // 最初の行: mul_1 で直接書き込み (memset + addmul_1 より軽い) r[an] = mul_1(r, a, an, b[0]); // 残りの行: addmul_1 で累加 for (size_t j = 1; j < bn; j++) { if (b[j] == 0) { r[an + j] = 0; continue; } r[an + j] = addmul_1(r + j, a, an, b[j]); } } // Short multiplication: 積の上位 rn word のみ計算 // rp[0..rn-1] = top rn words of a[0..an-1] * b[0..bn-1] // rn は呼び出し側が 2 ガード word を含めた値を渡す // 要件: an >= 1, bn >= 1, rn >= 1, rn <= an+bn inline void mulhigh_basecase(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn, size_t rn) { if (an < bn) { std::swap(ap, bp); std::swap(an, bn); } size_t total = an + bn; // rn >= total なら全積と同じ → 通常版にフォールバック if (rn >= total) { mul_basecase(rp, ap, an, bp, bn); return; } size_t lo = total - rn; // スキップする最下位列数 // 出力バッファをゼロ初期化 std::memset(rp, 0, rn * sizeof(uint64_t)); // 最初の寄与行を mul_1 で直接書き込み (addmul_1 より軽い) bool first_row = true; for (size_t j = 0; j < bn; j++) { if (bp[j] == 0) continue; // 行 j は列 j..j+an-1 に寄与 // lo 以上の列のみ計算 → a の開始位置を調整 size_t i_start = (j >= lo) ? 0 : (lo - j); if (i_start >= an) continue; // この行は寄与なし size_t count = an - i_start; size_t r_pos = j + i_start - lo; // 出力バッファ内の位置 uint64_t cy; if (first_row) { cy = mul_1(rp + r_pos, ap + i_start, count, bp[j]); first_row = false; } else { cy = addmul_1(rp + r_pos, ap + i_start, count, bp[j]); } // carry を伝播 size_t cy_pos = r_pos + count; if (cy != 0 && cy_pos < rn) { add_1(rp + cy_pos, rn - cy_pos, cy); } } } // ================================================================ // mulhigh_n: 上位 n limbs の高速計算 (積の下位半分をスキップ) // ================================================================ // mulhigh_n 用 scratch サイズ // multiply が前方宣言なので、この段階では保守的に確保 inline size_t mulhigh_n_scratch_size(size_t n); // mulhigh_n: a[0..n-1] * b[0..n-1] の上位 n limbs を計算 // rp[0..n-1] = floor(a * b / B^n) (近似値、最大 O(1) の誤差) // Newton 除算の商推定で使用。誤差は補正ループで吸収される。 // // アルゴリズム: Karatsuba 分解 a=a1*B^h+a0, b=b1*B^h+b0 で // a0*b0 をスキップし、a1*b1 と cross=(a0+a1)*(b0+b1)-a1*b1 の // 上位部分のみ計算。乗算コスト ≈ 2/3 * M(n)。 inline void mulhigh_n(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, size_t n, uint64_t* scratch); // Karatsuba に必要な scratch サイズを計算 inline size_t mul_karatsuba_scratch_size(size_t n) { return (n < 16) ? 128 : 8 * n + 64; } // Karatsuba 乗算: r[0..an+bn-1] = a[0..an-1] * b[0..bn-1] // r は a, b と重なってはならない // scratch: 一時バッファ (サイズ >= mul_karatsuba_scratch_size(max(an,bn))) inline void mul_karatsuba(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { if (an < bn) { std::swap(a, b); std::swap(an, bn); } if (bn < KARATSUBA_THRESHOLD) { mul_basecase(r, a, an, b, bn); return; } size_t half = (an + 1) / 2; const uint64_t* a0 = a; size_t a0n = std::min(half, an); const uint64_t* a1 = a + half; size_t a1n = (an > half) ? an - half : 0; const uint64_t* b0 = b; size_t b0n = std::min(half, bn); const uint64_t* b1 = b + half; size_t b1n = (bn > half) ? bn - half : 0; a0n = normalized_size(a0, a0n); a1n = normalized_size(a1, a1n); b0n = normalized_size(b0, b0n); b1n = normalized_size(b1, b1n); size_t rn = an + bn; uint64_t* s = scratch; uint64_t* t = scratch + (half + 1); uint64_t* middle = scratch + 2 * (half + 1); size_t middle_max = 2 * (half + 2); uint64_t* rec_scratch = scratch + 2 * (half + 1) + middle_max; size_t v0n = 0, vinfn = 0; if (a0n > 0 && b0n > 0) { mul_karatsuba(r, a0, a0n, b0, b0n, rec_scratch); v0n = a0n + b0n; } if (a1n > 0 && b1n > 0) { mul_karatsuba(r + 2 * half, a1, a1n, b1, b1n, rec_scratch); vinfn = a1n + b1n; } // ギャップ領域のゼロクリア (v0 と vinf の間、vinf の後) if (v0n < 2 * half) std::memset(r + v0n, 0, (2 * half - v0n) * sizeof(uint64_t)); if (2 * half + vinfn < rn) std::memset(r + 2 * half + vinfn, 0, (rn - 2 * half - vinfn) * sizeof(uint64_t)); size_t sn, tn; if (a0n >= a1n) { if (a1n > 0) { uint64_t carry = add(s, a0, a0n, a1, a1n); sn = a0n; if (carry) { s[sn] = carry; sn++; } } else { std::memcpy(s, a0, a0n * sizeof(uint64_t)); sn = a0n; } } else { uint64_t carry = add(s, a1, a1n, a0, a0n); sn = a1n; if (carry) { s[sn] = carry; sn++; } } if (b0n >= b1n) { if (b1n > 0) { uint64_t carry = add(t, b0, b0n, b1, b1n); tn = b0n; if (carry) { t[tn] = carry; tn++; } } else { std::memcpy(t, b0, b0n * sizeof(uint64_t)); tn = b0n; } } else { uint64_t carry = add(t, b1, b1n, b0, b0n); tn = b1n; if (carry) { t[tn] = carry; tn++; } } size_t mn = 0; if (sn > 0 && tn > 0) { mul_karatsuba(middle, s, sn, t, tn, rec_scratch); mn = normalized_size(middle, sn + tn); } { size_t t0n = normalized_size(r, a0n + b0n); if (t0n > 0 && mn > 0) { sub(middle, middle, mn, r, t0n); mn = normalized_size(middle, mn); } } { // vinf の正規化サイズを取得。a1n + b1n が出力バッファ末尾 (rn - 2*half) を // 超える場合がある (bn < half のとき b1n=0 だが a1n > rn-2*half)。 // rn を超えた領域は未初期化なので読んではならない。 size_t vinf_max = std::min(a1n + b1n, rn - 2 * half); size_t t2n = normalized_size(r + 2 * half, vinf_max); if (t2n > 0 && mn > 0) { sub(middle, middle, mn, r + 2 * half, t2n); mn = normalized_size(middle, mn); } } if (mn > 0) { uint64_t carry = add(r + half, r + half, rn - half, middle, mn); (void)carry; } } // ================================================================ // Exact Division (生 limb 版) // ================================================================ // 3 での exact division: r[0..n-1] = a[0..n-1] / 3 inline size_t divexact_by3(uint64_t* r, const uint64_t* a, size_t n) { static constexpr uint64_t INV3 = 0xAAAAAAAAAAAAAAABULL; uint64_t carry = 0; for (size_t i = 0; i < n; i++) { uint64_t ai_minus_carry = a[i] - carry; uint64_t borrow = (a[i] < carry) ? 1ULL : 0ULL; uint64_t yi = ai_minus_carry * INV3; r[i] = yi; UInt128 prod = UInt128::multiply(yi, 3); carry = prod.high + borrow; } return normalized_size(r, n); } // 右シフト 3 ビット inline size_t shift_right_3(uint64_t* r, const uint64_t* a, size_t n) { if (n == 0) return 0; for (size_t i = 0; i < n - 1; i++) { r[i] = (a[i] >> 3) | (a[i + 1] << 61); } r[n - 1] = a[n - 1] >> 3; return normalized_size(r, n); } // 5 での exact division: r[0..n-1] = a[0..n-1] / 5 inline size_t divexact_by5(uint64_t* r, const uint64_t* a, size_t n) { static constexpr uint64_t INV5 = 0xCCCCCCCCCCCCCCCDULL; uint64_t carry = 0; for (size_t i = 0; i < n; i++) { uint64_t ai_minus_carry = a[i] - carry; uint64_t borrow = (a[i] < carry) ? 1ULL : 0ULL; uint64_t yi = ai_minus_carry * INV5; r[i] = yi; UInt128 prod = UInt128::multiply(yi, 5); carry = prod.high + borrow; } return normalized_size(r, n); } // 7 での exact division: r[0..n-1] = a[0..n-1] / 7 inline size_t divexact_by7(uint64_t* r, const uint64_t* a, size_t n) { static constexpr uint64_t INV7 = 0x6DB6DB6DB6DB6DB7ULL; uint64_t carry = 0; for (size_t i = 0; i < n; i++) { uint64_t ai_minus_carry = a[i] - carry; uint64_t borrow = (a[i] < carry) ? 1ULL : 0ULL; uint64_t yi = ai_minus_carry * INV7; r[i] = yi; UInt128 prod = UInt128::multiply(yi, 7); carry = prod.high + borrow; } return normalized_size(r, n); } // 24 での exact division: r = a / 24 = (a >> 3) / 3 inline size_t divexact_by24(uint64_t* r, const uint64_t* a, size_t n) { size_t sn = shift_right_3(r, a, n); return divexact_by3(r, r, sn); } // 右シフト 2 ビット inline size_t shift_right_2(uint64_t* r, const uint64_t* a, size_t n) { if (n == 0) return 0; for (size_t i = 0; i < n - 1; i++) { r[i] = (a[i] >> 2) | (a[i + 1] << 62); } r[n - 1] = a[n - 1] >> 2; return normalized_size(r, n); } // 12 での exact division: r = a / 12 = (a >> 2) / 3 inline size_t divexact_by12(uint64_t* r, const uint64_t* a, size_t n) { size_t sn = shift_right_2(r, a, n); return divexact_by3(r, r, sn); } // 120 での exact division: r = a / 120 = (a >> 3) / 3 / 5 inline size_t divexact_by120(uint64_t* r, const uint64_t* a, size_t n) { size_t sn = shift_right_3(r, a, n); sn = divexact_by3(r, r, sn); return divexact_by5(r, r, sn); } // ================================================================ // 左シフト (小ビット数) // ================================================================ // r = a << shift, returns overflow word inline uint64_t lshift(uint64_t* r, const uint64_t* a, size_t n, unsigned shift) { if (n == 0 || shift == 0) return 0; #ifdef CALX_INT_HAS_ASM if (shift == 1) { // shift=1 は a+a で最適化 (ADC チェーン) return mpn_add_n_asm(r, a, a, n); } // SHLD 命令による高速シフト (shift 2-63) return mpn_lshift_asm(r, a, n, shift); #elif defined(_MSC_VER) && defined(_M_X64) if (shift == 1) { unsigned char c = 0; for (size_t i = 0; i < n; i++) { c = _addcarry_u64(c, a[i], a[i], &r[i]); } return c; } uint64_t carry = 0; unsigned rs = 64 - shift; for (size_t i = 0; i < n; i++) { uint64_t v = a[i]; r[i] = (v << shift) | carry; carry = v >> rs; } return carry; #else uint64_t carry = 0; unsigned rs = 64 - shift; for (size_t i = 0; i < n; i++) { uint64_t v = a[i]; r[i] = (v << shift) | carry; carry = v >> rs; } return carry; #endif } // ================================================================ // 右シフト 1 ビット (除算 /2 用) // ================================================================ // r = a >> 1, returns normalized size // r と a は同じポインタでも可 (in-place) inline size_t rshift_1(uint64_t* r, const uint64_t* a, size_t n) { if (n == 0) return 0; #ifdef CALX_INT_HAS_ASM mpn_rshift_asm(r, a, n, 1); #else for (size_t i = 0; i < n - 1; i++) { r[i] = (a[i] >> 1) | (a[i + 1] << 63); } r[n - 1] = a[n - 1] >> 1; #endif return normalized_size(r, n); } // ================================================================ // 単一 limb 乗算 // ================================================================ // r = a * b (single limb), returns carry inline uint64_t mul_1(uint64_t* r, const uint64_t* a, size_t n, uint64_t b) { #ifdef CALX_INT_HAS_ASM if (detail::has_bmi2_adx()) { return mpn_mul_1_mulx(r, a, n, b); } #endif uint64_t carry = 0; #if defined(_MSC_VER) && defined(_M_X64) for (size_t i = 0; i < n; i++) { uint64_t hi; uint64_t lo = _umul128(a[i], b, &hi); unsigned char c = _addcarry_u64(0, lo, carry, &r[i]); carry = hi + c; } #elif defined(__SIZEOF_INT128__) __uint128_t cy = 0; for (size_t i = 0; i < n; i++) { cy += (__uint128_t)a[i] * b; r[i] = (uint64_t)cy; cy >>= 64; } return (uint64_t)cy; #else for (size_t i = 0; i < n; i++) { UInt128 prod = UInt128::multiply(a[i], b); uint64_t lo = prod.low + carry; uint64_t c = (lo < prod.low) ? 1ULL : 0ULL; r[i] = lo; carry = prod.high + c; } #endif return carry; } // ================================================================ // Toom-Cook-3 用ヘルパー // ================================================================ // r = a + b (サイズ順序不問) // r は max(an,bn)+1 limbs の領域が必要 // 正規化されたサイズを返す inline size_t add_any(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn) { if (an == 0) { if (bn > 0) std::memcpy(r, b, bn * sizeof(uint64_t)); return bn; } if (bn == 0) { if (an > 0) std::memcpy(r, a, an * sizeof(uint64_t)); return an; } if (an < bn) { std::swap(a, b); std::swap(an, bn); } uint64_t carry = add(r, a, an, b, bn); if (carry) { r[an] = carry; return an + 1; } return an; } // r = |a - b|, sign を設定 (+1: a>=b, -1: a 0) { sign = 1; sub(r, a, an, b, bn); return normalized_size(r, an); } else { sign = -1; sub(r, b, bn, a, an); return normalized_size(r, bn); } } // ================================================================ // Toom-Cook-3 乗算 // ================================================================ constexpr size_t TOOMCOOK3_THRESHOLD = 140; inline size_t mul_toomcook3_scratch_size(size_t n) { if (n < TOOMCOOK3_THRESHOLD) return mul_karatsuba_scratch_size(n); return 30 * n + 256; } // Toom-Cook-3 乗算: r[0..an+bn-1] = a[0..an-1] * b[0..bn-1] // 評価点 {0, 1, -1, 2, ∞} (GMP 方式) // 補間は /2 (シフト) と /3 (divexact_by3) のみ (/24 不要) // r は a, b と重なってはならない inline void mul_toomcook3(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { if (an < bn) { std::swap(a, b); std::swap(an, bn); } if (bn < TOOMCOOK3_THRESHOLD) { mul_karatsuba(r, a, an, b, bn, scratch); return; } size_t k = (an + 2) / 3; size_t rn = an + bn; // --- 3 分割 (ポインタ演算のみ、コピーなし) --- const uint64_t* a0 = a; size_t a0n = normalized_size(a0, std::min(k, an)); const uint64_t* a1 = a + k; size_t a1n = (an > k) ? normalized_size(a1, std::min(k, an - k)) : 0; const uint64_t* a2 = a + 2 * k; size_t a2n = (an > 2 * k) ? normalized_size(a2, an - 2 * k) : 0; const uint64_t* b0 = b; size_t b0n = normalized_size(b0, std::min(k, bn)); const uint64_t* b1 = b + k; size_t b1n = (bn > k) ? normalized_size(b1, std::min(k, bn - k)) : 0; const uint64_t* b2 = b + 2 * k; size_t b2n = (bn > 2 * k) ? normalized_size(b2, bn - 2 * k) : 0; // --- Scratch レイアウト --- size_t blk = 2 * (k + 4); // 各バッファの最大サイズ uint64_t* v1_buf = scratch; // v(1) の積 uint64_t* vm1_buf = scratch + blk; // v(-1) の積 uint64_t* v2_buf = scratch + 2 * blk; // v(2) の積 uint64_t* tmp1 = scratch + 3 * blk; // 評価テンポラリ 1 uint64_t* tmp2 = scratch + 4 * blk; // 評価テンポラリ 2 uint64_t* interp_buf = scratch + 5 * blk; // 補間ワークスペース uint64_t* rec_scratch = scratch + 6 * blk; // 再帰用 scratch size_t v1n = 0, vm1n = 0, v2n = 0; int vm1_sign = 0; // ============================================================ // 点ごとの乗算 (5 回の再帰呼び出し) // ============================================================ // Point 0: v0 = a0 * b0 → r[0..] size_t v0n = 0; if (a0n > 0 && b0n > 0) { mul_toomcook3(r, a0, a0n, b0, b0n, rec_scratch); v0n = normalized_size(r, std::min(a0n + b0n, rn)); } // Point ∞: vinf = a2 * b2 → r[4k..] size_t vinfn = 0; size_t vinf_off = 4 * k; if (a2n > 0 && b2n > 0 && vinf_off < rn) { mul_toomcook3(r + vinf_off, a2, a2n, b2, b2n, rec_scratch); vinfn = normalized_size(r + vinf_off, std::min(a2n + b2n, rn - vinf_off)); } // ギャップ領域のゼロクリア (v0 と vinf の間、vinf の後) { size_t gap_start = v0n; size_t gap_end = std::min(vinf_off, rn); if (gap_start < gap_end) std::memset(r + gap_start, 0, (gap_end - gap_start) * sizeof(uint64_t)); size_t tail_start = std::min(vinf_off + vinfn, rn); if (tail_start < rn) std::memset(r + tail_start, 0, (rn - tail_start) * sizeof(uint64_t)); } // Point 1: v1 = (a0+a1+a2) * (b0+b1+b2) { size_t ean = add_any(tmp1, a0, a0n, a1, a1n); ean = add_any(tmp1, tmp1, ean, a2, a2n); size_t ebn = add_any(tmp2, b0, b0n, b1, b1n); ebn = add_any(tmp2, tmp2, ebn, b2, b2n); if (ean > 0 && ebn > 0) { mul_toomcook3(v1_buf, tmp1, ean, tmp2, ebn, rec_scratch); v1n = normalized_size(v1_buf, ean + ebn); } } // Point -1: vm1 = (a0-a1+a2) * (b0-b1+b2) [符号付き] { int ea_sign = 1, eb_sign = 1; // ea = (a0 + a2) - a1 size_t t_n = add_any(tmp1, a0, a0n, a2, a2n); size_t ean = abs_sub(tmp1, ea_sign, tmp1, t_n, a1, a1n); // eb = (b0 + b2) - b1 t_n = add_any(tmp2, b0, b0n, b2, b2n); size_t ebn = abs_sub(tmp2, eb_sign, tmp2, t_n, b1, b1n); vm1_sign = ea_sign * eb_sign; if (ean > 0 && ebn > 0) { mul_toomcook3(vm1_buf, tmp1, ean, tmp2, ebn, rec_scratch); vm1n = normalized_size(vm1_buf, ean + ebn); } if (vm1n == 0) vm1_sign = 0; } // Point 2: v2 = (a0+2*a1+4*a2) * (b0+2*b1+4*b2) { // ea2 = a0 + 2*a1 + 4*a2 if (a0n > 0) std::memcpy(tmp1, a0, a0n * sizeof(uint64_t)); size_t ean = a0n; if (a1n > 0) { uint64_t ov = lshift(interp_buf, a1, a1n, 1); size_t tn = a1n; if (ov) { interp_buf[tn] = ov; tn++; } ean = add_any(tmp1, tmp1, ean, interp_buf, tn); } if (a2n > 0) { uint64_t ov = lshift(interp_buf, a2, a2n, 2); size_t tn = a2n; if (ov) { interp_buf[tn] = ov; tn++; } ean = add_any(tmp1, tmp1, ean, interp_buf, tn); } // eb2 = b0 + 2*b1 + 4*b2 if (b0n > 0) std::memcpy(tmp2, b0, b0n * sizeof(uint64_t)); size_t ebn = b0n; if (b1n > 0) { uint64_t ov = lshift(interp_buf, b1, b1n, 1); size_t tn = b1n; if (ov) { interp_buf[tn] = ov; tn++; } ebn = add_any(tmp2, tmp2, ebn, interp_buf, tn); } if (b2n > 0) { uint64_t ov = lshift(interp_buf, b2, b2n, 2); size_t tn = b2n; if (ov) { interp_buf[tn] = ov; tn++; } ebn = add_any(tmp2, tmp2, ebn, interp_buf, tn); } if (ean > 0 && ebn > 0) { mul_toomcook3(v2_buf, tmp1, ean, tmp2, ebn, rec_scratch); v2n = normalized_size(v2_buf, ean + ebn); } } // ============================================================ // 補間 (GMP 方式: /2 と /3 のみ) // ============================================================ // // v0 = c0, vinf = c4 (直接 r に配置済み) // v1 = c0+c1+c2+c3+c4 // vm1 = c0-c1+c2-c3+c4 (符号付き) // v2 = c0+2c1+4c2+8c3+16c4 // // Step 1: A = v1 + vm1 = 2(c0+c2+c4) [常に非負] // Step 2: B = v1 - vm1 = 2(c1+c3) [常に非負] // Step 3: c2 = A/2 - v0 - vinf // Step 4: C = B/2 = c1+c3 // Step 5: D = v2 - v0 - 4*c2 - 16*vinf = 2(c1+4c3) // Step 6: E = D/2 = c1+4c3 // Step 7: c3 = (E - C) / 3 // Step 8: c1 = C - c3 // Step 1: A = v1 ± vm1 → tmp1 size_t An; if (vm1_sign >= 0) { An = add_any(tmp1, v1_buf, v1n, vm1_buf, vm1n); } else { // vm1 < 0 → A = v1 - |vm1|, 保証: v1 >= |vm1| std::memcpy(tmp1, v1_buf, v1n * sizeof(uint64_t)); An = v1n; if (vm1n > 0) { sub(tmp1, tmp1, An, vm1_buf, vm1n); An = normalized_size(tmp1, An); } } // Step 2: B = v1 ∓ vm1 → tmp2 size_t Bn; if (vm1_sign >= 0) { // B = v1 - vm1, 保証: v1 >= vm1 std::memcpy(tmp2, v1_buf, v1n * sizeof(uint64_t)); Bn = v1n; if (vm1n > 0) { sub(tmp2, tmp2, Bn, vm1_buf, vm1n); Bn = normalized_size(tmp2, Bn); } } else { // vm1 < 0 → B = v1 + |vm1| Bn = add_any(tmp2, v1_buf, v1n, vm1_buf, vm1n); } // Step 3: c2 = A/2 - v0 - vinf → tmp1 if (An > 0) An = rshift_1(tmp1, tmp1, An); if (v0n > 0 && An > 0) { sub(tmp1, tmp1, An, r, v0n); An = normalized_size(tmp1, An); } if (vinfn > 0 && An > 0) { sub(tmp1, tmp1, An, r + 4 * k, vinfn); An = normalized_size(tmp1, An); } size_t c2n = An; uint64_t* c2_ptr = tmp1; // Step 4: C = B/2 → tmp2 if (Bn > 0) Bn = rshift_1(tmp2, tmp2, Bn); size_t Cn = Bn; // Step 5: D = v2 - v0 - 4*c2 - 16*vinf → interp_buf if (v2n > 0) std::memcpy(interp_buf, v2_buf, v2n * sizeof(uint64_t)); size_t Dn = v2n; // D -= v0 if (v0n > 0 && Dn > 0) { sub(interp_buf, interp_buf, Dn, r, v0n); Dn = normalized_size(interp_buf, Dn); } // D -= 4*c2 if (c2n > 0 && Dn > 0) { uint64_t ov = lshift(v1_buf, c2_ptr, c2n, 2); size_t tn = c2n; if (ov) { v1_buf[tn] = ov; tn++; } sub(interp_buf, interp_buf, Dn, v1_buf, tn); Dn = normalized_size(interp_buf, Dn); } // D -= 16*vinf if (vinfn > 0 && Dn > 0) { uint64_t ov = lshift(v1_buf, r + 4 * k, vinfn, 4); size_t tn = vinfn; if (ov) { v1_buf[tn] = ov; tn++; } sub(interp_buf, interp_buf, Dn, v1_buf, tn); Dn = normalized_size(interp_buf, Dn); } // Step 6: E = D/2 → interp_buf if (Dn > 0) Dn = rshift_1(interp_buf, interp_buf, Dn); // Step 7: c3 = (E - C) / 3 → vm1_buf size_t c3n = 0; if (Dn > 0 || Cn > 0) { if (Dn >= Cn) { if (Cn > 0) sub(v1_buf, interp_buf, Dn, tmp2, Cn); else std::memcpy(v1_buf, interp_buf, Dn * sizeof(uint64_t)); } else { // Dn < Cn: interp_buf を Cn limb に拡張(上位はゼロ) // sub の前提条件 an >= bn を満たすために Cn を使用 for (size_t i = Dn; i < Cn; ++i) interp_buf[i] = 0; sub(v1_buf, interp_buf, Cn, tmp2, Cn); } size_t Fn = normalized_size(v1_buf, std::max(Dn, Cn)); if (Fn > 0) c3n = divexact_by3(vm1_buf, v1_buf, Fn); } uint64_t* c3_ptr = vm1_buf; // Step 8: c1 = C - c3 → v2_buf size_t c1n = 0; if (Cn > 0) { if (c3n > 0) { sub(v2_buf, tmp2, Cn, c3_ptr, c3n); c1n = normalized_size(v2_buf, Cn); } else { std::memcpy(v2_buf, tmp2, Cn * sizeof(uint64_t)); c1n = Cn; } } uint64_t* c1_ptr = v2_buf; // ============================================================ // 組み立て: r += c1*B^k + c2*B^(2k) + c3*B^(3k) // ============================================================ // c0 は r[0..] に、c4 は r[4k..] に配置済み if (c1n > 0 && k < rn) { size_t space = rn - k; uint64_t carry = add(r + k, r + k, space, c1_ptr, std::min(c1n, space)); (void)carry; } if (c2n > 0 && 2 * k < rn) { size_t space = rn - 2 * k; uint64_t carry = add(r + 2 * k, r + 2 * k, space, c2_ptr, std::min(c2n, space)); (void)carry; } if (c3n > 0 && 3 * k < rn) { size_t space = rn - 3 * k; uint64_t carry = add(r + 3 * k, r + 3 * k, space, c3_ptr, std::min(c3n, space)); (void)carry; } } // ================================================================ // Toom-Cook-4 乗算 // ================================================================ // Toom-4: 200+ limbs で Toom-3 と同等〜10% 高速 (2026-03-11 sweep, 200-1400 limbs)。 // 600 limbs 付近で最大 10% の優位。200 未満は Toom-3 の方が軽い。 constexpr size_t TOOMCOOK4_THRESHOLD = 200; inline size_t mul_toomcook4_scratch_size(size_t n) { if (n < TOOMCOOK4_THRESHOLD) return mul_toomcook3_scratch_size(n); return 50 * n + 512; } // Toom-Cook-4 乗算: r[0..an+bn-1] = a[0..an-1] * b[0..bn-1] // 評価点 {0, 1, -1, 2, -2, 3, ∞} // 補間は /2 (シフト), /3 (divexact_by3), /5 (divexact_by5), /12, /120 のみ // r は a, b と重なってはならない // // 最適化: 固定 k-limb 評価 + addmul_1 で一時バッファと normalized_size を削減 inline void mul_toomcook4(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { if (an < bn) { std::swap(a, b); std::swap(an, bn); } if (bn < TOOMCOOK4_THRESHOLD) { mul_toomcook3(r, a, an, b, bn, scratch); return; } size_t k = (an + 3) / 4; size_t rn = an + bn; // --- 係数を k limbs にゼロパディング --- // addmul_1 ベースの固定サイズ評価を可能にする uint64_t* pa = scratch; // 4*k limbs (a の係数) uint64_t* pb = scratch + 4 * k; // 4*k limbs (b の係数) auto pad_coeff = [k](uint64_t* dst, const uint64_t* src, size_t src_len, size_t offset) { size_t actual = (src_len > offset) ? std::min(k, src_len - offset) : 0; if (actual > 0) std::memcpy(dst, src + offset, actual * sizeof(uint64_t)); if (actual < k) std::memset(dst + actual, 0, (k - actual) * sizeof(uint64_t)); }; for (size_t i = 0; i < 4; i++) pad_coeff(pa + i * k, a, an, i * k); for (size_t i = 0; i < 4; i++) pad_coeff(pb + i * k, b, bn, i * k); uint64_t* pa0 = pa; uint64_t* pa1 = pa + k; uint64_t* pa2 = pa + 2 * k; uint64_t* pa3 = pa + 3 * k; uint64_t* pb0 = pb; uint64_t* pb1 = pb + k; uint64_t* pb2 = pb + 2 * k; uint64_t* pb3 = pb + 3 * k; // --- Scratch レイアウト (パディング後) --- size_t blk = 2 * k + 4; // 各積バッファ (最大 2k+2 limbs + 余裕) uint64_t* w1_buf = scratch + 8 * k; uint64_t* wm1_buf = w1_buf + blk; uint64_t* w2_buf = wm1_buf + blk; uint64_t* wm2_buf = w2_buf + blk; uint64_t* w3_buf = wm2_buf + blk; uint64_t* ea_even = w3_buf + blk; // k+2 limbs (a 偶数項) uint64_t* ea_odd = ea_even + k + 2; // k+2 limbs (a 奇数項) uint64_t* eb_even = ea_odd + k + 2; // k+2 limbs (b 偶数項) uint64_t* eb_odd = eb_even + k + 2; // k+2 limbs (b 奇数項) uint64_t* eval_tmp = eb_odd + k + 2; // k+2 limbs (評価結果) uint64_t* rec_scratch = eval_tmp + k + 2; size_t w1n = 0, wm1n = 0, w2n = 0, wm2n = 0, w3n = 0; int wm1_sign = 0, wm2_sign = 0; // ============================================================ // 点ごとの乗算 (7 回の再帰呼び出し) // 評価: addmul_1 で固定 k-limb 一括処理 // ============================================================ // Point 0: w0 = a0 * b0 → r[0..] (実サイズで呼び出し) size_t a0n_real = std::min(k, an); size_t b0n_real = std::min(k, bn); size_t w0n = 0; if (a0n_real > 0 && b0n_real > 0) { mul_toomcook4(r, a, a0n_real, b, b0n_real, rec_scratch); w0n = normalized_size(r, std::min(a0n_real + b0n_real, rn)); } // Point ∞: winf = a3 * b3 → r[6k..] (実サイズで呼び出し) size_t a3n_real = (an > 3 * k) ? an - 3 * k : 0; size_t b3n_real = (bn > 3 * k) ? bn - 3 * k : 0; size_t winfn = 0; size_t winf_off = 6 * k; if (a3n_real > 0 && b3n_real > 0 && winf_off < rn) { mul_toomcook4(r + winf_off, a + 3 * k, a3n_real, b + 3 * k, b3n_real, rec_scratch); winfn = normalized_size(r + winf_off, std::min(a3n_real + b3n_real, rn - winf_off)); } // ギャップ領域のゼロクリア { size_t gap_start = w0n; size_t gap_end = std::min(winf_off, rn); if (gap_start < gap_end) std::memset(r + gap_start, 0, (gap_end - gap_start) * sizeof(uint64_t)); size_t tail_start = std::min(winf_off + winfn, rn); if (tail_start < rn) std::memset(r + tail_start, 0, (rn - tail_start) * sizeof(uint64_t)); } // ±1 共有: even1 = a0+a2, odd1 = a1+a3 // a(1) = even1 + odd1, a(-1) = |even1 - odd1| { // a 側: even1/odd1 std::memcpy(ea_even, pa0, k * sizeof(uint64_t)); ea_even[k] = add(ea_even, ea_even, k, pa2, k); std::memcpy(ea_odd, pa1, k * sizeof(uint64_t)); ea_odd[k] = add(ea_odd, ea_odd, k, pa3, k); // b 側: even1/odd1 std::memcpy(eb_even, pb0, k * sizeof(uint64_t)); eb_even[k] = add(eb_even, eb_even, k, pb2, k); std::memcpy(eb_odd, pb1, k * sizeof(uint64_t)); eb_odd[k] = add(eb_odd, eb_odd, k, pb3, k); } // Point 1: w1 = a(1) * b(1) = (even1+odd1) * (even1+odd1) { // a(1) → eval_tmp std::memcpy(eval_tmp, ea_even, (k + 1) * sizeof(uint64_t)); eval_tmp[k + 1] = 0; uint64_t cy = add(eval_tmp, eval_tmp, k + 1, ea_odd, k + 1); if (cy) eval_tmp[k + 1] = cy; size_t ean = (eval_tmp[k + 1] ? k + 2 : (eval_tmp[k] ? k + 1 : normalized_size(eval_tmp, k))); // b(1) → wm1_buf (一時利用、Point -1 で上書きされる) std::memcpy(wm1_buf, eb_even, (k + 1) * sizeof(uint64_t)); wm1_buf[k + 1] = 0; cy = add(wm1_buf, wm1_buf, k + 1, eb_odd, k + 1); if (cy) wm1_buf[k + 1] = cy; size_t ebn = (wm1_buf[k + 1] ? k + 2 : (wm1_buf[k] ? k + 1 : normalized_size(wm1_buf, k))); if (ean > 0 && ebn > 0) { mul_toomcook4(w1_buf, eval_tmp, ean, wm1_buf, ebn, rec_scratch); w1n = normalized_size(w1_buf, ean + ebn); } } // Point -1: wm1 = a(-1) * b(-1) = |even1-odd1| * |even1-odd1| { int ea_sign = 1, eb_sign = 1; // a(-1) = |ea_even - ea_odd| → eval_tmp int c = cmp(ea_even, k + 1, ea_odd, k + 1); if (c >= 0) { sub(eval_tmp, ea_even, k + 1, ea_odd, k + 1); } else { sub(eval_tmp, ea_odd, k + 1, ea_even, k + 1); ea_sign = -1; } size_t ean = normalized_size(eval_tmp, k + 1); // b(-1) = |eb_even - eb_odd| → w2_buf (一時利用) c = cmp(eb_even, k + 1, eb_odd, k + 1); if (c >= 0) { sub(w2_buf, eb_even, k + 1, eb_odd, k + 1); } else { sub(w2_buf, eb_odd, k + 1, eb_even, k + 1); eb_sign = -1; } size_t ebn = normalized_size(w2_buf, k + 1); wm1_sign = ea_sign * eb_sign; if (ean > 0 && ebn > 0) { mul_toomcook4(wm1_buf, eval_tmp, ean, w2_buf, ebn, rec_scratch); wm1n = normalized_size(wm1_buf, ean + ebn); } if (wm1n == 0) wm1_sign = 0; } // ±2 共有: even2 = a0+4*a2, odd2 = 2*a1+8*a3 // a(2) = even2 + odd2, a(-2) = |even2 - odd2| { // a 側: even2/odd2 → ea_even/ea_odd を再利用 std::memcpy(ea_even, pa0, k * sizeof(uint64_t)); ea_even[k] = addmul_1(ea_even, pa2, k, 4); ea_odd[k] = mul_1(ea_odd, pa1, k, 2); ea_odd[k] += addmul_1(ea_odd, pa3, k, 8); // b 側: even2/odd2 → eb_even/eb_odd を再利用 std::memcpy(eb_even, pb0, k * sizeof(uint64_t)); eb_even[k] = addmul_1(eb_even, pb2, k, 4); eb_odd[k] = mul_1(eb_odd, pb1, k, 2); eb_odd[k] += addmul_1(eb_odd, pb3, k, 8); } // Point 2: w2 = a(2) * b(2) = (even2+odd2) * (even2+odd2) { // a(2) → eval_tmp std::memcpy(eval_tmp, ea_even, (k + 1) * sizeof(uint64_t)); eval_tmp[k + 1] = 0; uint64_t cy = add(eval_tmp, eval_tmp, k + 1, ea_odd, k + 1); if (cy) eval_tmp[k + 1] = cy; size_t ean = (eval_tmp[k + 1] ? k + 2 : (eval_tmp[k] ? k + 1 : normalized_size(eval_tmp, k))); // b(2) → wm2_buf (一時利用、Point -2 で上書きされる) std::memcpy(wm2_buf, eb_even, (k + 1) * sizeof(uint64_t)); wm2_buf[k + 1] = 0; cy = add(wm2_buf, wm2_buf, k + 1, eb_odd, k + 1); if (cy) wm2_buf[k + 1] = cy; size_t ebn = (wm2_buf[k + 1] ? k + 2 : (wm2_buf[k] ? k + 1 : normalized_size(wm2_buf, k))); if (ean > 0 && ebn > 0) { mul_toomcook4(w2_buf, eval_tmp, ean, wm2_buf, ebn, rec_scratch); w2n = normalized_size(w2_buf, ean + ebn); } } // Point -2: wm2 = a(-2) * b(-2) = |even2-odd2| * |even2-odd2| { int ea_sign = 1, eb_sign = 1; // a(-2) = |ea_even - ea_odd| → eval_tmp int c = cmp(ea_even, k + 1, ea_odd, k + 1); if (c >= 0) { sub(eval_tmp, ea_even, k + 1, ea_odd, k + 1); } else { sub(eval_tmp, ea_odd, k + 1, ea_even, k + 1); ea_sign = -1; } size_t ean = normalized_size(eval_tmp, k + 1); // b(-2) = |eb_even - eb_odd| → w3_buf (一時利用) c = cmp(eb_even, k + 1, eb_odd, k + 1); if (c >= 0) { sub(w3_buf, eb_even, k + 1, eb_odd, k + 1); } else { sub(w3_buf, eb_odd, k + 1, eb_even, k + 1); eb_sign = -1; } size_t ebn = normalized_size(w3_buf, k + 1); wm2_sign = ea_sign * eb_sign; if (ean > 0 && ebn > 0) { mul_toomcook4(wm2_buf, eval_tmp, ean, w3_buf, ebn, rec_scratch); wm2n = normalized_size(wm2_buf, ean + ebn); } if (wm2n == 0) wm2_sign = 0; } // Point 3: w3 = a(3) * b(3) // a(3) = a0 + 3*a1 + 9*a2 + 27*a3 [addmul_1] { std::memcpy(eval_tmp, pa0, k * sizeof(uint64_t)); eval_tmp[k] = 0; eval_tmp[k] += addmul_1(eval_tmp, pa1, k, 3); eval_tmp[k] += addmul_1(eval_tmp, pa2, k, 9); eval_tmp[k] += addmul_1(eval_tmp, pa3, k, 27); size_t ean = k + (eval_tmp[k] ? 1 : 0); // b(3) → ea_even (一時利用、以降 ea_even は不要) std::memcpy(ea_even, pb0, k * sizeof(uint64_t)); ea_even[k] = 0; ea_even[k] += addmul_1(ea_even, pb1, k, 3); ea_even[k] += addmul_1(ea_even, pb2, k, 9); ea_even[k] += addmul_1(ea_even, pb3, k, 27); size_t ebn = k + (ea_even[k] ? 1 : 0); if (ean > 0 && ebn > 0) { mul_toomcook4(w3_buf, eval_tmp, ean, ea_even, ebn, rec_scratch); w3n = normalized_size(w3_buf, ean + ebn); } } // ============================================================ // 補間 // ============================================================ // 補間用テンポラリ: rec_scratch 領域を再利用 (評価の再帰完了後は不要) // 各バッファは最大 2k+4 limbs 必要 uint64_t* tmp1 = rec_scratch; uint64_t* tmp2 = rec_scratch + blk; uint64_t* interp_buf = rec_scratch + 2 * blk; uint64_t* interp_buf2 = rec_scratch + 3 * blk; // // c0 = w0 (r[0..] に配置済み) // c6 = winf (r[6k..] に配置済み) // // Step 1-4: ±1, ±2 の対称性を利用して偶奇分離 // t1 = (w1 + wm1) / 2 = c0 + c2 + c4 + c6 // t2 = (w1 - wm1) / 2 = c1 + c3 + c5 // t3 = (w2 + wm2) / 2 = c0 + 4c2 + 16c4 + 64c6 // t4 = (w2 - wm2) / 2 = 2c1 + 8c3 + 32c5 // Step 1: t1 = w1 ± wm1 → tmp1 size_t t1n; if (wm1_sign >= 0) { t1n = add_any(tmp1, w1_buf, w1n, wm1_buf, wm1n); } else { std::memcpy(tmp1, w1_buf, w1n * sizeof(uint64_t)); t1n = w1n; if (wm1n > 0) { sub(tmp1, tmp1, t1n, wm1_buf, wm1n); t1n = normalized_size(tmp1, t1n); } } // t1 /= 2 if (t1n > 0) t1n = rshift_1(tmp1, tmp1, t1n); // Step 2: t2 = w1 ∓ wm1 → tmp2 size_t t2n; if (wm1_sign >= 0) { std::memcpy(tmp2, w1_buf, w1n * sizeof(uint64_t)); t2n = w1n; if (wm1n > 0) { sub(tmp2, tmp2, t2n, wm1_buf, wm1n); t2n = normalized_size(tmp2, t2n); } } else { t2n = add_any(tmp2, w1_buf, w1n, wm1_buf, wm1n); } // t2 /= 2 if (t2n > 0) t2n = rshift_1(tmp2, tmp2, t2n); // Step 3: t3 = w2 ± wm2 → interp_buf size_t t3n; if (wm2_sign >= 0) { t3n = add_any(interp_buf, w2_buf, w2n, wm2_buf, wm2n); } else { std::memcpy(interp_buf, w2_buf, w2n * sizeof(uint64_t)); t3n = w2n; if (wm2n > 0) { sub(interp_buf, interp_buf, t3n, wm2_buf, wm2n); t3n = normalized_size(interp_buf, t3n); } } // t3 /= 2 if (t3n > 0) t3n = rshift_1(interp_buf, interp_buf, t3n); // Step 4: t4 = w2 ∓ wm2 → interp_buf2 size_t t4n; if (wm2_sign >= 0) { std::memcpy(interp_buf2, w2_buf, w2n * sizeof(uint64_t)); t4n = w2n; if (wm2n > 0) { sub(interp_buf2, interp_buf2, t4n, wm2_buf, wm2n); t4n = normalized_size(interp_buf2, t4n); } } else { t4n = add_any(interp_buf2, w2_buf, w2n, wm2_buf, wm2n); } // t4 /= 2 if (t4n > 0) t4n = rshift_1(interp_buf2, interp_buf2, t4n); // ここまでの変数配置: // tmp1 = t1 (c0+c2+c4+c6) // tmp2 = t2 (c1+c3+c5) // interp_buf = t3 (c0+4c2+16c4+64c6) // interp_buf2= t4 (2c1+8c3+32c5) // w1_buf, wm1_buf, w2_buf, wm2_buf は再利用可能 // Step 9: t5 = t1 - w0 - winf = c2 + c4 → w1_buf if (t1n > 0) std::memcpy(w1_buf, tmp1, t1n * sizeof(uint64_t)); size_t t5n = t1n; if (w0n > 0 && t5n > 0) { sub(w1_buf, w1_buf, t5n, r, w0n); t5n = normalized_size(w1_buf, t5n); } if (winfn > 0 && t5n > 0) { sub(w1_buf, w1_buf, t5n, r + 6 * k, winfn); t5n = normalized_size(w1_buf, t5n); } // Step 10: t6 = t3 - w0 - 64*winf = 4c2 + 16c4 → wm1_buf if (t3n > 0) std::memcpy(wm1_buf, interp_buf, t3n * sizeof(uint64_t)); size_t t6n = t3n; if (w0n > 0 && t6n > 0) { sub(wm1_buf, wm1_buf, t6n, r, w0n); t6n = normalized_size(wm1_buf, t6n); } if (winfn > 0 && t6n > 0) { uint64_t bw = submul_1(wm1_buf, r + winf_off, winfn, 64); if (bw > 0 && t6n > winfn) sub_1(wm1_buf + winfn, t6n - winfn, bw); t6n = normalized_size(wm1_buf, t6n); } // Step 11: t6 = t6 - 4*t5 = 12*c4 → wm1_buf if (t5n > 0 && t6n > 0) { uint64_t bw = submul_1(wm1_buf, w1_buf, t5n, 4); if (bw > 0 && t6n > t5n) sub_1(wm1_buf + t5n, t6n - t5n, bw); t6n = normalized_size(wm1_buf, t6n); } // Step 12: c4 = t6 / 12 → wm2_buf size_t c4n = 0; if (t6n > 0) { c4n = divexact_by12(wm2_buf, wm1_buf, t6n); } uint64_t* c4_ptr = wm2_buf; // Step 13: c2 = t5 - c4 → w1_buf (t5 already there) size_t c2n = t5n; if (c4n > 0 && c2n > 0) { sub(w1_buf, w1_buf, c2n, c4_ptr, c4n); c2n = normalized_size(w1_buf, c2n); } uint64_t* c2_ptr = w1_buf; // Step 14: t8 = t4 - 2*t2 = 6c3+30c5 → wm1_buf if (t4n > 0) std::memcpy(wm1_buf, interp_buf2, t4n * sizeof(uint64_t)); size_t t8n = t4n; if (t2n > 0 && t8n > 0) { uint64_t bw = submul_1(wm1_buf, tmp2, t2n, 2); if (bw > 0 && t8n > t2n) sub_1(wm1_buf + t2n, t8n - t2n, bw); t8n = normalized_size(wm1_buf, t8n); } // Step 15: t9 = w3 - w0 - 9*c2 - 81*c4 - 729*winf → interp_buf if (w3n > 0) std::memcpy(interp_buf, w3_buf, w3n * sizeof(uint64_t)); size_t t9n = w3n; // t9 -= w0 if (w0n > 0 && t9n > 0) { sub(interp_buf, interp_buf, t9n, r, w0n); t9n = normalized_size(interp_buf, t9n); } // t9 -= 9*c2 [submul_1: 1パス] if (c2n > 0 && t9n > 0) { uint64_t bw = submul_1(interp_buf, c2_ptr, c2n, 9); if (bw > 0 && t9n > c2n) sub_1(interp_buf + c2n, t9n - c2n, bw); t9n = normalized_size(interp_buf, t9n); } // t9 -= 81*c4 [submul_1: 1パス] if (c4n > 0 && t9n > 0) { uint64_t bw = submul_1(interp_buf, c4_ptr, c4n, 81); if (bw > 0 && t9n > c4n) sub_1(interp_buf + c4n, t9n - c4n, bw); t9n = normalized_size(interp_buf, t9n); } // t9 -= 729*winf [submul_1: 1パス] if (winfn > 0 && t9n > 0) { uint64_t bw = submul_1(interp_buf, r + winf_off, winfn, 729); if (bw > 0 && t9n > winfn) sub_1(interp_buf + winfn, t9n - winfn, bw); t9n = normalized_size(interp_buf, t9n); } // Step 16: t9 -= 3*t2 [submul_1: 1パス] if (t2n > 0 && t9n > 0) { uint64_t bw = submul_1(interp_buf, tmp2, t2n, 3); if (bw > 0 && t9n > t2n) sub_1(interp_buf + t2n, t9n - t2n, bw); t9n = normalized_size(interp_buf, t9n); } // Step 17: t9 -= 4*t8 [submul_1: 1パス] if (t8n > 0 && t9n > 0) { uint64_t bw = submul_1(interp_buf, wm1_buf, t8n, 4); if (bw > 0 && t9n > t8n) sub_1(interp_buf + t8n, t9n - t8n, bw); t9n = normalized_size(interp_buf, t9n); } // Step 18: c5 = t9 / 120 → w3_buf size_t c5n = 0; if (t9n > 0) { c5n = divexact_by120(w3_buf, interp_buf, t9n); } uint64_t* c5_ptr = w3_buf; // Step 19: t8 = t8 - 30*c5 = 6*c3 → wm1_buf [submul_1: 1パス] if (c5n > 0 && t8n > 0) { uint64_t bw = submul_1(wm1_buf, c5_ptr, c5n, 30); if (bw > 0 && t8n > c5n) sub_1(wm1_buf + c5n, t8n - c5n, bw); t8n = normalized_size(wm1_buf, t8n); } // Step 20: c3 = t8 / 6 = (t8 >> 1) / 3 → interp_buf size_t c3n = 0; if (t8n > 0) { size_t sn = rshift_1(wm1_buf, wm1_buf, t8n); // /2 in-place c3n = divexact_by3(interp_buf, wm1_buf, sn); // /3 } uint64_t* c3_ptr = interp_buf; // Step 21: c1 = t2 - c3 - c5 → interp_buf2 size_t c1n = t2n; if (c1n > 0) { std::memcpy(interp_buf2, tmp2, t2n * sizeof(uint64_t)); if (c3n > 0) { sub(interp_buf2, interp_buf2, c1n, c3_ptr, c3n); c1n = normalized_size(interp_buf2, c1n); } if (c5n > 0 && c1n > 0) { sub(interp_buf2, interp_buf2, c1n, c5_ptr, c5n); c1n = normalized_size(interp_buf2, c1n); } } uint64_t* c1_ptr = interp_buf2; // ============================================================ // 組み立て: r += c1*B^k + c2*B^(2k) + c3*B^(3k) + c4*B^(4k) + c5*B^(5k) // ============================================================ // c0 は r[0..] に、c6 は r[6k..] に配置済み if (c1n > 0 && k < rn) { size_t space = rn - k; uint64_t carry = add(r + k, r + k, space, c1_ptr, std::min(c1n, space)); (void)carry; } if (c2n > 0 && 2 * k < rn) { size_t space = rn - 2 * k; uint64_t carry = add(r + 2 * k, r + 2 * k, space, c2_ptr, std::min(c2n, space)); (void)carry; } if (c3n > 0 && 3 * k < rn) { size_t space = rn - 3 * k; uint64_t carry = add(r + 3 * k, r + 3 * k, space, c3_ptr, std::min(c3n, space)); (void)carry; } if (c4n > 0 && 4 * k < rn) { size_t space = rn - 4 * k; uint64_t carry = add(r + 4 * k, r + 4 * k, space, c4_ptr, std::min(c4n, space)); (void)carry; } if (c5n > 0 && 5 * k < rn) { size_t space = rn - 5 * k; uint64_t carry = add(r + 5 * k, r + 5 * k, space, c5_ptr, std::min(c5n, space)); (void)carry; } } // ================================================================ // Toom-Cook-6 乗算 // ================================================================ // 6 分割、評価点 {0, ±1, ±2, ±3, ±4, ±5, ∞}、12 回の再帰乗算 // 積の次数 10 → 11 係数 (c0..c10) // ± 対称性で偶奇分離し、偶 4x4 + 奇 5x5 の Vandermonde 補間 // Toom-6: Toom-4 より 200-2500 limbs で 1-7% 高速 (Zen 3 実測) // ★ 以前 TOOM4==TOOM6==200 で Toom-4 分岐が死んでいた (Codex 指摘) constexpr size_t TOOMCOOK6_THRESHOLD = 600; // Toom-4: 200-599, Toom-6: 600+ inline size_t mul_toomcook6_scratch_size(size_t n) { if (n < TOOMCOOK6_THRESHOLD) return mul_toomcook4_scratch_size(n); return 80 * n + 1024; } inline void mul_toomcook6(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { if (an < bn) { std::swap(a, b); std::swap(an, bn); } if (bn < TOOMCOOK6_THRESHOLD) { mul_toomcook4(r, a, an, b, bn, scratch); return; } size_t k = (an + 5) / 6; size_t rn = an + bn; // --- 6 分割 --- const uint64_t* a0 = a; size_t a0n = normalized_size(a0, std::min(k, an)); const uint64_t* a1 = a + k; size_t a1n = (an > k) ? normalized_size(a1, std::min(k, an - k)) : 0; const uint64_t* a2 = a + 2 * k; size_t a2n = (an > 2 * k) ? normalized_size(a2, std::min(k, an - 2 * k)) : 0; const uint64_t* a3 = a + 3 * k; size_t a3n = (an > 3 * k) ? normalized_size(a3, std::min(k, an - 3 * k)) : 0; const uint64_t* a4 = a + 4 * k; size_t a4n = (an > 4 * k) ? normalized_size(a4, std::min(k, an - 4 * k)) : 0; const uint64_t* a5 = a + 5 * k; size_t a5n = (an > 5 * k) ? normalized_size(a5, an - 5 * k) : 0; const uint64_t* b0 = b; size_t b0n = normalized_size(b0, std::min(k, bn)); const uint64_t* b1 = b + k; size_t b1n = (bn > k) ? normalized_size(b1, std::min(k, bn - k)) : 0; const uint64_t* b2 = b + 2 * k; size_t b2n = (bn > 2 * k) ? normalized_size(b2, std::min(k, bn - 2 * k)) : 0; const uint64_t* b3 = b + 3 * k; size_t b3n = (bn > 3 * k) ? normalized_size(b3, std::min(k, bn - 3 * k)) : 0; const uint64_t* b4 = b + 4 * k; size_t b4n = (bn > 4 * k) ? normalized_size(b4, std::min(k, bn - 4 * k)) : 0; const uint64_t* b5 = b + 5 * k; size_t b5n = (bn > 5 * k) ? normalized_size(b5, bn - 5 * k) : 0; // --- Scratch レイアウト --- // 各 ±k の積バッファ + 評価テンポラリ + 補間ワーク + 再帰 scratch size_t blk = 2 * (k + 10); uint64_t* w1_buf = scratch; // v(1) の積 uint64_t* wm1_buf = scratch + blk; // v(-1) の積 uint64_t* w2_buf = scratch + 2 * blk; // v(2) の積 uint64_t* wm2_buf = scratch + 3 * blk; // v(-2) の積 uint64_t* w3_buf = scratch + 4 * blk; // v(3) の積 uint64_t* wm3_buf = scratch + 5 * blk; // v(-3) の積 uint64_t* w4_buf = scratch + 6 * blk; // v(4) の積 uint64_t* wm4_buf = scratch + 7 * blk; // v(-4) の積 uint64_t* w5_buf = scratch + 8 * blk; // v(5) の積 uint64_t* wm5_buf = scratch + 9 * blk; // v(-5) の積 uint64_t* tmp1 = scratch + 10 * blk; // 評価テンポラリ 1 uint64_t* tmp2 = scratch + 11 * blk; // 評価テンポラリ 2 uint64_t* tmp3 = scratch + 12 * blk; // 補間ワーク uint64_t* tmp4 = scratch + 13 * blk; // 補間ワーク 2 uint64_t* tmp5 = scratch + 14 * blk; // 補間ワーク 3 uint64_t* rec_scratch = scratch + 15 * blk; // 再帰用 scratch size_t w1n = 0, wm1n = 0, w2n = 0, wm2n = 0; size_t w3n = 0, wm3n = 0, w4n = 0, wm4n = 0; size_t w5n = 0, wm5n = 0; int wm1_sign = 0, wm2_sign = 0, wm3_sign = 0, wm4_sign = 0, wm5_sign = 0; // ============================================================ // 評価・乗算 (12 回の再帰呼び出し) // ============================================================ // --- Helper: 6 項の Horner 評価 a(t) = a0 + t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5)))) --- // 正の t のみ。|t| で偶奇分離する場合はインライン記述。 // Point 0: v0 = a0 * b0 → r[0..] size_t w0n = 0; if (a0n > 0 && b0n > 0) { mul_toomcook6(r, a0, a0n, b0, b0n, rec_scratch); w0n = normalized_size(r, std::min(a0n + b0n, rn)); } // Point ∞: vinf = a5 * b5 → r[10k..] size_t winfn = 0; size_t winf_off = 10 * k; if (a5n > 0 && b5n > 0 && winf_off < rn) { mul_toomcook6(r + winf_off, a5, a5n, b5, b5n, rec_scratch); winfn = normalized_size(r + winf_off, std::min(a5n + b5n, rn - winf_off)); } // ギャップ領域のゼロクリア { size_t gap_start = w0n; size_t gap_end = std::min(winf_off, rn); if (gap_start < gap_end) std::memset(r + gap_start, 0, (gap_end - gap_start) * sizeof(uint64_t)); size_t tail_start = std::min(winf_off + winfn, rn); if (tail_start < rn) std::memset(r + tail_start, 0, (rn - tail_start) * sizeof(uint64_t)); } // ±1, ±2, ±3, ±4, ±5 の評価は偶奇分離パターンを使う // a(t) = (a0 + a2*t² + a4*t⁴) + t*(a1 + a3*t² + a5*t⁴) // a(-t) = (a0 + a2*t² + a4*t⁴) - t*(a1 + a3*t² + a5*t⁴) // --- Helper lambda: 偶部 a_even(t²) = a0 + a2*t² + a4*t⁴ を tmp に計算 --- // --- Helper lambda: 奇部 a_odd(t²) = a1 + a3*t² + a5*t⁴ を tmp に計算 --- // t=1,2,3,4,5 についてインライン展開する // Point ±1: a(±1) = (a0+a2+a4) ± (a1+a3+a5) { // a_even = a0 + a2 + a4 size_t aen = add_any(tmp1, a0, a0n, a2, a2n); aen = add_any(tmp1, tmp1, aen, a4, a4n); // a_odd = a1 + a3 + a5 size_t aon = add_any(tmp3, a1, a1n, a3, a3n); aon = add_any(tmp3, tmp3, aon, a5, a5n); // b_even = b0 + b2 + b4 size_t ben = add_any(tmp2, b0, b0n, b2, b2n); ben = add_any(tmp2, tmp2, ben, b4, b4n); // b_odd = b1 + b3 + b5 size_t bon = add_any(tmp4, b1, b1n, b3, b3n); bon = add_any(tmp4, tmp4, bon, b5, b5n); // a(1) = a_even + a_odd, a(-1) = a_even - a_odd // 評価値を tmp5 (正), tmp1 (負) に格納して乗算 size_t ap_n = add_any(tmp5, tmp1, aen, tmp3, aon); // a(1) int am_sign = 1; size_t am_n = abs_sub(tmp1, am_sign, tmp1, aen, tmp3, aon); // a(-1) size_t bp_n = add_any(wm1_buf, tmp2, ben, tmp4, bon); // b(1) → wm1_buf (一時) int bm_sign = 1; size_t bm_n = abs_sub(tmp2, bm_sign, tmp2, ben, tmp4, bon); // b(-1) // v(1) = a(1) * b(1) if (ap_n > 0 && bp_n > 0) { mul_toomcook6(w1_buf, tmp5, ap_n, wm1_buf, bp_n, rec_scratch); w1n = normalized_size(w1_buf, ap_n + bp_n); } // v(-1) = a(-1) * b(-1) wm1_sign = am_sign * bm_sign; if (am_n > 0 && bm_n > 0) { mul_toomcook6(wm1_buf, tmp1, am_n, tmp2, bm_n, rec_scratch); wm1n = normalized_size(wm1_buf, am_n + bm_n); } if (wm1n == 0) wm1_sign = 0; } // Point ±2: a(±2) = (a0 + 4*a2 + 16*a4) ± 2*(a1 + 4*a3 + 16*a5) { // a_even = a0 + 4*a2 + 16*a4 size_t aen = a0n; if (a0n > 0) std::memcpy(tmp1, a0, a0n * sizeof(uint64_t)); if (a2n > 0) { uint64_t ov = lshift(tmp5, a2, a2n, 2); size_t tn = a2n; if (ov) { tmp5[tn] = ov; tn++; } aen = add_any(tmp1, tmp1, aen, tmp5, tn); } if (a4n > 0) { uint64_t ov = lshift(tmp5, a4, a4n, 4); size_t tn = a4n; if (ov) { tmp5[tn] = ov; tn++; } aen = add_any(tmp1, tmp1, aen, tmp5, tn); } // a_odd_half = a1 + 4*a3 + 16*a5 (あとで 2 倍する) size_t aon = a1n; if (a1n > 0) std::memcpy(tmp3, a1, a1n * sizeof(uint64_t)); if (a3n > 0) { uint64_t ov = lshift(tmp5, a3, a3n, 2); size_t tn = a3n; if (ov) { tmp5[tn] = ov; tn++; } aon = add_any(tmp3, tmp3, aon, tmp5, tn); } if (a5n > 0) { uint64_t ov = lshift(tmp5, a5, a5n, 4); size_t tn = a5n; if (ov) { tmp5[tn] = ov; tn++; } aon = add_any(tmp3, tmp3, aon, tmp5, tn); } // a_odd = 2 * a_odd_half if (aon > 0) { uint64_t ov = lshift(tmp3, tmp3, aon, 1); if (ov) { tmp3[aon] = ov; aon++; } } // b_even = b0 + 4*b2 + 16*b4 size_t ben = b0n; if (b0n > 0) std::memcpy(tmp2, b0, b0n * sizeof(uint64_t)); if (b2n > 0) { uint64_t ov = lshift(tmp5, b2, b2n, 2); size_t tn = b2n; if (ov) { tmp5[tn] = ov; tn++; } ben = add_any(tmp2, tmp2, ben, tmp5, tn); } if (b4n > 0) { uint64_t ov = lshift(tmp5, b4, b4n, 4); size_t tn = b4n; if (ov) { tmp5[tn] = ov; tn++; } ben = add_any(tmp2, tmp2, ben, tmp5, tn); } // b_odd_half = b1 + 4*b3 + 16*b5 size_t bon = b1n; if (b1n > 0) std::memcpy(tmp4, b1, b1n * sizeof(uint64_t)); if (b3n > 0) { uint64_t ov = lshift(tmp5, b3, b3n, 2); size_t tn = b3n; if (ov) { tmp5[tn] = ov; tn++; } bon = add_any(tmp4, tmp4, bon, tmp5, tn); } if (b5n > 0) { uint64_t ov = lshift(tmp5, b5, b5n, 4); size_t tn = b5n; if (ov) { tmp5[tn] = ov; tn++; } bon = add_any(tmp4, tmp4, bon, tmp5, tn); } // b_odd = 2 * b_odd_half if (bon > 0) { uint64_t ov = lshift(tmp4, tmp4, bon, 1); if (ov) { tmp4[bon] = ov; bon++; } } // a(2) = a_even + a_odd, a(-2) = a_even - a_odd size_t ap_n = add_any(tmp5, tmp1, aen, tmp3, aon); int am_sign = 1; size_t am_n = abs_sub(tmp1, am_sign, tmp1, aen, tmp3, aon); size_t bp_n = add_any(wm2_buf, tmp2, ben, tmp4, bon); // b(2) → wm2_buf (一時) int bm_sign = 1; size_t bm_n = abs_sub(tmp2, bm_sign, tmp2, ben, tmp4, bon); if (ap_n > 0 && bp_n > 0) { mul_toomcook6(w2_buf, tmp5, ap_n, wm2_buf, bp_n, rec_scratch); w2n = normalized_size(w2_buf, ap_n + bp_n); } wm2_sign = am_sign * bm_sign; if (am_n > 0 && bm_n > 0) { mul_toomcook6(wm2_buf, tmp1, am_n, tmp2, bm_n, rec_scratch); wm2n = normalized_size(wm2_buf, am_n + bm_n); } if (wm2n == 0) wm2_sign = 0; } // Point ±3: a(±3) = (a0+9a2+81a4) ± 3(a1+9a3+81a5) { // a_even = a0 + 9*a2 + 81*a4 size_t aen = a0n; if (a0n > 0) std::memcpy(tmp1, a0, a0n * sizeof(uint64_t)); if (a2n > 0) { uint64_t ov = mul_1(tmp5, a2, a2n, 9); size_t tn = a2n; if (ov) { tmp5[tn] = ov; tn++; } aen = add_any(tmp1, tmp1, aen, tmp5, tn); } if (a4n > 0) { uint64_t ov = mul_1(tmp5, a4, a4n, 81); size_t tn = a4n; if (ov) { tmp5[tn] = ov; tn++; } aen = add_any(tmp1, tmp1, aen, tmp5, tn); } // a_odd_third = a1 + 9*a3 + 81*a5, then *3 size_t aon = a1n; if (a1n > 0) std::memcpy(tmp3, a1, a1n * sizeof(uint64_t)); if (a3n > 0) { uint64_t ov = mul_1(tmp5, a3, a3n, 9); size_t tn = a3n; if (ov) { tmp5[tn] = ov; tn++; } aon = add_any(tmp3, tmp3, aon, tmp5, tn); } if (a5n > 0) { uint64_t ov = mul_1(tmp5, a5, a5n, 81); size_t tn = a5n; if (ov) { tmp5[tn] = ov; tn++; } aon = add_any(tmp3, tmp3, aon, tmp5, tn); } if (aon > 0) { uint64_t ov = mul_1(tmp3, tmp3, aon, 3); if (ov) { tmp3[aon] = ov; aon++; } } // b_even = b0 + 9*b2 + 81*b4 size_t ben = b0n; if (b0n > 0) std::memcpy(tmp2, b0, b0n * sizeof(uint64_t)); if (b2n > 0) { uint64_t ov = mul_1(tmp5, b2, b2n, 9); size_t tn = b2n; if (ov) { tmp5[tn] = ov; tn++; } ben = add_any(tmp2, tmp2, ben, tmp5, tn); } if (b4n > 0) { uint64_t ov = mul_1(tmp5, b4, b4n, 81); size_t tn = b4n; if (ov) { tmp5[tn] = ov; tn++; } ben = add_any(tmp2, tmp2, ben, tmp5, tn); } // b_odd_third = b1 + 9*b3 + 81*b5, then *3 size_t bon = b1n; if (b1n > 0) std::memcpy(tmp4, b1, b1n * sizeof(uint64_t)); if (b3n > 0) { uint64_t ov = mul_1(tmp5, b3, b3n, 9); size_t tn = b3n; if (ov) { tmp5[tn] = ov; tn++; } bon = add_any(tmp4, tmp4, bon, tmp5, tn); } if (b5n > 0) { uint64_t ov = mul_1(tmp5, b5, b5n, 81); size_t tn = b5n; if (ov) { tmp5[tn] = ov; tn++; } bon = add_any(tmp4, tmp4, bon, tmp5, tn); } if (bon > 0) { uint64_t ov = mul_1(tmp4, tmp4, bon, 3); if (ov) { tmp4[bon] = ov; bon++; } } size_t ap_n = add_any(tmp5, tmp1, aen, tmp3, aon); int am_sign = 1; size_t am_n = abs_sub(tmp1, am_sign, tmp1, aen, tmp3, aon); size_t bp_n = add_any(wm3_buf, tmp2, ben, tmp4, bon); // b(3) → wm3_buf (一時) int bm_sign = 1; size_t bm_n = abs_sub(tmp2, bm_sign, tmp2, ben, tmp4, bon); if (ap_n > 0 && bp_n > 0) { mul_toomcook6(w3_buf, tmp5, ap_n, wm3_buf, bp_n, rec_scratch); w3n = normalized_size(w3_buf, ap_n + bp_n); } wm3_sign = am_sign * bm_sign; if (am_n > 0 && bm_n > 0) { mul_toomcook6(wm3_buf, tmp1, am_n, tmp2, bm_n, rec_scratch); wm3n = normalized_size(wm3_buf, am_n + bm_n); } if (wm3n == 0) wm3_sign = 0; } // Point ±4: a(±4) = (a0+16a2+256a4) ± 4(a1+16a3+256a5) { size_t aen = a0n; if (a0n > 0) std::memcpy(tmp1, a0, a0n * sizeof(uint64_t)); if (a2n > 0) { uint64_t ov = lshift(tmp5, a2, a2n, 4); size_t tn = a2n; if (ov) { tmp5[tn] = ov; tn++; } aen = add_any(tmp1, tmp1, aen, tmp5, tn); } if (a4n > 0) { uint64_t ov = lshift(tmp5, a4, a4n, 8); size_t tn = a4n; if (ov) { tmp5[tn] = ov; tn++; } aen = add_any(tmp1, tmp1, aen, tmp5, tn); } size_t aon = a1n; if (a1n > 0) std::memcpy(tmp3, a1, a1n * sizeof(uint64_t)); if (a3n > 0) { uint64_t ov = lshift(tmp5, a3, a3n, 4); size_t tn = a3n; if (ov) { tmp5[tn] = ov; tn++; } aon = add_any(tmp3, tmp3, aon, tmp5, tn); } if (a5n > 0) { uint64_t ov = lshift(tmp5, a5, a5n, 8); size_t tn = a5n; if (ov) { tmp5[tn] = ov; tn++; } aon = add_any(tmp3, tmp3, aon, tmp5, tn); } if (aon > 0) { uint64_t ov = lshift(tmp3, tmp3, aon, 2); // *4 if (ov) { tmp3[aon] = ov; aon++; } } size_t ben = b0n; if (b0n > 0) std::memcpy(tmp2, b0, b0n * sizeof(uint64_t)); if (b2n > 0) { uint64_t ov = lshift(tmp5, b2, b2n, 4); size_t tn = b2n; if (ov) { tmp5[tn] = ov; tn++; } ben = add_any(tmp2, tmp2, ben, tmp5, tn); } if (b4n > 0) { uint64_t ov = lshift(tmp5, b4, b4n, 8); size_t tn = b4n; if (ov) { tmp5[tn] = ov; tn++; } ben = add_any(tmp2, tmp2, ben, tmp5, tn); } size_t bon = b1n; if (b1n > 0) std::memcpy(tmp4, b1, b1n * sizeof(uint64_t)); if (b3n > 0) { uint64_t ov = lshift(tmp5, b3, b3n, 4); size_t tn = b3n; if (ov) { tmp5[tn] = ov; tn++; } bon = add_any(tmp4, tmp4, bon, tmp5, tn); } if (b5n > 0) { uint64_t ov = lshift(tmp5, b5, b5n, 8); size_t tn = b5n; if (ov) { tmp5[tn] = ov; tn++; } bon = add_any(tmp4, tmp4, bon, tmp5, tn); } if (bon > 0) { uint64_t ov = lshift(tmp4, tmp4, bon, 2); // *4 if (ov) { tmp4[bon] = ov; bon++; } } size_t ap_n = add_any(tmp5, tmp1, aen, tmp3, aon); int am_sign = 1; size_t am_n = abs_sub(tmp1, am_sign, tmp1, aen, tmp3, aon); size_t bp_n = add_any(wm4_buf, tmp2, ben, tmp4, bon); // b(4) → wm4_buf (一時) int bm_sign = 1; size_t bm_n = abs_sub(tmp2, bm_sign, tmp2, ben, tmp4, bon); if (ap_n > 0 && bp_n > 0) { mul_toomcook6(w4_buf, tmp5, ap_n, wm4_buf, bp_n, rec_scratch); w4n = normalized_size(w4_buf, ap_n + bp_n); } wm4_sign = am_sign * bm_sign; if (am_n > 0 && bm_n > 0) { mul_toomcook6(wm4_buf, tmp1, am_n, tmp2, bm_n, rec_scratch); wm4n = normalized_size(wm4_buf, am_n + bm_n); } if (wm4n == 0) wm4_sign = 0; } // Point ±5: a(±5) = (a0+25a2+625a4) ± 5(a1+25a3+625a5) { size_t aen = a0n; if (a0n > 0) std::memcpy(tmp1, a0, a0n * sizeof(uint64_t)); if (a2n > 0) { uint64_t ov = mul_1(tmp5, a2, a2n, 25); size_t tn = a2n; if (ov) { tmp5[tn] = ov; tn++; } aen = add_any(tmp1, tmp1, aen, tmp5, tn); } if (a4n > 0) { uint64_t ov = mul_1(tmp5, a4, a4n, 625); size_t tn = a4n; if (ov) { tmp5[tn] = ov; tn++; } aen = add_any(tmp1, tmp1, aen, tmp5, tn); } size_t aon = a1n; if (a1n > 0) std::memcpy(tmp3, a1, a1n * sizeof(uint64_t)); if (a3n > 0) { uint64_t ov = mul_1(tmp5, a3, a3n, 25); size_t tn = a3n; if (ov) { tmp5[tn] = ov; tn++; } aon = add_any(tmp3, tmp3, aon, tmp5, tn); } if (a5n > 0) { uint64_t ov = mul_1(tmp5, a5, a5n, 625); size_t tn = a5n; if (ov) { tmp5[tn] = ov; tn++; } aon = add_any(tmp3, tmp3, aon, tmp5, tn); } if (aon > 0) { uint64_t ov = mul_1(tmp3, tmp3, aon, 5); if (ov) { tmp3[aon] = ov; aon++; } } size_t ben = b0n; if (b0n > 0) std::memcpy(tmp2, b0, b0n * sizeof(uint64_t)); if (b2n > 0) { uint64_t ov = mul_1(tmp5, b2, b2n, 25); size_t tn = b2n; if (ov) { tmp5[tn] = ov; tn++; } ben = add_any(tmp2, tmp2, ben, tmp5, tn); } if (b4n > 0) { uint64_t ov = mul_1(tmp5, b4, b4n, 625); size_t tn = b4n; if (ov) { tmp5[tn] = ov; tn++; } ben = add_any(tmp2, tmp2, ben, tmp5, tn); } size_t bon = b1n; if (b1n > 0) std::memcpy(tmp4, b1, b1n * sizeof(uint64_t)); if (b3n > 0) { uint64_t ov = mul_1(tmp5, b3, b3n, 25); size_t tn = b3n; if (ov) { tmp5[tn] = ov; tn++; } bon = add_any(tmp4, tmp4, bon, tmp5, tn); } if (b5n > 0) { uint64_t ov = mul_1(tmp5, b5, b5n, 625); size_t tn = b5n; if (ov) { tmp5[tn] = ov; tn++; } bon = add_any(tmp4, tmp4, bon, tmp5, tn); } if (bon > 0) { uint64_t ov = mul_1(tmp4, tmp4, bon, 5); if (ov) { tmp4[bon] = ov; bon++; } } size_t ap_n = add_any(tmp5, tmp1, aen, tmp3, aon); int am_sign = 1; size_t am_n = abs_sub(tmp1, am_sign, tmp1, aen, tmp3, aon); size_t bp_n = add_any(wm5_buf, tmp2, ben, tmp4, bon); // b(5) → wm5_buf (一時) int bm_sign = 1; size_t bm_n = abs_sub(tmp2, bm_sign, tmp2, ben, tmp4, bon); if (ap_n > 0 && bp_n > 0) { mul_toomcook6(w5_buf, tmp5, ap_n, wm5_buf, bp_n, rec_scratch); w5n = normalized_size(w5_buf, ap_n + bp_n); } wm5_sign = am_sign * bm_sign; if (am_n > 0 && bm_n > 0) { mul_toomcook6(wm5_buf, tmp1, am_n, tmp2, bm_n, rec_scratch); wm5n = normalized_size(wm5_buf, am_n + bm_n); } if (wm5n == 0) wm5_sign = 0; } // ============================================================ // 補間 — ± 対称性で偶奇分離 → Vandermonde 解法 // ============================================================ // // c0 = v0 (r[0..]), c10 = vinf (r[10k..]) // // ± coupling: // e_k = (v(k) + v(-k)) / 2 = c0 + c2*k² + c4*k⁴ + c6*k⁶ + c8*k⁸ + c10*k¹⁰ // o_k = (v(k) - v(-k)) / (2k) = c1 + c3*k² + c5*k⁴ + c7*k⁶ + c9*k⁸ // // Even: E_k = e_k - c0 - c10*k¹⁰ → 4x4 Vandermonde (k=1,2,3,4) for c2,c4,c6,c8 // Odd: o_k for k=1,2,3,4,5 → 5x5 Vandermonde for c1,c3,c5,c7,c9 // バッファ再利用: 評価完了後、w*_buf は補間に利用可能 // e_k → w(2k-1)_buf を再利用、o_k → wm(2k-1)_buf を再利用 // ただしバッファサイズ管理のため、新たに名前付けする // ---- Step 1: ± coupling (偶奇分離) ---- // e1 = (v1 + vm1) / 2, o1 = (v1 - vm1) / 2 // バッファ配置: e1 → tmp1, o1 → tmp2 size_t e1n, o1n; { // e1 = (w1 + wm1) / 2 if (wm1_sign >= 0) { e1n = add_any(tmp1, w1_buf, w1n, wm1_buf, wm1n); } else { std::memcpy(tmp1, w1_buf, w1n * sizeof(uint64_t)); e1n = w1n; if (wm1n > 0) { sub(tmp1, tmp1, e1n, wm1_buf, wm1n); e1n = normalized_size(tmp1, e1n); } } if (e1n > 0) e1n = rshift_1(tmp1, tmp1, e1n); // o1 = (w1 - wm1) / 2 if (wm1_sign >= 0) { std::memcpy(tmp2, w1_buf, w1n * sizeof(uint64_t)); o1n = w1n; if (wm1n > 0) { sub(tmp2, tmp2, o1n, wm1_buf, wm1n); o1n = normalized_size(tmp2, o1n); } } else { o1n = add_any(tmp2, w1_buf, w1n, wm1_buf, wm1n); } if (o1n > 0) o1n = rshift_1(tmp2, tmp2, o1n); } // e1 in tmp1, o1 in tmp2 // e2 = (w2 + wm2) / 2, o2 = (w2 - wm2) / 4 size_t e2n, o2n; { if (wm2_sign >= 0) { e2n = add_any(w1_buf, w2_buf, w2n, wm2_buf, wm2n); } else { std::memcpy(w1_buf, w2_buf, w2n * sizeof(uint64_t)); e2n = w2n; if (wm2n > 0) { sub(w1_buf, w1_buf, e2n, wm2_buf, wm2n); e2n = normalized_size(w1_buf, e2n); } } if (e2n > 0) e2n = rshift_1(w1_buf, w1_buf, e2n); if (wm2_sign >= 0) { std::memcpy(wm1_buf, w2_buf, w2n * sizeof(uint64_t)); o2n = w2n; if (wm2n > 0) { sub(wm1_buf, wm1_buf, o2n, wm2_buf, wm2n); o2n = normalized_size(wm1_buf, o2n); } } else { o2n = add_any(wm1_buf, w2_buf, w2n, wm2_buf, wm2n); } // /4 = >>2 if (o2n > 0) o2n = shift_right_2(wm1_buf, wm1_buf, o2n); } // e2 in w1_buf, o2 in wm1_buf // e3 = (w3 + wm3) / 2, o3 = (w3 - wm3) / 6 size_t e3n, o3n; { if (wm3_sign >= 0) { e3n = add_any(w2_buf, w3_buf, w3n, wm3_buf, wm3n); } else { std::memcpy(w2_buf, w3_buf, w3n * sizeof(uint64_t)); e3n = w3n; if (wm3n > 0) { sub(w2_buf, w2_buf, e3n, wm3_buf, wm3n); e3n = normalized_size(w2_buf, e3n); } } if (e3n > 0) e3n = rshift_1(w2_buf, w2_buf, e3n); if (wm3_sign >= 0) { std::memcpy(wm2_buf, w3_buf, w3n * sizeof(uint64_t)); o3n = w3n; if (wm3n > 0) { sub(wm2_buf, wm2_buf, o3n, wm3_buf, wm3n); o3n = normalized_size(wm2_buf, o3n); } } else { o3n = add_any(wm2_buf, w3_buf, w3n, wm3_buf, wm3n); } // /6 = >>1 then /3 if (o3n > 0) { o3n = rshift_1(wm2_buf, wm2_buf, o3n); o3n = divexact_by3(wm2_buf, wm2_buf, o3n); } } // e3 in w2_buf, o3 in wm2_buf // e4 = (w4 + wm4) / 2, o4 = (w4 - wm4) / 8 size_t e4n, o4n; { if (wm4_sign >= 0) { e4n = add_any(w3_buf, w4_buf, w4n, wm4_buf, wm4n); } else { std::memcpy(w3_buf, w4_buf, w4n * sizeof(uint64_t)); e4n = w4n; if (wm4n > 0) { sub(w3_buf, w3_buf, e4n, wm4_buf, wm4n); e4n = normalized_size(w3_buf, e4n); } } if (e4n > 0) e4n = rshift_1(w3_buf, w3_buf, e4n); if (wm4_sign >= 0) { std::memcpy(wm3_buf, w4_buf, w4n * sizeof(uint64_t)); o4n = w4n; if (wm4n > 0) { sub(wm3_buf, wm3_buf, o4n, wm4_buf, wm4n); o4n = normalized_size(wm3_buf, o4n); } } else { o4n = add_any(wm3_buf, w4_buf, w4n, wm4_buf, wm4n); } // /8 = >>3 if (o4n > 0) o4n = shift_right_3(wm3_buf, wm3_buf, o4n); } // e4 in w3_buf, o4 in wm3_buf // e5 = (w5 + wm5) / 2, o5 = (w5 - wm5) / 10 size_t e5n, o5n; { if (wm5_sign >= 0) { e5n = add_any(w4_buf, w5_buf, w5n, wm5_buf, wm5n); } else { std::memcpy(w4_buf, w5_buf, w5n * sizeof(uint64_t)); e5n = w5n; if (wm5n > 0) { sub(w4_buf, w4_buf, e5n, wm5_buf, wm5n); e5n = normalized_size(w4_buf, e5n); } } if (e5n > 0) e5n = rshift_1(w4_buf, w4_buf, e5n); if (wm5_sign >= 0) { std::memcpy(wm4_buf, w5_buf, w5n * sizeof(uint64_t)); o5n = w5n; if (wm5n > 0) { sub(wm4_buf, wm4_buf, o5n, wm5_buf, wm5n); o5n = normalized_size(wm4_buf, o5n); } } else { o5n = add_any(wm4_buf, w5_buf, w5n, wm5_buf, wm5n); } // /10 = >>1 then /5 if (o5n > 0) { o5n = rshift_1(wm4_buf, wm4_buf, o5n); o5n = divexact_by5(wm4_buf, wm4_buf, o5n); } } // e5 in w4_buf, o5 in wm4_buf // ---- バッファ配置のまとめ ---- // e1: tmp1, e2: w1_buf, e3: w2_buf, e4: w3_buf, e5: w4_buf // o1: tmp2, o2: wm1_buf, o3: wm2_buf, o4: wm3_buf, o5: wm4_buf // 空きバッファ: w5_buf, wm5_buf, tmp3, tmp4, tmp5 // ---- Step 2: Even 系の解法 ---- // E_k = e_k - c0 - c10 * k^10 for k=1,2,3,4 // E1 = e1 - c0 - c10 → tmp1 上書き // E2 = e2 - c0 - 1024*c10 → w1_buf 上書き // E3 = e3 - c0 - 59049*c10 → w2_buf 上書き // E4 = e4 - c0 - 1048576*c10 → w3_buf 上書き // E1 size_t E1n = e1n; if (w0n > 0 && E1n > 0) { sub(tmp1, tmp1, E1n, r, w0n); E1n = normalized_size(tmp1, E1n); } if (winfn > 0 && E1n > 0) { sub(tmp1, tmp1, E1n, r + winf_off, winfn); E1n = normalized_size(tmp1, E1n); } // E2: e2 - c0 - 1024*c10 size_t E2n = e2n; if (w0n > 0 && E2n > 0) { sub(w1_buf, w1_buf, E2n, r, w0n); E2n = normalized_size(w1_buf, E2n); } if (winfn > 0 && E2n > 0) { uint64_t ov = lshift(w5_buf, r + winf_off, winfn, 10); // 1024 = 2^10 size_t tn = winfn; if (ov) { w5_buf[tn] = ov; tn++; } sub(w1_buf, w1_buf, E2n, w5_buf, tn); E2n = normalized_size(w1_buf, E2n); } // E3: e3 - c0 - 59049*c10 (3^10 = 59049) size_t E3n = e3n; if (w0n > 0 && E3n > 0) { sub(w2_buf, w2_buf, E3n, r, w0n); E3n = normalized_size(w2_buf, E3n); } if (winfn > 0 && E3n > 0) { uint64_t ov = mul_1(w5_buf, r + winf_off, winfn, 59049); size_t tn = winfn; if (ov) { w5_buf[tn] = ov; tn++; } sub(w2_buf, w2_buf, E3n, w5_buf, tn); E3n = normalized_size(w2_buf, E3n); } // E4: e4 - c0 - 1048576*c10 (4^10 = 2^20 = 1048576) size_t E4n = e4n; if (w0n > 0 && E4n > 0) { sub(w3_buf, w3_buf, E4n, r, w0n); E4n = normalized_size(w3_buf, E4n); } if (winfn > 0 && E4n > 0) { uint64_t ov = lshift(w5_buf, r + winf_off, winfn, 20); // 2^20 size_t tn = winfn; if (ov) { w5_buf[tn] = ov; tn++; } sub(w3_buf, w3_buf, E4n, w5_buf, tn); E4n = normalized_size(w3_buf, E4n); } // Even Vandermonde 解法 (4x4, 変数 c2,c4,c6,c8, ノード u=1,4,9,16): // E1 = c2 + c4 + c6 + c8 // E2 = 4c2 + 16c4 + 64c6 + 256c8 // E3 = 9c2 + 81c4 + 729c6 + 6561c8 // E4 = 16c2 + 256c4 + 4096c6 + 65536c8 // // A = E2 - 4*E1 = 12c4 + 60c6 + 252c8 // B = E3 - 9*E1 = 72c4 + 720c6 + 6552c8 // C = E4 - 16*E1 = 240c4 + 4080c6 + 65520c8 // D = B - 6*A = 360c6 + 5040c8 // E_ = C - 20*A = 2880c6 + 60480c8 // H = E_ - 8*D = 20160*c8 // E1 を tmp5 に退避 (c2 計算用) size_t saved_E1n = E1n; if (E1n > 0) std::memcpy(tmp5, tmp1, E1n * sizeof(uint64_t)); // A = E2 - 4*E1 → w5_buf size_t An; { if (E1n > 0) { uint64_t ov = lshift(wm5_buf, tmp1, E1n, 2); // 4*E1 size_t tn = E1n; if (ov) { wm5_buf[tn] = ov; tn++; } std::memcpy(w5_buf, w1_buf, E2n * sizeof(uint64_t)); An = E2n; sub(w5_buf, w5_buf, An, wm5_buf, tn); An = normalized_size(w5_buf, An); } else { std::memcpy(w5_buf, w1_buf, E2n * sizeof(uint64_t)); An = E2n; } } // B = E3 - 9*E1 → wm5_buf size_t Bn; { if (E1n > 0) { uint64_t ov = mul_1(tmp3, tmp1, E1n, 9); size_t tn = E1n; if (ov) { tmp3[tn] = ov; tn++; } std::memcpy(wm5_buf, w2_buf, E3n * sizeof(uint64_t)); Bn = E3n; sub(wm5_buf, wm5_buf, Bn, tmp3, tn); Bn = normalized_size(wm5_buf, Bn); } else { std::memcpy(wm5_buf, w2_buf, E3n * sizeof(uint64_t)); Bn = E3n; } } // C = E4 - 16*E1 → tmp3 size_t Cn; { if (E1n > 0) { uint64_t ov = lshift(tmp4, tmp1, E1n, 4); size_t tn = E1n; if (ov) { tmp4[tn] = ov; tn++; } std::memcpy(tmp3, w3_buf, E4n * sizeof(uint64_t)); Cn = E4n; sub(tmp3, tmp3, Cn, tmp4, tn); Cn = normalized_size(tmp3, Cn); } else { std::memcpy(tmp3, w3_buf, E4n * sizeof(uint64_t)); Cn = E4n; } } // A in w5_buf, B in wm5_buf, C in tmp3 // E1 in tmp1 (no longer needed after this) // D = B - 6*A → tmp4 size_t Dn; { std::memcpy(tmp4, wm5_buf, Bn * sizeof(uint64_t)); Dn = Bn; if (An > 0) { uint64_t ov = mul_1(tmp1, w5_buf, An, 6); size_t tn = An; if (ov) { tmp1[tn] = ov; tn++; } sub(tmp4, tmp4, Dn, tmp1, tn); Dn = normalized_size(tmp4, Dn); } } // E_ = C - 20*A → tmp1 size_t E_n; { std::memcpy(tmp1, tmp3, Cn * sizeof(uint64_t)); E_n = Cn; if (An > 0) { uint64_t ov = mul_1(tmp3, w5_buf, An, 20); size_t tn = An; if (ov) { tmp3[tn] = ov; tn++; } sub(tmp1, tmp1, E_n, tmp3, tn); E_n = normalized_size(tmp1, E_n); } } // H = E_ - 8*D → tmp3 size_t Hn; { std::memcpy(tmp3, tmp1, E_n * sizeof(uint64_t)); Hn = E_n; if (Dn > 0) { uint64_t ov = lshift(tmp1, tmp4, Dn, 3); size_t tn = Dn; if (ov) { tmp1[tn] = ov; tn++; } sub(tmp3, tmp3, Hn, tmp1, tn); Hn = normalized_size(tmp3, Hn); } } // c8 = H / 20160 = >>6, /9, /5, /7 (20160 = 2^6 * 3^2 * 5 * 7) size_t c8n = 0; if (Hn > 0) { // >>6 for (size_t i = 0; i + 1 < Hn; i++) tmp3[i] = (tmp3[i] >> 6) | (tmp3[i + 1] << 58); tmp3[Hn - 1] >>= 6; c8n = normalized_size(tmp3, Hn); // /9 = /3, /3 c8n = divexact_by3(tmp3, tmp3, c8n); c8n = divexact_by3(tmp3, tmp3, c8n); // /5 c8n = divexact_by5(tmp3, tmp3, c8n); // /7 c8n = divexact_by7(tmp3, tmp3, c8n); } uint64_t* c8_ptr = tmp3; // c6 = (D - 5040*c8) / 360 // 5040 = 7*720 = 7*16*45 = 2^4 * 3^2 * 5 * 7 size_t c6n = Dn; // tmp4 に D がある if (c8n > 0 && c6n > 0) { uint64_t ov = mul_1(tmp1, c8_ptr, c8n, 5040); size_t tn = c8n; if (ov) { tmp1[tn] = ov; tn++; } sub(tmp4, tmp4, c6n, tmp1, tn); c6n = normalized_size(tmp4, c6n); } // /360 = >>3, /3, /3, /5 if (c6n > 0) { c6n = shift_right_3(tmp4, tmp4, c6n); c6n = divexact_by3(tmp4, tmp4, c6n); c6n = divexact_by3(tmp4, tmp4, c6n); c6n = divexact_by5(tmp4, tmp4, c6n); } uint64_t* c6_ptr = tmp4; // c4 = (A - 60*c6 - 252*c8) / 12 size_t c4n = An; // w5_buf に A がある if (c6n > 0 && c4n > 0) { uint64_t ov = mul_1(tmp1, c6_ptr, c6n, 60); size_t tn = c6n; if (ov) { tmp1[tn] = ov; tn++; } sub(w5_buf, w5_buf, c4n, tmp1, tn); c4n = normalized_size(w5_buf, c4n); } if (c8n > 0 && c4n > 0) { uint64_t ov = mul_1(tmp1, c8_ptr, c8n, 252); size_t tn = c8n; if (ov) { tmp1[tn] = ov; tn++; } sub(w5_buf, w5_buf, c4n, tmp1, tn); c4n = normalized_size(w5_buf, c4n); } if (c4n > 0) c4n = divexact_by12(w5_buf, w5_buf, c4n); uint64_t* c4_ptr = w5_buf; // c2 = E1 - c4 - c6 - c8 (E1 は tmp5 に退避済み) size_t c2n = saved_E1n; if (saved_E1n > 0) std::memcpy(wm5_buf, tmp5, saved_E1n * sizeof(uint64_t)); if (c4n > 0 && c2n > 0) { sub(wm5_buf, wm5_buf, c2n, c4_ptr, c4n); c2n = normalized_size(wm5_buf, c2n); } if (c6n > 0 && c2n > 0) { sub(wm5_buf, wm5_buf, c2n, c6_ptr, c6n); c2n = normalized_size(wm5_buf, c2n); } if (c8n > 0 && c2n > 0) { sub(wm5_buf, wm5_buf, c2n, c8_ptr, c8n); c2n = normalized_size(wm5_buf, c2n); } uint64_t* c2_ptr = wm5_buf; // ---- Step 3: Odd 系の解法 ---- // o_k = c1 + c3*k² + c5*k⁴ + c7*k⁶ + c9*k⁸ for k=1,2,3,4,5 // 5x5 Vandermonde (ノード u=1,4,9,16,25) // // o1 in tmp2, o2 in wm1_buf, o3 in wm2_buf, o4 in wm3_buf, o5 in wm4_buf // // D1 = o2 - o1 = 3c3 + 15c5 + 63c7 + 255c9 // D2 = o3 - o1 = 8c3 + 80c5 + 728c7 + 6560c9 // D3 = o4 - o1 = 15c3 + 255c5 + 4095c7 + 65535c9 // D4 = o5 - o1 = 24c3 + 624c5 + 15624c7 + 390624c9 // D1 = o2 - o1 → w1_buf size_t D1n; { std::memcpy(w1_buf, wm1_buf, o2n * sizeof(uint64_t)); D1n = o2n; if (o1n > 0 && D1n > 0) { sub(w1_buf, w1_buf, D1n, tmp2, o1n); D1n = normalized_size(w1_buf, D1n); } } // D2 = o3 - o1 → w2_buf size_t D2n; { std::memcpy(w2_buf, wm2_buf, o3n * sizeof(uint64_t)); D2n = o3n; if (o1n > 0 && D2n > 0) { sub(w2_buf, w2_buf, D2n, tmp2, o1n); D2n = normalized_size(w2_buf, D2n); } } // D3 = o4 - o1 → w3_buf size_t D3n; { std::memcpy(w3_buf, wm3_buf, o4n * sizeof(uint64_t)); D3n = o4n; if (o1n > 0 && D3n > 0) { sub(w3_buf, w3_buf, D3n, tmp2, o1n); D3n = normalized_size(w3_buf, D3n); } } // D4 = o5 - o1 → w4_buf (e5 was there but we've consumed it) size_t D4n; { std::memcpy(w4_buf, wm4_buf, o5n * sizeof(uint64_t)); D4n = o5n; if (o1n > 0 && D4n > 0) { sub(w4_buf, w4_buf, D4n, tmp2, o1n); D4n = normalized_size(w4_buf, D4n); } } // F1 = 3*D2 - 8*D1 = 120c5 + 1680c7 + 17640c9 → wm1_buf size_t F1n; { if (D2n > 0) { uint64_t ov = mul_1(wm1_buf, w2_buf, D2n, 3); F1n = D2n; if (ov) { wm1_buf[F1n] = ov; F1n++; } } else { F1n = 0; } if (D1n > 0 && F1n > 0) { uint64_t ov = lshift(tmp1, w1_buf, D1n, 3); size_t tn = D1n; if (ov) { tmp1[tn] = ov; tn++; } sub(wm1_buf, wm1_buf, F1n, tmp1, tn); F1n = normalized_size(wm1_buf, F1n); } } // F2 = 3*D3 - 15*D1 = 540c5 + 11340c7 + 192780c9 → wm2_buf size_t F2n; { if (D3n > 0) { uint64_t ov = mul_1(wm2_buf, w3_buf, D3n, 3); F2n = D3n; if (ov) { wm2_buf[F2n] = ov; F2n++; } } else { F2n = 0; } if (D1n > 0 && F2n > 0) { uint64_t ov = mul_1(tmp1, w1_buf, D1n, 15); size_t tn = D1n; if (ov) { tmp1[tn] = ov; tn++; } sub(wm2_buf, wm2_buf, F2n, tmp1, tn); F2n = normalized_size(wm2_buf, F2n); } } // F3 = 3*D4 - 24*D1 = 1512c5 + 45360c7 + 1165752c9 → wm3_buf size_t F3n; { if (D4n > 0) { uint64_t ov = mul_1(wm3_buf, w4_buf, D4n, 3); F3n = D4n; if (ov) { wm3_buf[F3n] = ov; F3n++; } } else { F3n = 0; } if (D1n > 0 && F3n > 0) { uint64_t ov = mul_1(tmp1, w1_buf, D1n, 24); size_t tn = D1n; if (ov) { tmp1[tn] = ov; tn++; } sub(wm3_buf, wm3_buf, F3n, tmp1, tn); F3n = normalized_size(wm3_buf, F3n); } } // G1 = 2*F2 - 9*F1 = 7560c7 + 226800c9 → wm4_buf size_t G1n; { if (F2n > 0) { uint64_t ov = lshift(wm4_buf, wm2_buf, F2n, 1); G1n = F2n; if (ov) { wm4_buf[G1n] = ov; G1n++; } } else { G1n = 0; } if (F1n > 0 && G1n > 0) { uint64_t ov = mul_1(tmp1, wm1_buf, F1n, 9); size_t tn = F1n; if (ov) { tmp1[tn] = ov; tn++; } sub(wm4_buf, wm4_buf, G1n, tmp1, tn); G1n = normalized_size(wm4_buf, G1n); } } // G2 = 5*F3 - 63*F1 = 120960c7 + 4717440c9 → w2_buf size_t G2n; { if (F3n > 0) { uint64_t ov = mul_1(w2_buf, wm3_buf, F3n, 5); G2n = F3n; if (ov) { w2_buf[G2n] = ov; G2n++; } } else { G2n = 0; } if (F1n > 0 && G2n > 0) { uint64_t ov = mul_1(tmp1, wm1_buf, F1n, 63); size_t tn = F1n; if (ov) { tmp1[tn] = ov; tn++; } sub(w2_buf, w2_buf, G2n, tmp1, tn); G2n = normalized_size(w2_buf, G2n); } } // H_odd = G2 - 16*G1 = 1088640*c9 → w3_buf size_t H_odd_n; { std::memcpy(w3_buf, w2_buf, G2n * sizeof(uint64_t)); H_odd_n = G2n; if (G1n > 0 && H_odd_n > 0) { uint64_t ov = lshift(tmp1, wm4_buf, G1n, 4); size_t tn = G1n; if (ov) { tmp1[tn] = ov; tn++; } sub(w3_buf, w3_buf, H_odd_n, tmp1, tn); H_odd_n = normalized_size(w3_buf, H_odd_n); } } // c9 = H_odd / 1088640 = >>7, /3^5, /5, /7 // 1088640 = 2^7 * 3^5 * 5 * 7 ... wait let me recheck. // 1088640 = 1088640. // 1088640 / 128 = 8505 // 8505 / 3 = 2835, /3 = 945, /3 = 315, /3 = 105, /3 = 35, /5 = 7, /7 = 1 // So 1088640 = 2^7 * 3^5 * 5 * 7 size_t c9n = H_odd_n; if (c9n > 0) { // >>7 for (size_t i = 0; i + 1 < c9n; i++) w3_buf[i] = (w3_buf[i] >> 7) | (w3_buf[i + 1] << 57); w3_buf[c9n - 1] >>= 7; c9n = normalized_size(w3_buf, c9n); // /3^5 c9n = divexact_by3(w3_buf, w3_buf, c9n); c9n = divexact_by3(w3_buf, w3_buf, c9n); c9n = divexact_by3(w3_buf, w3_buf, c9n); c9n = divexact_by3(w3_buf, w3_buf, c9n); c9n = divexact_by3(w3_buf, w3_buf, c9n); // /5 c9n = divexact_by5(w3_buf, w3_buf, c9n); // /7 c9n = divexact_by7(w3_buf, w3_buf, c9n); } uint64_t* c9_ptr = w3_buf; // c7 = (G1 - 226800*c9) / 7560 // 7560 = 2^3 * 3^3 * 5 * 7 size_t c7n = G1n; // wm4_buf に G1 がある if (c9n > 0 && c7n > 0) { uint64_t ov = mul_1(tmp1, c9_ptr, c9n, 226800); size_t tn = c9n; if (ov) { tmp1[tn] = ov; tn++; } sub(wm4_buf, wm4_buf, c7n, tmp1, tn); c7n = normalized_size(wm4_buf, c7n); } if (c7n > 0) { c7n = shift_right_3(wm4_buf, wm4_buf, c7n); // /8 c7n = divexact_by3(wm4_buf, wm4_buf, c7n); // /3 c7n = divexact_by3(wm4_buf, wm4_buf, c7n); // /3 c7n = divexact_by3(wm4_buf, wm4_buf, c7n); // /3 c7n = divexact_by5(wm4_buf, wm4_buf, c7n); // /5 c7n = divexact_by7(wm4_buf, wm4_buf, c7n); // /7 } uint64_t* c7_ptr = wm4_buf; // c5 = (F1 - 1680*c7 - 17640*c9) / 120 size_t c5n = F1n; // wm1_buf に F1 がある if (c7n > 0 && c5n > 0) { uint64_t ov = mul_1(tmp1, c7_ptr, c7n, 1680); size_t tn = c7n; if (ov) { tmp1[tn] = ov; tn++; } sub(wm1_buf, wm1_buf, c5n, tmp1, tn); c5n = normalized_size(wm1_buf, c5n); } if (c9n > 0 && c5n > 0) { uint64_t ov = mul_1(tmp1, c9_ptr, c9n, 17640); size_t tn = c9n; if (ov) { tmp1[tn] = ov; tn++; } sub(wm1_buf, wm1_buf, c5n, tmp1, tn); c5n = normalized_size(wm1_buf, c5n); } if (c5n > 0) c5n = divexact_by120(wm1_buf, wm1_buf, c5n); uint64_t* c5_ptr = wm1_buf; // c3 = (D1 - 15*c5 - 63*c7 - 255*c9) / 3 size_t c3n = D1n; // w1_buf に D1 がある if (c5n > 0 && c3n > 0) { uint64_t ov = mul_1(tmp1, c5_ptr, c5n, 15); size_t tn = c5n; if (ov) { tmp1[tn] = ov; tn++; } sub(w1_buf, w1_buf, c3n, tmp1, tn); c3n = normalized_size(w1_buf, c3n); } if (c7n > 0 && c3n > 0) { uint64_t ov = mul_1(tmp1, c7_ptr, c7n, 63); size_t tn = c7n; if (ov) { tmp1[tn] = ov; tn++; } sub(w1_buf, w1_buf, c3n, tmp1, tn); c3n = normalized_size(w1_buf, c3n); } if (c9n > 0 && c3n > 0) { uint64_t ov = mul_1(tmp1, c9_ptr, c9n, 255); size_t tn = c9n; if (ov) { tmp1[tn] = ov; tn++; } sub(w1_buf, w1_buf, c3n, tmp1, tn); c3n = normalized_size(w1_buf, c3n); } if (c3n > 0) c3n = divexact_by3(w1_buf, w1_buf, c3n); uint64_t* c3_ptr = w1_buf; // c1 = o1 - c3 - c5 - c7 - c9 size_t c1n = o1n; // tmp2 に o1 がある if (c3n > 0 && c1n > 0) { sub(tmp2, tmp2, c1n, c3_ptr, c3n); c1n = normalized_size(tmp2, c1n); } if (c5n > 0 && c1n > 0) { sub(tmp2, tmp2, c1n, c5_ptr, c5n); c1n = normalized_size(tmp2, c1n); } if (c7n > 0 && c1n > 0) { sub(tmp2, tmp2, c1n, c7_ptr, c7n); c1n = normalized_size(tmp2, c1n); } if (c9n > 0 && c1n > 0) { sub(tmp2, tmp2, c1n, c9_ptr, c9n); c1n = normalized_size(tmp2, c1n); } uint64_t* c1_ptr = tmp2; // ============================================================ // 組み立て: r += c1*B^k + c2*B^(2k) + ... + c9*B^(9k) // ============================================================ // c0 は r[0..] に、c10 は r[10k..] に配置済み if (c1n > 0 && k < rn) { size_t space = rn - k; add(r + k, r + k, space, c1_ptr, std::min(c1n, space)); } if (c2n > 0 && 2 * k < rn) { size_t space = rn - 2 * k; add(r + 2 * k, r + 2 * k, space, c2_ptr, std::min(c2n, space)); } if (c3n > 0 && 3 * k < rn) { size_t space = rn - 3 * k; add(r + 3 * k, r + 3 * k, space, c3_ptr, std::min(c3n, space)); } if (c4n > 0 && 4 * k < rn) { size_t space = rn - 4 * k; add(r + 4 * k, r + 4 * k, space, c4_ptr, std::min(c4n, space)); } if (c5n > 0 && 5 * k < rn) { size_t space = rn - 5 * k; add(r + 5 * k, r + 5 * k, space, c5_ptr, std::min(c5n, space)); } if (c6n > 0 && 6 * k < rn) { size_t space = rn - 6 * k; add(r + 6 * k, r + 6 * k, space, c6_ptr, std::min(c6n, space)); } if (c7n > 0 && 7 * k < rn) { size_t space = rn - 7 * k; add(r + 7 * k, r + 7 * k, space, c7_ptr, std::min(c7n, space)); } if (c8n > 0 && 8 * k < rn) { size_t space = rn - 8 * k; add(r + 8 * k, r + 8 * k, space, c8_ptr, std::min(c8n, space)); } if (c9n > 0 && 9 * k < rn) { size_t space = rn - 9 * k; add(r + 9 * k, r + 9 * k, space, c9_ptr, std::min(c9n, space)); } } // ============================================================================ // Schönhage-Strassen FFT 乗算 — 前方宣言と閾値定数 // ============================================================================ // multiply → mul_fft 切り替え閾値 (limb 数) // NTT ベースの FFT 乗算。Z/(B^F+1)Z 上で畳み込みを行う。 // 内部の点ごと乗算は F ≈ 2√N ≪ N なので再入ガード不要。 constexpr size_t FFT_THRESHOLD = 3000; // Fermat NTT (フォールバック用) // Prime NTT (3素数 + CRT) 切り替え閾値 (MUL) // 2026-03-11 sweep: n>=1400 で Toom-3 より高速 constexpr size_t PRIME_NTT_THRESHOLD = 1400; // SQR 用 Prime NTT 閾値 // sqr_basecase の対称性利用で Toom-3 がより長く有利 constexpr size_t SQR_PRIME_NTT_THRESHOLD = 2600; // Double FFT 切り替え閾値 // YC-8e 最適化後ベンチマーク (AVX2 pointwise + wider B with C_SAFETY=2): // MUL 1100: FFT 124us vs Toom 134us vs NTT 180us → FFT最速 // MUL 1200: FFT 139us vs Toom 217us vs NTT 186us → FFT最速 // MUL 1300: FFT 295us vs NTT 189us → NTT勝ち (FFT の N が 16384 にジャンプ) // Sweet spot: 1050-1250 limbs (balanced) // 上限: NTT がより高速になる 1250+ は PRIME_NTT_DIRECT_THRESHOLD で処理 constexpr size_t DOUBLE_FFT_THRESHOLD = 1050; constexpr size_t DOUBLE_FFT_MAX_THRESHOLD = 1250; constexpr size_t SQR_DOUBLE_FFT_THRESHOLD = 1050; constexpr size_t SQR_DOUBLE_FFT_MAX_THRESHOLD = 1250; // NTT 直接呼び出し閾値: この以上のサイズでは NTT が Toom より高速 constexpr size_t PRIME_NTT_DIRECT_THRESHOLD = 1250; constexpr size_t SQR_PRIME_NTT_DIRECT_THRESHOLD = 1250; // mul_fft / sqr_fft の前方宣言 (実装はファイル末尾) inline void mul_fft(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn); inline void sqr_fft(uint64_t* rp, const uint64_t* ap, size_t an); // ============================================================================ // 汎用乗算ディスパッチャー (アンバランス乗算対応) // ============================================================================ // multiply の scratch サイズ: 両オペランドのサイズを考慮 // FFT (mul_fft) は内部で scratch を自動確保するため、 // FFT 対象サイズでは 0 を返す。 inline size_t multiply_scratch_size(size_t an, size_t bn) { if (an < bn) std::swap(an, bn); if (bn == 0) return 0; if (bn < KARATSUBA_THRESHOLD) return 0; if (bn >= PRIME_NTT_DIRECT_THRESHOLD && an < 2 * bn) return 0; // NTT は自動確保 if (bn >= DOUBLE_FFT_THRESHOLD && bn < DOUBLE_FFT_MAX_THRESHOLD && an < 2 * bn) return 0; // double_fft / NTT は自動確保 if (an >= 2 * bn) { // アンバランス: tmp product buffer (2*bn) + バランス乗算の scratch return 2 * bn + multiply_scratch_size(bn, bn); } if (bn >= TOOMCOOK6_THRESHOLD) return mul_toomcook6_scratch_size(an); if (bn >= TOOMCOOK4_THRESHOLD) return mul_toomcook4_scratch_size(an); if (bn >= TOOMCOOK3_THRESHOLD) return mul_toomcook3_scratch_size(an); return mul_karatsuba_scratch_size(an); } // 前方宣言 (mul_unbalanced と相互参照) inline void multiply(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch); // アンバランス乗算: 大きい方をチャンクに分割して累積 // 前提: an >= 2*bn, an >= bn >= KARATSUBA_THRESHOLD inline void mul_unbalanced(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { size_t chunk = bn; size_t rn = an + bn; std::memset(r, 0, rn * sizeof(uint64_t)); size_t prod_buf_size = chunk + bn; // 最大 2*bn words uint64_t* tmp = scratch; uint64_t* sub_scratch = scratch + prod_buf_size; // 最初のチャンク: r に直接書き込み size_t first_an = std::min(chunk, an); multiply(r, a, first_an, b, bn, sub_scratch); // 残りのチャンク: tmp に乗算してから r に加算 for (size_t offset = chunk; offset < an; offset += chunk) { size_t this_an = std::min(chunk, an - offset); size_t this_prod_n = this_an + bn; std::memset(tmp, 0, this_prod_n * sizeof(uint64_t)); multiply(tmp, a + offset, this_an, b, bn, sub_scratch); size_t actual_n = normalized_size(tmp, this_prod_n); if (actual_n > 0) { add(r + offset, r + offset, rn - offset, tmp, actual_n); } } } // 汎用乗算: サイズに応じてアルゴリズムを自動選択 // r[0..an+bn-1] = a[0..an-1] * b[0..bn-1] // r は a, b と重なってはならない inline void multiply(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { if (an < bn) { std::swap(a, b); std::swap(an, bn); } if (bn == 0) { std::memset(r, 0, an * sizeof(uint64_t)); return; } if (bn < KARATSUBA_THRESHOLD) { mul_basecase(r, a, an, b, bn); return; } if (bn >= PRIME_NTT_DIRECT_THRESHOLD && an < 2 * bn) { prime_ntt::mul_prime_ntt(r, a, an, b, bn); return; } if (bn >= DOUBLE_FFT_THRESHOLD && bn < DOUBLE_FFT_MAX_THRESHOLD && an < 2 * bn) { if (double_fft::mul_double_fft(r, a, an, b, bn)) return; prime_ntt::mul_prime_ntt(r, a, an, b, bn); return; } if (an >= 2 * bn) { mul_unbalanced(r, a, an, b, bn, scratch); return; } if (bn >= TOOMCOOK6_THRESHOLD) { mul_toomcook6(r, a, an, b, bn, scratch); return; } if (bn >= TOOMCOOK4_THRESHOLD) { mul_toomcook4(r, a, an, b, bn, scratch); return; } if (bn >= TOOMCOOK3_THRESHOLD) { mul_toomcook3(r, a, an, b, bn, scratch); return; } mul_karatsuba(r, a, an, b, bn, scratch); } // ================================================================ // mulhigh_n 実装 // ================================================================ // mulhigh_n 用閾値: これ未満は mulhigh_basecase を使用 // bench_mulhigh.cpp で計測: n=52 で basecase が 0.74x (26%高速), n=128 で 0.97x // basecase は下位列スキップで ~50% 節約。Karatsuba は再帰オーバーヘッドで不利。 constexpr size_t MULHIGH_BASECASE_THRESHOLD = 128; inline size_t mulhigh_n_scratch_size(size_t n) { if (n < MULHIGH_BASECASE_THRESHOLD) return 0; // 大サイズ: full multiply + 上位抽出 return 2 * n + multiply_scratch_size(n, n); } inline void mulhigh_n(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, size_t n, uint64_t* scratch) { if (n == 0) return; // 小サイズ: mulhigh_basecase は下位列をスキップして ~50% 節約 if (n < MULHIGH_BASECASE_THRESHOLD) { mulhigh_basecase(rp, ap, n, bp, n, n); return; } // 大サイズ: full multiply して上位 n limbs を抽出 // TODO: Karatsuba ベースの正確な mulhigh (cross term で再帰的 mulhigh 使用) uint64_t* prod = scratch; uint64_t* mul_work = scratch + 2 * n; multiply(prod, ap, n, bp, n, mul_work); std::memcpy(rp, prod + n, n * sizeof(uint64_t)); } // ============================================================================ // Squaring (a×a の対称性を利用した専用自乗) // ============================================================================ // sqr_basecase: r[0..2n-1] = a[0..n-1]² // 対称性利用: off-diagonal は n(n-1)/2 回の乗算 + 2 倍、diagonal は n 回の乗算 // scratch 不要。r は a と重なってはならない。 inline void sqr_basecase(uint64_t* r, const uint64_t* a, size_t n) { if (n == 0) return; #if defined(_MSC_VER) && defined(_M_X64) // n=1: 単一 128-bit 乗算 (ASM の push/pop を回避) if (n == 1) { r[0] = _umul128(a[0], a[0], &r[1]); return; } #endif #ifdef CALX_INT_HAS_ASM if (detail::has_bmi2_adx()) { mpn_sqr_basecase_mulx(r, a, n); return; } #endif std::memset(r, 0, 2 * n * sizeof(uint64_t)); // Step 1: off-diagonal — Σ_{i= n) ? col - n + 1 : 0; size_t i_max = col / 2; // i <= i_max for off-diagonal (i < j) for (size_t i = i_min; i <= i_max; ++i) { size_t j = col - i; if (j >= n) continue; uint64_t hi, lo; lo = _umul128(a[i], a[j], &hi); if (i < j) { // Off-diagonal: 2*(hi:lo) を加算 unsigned char top = static_cast(hi >> 63); hi = (hi << 1) | (lo >> 63); lo <<= 1; unsigned char carry; carry = _addcarry_u64(0, c0, lo, &c0); carry = _addcarry_u64(carry, c1, hi, &c1); _addcarry_u64(carry, c2, top, &c2); } else { // Diagonal: (hi:lo) を加算 unsigned char carry; carry = _addcarry_u64(0, c0, lo, &c0); carry = _addcarry_u64(carry, c1, hi, &c1); _addcarry_u64(carry, c2, 0, &c2); } } r[col] = c0; c0 = c1; c1 = c2; c2 = 0; } r[2 * n - 1] = c0; #else // フォールバック: 標準 sqr_basecase を使用 sqr_basecase(r, a, n); #endif } // sqr_karatsuba scratch サイズ inline size_t sqr_karatsuba_scratch_size(size_t n) { return (n < 16) ? 128 : 8 * n + 64; } // sqr_karatsuba: r[0..2n-1] = a[0..n-1]² // a² = a0² + ((a0+a1)² - a0² - a1²)·B^half + a1²·B^(2·half) // 3 回の再帰 squaring inline void sqr_karatsuba(uint64_t* r, const uint64_t* a, size_t n, uint64_t* scratch) { if (n < SQR_KARATSUBA_THRESHOLD) { sqr_basecase(r, a, n); return; } size_t half = (n + 1) / 2; const uint64_t* a0 = a; size_t a0n = std::min(half, n); const uint64_t* a1 = a + half; size_t a1n = (n > half) ? n - half : 0; a0n = normalized_size(a0, a0n); a1n = normalized_size(a1, a1n); size_t rn = 2 * n; std::memset(r, 0, rn * sizeof(uint64_t)); uint64_t* s = scratch; // a0 + a1: half+1 limbs uint64_t* middle = scratch + (half + 1); // (a0+a1)² の結果: 2*(half+1)+2 limbs size_t middle_max = 2 * (half + 2); uint64_t* rec_scratch = scratch + (half + 1) + middle_max; // r[0..] = a0² if (a0n > 0) sqr_karatsuba(r, a0, a0n, rec_scratch); // r[2*half..] = a1² if (a1n > 0) sqr_karatsuba(r + 2 * half, a1, a1n, rec_scratch); // s = a0 + a1 size_t sn; std::memset(s, 0, (half + 1) * sizeof(uint64_t)); if (a0n >= a1n) { if (a1n > 0) { uint64_t carry = add(s, a0, a0n, a1, a1n); sn = a0n; if (carry) { s[sn] = carry; sn++; } } else { std::memcpy(s, a0, a0n * sizeof(uint64_t)); sn = a0n; } } else { uint64_t carry = add(s, a1, a1n, a0, a0n); sn = a1n; if (carry) { s[sn] = carry; sn++; } } // middle = (a0+a1)² std::memset(middle, 0, middle_max * sizeof(uint64_t)); if (sn > 0) sqr_karatsuba(middle, s, sn, rec_scratch); size_t mn = normalized_size(middle, 2 * sn); // middle -= a0² (= r[0..2*a0n-1]) { size_t t0n = normalized_size(r, a0n + a0n); if (t0n > 0 && mn > 0) { sub(middle, middle, mn, r, t0n); mn = normalized_size(middle, mn); } } // middle -= a1² (= r[2*half..]) { size_t t2n = normalized_size(r + 2 * half, a1n + a1n); if (t2n > 0 && mn > 0) { sub(middle, middle, mn, r + 2 * half, t2n); mn = normalized_size(middle, mn); } } // r[half..] += middle if (mn > 0) { uint64_t carry = add(r + half, r + half, rn - half, middle, mn); (void)carry; } } // 自乗用 Toom-Cook-3 閾値 (乗算の 80 より高い) // sqr_basecase の対称性活用により sqr_karatsuba の優位区間が広い // ベンチマーク: n=96-140 で Karatsuba が 2-13% 高速、n=144+ で TC3 が 9-24% 高速 constexpr size_t SQR_TOOMCOOK3_THRESHOLD = 200; // sqr_toomcook3 scratch サイズ inline size_t sqr_toomcook3_scratch_size(size_t n) { if (n < SQR_TOOMCOOK3_THRESHOLD) return sqr_karatsuba_scratch_size(n); return 30 * n + 256; // mul_toomcook3 と同じ保守的なサイズ } // 前方宣言 (sqr_toomcook3 → square の相互再帰) inline size_t square_scratch_size(size_t n); inline void square(uint64_t* r, const uint64_t* a, size_t n, uint64_t* scratch); // sqr_toomcook3: r[0..2n-1] = a[0..n-1]² // 評価点 {0, 1, -1, 2, ∞} — ポリノミアル 1 本のみ評価 (対称性利用) // 5 回の再帰 squaring + 補間 inline void sqr_toomcook3(uint64_t* r, const uint64_t* a, size_t n, uint64_t* scratch) { if (n < SQR_TOOMCOOK3_THRESHOLD) { sqr_karatsuba(r, a, n, scratch); return; } size_t k = (n + 2) / 3; size_t rn = 2 * n; // 3 分割 const uint64_t* a0 = a; size_t a0n = normalized_size(a0, std::min(k, n)); const uint64_t* a1 = a + k; size_t a1n = (n > k) ? normalized_size(a1, std::min(k, n - k)) : 0; const uint64_t* a2 = a + 2 * k; size_t a2n = (n > 2 * k) ? normalized_size(a2, n - 2 * k) : 0; // Scratch レイアウト (乗算版より小さい: b側バッファ不要) size_t blk = 2 * (k + 4); uint64_t* v1_buf = scratch; uint64_t* vm1_buf = scratch + blk; uint64_t* v2_buf = scratch + 2 * blk; uint64_t* tmp1 = scratch + 3 * blk; // 評価テンポラリ uint64_t* interp_buf = scratch + 4 * blk; // 補間/評価ワーク uint64_t* rec_scratch = scratch + 5 * blk; size_t v1n = 0, vm1n = 0, v2n = 0; std::memset(r, 0, rn * sizeof(uint64_t)); // Point 0: v0 = a0² → r[0..] if (a0n > 0) square(r, a0, a0n, rec_scratch); size_t v0n = normalized_size(r, std::min(2 * a0n, rn)); // Point ∞: vinf = a2² → r[4k..] size_t vinfn = 0; if (a2n > 0 && 4 * k < rn) { square(r + 4 * k, a2, a2n, rec_scratch); vinfn = normalized_size(r + 4 * k, std::min(2 * a2n, rn - 4 * k)); } // Point 1: v1 = (a0+a1+a2)² { std::memset(tmp1, 0, blk * sizeof(uint64_t)); size_t ean = add_any(tmp1, a0, a0n, a1, a1n); ean = add_any(tmp1, tmp1, ean, a2, a2n); std::memset(v1_buf, 0, blk * sizeof(uint64_t)); if (ean > 0) square(v1_buf, tmp1, ean, rec_scratch); v1n = normalized_size(v1_buf, 2 * ean); } // Point -1: vm1 = (a0-a1+a2)² [常に非負 — 自乗なので符号不要] { std::memset(tmp1, 0, blk * sizeof(uint64_t)); size_t t_n = add_any(tmp1, a0, a0n, a2, a2n); int ea_sign = 1; size_t ean = abs_sub(tmp1, ea_sign, tmp1, t_n, a1, a1n); std::memset(vm1_buf, 0, blk * sizeof(uint64_t)); if (ean > 0) square(vm1_buf, tmp1, ean, rec_scratch); vm1n = normalized_size(vm1_buf, 2 * ean); } // Point 2: v2 = (a0+2*a1+4*a2)² { std::memset(tmp1, 0, blk * sizeof(uint64_t)); if (a0n > 0) std::memcpy(tmp1, a0, a0n * sizeof(uint64_t)); size_t ean = a0n; if (a1n > 0) { std::memset(interp_buf, 0, blk * sizeof(uint64_t)); uint64_t ov = lshift(interp_buf, a1, a1n, 1); size_t tn = a1n; if (ov) { interp_buf[tn] = ov; tn++; } ean = add_any(tmp1, tmp1, ean, interp_buf, tn); } if (a2n > 0) { std::memset(interp_buf, 0, blk * sizeof(uint64_t)); uint64_t ov = lshift(interp_buf, a2, a2n, 2); size_t tn = a2n; if (ov) { interp_buf[tn] = ov; tn++; } ean = add_any(tmp1, tmp1, ean, interp_buf, tn); } std::memset(v2_buf, 0, blk * sizeof(uint64_t)); if (ean > 0) square(v2_buf, tmp1, ean, rec_scratch); v2n = normalized_size(v2_buf, 2 * ean); } // ============================================================ // 補間 (元の mul_toomcook3 と同一構造、vm1_sign は常に >= 0) // ============================================================ // バッファ配置: // tmp1: A → c2 (Step 1,3) // interp_buf: B → C → c1 (Step 2,4,8) ※ tmp2 として使用 // v1_buf: 一時バッファ (Step 5,7) // vm1_buf: c3 (Step 7) // v2_buf: D → E (Step 5,6) ※ 最終的に空き uint64_t* tmp2 = interp_buf; // interp_buf を tmp2 として再利用 // Step 1: A = v1 + vm1 → tmp1 (vm1_sign >= 0 なので常に加算) std::memset(tmp1, 0, blk * sizeof(uint64_t)); size_t An = add_any(tmp1, v1_buf, v1n, vm1_buf, vm1n); // Step 2: B = v1 - vm1 → tmp2 (v1 >= vm1 保証: f(1)² >= f(-1)²) std::memset(tmp2, 0, blk * sizeof(uint64_t)); if (v1n > 0) std::memcpy(tmp2, v1_buf, v1n * sizeof(uint64_t)); size_t Bn = v1n; if (vm1n > 0) { sub(tmp2, tmp2, Bn, vm1_buf, vm1n); Bn = normalized_size(tmp2, Bn); } // Step 3: c2 = A/2 - v0 - vinf → tmp1 if (An > 0) An = rshift_1(tmp1, tmp1, An); if (v0n > 0 && An > 0) { sub(tmp1, tmp1, An, r, v0n); An = normalized_size(tmp1, An); } if (vinfn > 0 && An > 0) { sub(tmp1, tmp1, An, r + 4 * k, vinfn); An = normalized_size(tmp1, An); } size_t c2n = An; uint64_t* c2_ptr = tmp1; // Step 4: C = B/2 → tmp2 if (Bn > 0) Bn = rshift_1(tmp2, tmp2, Bn); size_t Cn = Bn; // Step 5: D = v2 - v0 - 4*c2 - 16*vinf → interp_buf として v2_buf を使う // (元のコードに合わせて interp_buf 相当の場所に D を蓄積) { // D の初期値 = v2 (v2_buf にそのまま) size_t Dn = v2n; // D -= v0 if (v0n > 0 && Dn > 0) { sub(v2_buf, v2_buf, Dn, r, v0n); Dn = normalized_size(v2_buf, Dn); } // D -= 4*c2 if (c2n > 0 && Dn > 0) { std::memset(v1_buf, 0, blk * sizeof(uint64_t)); uint64_t ov = lshift(v1_buf, c2_ptr, c2n, 2); size_t tn = c2n; if (ov) { v1_buf[tn] = ov; tn++; } sub(v2_buf, v2_buf, Dn, v1_buf, tn); Dn = normalized_size(v2_buf, Dn); } // D -= 16*vinf if (vinfn > 0 && Dn > 0) { std::memset(v1_buf, 0, blk * sizeof(uint64_t)); uint64_t ov = lshift(v1_buf, r + 4 * k, vinfn, 4); size_t tn = vinfn; if (ov) { v1_buf[tn] = ov; tn++; } sub(v2_buf, v2_buf, Dn, v1_buf, tn); Dn = normalized_size(v2_buf, Dn); } // Step 6: E = D/2 → v2_buf if (Dn > 0) Dn = rshift_1(v2_buf, v2_buf, Dn); // Step 7: c3 = (E - C) / 3 → vm1_buf size_t c3n_local = 0; if (Dn > 0 || Cn > 0) { std::memset(v1_buf, 0, blk * sizeof(uint64_t)); if (Dn >= Cn) { if (Cn > 0) sub(v1_buf, v2_buf, Dn, tmp2, Cn); else std::memcpy(v1_buf, v2_buf, Dn * sizeof(uint64_t)); } else { // Dn < Cn: v2_buf を Cn limb に拡張 (上位ゼロ埋め) for (size_t i = Dn; i < Cn; ++i) v2_buf[i] = 0; sub(v1_buf, v2_buf, Cn, tmp2, Cn); } size_t Fn = normalized_size(v1_buf, std::max(Dn, Cn)); if (Fn > 0) c3n_local = divexact_by3(vm1_buf, v1_buf, Fn); } v2n = c3n_local; // v2n を c3n として再利用 } size_t c3n = v2n; uint64_t* c3_ptr = vm1_buf; // Step 8: c1 = C - c3 → v2_buf size_t c1n = 0; if (Cn > 0) { if (c3n > 0) { sub(v2_buf, tmp2, Cn, c3_ptr, c3n); c1n = normalized_size(v2_buf, Cn); } else { std::memcpy(v2_buf, tmp2, Cn * sizeof(uint64_t)); c1n = Cn; } } uint64_t* c1_ptr = v2_buf; // ============================================================ // 組み立て: r += c1*B^k + c2*B^(2k) + c3*B^(3k) // ============================================================ if (c1n > 0 && k < rn) { size_t space = rn - k; add(r + k, r + k, space, c1_ptr, std::min(c1n, space)); } if (c2n > 0 && 2 * k < rn) { size_t space = rn - 2 * k; add(r + 2 * k, r + 2 * k, space, c2_ptr, std::min(c2n, space)); } if (c3n > 0 && 3 * k < rn) { size_t space = rn - 3 * k; add(r + 3 * k, r + 3 * k, space, c3_ptr, std::min(c3n, space)); } } // square scratch サイズ inline size_t square_scratch_size(size_t n) { if (n < SQR_KARATSUBA_THRESHOLD) return 0; if (n >= SQR_PRIME_NTT_DIRECT_THRESHOLD) return 0; // NTT は自動確保 if (n >= SQR_DOUBLE_FFT_THRESHOLD && n < SQR_DOUBLE_FFT_MAX_THRESHOLD) return 0; // double_fft / NTT は自動確保 if (n >= SQR_TOOMCOOK3_THRESHOLD) return sqr_toomcook3_scratch_size(n); return sqr_karatsuba_scratch_size(n); } // 汎用自乗: サイズに応じてアルゴリズムを自動選択 // r[0..2n-1] = a[0..n-1]² // r は a と重なってはならない inline void square(uint64_t* r, const uint64_t* a, size_t n, uint64_t* scratch) { if (n == 0) { return; } if (n < SQR_KARATSUBA_THRESHOLD) { sqr_basecase(r, a, n); return; } if (n >= SQR_PRIME_NTT_DIRECT_THRESHOLD) { prime_ntt::sqr_prime_ntt(r, a, n); return; } if (n >= SQR_DOUBLE_FFT_THRESHOLD && n < SQR_DOUBLE_FFT_MAX_THRESHOLD) { if (double_fft::sqr_double_fft(r, a, n)) return; prime_ntt::sqr_prime_ntt(r, a, n); return; } if (n >= SQR_TOOMCOOK3_THRESHOLD) { sqr_toomcook3(r, a, n, scratch); return; } sqr_karatsuba(r, a, n, scratch); } // ============================================================================ // Division Operations (Burnikel-Ziegler / Schoolbook) // ============================================================================ // BZ 再帰の閾値 (limb 数)。これ以下は schoolbook (div_basecase) constexpr size_t BZ_THRESHOLD = 64; // Svoboda Division の閾値 (limb 数)。 // 注: 元の「除数 2 倍 → q̂=上位 limb」方式は商推定誤差が 1 を超えるため不正確。 // Möller-Granlund 3/2 ベースの商推定に置き換えるまで無効化。 constexpr size_t SVOBODA_THRESHOLD = BZ_THRESHOLD + 1; // 事実上無効 // mu-division (Newton 逆数反復) の閾値 (limb 数)。 // unbalanced (an > 2*bn+1) の閾値は an/bn 比率に応じて動的に決定: // k=3: bn≥2000, k=5: bn≥1500, k≥8: bn≥1000 // balanced (an ≤ 2*bn+1) では BZ が O(M(n) log n) なのに対し // mu_div_qr は O(M(n)) のため、大きな bn で mu が優位。 constexpr size_t MU_DIV_THRESHOLD_LOW = 100; // k≥8 の高比率向け constexpr size_t MU_DIV_THRESHOLD_MID = 200; // k=4-7 の中比率向け constexpr size_t MU_DIV_THRESHOLD_HIGH = 2500; // k=3 の低比率向け constexpr size_t MU_DIV_BALANCED_THRESHOLD = 10000; // balanced mu 閾値 (mulhigh_n が最適化されるまで BZ 優位) // invert_approx: Newton 再帰の底。n ≤ INV_NEWTON_THRESHOLD では // div_basecase による直接逆数計算 (schoolbook inversion) を使用。 constexpr size_t INV_NEWTON_THRESHOLD = 10; inline size_t mu_div_threshold(size_t an, size_t bn) { size_t k = an / bn; // 整数比率 if (k >= 8) return MU_DIV_THRESHOLD_LOW; if (k >= 4) return MU_DIV_THRESHOLD_MID; return MU_DIV_THRESHOLD_HIGH; } // -------------------------------------------------------------------------- // invert_limb: 正規化された除数の逆数 // -------------------------------------------------------------------------- // v = floor((B² - 1) / d) - B (B = 2^64) // 前提: d の MSB がセット済み (d >= 2^63) // div_basecase / divmod_1 のループ内でハードウェア除算命令 (~35-90 cycles) を // 乗算ベースの商推定 (~10-15 cycles) に置き換えるために使用 inline uint64_t invert_limb(uint64_t d) { // v = floor([~d, ~0] / d) // ~d < d (d >= 2^63 なので ~d = 2^64 - 1 - d < 2^63 <= d) #if defined(_MSC_VER) && defined(_M_X64) uint64_t dummy; return _udiv128(~d, ~uint64_t(0), d, &dummy); #elif defined(__SIZEOF_INT128__) __uint128_t num = (static_cast<__uint128_t>(~d) << 64) | ~uint64_t(0); return static_cast(num / d); #else auto [v, rem] = UInt128::divmod_fast(~d, ~uint64_t(0), d); return v; #endif } // -------------------------------------------------------------------------- // udiv_qrnnd_preinv: 逆数を使った 2-limb ÷ 1-limb 除算 // -------------------------------------------------------------------------- // [u1, u0] / d → (商, 余り) // 前提: u1 < d, d の MSB がセット済み, dinv = invert_limb(d) // Möller & Granlund (2011) のアルゴリズム // q1 + 1 のオーバーフロー時も剰余演算で正しく回復する inline std::pair udiv_qrnnd_preinv( uint64_t u1, uint64_t u0, uint64_t d, uint64_t dinv) { // (q1, q0) = u1 * dinv + (u1, u0) uint64_t p_hi; uint64_t p_lo = _umul128(u1, dinv, &p_hi); uint64_t q0 = p_lo + u0; uint64_t q1 = p_hi + u1 + (q0 < p_lo); q1 += 1; // tentative quotient (may wrap to 0; handled below) uint64_t r = u0 - q1 * d; // mod B if (r > q0) { // q1 was too large by 1 (or wrapped from overflow) q1--; r += d; } if (r >= d) { q1++; r -= d; } return {q1, r}; } // -------------------------------------------------------------------------- // invert_pi1: 3-by-2 除算用の逆数計算 // -------------------------------------------------------------------------- // floor((B^3 - 1) / (d1*B + d0)) - B を返す。 // GMP の invert_pi1 マクロに相当。invert_limb(d1) から出発し d0 で補正する。 inline uint64_t invert_pi1(uint64_t d1, uint64_t d0) { uint64_t v = invert_limb(d1); // floor((B^2-1)/d1) - B uint64_t p = d1 * v; p += d0; if (p < d0) { v--; uint64_t mask = (p >= d1) ? UINT64_MAX : 0; p -= d1; v += mask; // mask は -1 (減算) or 0 p -= mask & d1; } // {t1, t0} = d0 * v UInt128 t = UInt128::multiply(d0, v); p += t.high; if (p < t.high) { v--; if (p >= d1) { if (p > d1 || t.low >= d0) v--; } } return v; } // -------------------------------------------------------------------------- // udiv_qrnnd_3by2: Möller-Granlund 3-by-2 商推定 // -------------------------------------------------------------------------- // {u2, u1, u0} / {d1, d0} → 商 q, 余り {r1, r0} // 前提: d1 の MSB がセット済み, u2 < d1 || (u2 == d1 && u1 <= d0 の場合は未定義) // dinv = invert_pi1(d1, d0) (3-by-2 逆数) inline void udiv_qrnnd_3by2(uint64_t& q_out, uint64_t& r1_out, uint64_t& r0_out, uint64_t u2, uint64_t u1, uint64_t u0, uint64_t d1, uint64_t d0, uint64_t dinv) { // {q1, q0} = u2 * dinv + {u2, u1} uint64_t p_hi; uint64_t p_lo = _umul128(u2, dinv, &p_hi); uint64_t q0 = p_lo + u1; uint64_t q1 = p_hi + u2 + (q0 < p_lo); // r1 = u1 - d1*q1, {r1, r0} = {r1, u0} - {d1, d0} uint64_t r1 = u1 - d1 * q1; uint64_t r0 = u0 - d0; r1 -= d1 + (u0 < d0); // {r1, r0} -= d0 * q1 uint64_t t_hi; uint64_t t_lo = _umul128(d0, q1, &t_hi); uint64_t old_r0 = r0; r0 -= t_lo; r1 -= t_hi + (r0 > old_r0); q1++; // 条件付き補正 (branchless) uint64_t mask = static_cast(0) - static_cast(r1 >= q0); q1 += mask; old_r0 = r0; r0 += mask & d0; r1 += (mask & d1) + (r0 < old_r0); if (r1 >= d1 && (r1 > d1 || r0 >= d0)) { q1++; old_r0 = r0; r0 -= d0; r1 -= d1 + (r0 > old_r0); } q_out = q1; r1_out = r1; r0_out = r0; } // 前方宣言 (invert_approx の n=2 base case から使用) inline void div_basecase(uint64_t* q, uint64_t* a, size_t an, const uint64_t* b, size_t bn); // -------------------------------------------------------------------------- // invert_approx: N ワードの近似逆数 (Newton 逆数反復) // -------------------------------------------------------------------------- // d[0..n-1]: 正規化された除数 (d[n-1] の MSB がセット済み) // inv[0..n-1]: 逆数の出力 (暗黙の先頭 1 を含まない) // I = B^n + inv と解釈し、D * I ≈ B^(2n) (誤差 < D) // scratch: invert_approx_scratch_size(n) limbs // 前提: n >= 1, d[n-1] の MSB がセット済み inline size_t invert_approx_scratch_size(size_t n) { if (n <= INV_NEWTON_THRESHOLD) { // base case: div_basecase で floor((B^(2n)-1)/D) を直接計算 // num: 2n+1 limbs, q: n+1 limbs → 3n+4 return 3 * n + 4; } size_t h = (n + 2) / 2; // ceil(n/2) + (n%2==0) で 2h > n を保証 size_t mul_nh = multiply_scratch_size(n, h); size_t mul_hh = multiply_scratch_size(h, h); size_t mul_sz = std::max(mul_nh, mul_hh); size_t rec_sz = invert_approx_scratch_size(h); // Newton ステップ: xp (n+h) + err (h) + multiply_scratch (mul_sz) // 再帰呼び出しは Newton ステップの前に完了するので scratch を共有可能 size_t newton_sz = n + 2 * h + mul_sz; return std::max(newton_sz, rec_sz); } // 内部実装 (GMP mpn_ni_invertappr 準拠) // D の切り詰め加算 + 正/負剰余クラス分岐 + h×h 補正乗算。 // 結果は真の逆数に対して ±1 以内の誤差。 // 内部実装: 戻り値は cy フラグ (1 または 2)。 // cy=1: 逆数は真の値に対して ±0 (正確な floor) // cy=2: 逆数は真の値に対して +1 の可能性あり inline int invert_approx_inner(uint64_t* inv, const uint64_t* d, size_t n, uint64_t* scratch) { // ================ Base case: n = 1 ================ if (n == 1) { inv[0] = invert_limb(d[0]); return 1; } // ================ Base case: n <= INV_NEWTON_THRESHOLD ================ // I = floor((β^(2n) - 1) / D), inv = I - β^n. // D は正規化 (MSB set) なので β^n/2 ≤ D < β^n、β^n < I ≤ 2β^n - 1. // div_basecase の前提 a_high < b を満たすため、先に β^n 分を引く: // num = (β^(2n) - 1) - D × β^n とすれば num[n..2n-1] = (β^n-1) - D < D. // floor(num / D) = floor((β^(2n)-1)/D) - β^n = inv. if (n <= INV_NEWTON_THRESHOLD) { uint64_t* num = scratch; uint64_t* q = scratch + 2 * n + 1; // num[0..2n-1] = β^(2n) - 1 (全ビット 1) std::memset(num, 0xFF, 2 * n * sizeof(uint64_t)); // num[n..2n-1] -= D → (β^n - 1) - D < D sub(num + n, num + n, n, d, n); // div_basecase 用センチネル num[2 * n] = 0; std::memset(q, 0, (n + 1) * sizeof(uint64_t)); div_basecase(q, num, 2 * n, d, n); std::memcpy(inv, q, n * sizeof(uint64_t)); return 1; // div_basecase は正確な商を返す } // ================ Recursive case (n > INV_NEWTON_THRESHOLD) ================ // h = ceil(n/2) + (n%2==0): 偶数 n で h = n/2+1 とし、2h > n を保証する。 // Newton の誤差増幅は eps_new ≈ eps_old^2 × β^{n-2h}。 // 2h > n とすることで β^{n-2h} < 1 となり、誤差が常に収束する。 size_t h = (n + 2) / 2; size_t l = n - h; // Step 1: 上位 h limbs の h-limb 逆数を再帰的に計算 // inv[l..n-1] に格納 int cy_rec = invert_approx_inner(inv + l, d + l, h, scratch); // I_h = B^h + inv[l..n-1] // Step 2: Newton ステップで h → n limbs に拡張 (GMP 方式) // scratch layout: // xp[0..n+h-1]: D × I_h の積 (n+h limbs) // err[0..h-1]: 誤差 (h limbs) // mul_work[...]: 乗算 scratch uint64_t* xp = scratch; uint64_t* err_buf = scratch + n + h; uint64_t* mul_work = scratch + n + 2 * h; // (a) xp = D × I_h (n × h → n+h limbs) multiply(xp, d, n, inv + l, h, mul_work); // (b) xp[h..n] += D[0..l] (切り詰め加算: D の下位 l+1 limbs のみ) // GMP と同じく D[l+1..n-1] × B^{n+1} の寄与は無視。 // これにより xp[0..n] に T mod B^{n+1} を得る。 // carry は無視 (T mod B^{n+1} の一部として吸収される)。 (void)add(xp + h, xp + h, l + 1, d, l + 1); // (c) xp[n] に基づいて正規化 (GMP の正/負剰余クラス分岐) // T ≈ B^{n+h} なので xp[n] は 0,1 (正) または UINT64_MAX,MAX-1 (負)。 int ret_cy; if (xp[n] < 2) { // === 正の剰余クラス === // cy は「I_h を何だけ減算するか」。切り詰め分の +1 を含む。 uint64_t cy = xp[n]; // cy++ で切り詰め分を加算 (GMP: "Remember we truncated") if (cy++) { // xp[n] が 1 だった → xp[0..n-1] から D を引く uint64_t borrow = sub(xp, xp, n, d, n); if (!borrow) { // borrow なし → まだ xp >= D → もう 1 回引く sub(xp, xp, n, d, n); ++cy; } } // 1 <= cy <= 3 // 最終チェック: xp > D なら 1 回引く if (cmp(xp, n, d, n) > 0) { sub(xp, xp, n, d, n); ++cy; } // 1 <= cy <= 4 // I_h を cy だけ減算 sub_1(inv + l, h, cy); // 誤差 = D_h - xp_upper (h limbs) // D_h = d[l..n-1], xp_upper = xp[l..n-1] // 下位からの借り: xp[0..l-1] > d[0..l-1] なら 1 sub(err_buf, d + l, h, xp + l, h); int _borrow_low = (cmp(xp, l, d, l) > 0) ? 1 : 0; if (_borrow_low) { sub_1(err_buf, h, 1); } // 正の分岐: cy_rec を伝播 (Newton は精度を倍増、誤差を増幅しない) ret_cy = cy_rec; } else { // === 負の剰余クラス (xp[n] >= UINT64_MAX - 1) === // 切り詰めフラグ (1) を xp から引く sub_1(xp, n + 1, 1); if (xp[n] != UINT64_MAX) { // I_h を 1 増やす add_1(inv + l, h, 1); // xp に D を加算 (キャリーが出て xp[n] = UINT64_MAX に戻るはず) add(xp, xp, n, d, n); } // 誤差 = ~xp[l..n-1] (1の補数 ≈ B^h - 1 - xp_upper) for (size_t i = 0; i < h; i++) { err_buf[i] = ~xp[l + i]; } // 負の分岐: 補数操作で正確 → cy=1 ret_cy = 1; } // (d) 補正乗算: err × I_h (h × h → 2h limbs) // 旧コードの n × h 乗算より効率的。 multiply(xp, err_buf, h, inv + l, h, mul_work); // xp[0..2h-1] = err × I_h // (e) 暗黙の B^h 成分を加算: δ = err + (err × I_h) >> h // GMP 方式で 2 パートに分割: // Part 1: xp[h..3h-n-1] += err[0..2h-n-1] (2h-n limbs) uint64_t cy_corr = add(xp + h, xp + h, 2 * h - n, err_buf, 2 * h - n); // Part 2: inv[0..l-1] = xp[3h-n..2h-1] + err[2h-n..h-1] + cy (l limbs) uint64_t cy2 = add(inv, xp + 3 * h - n, l, err_buf + 2 * h - n, l); if (cy_corr) cy2 += add_1(inv, l, 1); // キャリーを inv[l..n-1] に伝播 if (cy2) add_1(inv + l, h, cy2); return ret_cy; } // 公開 API: 近似逆数を計算し、D × I ≤ β^(2n) を保証する。 // 誤差上限: β^(2n) - D × I ≤ D (mu_div_qr の商補正ループで処理可能) // DIV-1c: cy は帰納的に常に 1 (base case=1, 正の分岐=cy_rec, 負の分岐=1)。 // sub_1(inv, n, 2) のケースは発生しない。 inline void invert_approx(uint64_t* inv, const uint64_t* d, size_t n, uint64_t* scratch) { [[maybe_unused]] int cy = invert_approx_inner(inv, d, n, scratch); assert(cy == 1 && "invert_approx_inner should always return cy=1"); // inv を 1 減算して過小評価側にバイアスし、 // D × (B^n + inv) ≤ B^{2n} を保証する。 // D ≈ β^n (inv ≈ 0) の場合、sub_1(1) が underflow するため、 // borrow を検査して inv = 0 にフォールバック (I = β^n, 安全な過小評価)。 uint64_t borrow = sub_1(inv, n, 1ULL); if (borrow) { std::memset(inv, 0, n * sizeof(uint64_t)); } } // 前方宣言 (div_basecase から呼ばれる) inline void div_basecase_svoboda(uint64_t* q, uint64_t* a, size_t an, const uint64_t* b, size_t bn); // -------------------------------------------------------------------------- // div_basecase: Knuth Algorithm D の mpn ポート (preinv 最適化版) // -------------------------------------------------------------------------- // 前提: // - bn >= 2 // - b[bn-1] の MSB がセット済み (正規化) // - a[] のスペースは an+1 limbs (a[an] をセンチネルとして使う) // - 呼び出し元が a[an] = 0 をセットしておく // 結果: // - q[0 .. an-bn] に商 // - a[0 .. bn-1] に余り (a を上書き) inline void div_basecase(uint64_t* q, uint64_t* a, size_t an, const uint64_t* b, size_t bn) { // Svoboda: 除数・商桁が共に閾値以上なら Svoboda Division を使う if (bn >= SVOBODA_THRESHOLD && (an - bn) >= SVOBODA_THRESHOLD) { div_basecase_svoboda(q, a, an, b, bn); return; } const uint64_t d1 = b[bn - 1]; // 最上位 limb (MSB set) const uint64_t d0 = b[bn - 2]; // 2 番目の limb if (bn == 2) { // ── divrem_2 パス (bn == 2) ── // GMP の mpn_divrem_2 に相当: 3-by-2 で正確な商を得る。 // bn-2 = 0 なので submul 不要。余り {n1, n0} をレジスタで保持。 const uint64_t dinv3 = invert_pi1(d1, d0); uint64_t n1 = 0; uint64_t n0 = a[an - 1]; for (size_t j = an - 2; ; ) { uint64_t q_hat; udiv_qrnnd_3by2(q_hat, n1, n0, n1, n0, a[j], d1, d0, dinv3); q[j] = q_hat; if (j == 0) break; --j; } a[0] = n0; a[1] = n1; } else if (bn >= 7) { // ── Möller-Granlund 3-by-2 パス (bn >= 7) ── const uint64_t dinv3 = invert_pi1(d1, d0); #ifdef CALX_INT_HAS_ASM if (detail::has_bmi2_adx()) { mpn_sbpi1_div_qr_asm(q, a, an - bn, b, bn, dinv3); return; } #endif for (size_t j = an - bn; ; ) { uint64_t q_hat; const uint64_t u2 = a[j + bn]; const uint64_t u1 = a[j + bn - 1]; if ((u2 == d1) && u1 == d0) { q_hat = UINT64_MAX; uint64_t borrow = submul_1(a + j, b, bn, q_hat); uint64_t prev = a[j + bn]; a[j + bn] = prev - borrow; if (prev < borrow) { --q_hat; add(a + j, a + j, bn + 1, b, bn); } } else { uint64_t n1, n0; udiv_qrnnd_3by2(q_hat, n1, n0, u2, u1, a[j + bn - 2], d1, d0, dinv3); uint64_t cy = submul_1(a + j, b, bn - 2, q_hat); uint64_t cy0 = (n0 < cy) ? 1 : 0; n0 -= cy; uint64_t cy1 = (n1 < cy0) ? 1 : 0; n1 -= cy0; a[j + bn - 2] = n0; if (cy1) { n1 += d1 + add(a + j, a + j, bn - 1, b, bn - 1); --q_hat; } a[j + bn - 1] = n1; a[j + bn] = 0; } q[j] = q_hat; if (j == 0) break; --j; } } else { // ── 2-by-1 パス (bn = 3..6) ── // 小さい bn では 3-by-2 のオーバーヘッドが支配的なため従来方式を使用。 const uint64_t dinv = invert_limb(d1); for (size_t j = an - bn; ; ) { uint64_t q_hat; const uint64_t u2 = a[j + bn]; const uint64_t u1 = a[j + bn - 1]; if (u2 >= d1) { q_hat = UINT64_MAX; } else { auto [qh, rh] = udiv_qrnnd_preinv(u2, u1, d1, dinv); q_hat = qh; } uint64_t borrow = submul_1(a + j, b, bn, q_hat); uint64_t prev = a[j + bn]; a[j + bn] = prev - borrow; if (prev < borrow) { --q_hat; uint64_t cy = add(a + j, a + j, bn + 1, b, bn); if (!cy) { --q_hat; add(a + j, a + j, bn + 1, b, bn); } } q[j] = q_hat; if (j == 0) break; --j; } } } // -------------------------------------------------------------------------- // div_basecase_svoboda: Svoboda Division (商推定を最上位 limb 読み取りに簡略化) // -------------------------------------------------------------------------- // 除数を 2 倍して V'[bn] = 1 にし、商推定 q̂ = a[j+bn] とする。 // udiv_qrnnd_preinv (~12 cycles/桁) を除去し、代わりに submul_1 が bn+1 limbs になる。 // 前提: // - bn >= 2, b[bn-1] の MSB がセット済み (正規化) // - a[] のスペースは an+2 limbs (a[an] sentinel + a[an+1] Svoboda sentinel) // - 呼び出し元が a[an] = 0, a[an+1] = 0 をセットしておく (lshift で上書きされる) // 結果: // - q[0 .. an-bn] に商 // - a[0 .. bn-1] に余り (a を上書き) inline void div_basecase_svoboda(uint64_t* q, uint64_t* a, size_t an, const uint64_t* b, size_t bn) { // bp = 2*b (bn+1 limbs, bp[bn] = 1) uint64_t bp[BZ_THRESHOLD + 1]; bp[bn] = lshift(bp, b, bn, 1); // a = 2*a (in-place) a[an] = lshift(a, a, an, 1); a[an + 1] = 0; // Svoboda 用追加 sentinel for (size_t j = an - bn; ; ) { // 商推定: 最上位 limb の読み取り (Svoboda の核心) uint64_t q_hat = a[j + bn]; // 乗減算: a[j .. j+bn] -= bp[0..bn] * q_hat uint64_t borrow = submul_1(a + j, bp, bn + 1, q_hat); uint64_t prev = a[j + bn + 1]; a[j + bn + 1] = prev - borrow; if (prev < borrow) { // add-back: q_hat was 1 too large (Svoboda 保証: 最大 1 回) --q_hat; add(a + j, a + j, bn + 2, bp, bn + 1); } q[j] = q_hat; if (j == 0) break; --j; } // Svoboda スケール除去: 余り = a[0..bn] / 2 // 倍化余り 2r は bn+1 limbs (a[bn] は 0 or 1)。半減で bn limbs に収まる。 for (size_t i = 0; i < bn - 1; i++) { a[i] = (a[i] >> 1) | (a[i + 1] << 63); } a[bn - 1] = (a[bn - 1] >> 1) | (a[bn] << 63); } // -------------------------------------------------------------------------- // BZ 再帰コア: 前方宣言 // -------------------------------------------------------------------------- inline void div_2n_by_n(uint64_t* q, uint64_t* r, const uint64_t* a, const uint64_t* b, size_t n, uint64_t* scratch); // div_3n_by_2n: BZ の中核 // A[3*half] を B[n] = [B1, B0] (各 half limbs) で割る // q[0..half-1] に商, r[0..n-1] に余り // scratch: multiply 用 + 作業バッファ inline void div_3n_by_2n(uint64_t* q, uint64_t* r, const uint64_t* a, // 3*half limbs const uint64_t* b, // n = 2*half limbs size_t n, uint64_t* scratch) { const size_t half = n / 2; // A = [A2, A1, A0] each half limbs const uint64_t* a0 = a; const uint64_t* a1 = a + half; const uint64_t* a2 = a + 2 * half; // B = [B1, B0] each half limbs const uint64_t* b0 = b; const uint64_t* b1 = b + half; // (Q_hat, R1) = div_2n_by_n([A2, A1], B1) // A2,A1 → 2*half limbs, B1 → half limbs uint64_t* r1 = scratch; // half limbs uint64_t* q_hat = q; // half limbs (output) uint64_t* sub_scratch = scratch + half; // for recursion // a_top = [A1, A2] (2*half limbs, little-endian = A1 が下) // 既に a の中で連続 div_2n_by_n(q_hat, r1, a1, b1, half, sub_scratch); // D = Q_hat * B0 (乗算: half × half → 2*half limbs) uint64_t* d = scratch + half; // 2*half limbs uint64_t* mul_scratch = d + 2 * half; size_t mul_sz = multiply_scratch_size(half, half); std::memset(d, 0, 2 * half * sizeof(uint64_t)); multiply(d, q_hat, half, b0, half, mul_scratch); // R = [R1, A0] - D // [R1, A0] は n limbs (A0 が下位, R1 が上位) // r[] に組み立て std::memcpy(r, a0, half * sizeof(uint64_t)); // 下位 half = A0 std::memcpy(r + half, r1, half * sizeof(uint64_t)); // 上位 half = R1 // r -= D uint64_t borrow = sub(r, r, n, d, n); // 補正: r < 0 なら Q_hat--, r += B (最大 2 回) while (borrow != 0) { // Q_hat を 1 減算 (multi-limb decrement) for (size_t i = 0; i < half; i++) { if (q_hat[i] != 0) { q_hat[i]--; break; } q_hat[i] = UINT64_MAX; } // r += B uint64_t carry = add(r, r, n, b, n); borrow -= carry; } } // div_2n_by_n: A[2n] を B[n] で割る // q[0..n-1] に商, r[0..n-1] に余り inline void div_2n_by_n(uint64_t* q, uint64_t* r, const uint64_t* a, // 2n limbs const uint64_t* b, // n limbs size_t n, uint64_t* scratch) { // ベースケース: schoolbook if (n < BZ_THRESHOLD) { // 除数の先頭ゼロを除去 (BZ 再帰の奇数パディングで生じうる) size_t bn_real = normalized_size(b, n); if (bn_real == 0) { // ゼロ除算 — 商・余りをゼロに設定 std::memset(q, 0, n * sizeof(uint64_t)); std::memset(r, 0, n * sizeof(uint64_t)); return; } // 被除数の実効サイズ size_t an_real = normalized_size(a, 2 * n); if (an_real == 0) { std::memset(q, 0, n * sizeof(uint64_t)); std::memset(r, 0, n * sizeof(uint64_t)); return; } // 被除数 < 除数 の場合: 商=0, 余り=被除数 // BZ 再帰で上位部分が除数より小さくなることがある if (an_real < bn_real || (an_real == bn_real && cmp(a, an_real, b, bn_real) < 0)) { std::memset(q, 0, n * sizeof(uint64_t)); std::memset(r, 0, n * sizeof(uint64_t)); std::memcpy(r, a, an_real * sizeof(uint64_t)); return; } // 正規化: 除数の MSB をセット unsigned shift = std::countl_zero(b[bn_real - 1]); // scratch レイアウト: // nb: bn_real limbs (正規化された除数) // tmp_a: an_real + 2 limbs (正規化された被除数 + sentinel + Svoboda sentinel) // tmp_q: an_real - bn_real + 1 limbs (商) uint64_t* nb = scratch; uint64_t* tmp_a = nb + bn_real; uint64_t* tmp_q = tmp_a + an_real + 2; size_t nan; if (shift > 0) { lshift(nb, b, bn_real, shift); tmp_a[an_real] = lshift(tmp_a, a, an_real, shift); nan = an_real + (tmp_a[an_real] ? 1 : 0); tmp_a[nan] = 0; div_basecase(tmp_q, tmp_a, nan, nb, bn_real); } else { // ★ shift==0: nb コピー省略、b を直接使用 std::memcpy(tmp_a, a, an_real * sizeof(uint64_t)); tmp_a[an_real] = 0; nan = an_real; tmp_a[nan] = 0; div_basecase(tmp_q, tmp_a, nan, b, bn_real); } // 商をコピー (n limbs、余分はゼロ埋め) size_t qn = nan - bn_real + 1; std::memset(q, 0, n * sizeof(uint64_t)); std::memcpy(q, tmp_q, std::min(qn, n) * sizeof(uint64_t)); // 余りの逆正規化 if (shift > 0) { for (size_t i = 0; i < bn_real - 1; i++) { tmp_a[i] = (tmp_a[i] >> shift) | (tmp_a[i + 1] << (64 - shift)); } tmp_a[bn_real - 1] >>= shift; } std::memset(r, 0, n * sizeof(uint64_t)); std::memcpy(r, tmp_a, bn_real * sizeof(uint64_t)); return; } // n が奇数のとき: 1 word 左シフトして偶数化 // (上位に 0 を置くと BZ 分割で B1 の先頭がゼロになりバグる) // A/B = (A*2^64)/(B*2^64) — 商は同じ、余りは 1 word 右シフトで復元 if (n & 1) { size_t nn = n + 1; // 偶数 // B' = B * 2^64 = [0, b[0], b[1], ..., b[n-1]] (nn limbs, MSB は b[n-1] のまま) uint64_t* bp = scratch; bp[0] = 0; std::memcpy(bp + 1, b, n * sizeof(uint64_t)); // A' = A * 2^64 = [0, a[0], a[1], ..., a[2n-1], 0] (2*nn = 2n+2 limbs) uint64_t* ap = scratch + nn; ap[0] = 0; std::memcpy(ap + 1, a, 2 * n * sizeof(uint64_t)); ap[2 * n + 1] = 0; // top limb // 再帰 (偶数版) uint64_t* qp = scratch + nn + 2 * nn; // nn limbs uint64_t* rp = qp + nn; // nn limbs uint64_t* sub_scratch = rp + nn; div_2n_by_n(qp, rp, ap, bp, nn, sub_scratch); // 商: qp の下位 n limbs (上位は qp[n] だが、通常 0) std::memcpy(q, qp, n * sizeof(uint64_t)); // 余り: rp は (A mod B)*2^64 なので 1 word 右シフト std::memcpy(r, rp + 1, n * sizeof(uint64_t)); return; } // n が偶数 → BZ 分割 const size_t half = n / 2; // A = [A3, A2, A1, A0] each half limbs (little-endian) const uint64_t* a0 = a; const uint64_t* a1 = a + half; // (Q1, R1) = div_3n_by_2n([A3, A2, A1], B) // [A1, A2, A3] は a + half から 3*half limbs uint64_t* q1 = q + half; // 上位 half limbs of Q uint64_t* r1_tmp = scratch; // n limbs (余り) uint64_t* sub_scratch1 = scratch + n; div_3n_by_2n(q1, r1_tmp, a + half, b, n, sub_scratch1); // (Q0, R0) = div_3n_by_2n([R1, A0], B) // [A0, R1] = A0 (half limbs) + R1 (n limbs) → 3*half limbs uint64_t* combined = scratch + n; // 3*half limbs std::memcpy(combined, a0, half * sizeof(uint64_t)); std::memcpy(combined + half, r1_tmp, n * sizeof(uint64_t)); uint64_t* q0 = q; // 下位 half limbs of Q uint64_t* sub_scratch2 = combined + 3 * half; div_3n_by_2n(q0, r, combined, b, n, sub_scratch2); } // -------------------------------------------------------------------------- // div_unbalanced: an > 2*bn のケースをチャンク分割で処理 // -------------------------------------------------------------------------- inline size_t div_unbalanced(uint64_t* q, uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { // bn limbs ずつのチャンクで上位から処理 // 各ステップ: 2*bn limbs (前回の余り bn limbs + 次のチャンク bn limbs) を bn で割る size_t qn = an - bn; // 商のサイズ (最大) // 最初のチャンク: 最上位の (an mod bn) + bn limbs // → an が bn の倍数でない場合に対応 size_t first_chunk = an % bn; if (first_chunk == 0) first_chunk = bn; size_t pos = an; // 処理位置 (上端、下方向に進む) // remainder buffer (scratch 内): 最初は a の最上位チャンク uint64_t* rem = scratch; // bn + 1 limbs uint64_t* div_scratch = scratch + bn + 1; // 最初のチャンク: a[pos - first_chunk .. pos - 1] pos -= first_chunk; std::memcpy(rem, a + pos, first_chunk * sizeof(uint64_t)); if (first_chunk < bn) { std::memset(rem + first_chunk, 0, (bn - first_chunk) * sizeof(uint64_t)); } // 最初のチャンクが bn より短い場合: 商の上位は 0 if (first_chunk < bn) { // a の上位 first_chunk limbs < b (bn limbs) なので商は 0 // rem にそのまま保持、次のチャンクへ q[qn] = 0; } else { // first_chunk == bn: rem が b 以上なら 1 回引いて商の最上位を立てる // (b は正規化済みで MSB がセットされているため rem < 2*b → 最大 1 回) if (cmp(rem, bn, b, bn) >= 0) { sub(rem, rem, bn, b, bn); q[qn] = 1; } else { q[qn] = 0; } } // メインループ: 下方向にチャンク処理 while (pos > 0) { size_t chunk = (pos >= bn) ? bn : pos; pos -= chunk; // 被除数を組み立て: [a[pos..pos+chunk-1], rem[0..bn-1]] // → combined[0..chunk+bn-1] = a[pos..] (下位) + rem (上位) uint64_t* combined = div_scratch; std::memcpy(combined, a + pos, chunk * sizeof(uint64_t)); std::memcpy(combined + chunk, rem, bn * sizeof(uint64_t)); size_t cn = chunk + bn; size_t real_cn = normalized_size(combined, cn); if (real_cn <= bn) { // 被除数 < 除数 → 商 = 0, 余り = combined std::memset(q + pos, 0, chunk * sizeof(uint64_t)); std::memcpy(rem, combined, bn * sizeof(uint64_t)); } else { // div_2n_by_n で割る (cn を 2*bn にパディング) uint64_t* padded_a = div_scratch + cn + 1; std::memcpy(padded_a, combined, cn * sizeof(uint64_t)); if (cn < 2 * bn) { std::memset(padded_a + cn, 0, (2 * bn - cn) * sizeof(uint64_t)); } uint64_t* temp_q = padded_a + 2 * bn; // bn limbs uint64_t* inner_scratch = temp_q + bn; div_2n_by_n(temp_q, rem, padded_a, b, bn, inner_scratch); // temp_q のうち必要な分だけ q にコピー (バッファオーバーフロー防止) size_t q_copy = std::min(chunk, qn + 1 - pos); std::memcpy(q + pos, temp_q, q_copy * sizeof(uint64_t)); } } // 余りを出力 std::memcpy(r, rem, bn * sizeof(uint64_t)); return normalized_size(q, qn + 1); } // -------------------------------------------------------------------------- // mulmod_bnm1: CRT ベース巡回乗算 (mod β^n - 1) // -------------------------------------------------------------------------- // (A × B) mod (β^n - 1) を計算する。結果は n limbs。 // β^n ≡ 1 (mod β^n - 1) なので、位置 ≥ n の limbs は下位に折り返す。 // // β^n - 1 = (β^h - 1)(β^h + 1) (h = n/2) を利用して CRT 分解: // xm = A*B mod (β^h - 1) — 再帰呼び出し // xp = A*B mod (β^h + 1) — negacyclic 乗算 // CRT 合成: コスト ≈ 2*M(h) ≈ M(n)/2 // // GMP mpn_mulmod_bnm1 準拠。 constexpr size_t MULMOD_BNM1_THRESHOLD = 16; constexpr size_t MUL_TO_MULMOD_BNM1_FOR_2NXN_THRESHOLD = 20; // n を CRT フレンドリーなサイズ (偶数) に切り上げ inline size_t mulmod_bnm1_next_size(size_t n) { if (n < MULMOD_BNM1_THRESHOLD) return n; if (n < 4 * (MULMOD_BNM1_THRESHOLD - 1) + 1) return (n + 1) & ~size_t(1); if (n < 8 * (MULMOD_BNM1_THRESHOLD - 1) + 1) return (n + 3) & ~size_t(3); return (n + 7) & ~size_t(7); } inline size_t mulmod_bnm1_scratch_size(size_t rn, size_t an, size_t bn) { // ベースケース: 完全積 + multiply workspace size_t bc = (an + bn) + multiply_scratch_size(an, bn); if ((rn & 1) != 0 || rn < MULMOD_BNM1_THRESHOLD) return bc; // CRT 再帰: 各レベルで固定 5h+3 + max(bnp1, recursive) // bnp1 scratch = 2*h + multiply_scratch_size(h, h) // 保守的上限: 8*rn + multiply_scratch_size(rn/2, rn/2) + 16 size_t crt = 8 * rn + multiply_scratch_size(rn / 2, rn / 2) + 16; return std::max(bc, crt); } // neg: r = 0 - a (mod β^n). 戻り値: a != 0 なら 1 (borrow) inline uint64_t neg(uint64_t* r, const uint64_t* a, size_t n) { size_t i = 0; while (i < n && a[i] == 0) { r[i] = 0; i++; } if (i == n) return 0; r[i] = ~a[i] + 1; // = -a[i] (uint64 wrap) i++; for (; i < n; i++) r[i] = ~a[i]; return 1; } // mulmod_bnp1_bc: negacyclic 乗算 (ベースケース) // r[0..n] = (a[0..n] × b[0..n]) mod (β^n + 1) // a[n], b[n] ∈ {0, 1}. r[n] ∈ {0, 1}. // scratch: 2*n + multiply_scratch_size(n, n) limbs inline void mulmod_bnp1_bc(uint64_t* r, const uint64_t* a, const uint64_t* b, size_t n, uint64_t* scratch) { uint64_t cy; if (a[n] | b[n]) { // a[n]=1 → a mod (β^n+1) = -1, 積 = -b (or +1 if both high) if (a[n]) cy = b[n] + neg(r, b, n); else cy = neg(r, a, n); } else { // 完全乗算 → negacyclic fold uint64_t* tp = scratch; uint64_t* mw = scratch + 2 * n; std::memset(tp, 0, 2 * n * sizeof(uint64_t)); multiply(tp, a, n, b, n, mw); cy = sub(r, tp, n, tp + n, n); } r[n] = 0; if (cy) add_1(r, n + 1, cy); } // ベースケース mulmod_bnm1: 完全乗算 + fold inline void mulmod_bnm1_bc(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, size_t rn, uint64_t* scratch) { size_t pn = an + bn; uint64_t* prod = scratch; uint64_t* mul_work = scratch + pn; std::memset(prod, 0, pn * sizeof(uint64_t)); multiply(prod, a, an, b, bn, mul_work); if (pn <= rn) { std::memcpy(r, prod, pn * sizeof(uint64_t)); if (pn < rn) std::memset(r + pn, 0, (rn - pn) * sizeof(uint64_t)); } else { std::memcpy(r, prod, rn * sizeof(uint64_t)); size_t high_n = pn - rn; uint64_t carry = add(r, r, rn, prod + rn, high_n); while (carry > 0) carry = add_1(r, rn, carry); } } // r[0..rn-1] = (a[0..an-1] × b[0..bn-1]) mod (β^rn - 1) // 前提: 0 < bn <= an <= rn, an + bn > rn/2 // r は a, b と重なってはならない // scratch: mulmod_bnm1_scratch_size(rn, an, bn) limbs inline void mulmod_bnm1(uint64_t* r, size_t rn, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { // ベースケース: rn が奇数 or 小さい場合 if ((rn & 1) != 0 || rn < MULMOD_BNM1_THRESHOLD) { if (an + bn <= rn) { // 積が rn 以下: 完全乗算で十分 uint64_t* prod = scratch; uint64_t* mw = scratch + an + bn; std::memset(prod, 0, (an + bn) * sizeof(uint64_t)); multiply(prod, a, an, b, bn, mw); std::memcpy(r, prod, (an + bn) * sizeof(uint64_t)); if (an + bn < rn) std::memset(r + an + bn, 0, (rn - an - bn) * sizeof(uint64_t)); } else { mulmod_bnm1_bc(r, a, an, b, bn, rn, scratch); } return; } // CRT 分解: rn = 2*h size_t h = rn >> 1; uint64_t cy; // scratch layout (GMP 準拠): // xp[0..2h+1]: mod (β^h+1) の積 + 一時バッファ // sp1[0..]: ap1/bp1 の格納 (xp の後) uint64_t* xp = scratch; uint64_t* sp1 = scratch + 2 * h + 2; #define a0 a #define a1 (a + h) #define b0 b #define b1 (b + h) // ====== mod (β^h - 1) 側: xm = A*B mod (β^h - 1) ====== { const uint64_t* am1; const uint64_t* bm1; size_t anm, bnm; uint64_t* so; bm1 = b0; bnm = bn; if (an > h) { // am1 = a0 + a1 mod (β^h - 1) am1 = xp; cy = add(xp, a0, h, a1, an - h); if (cy) add_1(xp, h, cy); anm = h; so = xp + h; if (bn > h) { // bm1 = b0 + b1 mod (β^h - 1) bm1 = so; cy = add(so, b0, h, b1, bn - h); if (cy) add_1(so, h, cy); bnm = h; so += h; } } else { so = xp; am1 = a0; anm = an; } // 再帰呼び出し: 結果を r[0..h-1] に出力 mulmod_bnm1(r, h, am1, anm, bm1, bnm, so); } // ====== mod (β^h + 1) 側: xp = A*B mod (β^h + 1) ====== { const uint64_t* ap1; const uint64_t* bp1; size_t anp, bnp; bp1 = b0; bnp = bn; if (an > h) { ap1 = sp1; cy = sub(sp1, a0, h, a1, an - h); sp1[h] = 0; if (cy) add_1(sp1, h + 1, cy); anp = h + ap1[h]; if (bn > h) { bp1 = sp1 + h + 1; cy = sub(sp1 + h + 1, b0, h, b1, bn - h); sp1[2 * h + 1] = 0; if (cy) add_1(sp1 + h + 1, h + 1, cy); bnp = h + bp1[h]; } } else { ap1 = a0; anp = an; } // mulmod_bnp1: negacyclic 乗算 // 入力は (h+1) limbs (高位 0 or 1)、出力は xp[0..h] // bnp1_scratch は sp1 の後 (sp1 + 2*(h+1)) から配置 uint64_t* bnp1_scratch = sp1 + 2 * (h + 1); if (bp1 == b0) { // bp1 が short (folding なし): 完全乗算 + negacyclic fold // ap1 も short の場合がある (an <= h のとき) size_t a1n = std::min(anp, h); // 高位ビット除外 size_t b1n = std::min(bnp, h); uint64_t* tp2 = bnp1_scratch; size_t pn2 = a1n + b1n; uint64_t* mw = tp2 + pn2; std::memset(tp2, 0, pn2 * sizeof(uint64_t)); if (a1n >= b1n) multiply(tp2, ap1, a1n, bp1, b1n, mw); else multiply(tp2, bp1, b1n, ap1, a1n, mw); // negacyclic fold if (pn2 > h) { cy = sub(xp, tp2, h, tp2 + h, pn2 - h); } else { std::memcpy(xp, tp2, pn2 * sizeof(uint64_t)); if (pn2 < h) std::memset(xp + pn2, 0, (h - pn2) * sizeof(uint64_t)); cy = 0; } xp[h] = 0; if (cy) add_1(xp, h + 1, cy); } else { // 両方 (h+1) limbs: mulmod_bnp1_bc mulmod_bnp1_bc(xp, ap1, bp1, h, bnp1_scratch); } } // ====== CRT 合成 ====== // xm = r[0..h-1] (mod β^h - 1) // xp = xp[0..h] (mod β^h + 1, 正規化: xp[h] ∈ {0,1}) // // CRT: x = -xp * β^h + (β^h + 1) * [(xp + xm)/2 mod (β^h - 1)] // // Step 1: r[0..h-1] = (xm + xp) / 2 mod (β^h - 1) cy = xp[h] + add(r, r, h, xp, h); cy += (r[0] & 1); // 1-bit right shift (固定サイズ h, carry-in = cy) #ifdef CALX_INT_HAS_ASM mpn_rshift_asm(r, r, h, 1); #else for (size_t i = 0; i < h - 1; i++) r[i] = (r[i] >> 1) | (r[i + 1] << 63); r[h - 1] = r[h - 1] >> 1; #endif uint64_t hi = (cy & 1) << 63; cy >>= 1; r[h - 1] |= hi; if (cy) add_1(r, h, cy); // Step 2: r[h..2h-1] = r[0..h-1] - xp[0..h-1] cy = xp[h] + sub(r + h, r, h, xp, h); if (cy) sub_1(r, 2 * h, cy); #undef a0 #undef a1 #undef b0 #undef b1 } // -------------------------------------------------------------------------- // fold_mod_bnm1: 配列を n limbs に fold (mod β^n - 1) // -------------------------------------------------------------------------- // r[0..n-1] = a[0..an-1] mod (β^n - 1) // r == a でも可 (in-place)。an <= 2*n を前提とする。 inline void fold_mod_bnm1(uint64_t* r, size_t n, const uint64_t* a, size_t an) { if (an <= n) { if (r != a) std::memcpy(r, a, an * sizeof(uint64_t)); if (an < n) std::memset(r + an, 0, (n - an) * sizeof(uint64_t)); return; } if (r != a) std::memcpy(r, a, n * sizeof(uint64_t)); size_t high_n = an - n; uint64_t carry = add(r, r, n, a + n, high_n); while (carry > 0) { carry = add_1(r, n, carry); } } // -------------------------------------------------------------------------- // mu_div_qr: Newton 逆数反復を使った O(M(n)) 除算 // -------------------------------------------------------------------------- // GMP mpn_mu_div_qr / mpn_preinv_mu_div_qr 準拠。 // 逆数サイズ in ≈ qn/2 (balanced) で計算し、商をチャンク処理。 // 各チャンクで in limbs の商を推定し、Q×D を減算して余りを更新。 // // 前提: // - bn >= 2, b[bn-1] の MSB がセット済み (正規化) // - an >= bn // - a[] は作業用に上書きされる (余りが a[0..bn-1] に残る) // 結果: // - q[0..an-bn] に商 // - a[0..bn-1] に余り (呼び出し元が r[] にコピー) // 返値: 商の normalized size // scratch: mu_div_qr_scratch_size(an, bn) limbs // 逆数サイズの選択 (GMP mpn_mu_div_qr_choose_in 準拠) inline size_t mu_div_qr_choose_in(size_t qn, size_t bn) { if (qn > bn) { // qn/bn ブロックに分割し、均等なチャンクサイズを選択 size_t blocks = (qn - 1) / bn + 1; // ceil(qn/bn) return (qn - 1) / blocks + 1; // ceil(qn/blocks) } else if (3 * qn > bn) { return (qn - 1) / 2 + 1; // 2 チャンク (balanced case) } else { return qn; // 1 チャンク (qn が小さい場合) } } inline size_t mu_div_qr_scratch_size(size_t an, size_t bn) { size_t qn = (an > bn) ? an - bn : 1; size_t in = mu_div_qr_choose_in(qn, bn); // scratch layout: ip[0..in-1] + work[0..] // work は逆数計算時と preinv ループ時で共用 // 逆数計算時の work: size_t inv_work = 2 * (in + 1) + invert_approx_scratch_size(in + 1); // preinv ループ時の work: // rp[0..bn-1]: 部分余り // tp[0..tn-1]: 積/mulmod バッファ // mul_work / mulmod_scratch: workspace size_t mul_in_in = multiply_scratch_size(in, in); size_t mul_bn_in = multiply_scratch_size(bn, in); size_t mh_in = mulhigh_n_scratch_size(in); // mulhigh_n for Step 1 size_t mul_sz = std::max({mul_in_in, mul_bn_in, mh_in}); size_t tn_full = bn + in + 1; size_t preinv_full = bn + tn_full + mul_sz; // mulmod パス: tp[tn_mod] + mulmod_scratch size_t preinv_mulmod = 0; if (in >= MUL_TO_MULMOD_BNM1_FOR_2NXN_THRESHOLD) { size_t tn_mod = mulmod_bnm1_next_size(bn + 1); size_t mm_scratch = mulmod_bnm1_scratch_size(tn_mod, bn, in); preinv_mulmod = bn + tn_mod + mm_scratch; } size_t preinv_work = std::max(preinv_full, preinv_mulmod); return in + std::max(inv_work, preinv_work); } // 前方宣言 inline size_t mu_div_qr(uint64_t* q, uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch); // preinv_mu_div_qr: 事前計算済み逆数を使ったチャンクベース除算ループ // (GMP mpn_preinv_mu_div_qr 準拠) // // np[0..nn-1]: 被除数 (読み取り専用) // dp[0..dn-1]: 除数 // ip[0..in-1]: 近似逆数 (暗黙の MSB 1) // qp[0..nn-dn-1]: 商の出力 // rp[0..dn-1]: 余りの出力 // scratch: tp + mul_work // 返値: qh (商の最上位ビット、0 or 1) inline uint64_t preinv_mu_div_qr(uint64_t* qp, uint64_t* rp, const uint64_t* np, size_t nn, const uint64_t* dp, size_t dn, const uint64_t* ip, size_t in, uint64_t* scratch) { size_t qn = nn - dn; uint64_t* tp = scratch; // tp には最大 dn + in + 1 limbs を格納 (Q_chunk × D の積) size_t tn = dn + in + 1; uint64_t* mul_work = scratch + tn; np += qn; qp += qn; // 初期比較: np[0..dn-1] >= dp[0..dn-1] なら 1 回引く uint64_t qh = (cmp(np, dn, dp, dn) >= 0) ? 1 : 0; if (qh) sub(rp, np, dn, dp, dn); else std::memcpy(rp, np, dn * sizeof(uint64_t)); while (qn > 0) { if (qn < in) { // 最後のチャンクが in より小さい場合、逆数の上位部分のみ使用 ip += in - qn; in = qn; } np -= in; qp -= in; // ====== Step 1: 商推定 ====== // Q_chunk = mulhi(rp[dn-in..dn-1], ip[0..in-1]) + rp[dn-in..dn-1] // (逆数の暗黙の MSB 1 分を加算) mulhigh_n(tp, rp + dn - in, ip, in, mul_work); uint64_t cy = add(qp, tp, in, rp + dn - in, in); // cy は通常 0 (逆数は過小評価側にバイアスされている) (void)cy; qn -= in; // ====== Step 2: Q_chunk × D ====== if (in >= MUL_TO_MULMOD_BNM1_FOR_2NXN_THRESHOLD) { // mulmod_bnm1 で低コスト化 (CRT 分解: ~M(n)/2) size_t tn_mod = mulmod_bnm1_next_size(dn + 1); mulmod_bnm1(tp, tn_mod, dp, dn, qp, in, mul_work); // wrapping correction: 積の上位 wn limbs が mulmod で折り返された分を補正 // wn limbs (位置 tn_mod..dn+in-1) は rp[dn-wn..dn-1] に近い size_t wn = dn + in - tn_mod; if (wn > 0) { cy = sub(tp, tp, wn, rp + dn - wn, wn); if (tn_mod > wn) cy = sub_1(tp + wn, tn_mod - wn, cy); int cx = (cmp(rp + dn - in, tn_mod - dn, tp + dn, tn_mod - dn) < 0) ? 1 : 0; // cx >= cy (GMP invariant) if (cx != cy) add_1(tp, tn_mod, static_cast(cx - cy)); } } else { std::memset(tp, 0, tn * sizeof(uint64_t)); if (dn >= in) multiply(tp, dp, dn, qp, in, mul_work); else multiply(tp, qp, in, dp, dn, mul_work); } // r = rp[dn - in] - tp[dn]: 商推定の精度指標 // r != 0 → 補正が必要 (過大推定) uint64_t r = rp[dn - in] - tp[dn]; // ====== Step 3: 新しい部分余り ====== // new_rp = [np[0..in-1], rp[0..dn-1]] - tp[0..dn] // = (旧余り × β^in + 次の被除数チャンク) - Q_chunk × D if (dn != in) { // 下位 in limbs: np[0..in-1] - tp[0..in-1] cy = sub(tp, np, in, tp, in); // 上位 dn-in limbs: rp[0..dn-in-1] - tp[in..dn-1] - borrow // sub_nc が無いので手動でボローを処理 for (size_t i = 0; i < dn - in; i++) { uint64_t a_val = rp[i]; uint64_t b_val = tp[in + i]; uint64_t diff = a_val - b_val; uint64_t bw1 = (a_val < b_val) ? 1 : 0; uint64_t diff2 = diff - cy; uint64_t bw2 = (diff < cy) ? 1 : 0; tp[in + i] = diff2; cy = bw1 + bw2; } std::memcpy(rp, tp, dn * sizeof(uint64_t)); } else { cy = sub(rp, np, in, tp, in); } // ====== Step 4: 補正 ====== // 過大推定: r != 0 → Q++, rp -= D r -= cy; while (r != 0) { add_1(qp, in, 1); uint64_t sub_cy = sub(rp, rp, dn, dp, dn); r -= sub_cy; } // 過小推定: rp >= D → Q++, rp -= D if (cmp(rp, dn, dp, dn) >= 0) { add_1(qp, in, 1); sub(rp, rp, dn, dp, dn); } } return qh; } inline size_t mu_div_qr(uint64_t* q, uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { size_t qn = an - bn; size_t in = mu_div_qr_choose_in(qn, bn); // scratch layout: // ip[0..in-1]: 逆数 // work[0..]: 逆数計算後は preinv ループ用に再利用 uint64_t* ip = scratch; uint64_t* work = scratch + in; // ====== Step 1: in-limb 逆数を計算 ====== // D の上位 (in+1) limbs の逆数を計算し、上位 in limbs を使用 // (GMP mpn_mu_div_qr2 方式) // // work layout (一時使用、Step 2 で上書きされる): // ip_full[0..in]: (in+1)-limb 逆数出力 // tp_inv[0..in]: (in+1)-limb 除数コピー // inv_scratch[..]: invert_approx workspace { uint64_t* ip_full = work; uint64_t* tp_inv = work + in + 1; uint64_t* inv_scratch = tp_inv + in + 1; if (bn == in) { // 完全逆数: D 全体の逆数 // (1, D[0..in-1]) の (in+1)-limb 逆数を計算 tp_inv[0] = 1; std::memcpy(tp_inv + 1, b, in * sizeof(uint64_t)); invert_approx(ip_full, tp_inv, in + 1, inv_scratch); // 最下位 limb を捨てて ip[0..in-1] にコピー std::memcpy(ip, ip_full + 1, in * sizeof(uint64_t)); } else { // 部分逆数: D の上位 (in+1) limbs の逆数 // D[bn-in-1..bn-1] + 1 の逆数 (切り上げで安全な過小評価) std::memcpy(tp_inv, b + bn - (in + 1), (in + 1) * sizeof(uint64_t)); uint64_t cy = add_1(tp_inv, in + 1, 1); if (cy) { // オーバーフロー: D の上位が全て 1 → 逆数 ≈ 0 std::memset(ip, 0, in * sizeof(uint64_t)); } else { invert_approx(ip_full, tp_inv, in + 1, inv_scratch); std::memcpy(ip, ip_full + 1, in * sizeof(uint64_t)); } } } // ====== Step 2: preinv ループで除算 ====== // rp: scratch から bn limbs 確保 uint64_t* rp = work; uint64_t* preinv_scratch = work + bn; uint64_t qh = preinv_mu_div_qr(q, rp, a, an, b, bn, ip, in, preinv_scratch); // 余りを a[0..bn-1] にコピー std::memcpy(a, rp, bn * sizeof(uint64_t)); if (qh) { q[qn] = qh; return normalized_size(q, qn + 1); } return normalized_size(q, qn); } // -------------------------------------------------------------------------- // divide_scratch_size / divide: 汎用除算ディスパッチャー // -------------------------------------------------------------------------- inline size_t divide_scratch_size(size_t an, size_t bn) { if (bn < 2) return 0; // 正規化コピー: bn (nb) + an + 2 (na + sentinel) + work size_t norm_sz = bn + an + 8; // +8: Svoboda sentinel + 余裕 // BZ 再帰は log2(bn/BZ_THRESHOLD) レベル。各レベルで ~4*half の scratch。 // 安全な上限: 30*bn + mul_sz (大きめに確保) size_t mul_sz = multiply_scratch_size(bn, bn); size_t bz_sz = norm_sz + 30 * bn + mul_sz; if (bn >= MU_DIV_BALANCED_THRESHOLD) { // mu-division (balanced + unbalanced) と BZ の両方に対応 size_t mu_sz = norm_sz + mu_div_qr_scratch_size(an, bn); return std::max(bz_sz, mu_sz); } return bz_sz; } // 汎用除算 // q[0..an-bn] に商, r[0..bn-1] に余り // 返値: 商の normalized size // scratch: divide_scratch_size(an, bn) limbs inline size_t divide(uint64_t* q, uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { // an < bn → 商 = 0, 余り = a if (an < bn || (an == bn && cmp(a, an, b, bn) < 0)) { std::memcpy(r, a, an * sizeof(uint64_t)); if (an < bn) std::memset(r + an, 0, (bn - an) * sizeof(uint64_t)); return 0; } // 1-limb 除数は呼び出し元で処理されるはず (fast path) // ここでは bn >= 2 を想定 // --- 正規化: b の最上位 limb の MSB をセット --- unsigned shift = 0; { uint64_t hi = b[bn - 1]; shift = std::countl_zero(hi); } // scratch 内に正規化版のコピーを作る uint64_t* nb = scratch; // bn limbs uint64_t* na = scratch + bn; // an + 2 limbs (Svoboda sentinel 用に +1) uint64_t* work = scratch + bn + an + 2; // 残り scratch if (shift > 0) { lshift(nb, b, bn, shift); na[an] = lshift(na, a, an, shift); } else { std::memcpy(nb, b, bn * sizeof(uint64_t)); std::memcpy(na, a, an * sizeof(uint64_t)); na[an] = 0; } size_t nan = an + (na[an] ? 1 : 0); // 正規化後の実効サイズ size_t qn; if (bn < BZ_THRESHOLD) { // small divisor: always schoolbook na[nan] = 0; // sentinel for div_basecase div_basecase(q, na, nan, nb, bn); qn = normalized_size(q, nan - bn + 1); } else if (bn >= MU_DIV_BALANCED_THRESHOLD) { // 大きな除数: Newton 逆数反復 (balanced/unbalanced 共通) // BZ O(M(n) log n) → mu O(M(n)) で高速化 qn = mu_div_qr(q, na, nan, nb, bn, work); } else if (nan > 2 * bn + 1 && bn >= mu_div_threshold(nan, bn)) { qn = mu_div_qr(q, na, nan, nb, bn, work); } else if (nan < 2 * bn) { if (nan - bn < BZ_THRESHOLD) { // 商が小さい (< BZ_THRESHOLD): schoolbook na[nan] = 0; div_basecase(q, na, nan, nb, bn); qn = normalized_size(q, nan - bn + 1); } else { // 商が大きいが nan < 2*bn: zero-pad して BZ で処理 // ★ na は scratch 上なので直接 zero-pad (コピー不要) std::memset(na + nan, 0, (2 * bn - nan) * sizeof(uint64_t)); uint64_t* bz_q = work; // bn limbs uint64_t* bz_r = work + bn; // bn limbs uint64_t* bz_scratch = work + 2 * bn; div_2n_by_n(bz_q, bz_r, na, nb, bn, bz_scratch); std::memcpy(q, bz_q, (nan - bn + 1) * sizeof(uint64_t)); std::memcpy(na, bz_r, bn * sizeof(uint64_t)); qn = normalized_size(q, nan - bn + 1); } } else if (nan == 2 * bn) { // balanced BZ (nan = 2*bn のみ) // div_2n_by_n は a < B^n * b (商が n limbs に収まる) を前提とする。 // 上位 n limbs >= b のとき、1 回引いて商の最上位を 1 にする。 uint64_t q_top = 0; // 上位 n limbs >= nb のチェック if (cmp(na + bn, bn, nb, bn) >= 0) { q_top++; sub(na + bn, na + bn, bn, nb, bn); } q[bn] = q_top; // na[0..2*bn-1] を BZ で処理 (a の上位 < b が保証される) uint64_t* padded_a = work; std::memcpy(padded_a, na, 2 * bn * sizeof(uint64_t)); uint64_t* bz_q = padded_a + 2 * bn; // bn limbs uint64_t* bz_r = bz_q + bn; // bn limbs uint64_t* bz_scratch = bz_r + bn; div_2n_by_n(bz_q, bz_r, padded_a, nb, bn, bz_scratch); std::memcpy(q, bz_q, bn * sizeof(uint64_t)); std::memcpy(na, bz_r, bn * sizeof(uint64_t)); qn = normalized_size(q, nan - bn + 1); } else { // heavily unbalanced: チャンク分割 qn = div_unbalanced(q, na, na, nan, nb, bn, work); } // 余りの逆正規化 (shift < 64 なのでインラインで処理) if (shift > 0) { unsigned rsh = shift; unsigned lsh = 64 - rsh; for (size_t i = 0; i < bn - 1; i++) { r[i] = (na[i] >> rsh) | (na[i + 1] << lsh); } r[bn - 1] = na[bn - 1] >> rsh; } else { std::memcpy(r, na, bn * sizeof(uint64_t)); } return qn; } // ============================================================================ // GCD Operations (Binary GCD / Stein's Algorithm) // ============================================================================ // Count trailing zeros in a limb array (bits, not limbs) // Returns the total number of trailing zero bits across all limbs // Used for extracting common power-of-2 factor in binary GCD inline size_t ctz_limb_array(const uint64_t* a, size_t n) { for (size_t i = 0; i < n; i++) { if (a[i] != 0) { return i * 64 + std::countr_zero(a[i]); } } // All limbs are zero return n * 64; } // Right shift by arbitrary number of bits // Returns normalized size after shift inline size_t rshift(uint64_t* r, const uint64_t* a, size_t n, size_t shift_bits) { if (shift_bits == 0 || n == 0) { if (r != a) { for (size_t i = 0; i < n; i++) r[i] = a[i]; } return n; } size_t limb_shift = shift_bits / 64; size_t bit_shift = shift_bits % 64; if (limb_shift >= n) { // Shifted everything away return 0; } size_t remaining = n - limb_shift; if (bit_shift == 0) { // Pure limb shift for (size_t i = 0; i < remaining; i++) { r[i] = a[i + limb_shift]; } return normalized_size(r, remaining); } // Combined shift: limb ずらし + ビットシフト #ifdef CALX_INT_HAS_ASM mpn_rshift_asm(r, a + limb_shift, remaining, static_cast(bit_shift)); #else for (size_t i = 0; i < remaining - 1; i++) { uint64_t low = a[i + limb_shift] >> bit_shift; uint64_t high = a[i + limb_shift + 1] << (64 - bit_shift); r[i] = low | high; } r[remaining - 1] = a[n - 1] >> bit_shift; #endif return normalized_size(r, remaining); } // Left shift by arbitrary number of bits // Returns normalized size after shift // r must have space for at least n + (shift_bits + 63) / 64 limbs // In-place safe (r == a) inline size_t lshift_arbitrary(uint64_t* r, const uint64_t* a, size_t n, size_t shift_bits) { if (shift_bits == 0 || n == 0) { if (r != a) { for (size_t i = 0; i < n; i++) r[i] = a[i]; } return n; } size_t limb_shift = shift_bits / 64; size_t bit_shift = shift_bits % 64; if (bit_shift == 0) { // Pure limb shift — reverse order for in-place safety for (size_t i = n; i-- > 0; ) { r[i + limb_shift] = a[i]; } } else { // Combined shift — reverse order for in-place safety // Carry must be computed before any writes (a may alias r) uint64_t top_carry = a[n - 1] >> (64 - bit_shift); if (top_carry != 0) { r[n + limb_shift] = top_carry; } for (size_t i = n - 1; i > 0; i--) { r[i + limb_shift] = (a[i] << bit_shift) | (a[i - 1] >> (64 - bit_shift)); } r[limb_shift] = a[0] << bit_shift; // Clear lower limbs (after data is moved, safe for in-place) for (size_t i = 0; i < limb_shift; i++) { r[i] = 0; } if (top_carry != 0) { return n + limb_shift + 1; } return n + limb_shift; } // Clear lower limbs for pure limb shift for (size_t i = 0; i < limb_shift; i++) { r[i] = 0; } return n + limb_shift; } // Binary GCD (Stein's algorithm) - division-free GCD // Input: a[0..an-1], b[0..bn-1] (normalized, non-zero) // Output: r[0..return_value-1] = gcd(a, b) // scratch: temporary buffer, size >= 3 * max(an, bn) // Returns: size of result // Algorithm: Cohen "Computational Algebraic Number Theory" p.14 inline size_t gcd_binary(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* b, size_t bn, uint64_t* scratch) { if (an == 0 || bn == 0) { // gcd(0, b) = b, gcd(a, 0) = a if (an == 0) { for (size_t i = 0; i < bn; i++) r[i] = b[i]; return bn; } else { for (size_t i = 0; i < an; i++) r[i] = a[i]; return an; } } size_t max_size = (an > bn) ? an : bn; // Allocate scratch buffers uint64_t* a_work = scratch; uint64_t* b_work = scratch + max_size; uint64_t* temp = scratch + 2 * max_size; // Copy inputs to work buffers for (size_t i = 0; i < an; i++) a_work[i] = a[i]; for (size_t i = 0; i < bn; i++) b_work[i] = b[i]; size_t a_size = an; size_t b_size = bn; // Step 1: Count trailing zeros size_t k_a = ctz_limb_array(a_work, a_size); size_t k_b = ctz_limb_array(b_work, b_size); size_t k = (k_a < k_b) ? k_a : k_b; // Common factor = 2^k // Step 2: Remove trailing zeros from both if (k_a > 0) { a_size = rshift(a_work, a_work, a_size, k_a); } if (k_b > 0) { b_size = rshift(b_work, b_work, b_size, k_b); } // Step 3: Binary GCD loop while (true) { // Compare a_work and b_work int c = cmp(a_work, a_size, b_work, b_size); if (c == 0) { // a_work == b_work, we're done break; } // Subtract smaller from larger int sign_dummy; size_t diff_size; if (c > 0) { // a_work > b_work: a_work = a_work - b_work diff_size = abs_sub(temp, sign_dummy, a_work, a_size, b_work, b_size); for (size_t i = 0; i < diff_size; i++) a_work[i] = temp[i]; a_size = diff_size; // Remove trailing zeros from difference size_t shift = ctz_limb_array(a_work, a_size); if (shift > 0) { a_size = rshift(a_work, a_work, a_size, shift); } } else { // b_work > a_work: b_work = b_work - a_work diff_size = abs_sub(temp, sign_dummy, b_work, b_size, a_work, a_size); for (size_t i = 0; i < diff_size; i++) b_work[i] = temp[i]; b_size = diff_size; // Remove trailing zeros from difference size_t shift = ctz_limb_array(b_work, b_size); if (shift > 0) { b_size = rshift(b_work, b_work, b_size, shift); } } // Check for zero (shouldn't happen in normal cases) if (a_size == 0 || b_size == 0) { break; } } // Step 4: Restore common factor by left shift size_t result_size; if (k > 0) { result_size = lshift_arbitrary(r, a_work, a_size, k); } else { for (size_t i = 0; i < a_size; i++) r[i] = a_work[i]; result_size = a_size; } return result_size; } // Calculate scratch size needed for gcd_binary inline size_t gcd_binary_scratch_size(size_t an, size_t bn) { size_t max_size = (an > bn) ? an : bn; return 3 * max_size; } // ============================================================================ // Lehmer GCD (GMP の mpn_gcd に相当) // hgcd2 で上位 2 limb から変換行列を構築し、n-limb ベクトルに一括適用 // 計算量: O(n² / 64) — Binary GCD の O(n²) に対して ~64 倍高速 // ============================================================================ // 1-limb GCD (両方奇数が前提) // GMP の mpn_gcd_11 に相当 inline uint64_t gcd_11(uint64_t u, uint64_t v) { // u, v ともに奇数であること // GMP 方式ブランチレス Binary GCD // 冗長な LSB を除去して表現 (暗黙の最下位ビット) u >>= 1; v >>= 1; while (u != v) { uint64_t t = u - v; // vgtu = (u < v) ? ~0 : 0 (符号ビットを全ビットに拡張) uint64_t vgtu = static_cast(static_cast(t) >> 63); // v = min(u, v) v += (vgtu & t); // u = |u - v| u = (t ^ vgtu) - vgtu; // 末尾ゼロ除去 (ctz(t) は ctz(|t|) と同じ) int c = std::countr_zero(t); // (u >> 1) >> c: 1ビット追加シフトは ctz と独立に実行可能 u = (u >> 1) >> c; } return (u << 1) + 1; } // 2-limb GCD: gcd({u1,u0}, {v1,v0}) // 両方奇数が前提 // GMP の mpn_gcd_22 に相当 struct DoubleLimb { uint64_t d0, d1; // little-endian: d0 = low, d1 = high }; inline DoubleLimb gcd_22(uint64_t u1, uint64_t u0, uint64_t v1, uint64_t v0) { // GMP 方式ブランチレス 2-limb Binary GCD // 暗黙の LSB: 右に 1 ビットシフト u0 = (u0 >> 1) | (u1 << 63); u1 >>= 1; v0 = (v0 >> 1) | (v1 << 63); v1 >>= 1; while (u1 || v1) { // sub_ddmmss: (t1, t0) = (u1, u0) - (v1, v0) uint64_t borrow = (u0 < v0) ? 1ULL : 0ULL; uint64_t t0 = u0 - v0; uint64_t t1 = u1 - v1 - borrow; // vgtu: u < v なら全ビット 1、そうでなければ 0 uint64_t vgtu = static_cast(static_cast(t1) >> 63); if (t0 == 0) { if (t1 == 0) { // u == v: GCD 発見 DoubleLimb g; g.d1 = (u1 << 1) | (u0 >> 63); g.d0 = (u0 << 1) | 1; return g; } int c = std::countr_zero(t1); // v1 = min(u1, v1) via branchless v1 += (vgtu & t1); // u0 = |u1 - v1| u0 = (t1 ^ vgtu) - vgtu; u0 >>= c + 1; u1 = 0; } else { int c = std::countr_zero(t0) + 1; // V <-- min(U, V) via branchless add uint64_t add0 = vgtu & t0; uint64_t add1 = vgtu & t1; uint64_t carry2 = 0; v0 += add0; carry2 = (v0 < add0) ? 1ULL : 0ULL; v1 += add1 + carry2; // U <-- |U - V| u0 = (t0 ^ vgtu) - vgtu; u1 = t1 ^ vgtu; if (c == 64) { u0 = u1; u1 = 0; } else { u0 = (u0 >> c) | (u1 << (64 - c)); u1 >>= c; } } } // 片方の high limb が 0 → 値が MSB 付近にある場合の遷移ループ while ((v0 | u0) & (1ULL << 63)) { uint64_t t0 = u0 - v0; uint64_t vgtu = static_cast(-(u0 < v0)); // borrow → mask if (t0 == 0) { DoubleLimb g; g.d1 = u0 >> 63; g.d0 = (u0 << 1) | 1; return g; } v0 += (vgtu & t0); u0 = (t0 ^ vgtu) - vgtu; int c = std::countr_zero(t0); u0 = (u0 >> 1) >> c; } // 1-limb GCD へフォールスルー DoubleLimb g; g.d0 = gcd_11((u0 << 1) + 1, (v0 << 1) + 1); g.d1 = 0; return g; } // 2×2 変換行列 (hgcd2 の結果) struct HgcdMatrix1 { uint64_t u[2][2]; // u[row][col] }; // 2-limb 除算: (ah,al) / (bh,bl) → 商を返し、余りを r に格納 // GMP の hgcd2-div.h の div2 に相当 inline uint64_t hgcd2_div2(uint64_t r[2], uint64_t n1, uint64_t n0, uint64_t d1, uint64_t d0) { // GMP Method 2 方式: ブランチレスビットワイズ除算 // (n1,n0) / (d1,d0) の商と余りを計算 uint64_t q = 0; int ncnt = std::countl_zero(n1); int dcnt = std::countl_zero(d1); int cnt = dcnt - ncnt; // d を左シフトして n と桁を揃える // d1 = (d1 << cnt) + (d0 >> 1 >> (63 - cnt)) // d0 <<= cnt if (cnt > 0) { d1 = (d1 << cnt) | (d0 >> (64 - cnt)); d0 <<= cnt; } // ブランチレスループ: cnt+1 回のイテレーション do { uint64_t mask; q <<= 1; // mask = (n >= d) ? ~0 : 0 (ブランチレス) if (n1 == d1) mask = static_cast(-static_cast(n0 >= d0)); else mask = static_cast(-static_cast(n1 > d1)); q -= mask; // mask が ~0 なら q += 1 // n -= d & mask (条件付き減算) uint64_t sub0 = mask & d0; uint64_t sub1 = mask & d1; uint64_t borrow = (n0 < sub0) ? 1ULL : 0ULL; n0 -= sub0; n1 -= sub1 + borrow; // d >>= 1 d0 = (d1 << 63) | (d0 >> 1); d1 >>= 1; } while (cnt--); r[0] = n0; r[1] = n1; return q; } // 1-limb 除算: ah / bh → 商と余り inline DoubleLimb hgcd2_div1(uint64_t ah, uint64_t bh) { DoubleLimb result; result.d1 = ah / bh; // quotient result.d0 = ah % bh; // remainder return result; } // hgcd2: 上位 2 limb で複数の Euclidean ステップを実行し、変換行列を構築 // GMP の mpn_hgcd2 に相当 // 返値: 1 = 進捗あり (行列が有効), 0 = 進捗なし inline int hgcd2(uint64_t ah, uint64_t al, uint64_t bh, uint64_t bl, HgcdMatrix1* M) { uint64_t u00, u01, u10, u11; if (ah < 2 || bh < 2) return 0; // 初期減算で a >= b を保証 if (ah > bh || (ah == bh && al > bl)) { // a -= b uint64_t borrow = (al < bl) ? 1ULL : 0ULL; ah = ah - bh - borrow; al = al - bl; if (ah < 2) return 0; u00 = u01 = u11 = 1; u10 = 0; } else { // b -= a uint64_t borrow = (bl < al) ? 1ULL : 0ULL; bh = bh - ah - borrow; bl = bl - al; if (bh < 2) return 0; u00 = u10 = u11 = 1; u01 = 0; } if (ah < bh) goto subtract_a; // Double-precision ループ for (;;) { if (ah == bh) goto done; if (ah < (1ULL << 32)) { // 上位半分に縮約して single-precision ループへ ah = (ah << 32) | (al >> 32); bh = (bh << 32) | (bl >> 32); break; } // a -= q*b, 行列の第2列を更新 { uint64_t borrow = (al < bl) ? 1ULL : 0ULL; ah = ah - bh - borrow; al = al - bl; } if (ah < 2) goto done; if (ah <= bh) { // q = 1 u01 += u00; u11 += u10; } else { uint64_t r[2]; uint64_t q = hgcd2_div2(r, ah, al, bh, bl); al = r[0]; ah = r[1]; if (ah < 2) { u01 += q * u00; u11 += q * u10; goto done; } q++; u01 += q * u00; u11 += q * u10; } subtract_a: if (ah == bh) goto done; if (bh < (1ULL << 32)) { ah = (ah << 32) | (al >> 32); bh = (bh << 32) | (bl >> 32); goto subtract_a1; } // b -= q*a, 行列の第1列を更新 { uint64_t borrow = (bl < al) ? 1ULL : 0ULL; bh = bh - ah - borrow; bl = bl - al; } if (bh < 2) goto done; if (bh <= ah) { u00 += u01; u10 += u11; } else { uint64_t r[2]; uint64_t q = hgcd2_div2(r, bh, bl, ah, al); bl = r[0]; bh = r[1]; if (bh < 2) { u00 += q * u01; u10 += q * u11; goto done; } q++; u00 += q * u01; u10 += q * u11; } } // Single-precision ループ (上位半分のみ) for (;;) { ah -= bh; if (ah < (1ULL << 33)) break; if (ah <= bh) { u01 += u00; u11 += u10; } else { DoubleLimb rq = hgcd2_div1(ah, bh); uint64_t q = rq.d1; ah = rq.d0; if (ah < (1ULL << 33)) { u01 += q * u00; u11 += q * u10; break; } q++; u01 += q * u00; u11 += q * u10; } subtract_a1: bh -= ah; if (bh < (1ULL << 33)) break; if (bh <= ah) { u00 += u01; u10 += u11; } else { DoubleLimb rq = hgcd2_div1(bh, ah); uint64_t q = rq.d1; bh = rq.d0; if (bh < (1ULL << 33)) { u00 += q * u01; u10 += q * u11; break; } q++; u00 += q * u01; u10 += q * u11; } } done: M->u[0][0] = u00; M->u[0][1] = u01; M->u[1][0] = u10; M->u[1][1] = u11; return 1; } // 行列 M の逆行列を n-limb ベクトル (ap, bp) に適用 // GMP の mpn_matrix22_mul1_inverse_vector に準拠 // // hgcd2 が格納する行列 M_stored は逆変換: // original = M_stored × reduced (det(M_stored) = 1) // M_stored^{-1} = [[u11, -u01], [-u10, u00]] // // 計算: // rp = u11 * ap - u01 * bp (carry と borrow は必ず等しい) // bp_new = u00 * bp - u10 * ap (同上) // // rp は ap とは別バッファが必要、bp は in-place 更新 inline size_t hgcd_mul_matrix1_vector(const HgcdMatrix1* M, uint64_t* rp, const uint64_t* ap, uint64_t* bp, size_t n) { uint64_t u00 = M->u[0][0], u01 = M->u[0][1]; uint64_t u10 = M->u[1][0], u11 = M->u[1][1]; // rp = u11 * ap - u01 * bp (carry == borrow が保証される) mul_1(rp, ap, n, u11); submul_1(rp, bp, n, u01); // bp = u00 * bp - u10 * ap (同上) mul_1(bp, bp, n, u00); submul_1(bp, ap, n, u10); // GMP 準拠の正規化: 最上位 limb が両方ゼロなら 1 減 n -= (rp[n - 1] | bp[n - 1]) == 0; return n; } // ============================================================================ // 再帰的 HGCD (Half-GCD) — O(M(n) log n) の準線形 GCD // GMP の mpn_hgcd に相当 (Möller 2008) // // 概要: // hgcd(ap, bp, n, M, tp) は n-limb の (ap, bp) を n/2+1 limb 以下に // 縮小する変換行列 M を計算する。M は蓄積された Euclidean ステップの積。 // // hgcd2 が格納する行列の規約に従い: // M^{-1} = [[M11, -M01], [-M10, M00]] (det(M) = 1) // 新 a = M11 * old_a - M01 * old_b // 新 b = M00 * old_b - M10 * old_a // // GCD 用エントリポイント: // hgcd_reduce() が gcd_lehmer() から呼ばれ、HGCD で高速縮小する。 // HGCD_THRESHOLD 未満は従来の Lehmer ループがそのまま動く。 // ============================================================================ // HGCD 閾値: これ未満は Lehmer (hgcd2) のみ使用 // 小サイズでの GMP に対する優位 (0.7-0.8x) を維持するため、十分大きく設定 constexpr size_t HGCD_THRESHOLD = 126; // 多倍長 2×2 変換行列 // M = [[p[0][0], p[0][1]], [p[1][0], p[1][1]]] // det(M) = 1 (常に) struct HgcdMatrix { size_t alloc; // 各エントリの最大 limb 数 size_t n; // 使用中の最大 limb 数 uint64_t* p[2][2]; // 行列エントリへのポインタ }; // HgcdMatrix の初期化 (単位行列、arena から割り当て) // 全バッファをゼロ初期化してから対角要素を 1 に設定 inline void hgcd_matrix_init(HgcdMatrix* M, size_t max_alloc) { M->alloc = max_alloc; M->n = 1; auto& arena = getThreadArena(); for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) { M->p[i][j] = arena.alloc_limbs(max_alloc); std::memset(M->p[i][j], 0, max_alloc * sizeof(uint64_t)); if (i == j) M->p[i][j][0] = 1; } } // M = M × M1 (多倍長 × 単一 limb 行列) // M1 は hgcd2 の出力 (HgcdMatrix1) // tp は M->n + 1 以上のサイズが必要 // // 計算: // [M00' M01'] = [M00 M01] × [u00 u01] // [M10' M11'] [M10 M11] [u10 u11] // // 行ごとに、M[row][0] を tp に退避してから更新 inline void hgcd_matrix_mul_1(HgcdMatrix* M, const HgcdMatrix1* M1, uint64_t* tp) { size_t n = M->n; uint64_t u00 = M1->u[0][0], u01 = M1->u[0][1]; uint64_t u10 = M1->u[1][0], u11 = M1->u[1][1]; // 行 0: [M00', M01'] = [M00, M01] × [[u00, u01], [u10, u11]] uint64_t c0 = mul_1(tp, M->p[0][0], n, u00); uint64_t c1 = addmul_1(tp, M->p[0][1], n, u10); tp[n] = c0 + c1; uint64_t c2 = mul_1(M->p[0][1], M->p[0][1], n, u11); uint64_t c3 = addmul_1(M->p[0][1], M->p[0][0], n, u01); M->p[0][1][n] = c2 + c3; std::memcpy(M->p[0][0], tp, (n + 1) * sizeof(uint64_t)); // 行 1 c0 = mul_1(tp, M->p[1][0], n, u00); c1 = addmul_1(tp, M->p[1][1], n, u10); tp[n] = c0 + c1; c2 = mul_1(M->p[1][1], M->p[1][1], n, u11); c3 = addmul_1(M->p[1][1], M->p[1][0], n, u01); M->p[1][1][n] = c2 + c3; std::memcpy(M->p[1][0], tp, (n + 1) * sizeof(uint64_t)); // サイズ更新 n++; if (M->p[0][0][n-1] | M->p[0][1][n-1] | M->p[1][0][n-1] | M->p[1][1][n-1]) M->n = n; } // M = M × M2 (多倍長 × 多倍長行列) // tp レイアウト: [save: n1] [t1: rn] [t2: rn] [mul_scratch] inline void hgcd_matrix_mul(HgcdMatrix* M, const HgcdMatrix* M2, uint64_t* tp) { size_t n1 = M->n; size_t n2 = M2->n; size_t rn = n1 + n2; uint64_t* save = tp; uint64_t* t1 = tp + n1; uint64_t* t2 = tp + n1 + rn; uint64_t* ms = tp + n1 + 2 * rn; size_t new_n = 0; for (int row = 0; row < 2; row++) { std::memcpy(save, M->p[row][0], n1 * sizeof(uint64_t)); for (int col = 0; col < 2; col++) { multiply(t1, save, n1, M2->p[0][col], n2, ms); multiply(t2, M->p[row][1], n1, M2->p[1][col], n2, ms); size_t s1 = normalized_size(t1, rn); size_t s2 = normalized_size(t2, rn); if (s1 >= s2) { std::memcpy(M->p[row][col], t1, s1 * sizeof(uint64_t)); if (s2 > 0) { uint64_t carry = add(M->p[row][col], M->p[row][col], s1, t2, s2); if (carry) { M->p[row][col][s1] = carry; s1++; } } } else { std::memcpy(M->p[row][col], t2, s2 * sizeof(uint64_t)); uint64_t carry = add(M->p[row][col], M->p[row][col], s2, t1, s1); if (carry) { M->p[row][col][s2] = carry; s1 = s2 + 1; } else s1 = s2; } if (s1 > new_n) new_n = s1; } } M->n = new_n; } // 行列列更新: M[i][col] += q * M[i][col^1] // ユークリッド除算ステップ後の行列追跡に使用 // col = 1: a ← a - q*b のステップ、col = 0: b ← b - q*a のステップ // qp[0..qn-1]: 商 (多倍長) // tp: scratch バッファ (多倍長 q の場合に使用) inline void hgcd_matrix_update_q(HgcdMatrix* M, const uint64_t* qp, size_t qn, int col, uint64_t* tp) { int other = col ^ 1; size_t mn = M->n; // 全エントリを max 可能サイズにゼロ拡張 size_t max_new = mn + qn; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) std::memset(M->p[i][j] + mn, 0, (max_new + 1 - mn) * sizeof(uint64_t)); size_t new_n = mn; for (int row = 0; row < 2; row++) { if (qn == 1) { // 単一 limb: addmul_1 で効率的に処理 uint64_t cy = addmul_1(M->p[row][col], M->p[row][other], mn, qp[0]); if (cy) { M->p[row][col][mn] += cy; if (M->p[row][col][mn] != 0 && mn + 1 > new_n) new_n = mn + 1; } } else { // 多倍長 q: multiply → add size_t prod_n = mn + qn; uint64_t* ms = tp + prod_n; if (qn >= mn) multiply(tp, qp, qn, M->p[row][other], mn, ms); else multiply(tp, M->p[row][other], mn, qp, qn, ms); uint64_t cy = add(M->p[row][col], M->p[row][col], max_new, tp, prod_n); if (cy) { M->p[row][col][max_new] = cy; if (max_new + 1 > new_n) new_n = max_new + 1; } else { if (prod_n > new_n) new_n = prod_n; } } } // 正規化: 最上位の非ゼロ位置を探す while (new_n > 1) { bool any = false; for (int i = 0; i < 2 && !any; i++) for (int j = 0; j < 2 && !any; j++) if (M->p[i][j][new_n - 1] != 0) any = true; if (any) break; new_n--; } M->n = new_n; } // M^{-1} をフルサイズベクトルに適用 (GMP mpn_hgcd_matrix_adjust 準拠) // // HGCD(ap+p, bp+p, nn) の後に呼び出す。 // ap[0..p-1]: 未変更の下位部分 (a_lo) // bp[0..p-1]: 未変更の下位部分 (b_lo) // ap[p..n-1]: HGCD 縮小後の上位部分 (a_hi') ← n = p + nn // bp[p..n-1]: HGCD 縮小後の上位部分 (b_hi') // // 計算 (det(M) = 1, M^{-1} = [[M11,-M01],[-M10,M00]]): // new_a = M11 * a_lo + B^p * a_hi' - M01 * b_lo // new_b = M00 * b_lo + B^p * b_hi' - M10 * a_lo // // tp サイズ: 2 * (p + M.n) + multiply_scratch inline size_t hgcd_matrix_adjust(const HgcdMatrix* M, size_t n, uint64_t* ap, uint64_t* bp, size_t p, uint64_t* tp) { size_t mn = M->n; size_t prod_n = mn + p; uint64_t* t0 = tp; uint64_t* t1 = tp + prod_n; uint64_t* ms = tp + 2 * prod_n; uint64_t ah, bh, cy; // a を上書きする前に a_lo に依存する 2 つの積を先に計算 if (mn >= p) multiply(t0, M->p[1][1], mn, ap, p, ms); else multiply(t0, ap, p, M->p[1][1], mn, ms); if (mn >= p) multiply(t1, M->p[1][0], mn, ap, p, ms); else multiply(t1, ap, p, M->p[1][0], mn, ms); // --- new_a = M11 * a_lo + B^p * a_hi' - M01 * b_lo --- std::memcpy(ap, t0, p * sizeof(uint64_t)); ah = add(ap + p, ap + p, n - p, t0 + p, mn); if (mn >= p) multiply(t0, M->p[0][1], mn, bp, p, ms); else multiply(t0, bp, p, M->p[0][1], mn, ms); cy = sub(ap, ap, n, t0, prod_n); ah -= cy; // --- new_b = M00 * b_lo + B^p * b_hi' - M10 * a_lo --- if (mn >= p) multiply(t0, M->p[0][0], mn, bp, p, ms); else multiply(t0, bp, p, M->p[0][0], mn, ms); std::memcpy(bp, t0, p * sizeof(uint64_t)); bh = add(bp + p, bp + p, n - p, t0 + p, mn); cy = sub(bp, bp, n, t1, prod_n); bh -= cy; if (ah > 0 || bh > 0) { ap[n] = ah; bp[n] = bh; n++; } else { if (ap[n - 1] == 0 && bp[n - 1] == 0) n--; } return n; } // 前方宣言 inline size_t hgcd(uint64_t* ap, uint64_t* bp, size_t n, HgcdMatrix* M, uint64_t* tp); // 単一ステップ: hgcd2 を試行し、失敗時は除算フォールバック // GMP の mpn_hgcd_step + mpn_gcd_subdiv_step に相当 // 返値: 縮小後のサイズ (0 = 進捗なしまたは GCD 発見) inline size_t hgcd_step(size_t n, uint64_t* ap, uint64_t* bp, size_t s, HgcdMatrix* M, uint64_t* tp) { HgcdMatrix1 M1; uint64_t mask = ap[n-1] | bp[n-1]; if (mask == 0) { n--; return n; } // GMP 準拠: 上位 2 limb の抽出 (3 分岐) { uint64_t uh, ul, vh, vl; if (n == s + 1) { // 閾値直前: mask < 4 なら hgcd2 をスキップ、 // それ以外はシフト正規化なしで生の limb を使用 (保守的) if (mask < 4) goto subtract; uh = ap[n-1]; ul = ap[n-2]; vh = bp[n-1]; vl = bp[n-2]; } else if (mask >> 63) { // MSB が立っている: シフト不要 uh = ap[n-1]; ul = ap[n-2]; vh = bp[n-1]; vl = bp[n-2]; } else { // シフト正規化 int shift = std::countl_zero(mask); if (n >= 3) { uh = (ap[n-1] << shift) | (ap[n-2] >> (64 - shift)); ul = (ap[n-2] << shift) | (ap[n-3] >> (64 - shift)); vh = (bp[n-1] << shift) | (bp[n-2] >> (64 - shift)); vl = (bp[n-2] << shift) | (bp[n-3] >> (64 - shift)); } else { uh = (ap[n-1] << shift) | (ap[n-2] >> (64 - shift)); ul = ap[n-2] << shift; vh = (bp[n-1] << shift) | (bp[n-2] >> (64 - shift)); vl = bp[n-2] << shift; } } // hgcd2 による変換行列構築を試行 if (hgcd2(uh, ul, vh, vl, &M1)) { n = hgcd_mul_matrix1_vector(&M1, tp, ap, bp, n); std::memcpy(ap, tp, n * sizeof(uint64_t)); hgcd_matrix_mul_1(M, &M1, tp); return n; } } // ブロック終端 (uh, ul, vh, vl スコープ) // hgcd2 失敗 (またはスキップ) → GMP gcd_subdiv_step に準拠したフォールバック // 1. 減算 (q=1) → 行列更新 // 2. 除算 → 行列更新 // s 閾値チェックでオーバーシュートを防止 subtract: { size_t an = normalized_size(ap, n); size_t bn = normalized_size(bp, n); if (an == 0 || bn == 0) return 0; // a < b に配置 (ポインタスワップ、GMP MP_PTR_SWAP に相当) uint64_t* lp = ap; // 小さい方 uint64_t* rp = bp; // 大きい方 size_t ln = an, rn = bn; int col = 0; // bp 側の縮小 → col=0 if (ln == rn) { int c = cmp(lp, ln, rp, rn); if (c == 0) return 0; // GCD 発見 if (c > 0) { std::swap(lp, rp); col ^= 1; } } else if (ln > rn) { std::swap(lp, rp); std::swap(ln, rn); col ^= 1; } // lp[0:ln] < rp[0:rn] (ln <= rn) // 小さい方 ≤ s なら進捗なし if (ln <= s) return 0; // Step 1: 減算 rp -= lp (大 -= 小) sub(rp, rp, rn, lp, ln); rn = normalized_size(rp, rn); if (rn == 0) return 0; // lp | rp → GCD = lp // s チェック: 結果 ≤ s なら元に戻して進捗なし if (rn <= s) { uint64_t cy = add(rp, lp, ln, rp, rn); if (cy) rp[ln] = cy; return 0; } // q=1 の行列更新 { M->p[0][col][M->n] = 0; M->p[1][col][M->n] = 0; uint64_t cy = addmul_1(M->p[0][col], M->p[0][col^1], M->n, 1); if (cy) M->p[0][col][M->n] = cy; cy = addmul_1(M->p[1][col], M->p[1][col^1], M->n, 1); if (cy) M->p[1][col][M->n] = cy; size_t new_mn = M->n; if (M->p[0][col][new_mn] || M->p[1][col][new_mn]) new_mn++; M->n = new_mn; } // 再配置 a < b (減算後、大小関係が変わることがある) ln = normalized_size(lp, n); rn = normalized_size(rp, n); if (ln == 0 || rn == 0) goto step_done; if (ln == rn) { int c = cmp(lp, ln, rp, rn); if (c == 0) goto step_done; // GCD 発見 if (c > 0) { std::swap(lp, rp); std::swap(ln, rn); col ^= 1; } } else if (ln > rn) { std::swap(lp, rp); std::swap(ln, rn); col ^= 1; } // Step 2: 除算 rp /= lp { auto& arena = getThreadArena(); size_t div_mark = arena.mark(); size_t qn = rn - ln + 1; uint64_t* qp = arena.alloc_limbs(qn + 4); uint64_t* dividend = arena.alloc_limbs(rn + 4); std::memcpy(dividend, rp, rn * sizeof(uint64_t)); size_t ds = divide_scratch_size(rn, ln); uint64_t* work = arena.alloc_limbs(ds + 4); divide(qp, rp, dividend, rn, lp, ln, work); // rp[0:ln-1] = 余り。上位の stale limb をクリア。 if (ln < n) std::memset(rp + ln, 0, (n - ln) * sizeof(uint64_t)); size_t rem_n = normalized_size(rp, ln); qn = normalized_size(qp, qn); // s チェック: 余り ≤ s なら商を調整 if (rem_n <= s) { if (qn == 0) { // 矛盾 (rp > lp なのに q=0) arena.rewind(div_mark); goto step_done; } // 商を 1 減らし、余りに除数を加算 (r' = r + lp > s) sub_1(qp, qn, 1); qn = normalized_size(qp, qn); if (rem_n > 0) { uint64_t cy = add(rp, lp, ln, rp, rem_n); if (cy) rp[ln] = cy; } else { std::memcpy(rp, lp, ln * sizeof(uint64_t)); } } // 行列更新 (q > 0 の場合のみ) qn = normalized_size(qp, qn); if (qn > 0) { if (qn == 1) { M->p[0][col][M->n] = 0; M->p[1][col][M->n] = 0; uint64_t cy = addmul_1(M->p[0][col], M->p[0][col^1], M->n, qp[0]); if (cy) M->p[0][col][M->n] = cy; cy = addmul_1(M->p[1][col], M->p[1][col^1], M->n, qp[0]); if (cy) M->p[1][col][M->n] = cy; size_t new_mn = M->n; if (M->p[0][col][new_mn] || M->p[1][col][new_mn]) new_mn++; M->n = new_mn; } else { size_t mn = M->n; size_t prod_n = mn + qn; size_t ms_sz = multiply_scratch_size( std::max(mn, qn), std::min(mn, qn)); uint64_t* update_tp = arena.alloc_limbs(prod_n + ms_sz + 16); hgcd_matrix_update_q(M, qp, qn, col, update_tp); } } arena.rewind(div_mark); } } step_done: // サイズを再計算して正規化 (stale limb をクリア) { size_t an = normalized_size(ap, n); size_t bn = normalized_size(bp, n); n = std::max(an, bn); if (n == 0) return 0; if (an < n) std::memset(ap + an, 0, (n - an) * sizeof(uint64_t)); if (bn < n) std::memset(bp + bn, 0, (n - bn) * sizeof(uint64_t)); } return n; } // 再帰的 HGCD 本体 (Möller 2008 の 4-phase アルゴリズム) // GMP の mpn_hgcd に準拠 // // Phase 1: 上位 ceil(n/2) limb で再帰 → M を構築 // Phase 2: hgcd_step で追加縮小 (目標: 3n/4 付近) // Phase 3: 2 回目の再帰 → M2 を構築、M = M × M2 // Phase 4: 最終 hgcd_step ループ (目標: s = n/2 + 1) inline size_t hgcd(uint64_t* ap, uint64_t* bp, size_t n, HgcdMatrix* M, uint64_t* tp) { size_t s = n / 2 + 1; size_t n_orig = n; bool success = false; // ベースケース: hgcd_step ループのみ if (n < HGCD_THRESHOLD) { while (n > s) { size_t nn = hgcd_step(n, ap, bp, s, M, tp); if (nn == 0) break; n = nn; success = true; } return success ? n : 0; } auto& arena = getThreadArena(); // ===== Phase 1: 上位 ceil(n/2) limb で再帰 ===== // GMP 準拠: p = floor(n/2)、再帰サイズ = n - p = ceil(n/2) size_t p = n / 2; size_t nn = hgcd(ap + p, bp + p, n - p, M, tp); if (nn > 0) { // GMP 準拠: nn > 0 なら常に adjust を実行 (p + nn を渡す) size_t mark = arena.mark(); size_t mn = M->n; size_t prod_n = mn + p; size_t ms_size = multiply_scratch_size( std::max(mn, p), std::min(mn, p)); size_t adj_tp_size = 2 * prod_n + ms_size; uint64_t* adj_tp = arena.alloc_limbs(adj_tp_size + 16); n = hgcd_matrix_adjust(M, p + nn, ap, bp, p, adj_tp); arena.rewind(mark); success = true; } // 正規化 while (n > s && (ap[n-1] | bp[n-1]) == 0) n--; if (n <= s) return success ? n : 0; // ===== Phase 2: hgcd_step で追加縮小 ===== // GMP 準拠: 中間目標 n2 = 3*n_orig/4 + 1 (Phase 3 の再帰に備える) { size_t n2 = 3 * (n_orig / 4) + 1; if (n2 < s) n2 = s; size_t mark = arena.mark(); uint64_t* ltp = arena.alloc_limbs(std::max(n, M->alloc) + 4); while (n > n2) { size_t step_n = hgcd_step(n, ap, bp, s, M, ltp); if (step_n == 0) break; n = step_n; success = true; } arena.rewind(mark); } if (n <= s) return success ? n : 0; // ===== Phase 3: 2 回目の再帰 (GMP 準拠の分割点) ===== // p2 = 2*s - n + 1、再帰サイズ = n - p2 = 2*(n-s) - 1 if (n > s + 2) { size_t p2 = 2 * s - n + 1; size_t nn2 = n - p2; size_t mark = arena.mark(); HgcdMatrix M2; hgcd_matrix_init(&M2, 2 * nn2 + 4); // nn2+2 では不足: Phase 2/4 で M が成長 uint64_t* rec_tp = arena.alloc_limbs(std::max(nn2, M2.alloc) + 4); size_t nn_result = hgcd(ap + p2, bp + p2, nn2, &M2, rec_tp); if (nn_result > 0) { // GMP 準拠: nn_result > 0 なら常に adjust + 行列乗算 (p2 + nn_result を渡す) size_t mn2 = M2.n; size_t prod_n2 = mn2 + p2; size_t ms2 = multiply_scratch_size( std::max(mn2, p2), std::min(mn2, p2)); size_t adj2_size = 2 * prod_n2 + ms2; uint64_t* adj2_tp = arena.alloc_limbs(adj2_size + 16); n = hgcd_matrix_adjust(&M2, p2 + nn_result, ap, bp, p2, adj2_tp); // M = M × M2 size_t n1 = M->n, n2m = M2.n; size_t rn = n1 + n2m; size_t mul_ms = multiply_scratch_size(n1, n2m); size_t mul_tp_size = n1 + 2 * rn + mul_ms; uint64_t* mul_tp = arena.alloc_limbs(mul_tp_size + 16); hgcd_matrix_mul(M, &M2, mul_tp); success = true; } arena.rewind(mark); } // 正規化 while (n > s && (ap[n-1] | bp[n-1]) == 0) n--; if (n <= s) return success ? n : 0; // ===== Phase 4: 最終 hgcd_step ループ ===== { size_t mark = arena.mark(); uint64_t* ltp = arena.alloc_limbs(std::max(n, M->alloc) + 4); while (n > s) { size_t step_n = hgcd_step(n, ap, bp, s, M, ltp); if (step_n == 0) break; n = step_n; success = true; } arena.rewind(mark); } return success ? n : 0; } // GCD 用 HGCD ラッパー // gcd_lehmer から呼ばれ、HGCD で高速に (up, vp) を縮小する // 返値: 新しいサイズ inline size_t hgcd_reduce(uint64_t* up, uint64_t* vp, size_t n) { auto& arena = getThreadArena(); size_t mark = arena.mark(); // 正規化 size_t un = normalized_size(up, n); size_t vn = normalized_size(vp, n); if (un == 0 || vn == 0) { arena.rewind(mark); return std::max(un, vn); } n = std::max(un, vn); if (n < HGCD_THRESHOLD) { arena.rewind(mark); return n; } // 同サイズにパディング if (un < n) std::memset(up + un, 0, (n - un) * sizeof(uint64_t)); if (vn < n) std::memset(vp + vn, 0, (n - vn) * sizeof(uint64_t)); // up >= vp を保証 if (cmp(up, n, vp, n) < 0) { for (size_t i = 0; i < n; i++) std::swap(up[i], vp[i]); } // 分割点: p = n/2 (ベンチマーク実測で n/3 より高速) size_t p = n / 2; // HGCD 行列を初期化 HgcdMatrix M; hgcd_matrix_init(&M, 2 * (n - p) + 4); // HGCD 用 scratch uint64_t* tp = arena.alloc_limbs(std::max(n - p, M.alloc) + 4); // 上位 n-p limb で HGCD 実行 size_t nn = hgcd(up + p, vp + p, n - p, &M, tp); // 行列の結果をフルベクトルに反映 if (nn > 0) { // GMP 準拠: p + nn を渡す size_t mn = M.n; size_t prod_n = mn + p; size_t ms_sz = multiply_scratch_size( std::max(mn, p), std::min(mn, p)); size_t adj_size = 2 * prod_n + ms_sz; uint64_t* adj_tp = arena.alloc_limbs(adj_size + 16); n = hgcd_matrix_adjust(&M, p + nn, up, vp, p, adj_tp); } // 正規化 n = std::max(normalized_size(up, n), normalized_size(vp, n)); arena.rewind(mark); return n; } // GCD subdiv step: hgcd2 失敗時のフォールバック // 1 回の減算 + 除算で a, b を縮小 // 返値: 新しいサイズ (0 = GCD 発見、gp に結果格納) inline size_t gcd_subdiv_step(uint64_t* ap, uint64_t* bp, size_t n, uint64_t* gp, size_t* gn, uint64_t* tp) { size_t an = n, bn = n; // 正規化 while (an > 0 && ap[an - 1] == 0) an--; while (bn > 0 && bp[bn - 1] == 0) bn--; if (an == 0) { for (size_t i = 0; i < bn; i++) gp[i] = bp[i]; *gn = bn; return 0; } if (bn == 0) { for (size_t i = 0; i < an; i++) gp[i] = ap[i]; *gn = an; return 0; } int swapped = 0; // a < b を保証 if (an == bn) { int c = cmp(ap, an, bp, bn); if (c == 0) { // a == b → GCD = a for (size_t i = 0; i < an; i++) gp[i] = ap[i]; *gn = an; return 0; } if (c > 0) { std::swap(ap, bp); swapped = 1; } } else if (an > bn) { std::swap(ap, bp); std::swap(an, bn); swapped = 1; } // b -= a (b > a) sub(bp, bp, bn, ap, an); bn = normalized_size(bp, bn); if (bn == 0) { for (size_t i = 0; i < an; i++) gp[i] = ap[i]; *gn = an; return 0; } // a < b を再度保証 if (an == bn) { int c = cmp(ap, an, bp, bn); if (c == 0) { for (size_t i = 0; i < an; i++) gp[i] = ap[i]; *gn = an; return 0; } if (c > 0) { std::swap(ap, bp); swapped ^= 1; } } else if (an > bn) { std::swap(ap, bp); std::swap(an, bn); swapped ^= 1; } // b = b mod a (除算) // tp を商バッファとして使用 size_t qn = bn - an + 1; // divide は divide(q, r, a, an, b, bn, scratch) の形式 // ここでは bp を a で割る: bp mod ap // tp を商に、bp を余りに size_t scratch_needed = divide_scratch_size(bn, an); // tp の先頭 qn limb を商用、残りを scratch として使用 // ただし tp が十分大きい前提 uint64_t* qp = tp; uint64_t* work = tp + qn; // bp の内容を一時バッファにコピー (divide は被除数を上書きするため) uint64_t* dividend = work; for (size_t i = 0; i < bn; i++) dividend[i] = bp[i]; work += bn; divide(qp, bp, dividend, bn, ap, an, work); bn = normalized_size(bp, an); // 余りは an limb 以下 if (bn == 0) { // 割り切れた → GCD = ap for (size_t i = 0; i < an; i++) gp[i] = ap[i]; *gn = an; return 0; } // ap と bp を正規化してサイズを返す // an は変わらず、bn は余りのサイズ return an; } // Lehmer GCD scratch サイズ inline size_t gcd_lehmer_scratch_size(size_t an, size_t bn) { size_t n = (an > bn) ? an : bn; // Lehmer ループ用: // tp バッファ (hgcd_mul_matrix1_vector の出力、up と swap する): n+1 limb // subdiv_scratch (gcd_subdiv_step 用): 商(n) + 被除数コピー(n) + divide scratch // tp と subdiv_scratch は分離 (swap 後 up が tp 領域を指すため) size_t tp_size = n + 1; size_t div_scratch = divide_scratch_size(n, n); size_t subdiv_scratch = n + n + div_scratch; size_t lehmer_total = tp_size + subdiv_scratch; // 初期不均衡除算用: qn + an + divide_scratch(an, bn) size_t init_total = 0; if (an > bn) { size_t qn = an - bn + 1; size_t init_div = divide_scratch_size(an, bn); init_total = qn + an + init_div; } return (lehmer_total > init_total) ? lehmer_total : init_total; } // Lehmer GCD メインループ // GMP の mpn_gcd に相当 (HGCD 再帰なし、Lehmer のみ) // 入力: up[0..usize-1], vp[0..n-1] は上書きされる // 前提: usize >= n, n >= 1, vp[n-1] > 0, 両方とも奇数 // 出力: gp[0..return_value-1] = gcd(up, vp) inline size_t gcd_lehmer(uint64_t* gp, uint64_t* up, size_t usize, uint64_t* vp, size_t n, uint64_t* scratch) { size_t gn = 0; // 初期の不均衡除算: usize > n の場合、up = up mod vp if (usize > n) { size_t qn = usize - n + 1; uint64_t* qp = scratch; uint64_t* tmp = scratch + qn; // up の内容を一時バッファにコピー for (size_t i = 0; i < usize; i++) tmp[i] = up[i]; uint64_t* work = tmp + usize; divide(qp, up, tmp, usize, vp, n, work); // up は余り (n limb 以下) // ゼロチェック bool zero = true; for (size_t i = 0; i < n; i++) { if (up[i] != 0) { zero = false; break; } } if (zero) { for (size_t i = 0; i < n; i++) gp[i] = vp[i]; return n; } } // Lehmer ループ: n > 2 の間 // scratch レイアウト: [tp: n+1] [subdiv_scratch: ...] // tp は hgcd_mul_matrix1_vector の出力先、up と pointer swap する // subdiv_scratch は gcd_subdiv_step 専用 (tp 領域と分離) uint64_t* tp = scratch; uint64_t* subdiv_scratch = scratch + n + 1; while (n > 2) { // 大サイズの場合は HGCD で高速縮小 if (n >= HGCD_THRESHOLD) { size_t new_n = hgcd_reduce(up, vp, n); if (new_n < n) { n = new_n; // HGCD 後のゼロチェック size_t un2 = normalized_size(up, n); size_t vn2 = normalized_size(vp, n); if (un2 == 0) { for (size_t i = 0; i < vn2; i++) gp[i] = vp[i]; return vn2; } if (vn2 == 0) { for (size_t i = 0; i < un2; i++) gp[i] = up[i]; return un2; } continue; } // HGCD が進捗しなかった場合は Lehmer にフォールスルー } HgcdMatrix1 M_mat; uint64_t uh, ul, vh, vl; uint64_t mask = up[n - 1] | vp[n - 1]; if (mask == 0) { // 両方の最上位が 0 → 正規化 n--; continue; } // 上位 2 limb を正規化して取り出す if (mask >> 63) { // 最上位ビットが立っている → そのまま使用 uh = up[n - 1]; ul = up[n - 2]; vh = vp[n - 1]; vl = vp[n - 2]; } else { // 左シフトで正規化 int shift = std::countl_zero(mask); if (n >= 3) { uh = (up[n - 1] << shift) | (up[n - 2] >> (64 - shift)); ul = (up[n - 2] << shift) | (up[n - 3] >> (64 - shift)); vh = (vp[n - 1] << shift) | (vp[n - 2] >> (64 - shift)); vl = (vp[n - 2] << shift) | (vp[n - 3] >> (64 - shift)); } else { uh = (up[n - 1] << shift) | (up[n - 2] >> (64 - shift)); ul = up[n - 2] << shift; vh = (vp[n - 1] << shift) | (vp[n - 2] >> (64 - shift)); vl = vp[n - 2] << shift; } } // hgcd2 で変換行列を構築 if (hgcd2(uh, ul, vh, vl, &M_mat)) { // 行列を n-limb ベクトルに適用 // tp を rp として使用、up は ap として読まれる n = hgcd_mul_matrix1_vector(&M_mat, tp, up, vp, n); // GMP と同様にポインタスワップ (O(1)) // tp には新しい a の値、vp は in-place 更新済み std::swap(up, tp); // hgcd_mul_matrix1_vector 内で正規化済み if (n == 0) { gp[0] = 0; return 0; } } else { // hgcd2 失敗 → subdiv_step (tp 領域と分離した scratch を使用) n = gcd_subdiv_step(up, vp, n, gp, &gn, subdiv_scratch); if (n == 0) return gn; } } // ベースケース: n <= 2 // 正規化 size_t un = normalized_size(up, n); size_t vn = normalized_size(vp, n); if (un == 0) { for (size_t i = 0; i < vn; i++) gp[i] = vp[i]; return vn; } if (vn == 0) { for (size_t i = 0; i < un; i++) gp[i] = up[i]; return un; } // 奇数にする if ((up[0] & 1) == 0) { std::swap(up, vp); std::swap(un, vn); } // vp も奇数にする if ((vp[0] & 1) == 0) { // trailing zeros を除去 size_t tz = ctz_limb_array(vp, vn); if (tz > 0) { vn = rshift(vp, vp, vn, tz); } } if (n <= 1 || (un <= 1 && vn <= 1)) { uint64_t u0 = up[0]; uint64_t v0 = vp[0]; // 両方奇数にする if ((u0 & 1) == 0) { u0 >>= std::countr_zero(u0); } if ((v0 & 1) == 0) { v0 >>= std::countr_zero(v0); } gp[0] = gcd_11(u0, v0); return 1; } // n == 2 uint64_t u0 = up[0], u1 = (un > 1) ? up[1] : 0; uint64_t v0 = vp[0], v1 = (vn > 1) ? vp[1] : 0; // v が偶数なら奇数にシフト if (v0 == 0) { v0 = v1; v1 = 0; } if ((v0 & 1) == 0) { int cnt = std::countr_zero(v0); v0 = (v0 >> cnt) | (v1 << (64 - cnt)); v1 >>= cnt; } DoubleLimb g = gcd_22(u1, u0, v1, v0); gp[0] = g.d0; if (g.d1 > 0) { gp[1] = g.d1; return 2; } return 1; } // ============================================================================ // mod_34lsub1: N mod (2^48-1) を除算なしで計算 // GMP の mpn_mod_34lsub1 に相当。IsSquare の高速フィルタリングに使用。 // ============================================================================ // 64-bit limb 用: B1=16, B2=32, B3=48 // 2^48-1 の素因数: 3, 5, 7, 13, 17, 97, 241, 257, 673 // 返値は正確な余りではなく、N mod (2^48-1) と合同な値 inline uint64_t mod_34lsub1(const uint64_t* p, size_t n) { constexpr unsigned B1 = 16; constexpr unsigned B2 = 32; constexpr unsigned B3 = 48; constexpr uint64_t M1 = (1ULL << B1) - 1; // 0xFFFF constexpr uint64_t M2 = (1ULL << B2) - 1; // 0xFFFFFFFF constexpr uint64_t M3 = (1ULL << B3) - 1; // 0xFFFFFFFFFFFF uint64_t a0 = 0, a1 = 0, a2 = 0; uint64_t c0 = 0, c1 = 0, c2 = 0; size_t i = 0; // 3 limb ずつ処理 while (i + 3 <= n) { // a0 += p[i] with carry tracking uint64_t s0 = a0 + p[i]; c0 += (s0 < a0) ? 1 : 0; a0 = s0; uint64_t s1 = a1 + p[i + 1]; c1 += (s1 < a1) ? 1 : 0; a1 = s1; uint64_t s2 = a2 + p[i + 2]; c2 += (s2 < a2) ? 1 : 0; a2 = s2; i += 3; } // 端数処理 if (i < n) { uint64_t s0 = a0 + p[i]; c0 += (s0 < a0) ? 1 : 0; a0 = s0; i++; if (i < n) { uint64_t s1 = a1 + p[i]; c1 += (s1 < a1) ? 1 : 0; a1 = s1; } } // PARTS0(x) = (x & M3) + (x >> B3) // PARTS1(x) = ((x & M2) << B1) + (x >> B2) // PARTS2(x) = ((x & M1) << B2) + (x >> B1) auto PARTS0 = [&](uint64_t x) -> uint64_t { return (x & M3) + (x >> B3); }; auto PARTS1 = [&](uint64_t x) -> uint64_t { return ((x & M2) << B1) + (x >> B2); }; auto PARTS2 = [&](uint64_t x) -> uint64_t { return ((x & M1) << B2) + (x >> B1); }; return PARTS0(a0) + PARTS1(a1) + PARTS2(a2) + PARTS1(c0) + PARTS2(c1) + PARTS0(c2); } // ============================================================================ // 平方根 (Zimmermann 再帰的アルゴリズム, mpn レベル) // ============================================================================ // 逆数平方根テーブル (GMP sqrtrem.c より) // invsqrttab[i] ≈ 256/sqrt((i+128)/256) - 256 static constexpr unsigned char invsqrttab[384] = { 0xff,0xfd,0xfb,0xf9,0xf7,0xf5,0xf3,0xf2, 0xf0,0xee,0xec,0xea,0xe9,0xe7,0xe5,0xe4, 0xe2,0xe0,0xdf,0xdd,0xdb,0xda,0xd8,0xd7, 0xd5,0xd4,0xd2,0xd1,0xcf,0xce,0xcc,0xcb, 0xc9,0xc8,0xc6,0xc5,0xc4,0xc2,0xc1,0xc0, 0xbe,0xbd,0xbc,0xba,0xb9,0xb8,0xb7,0xb5, 0xb4,0xb3,0xb2,0xb0,0xaf,0xae,0xad,0xac, 0xaa,0xa9,0xa8,0xa7,0xa6,0xa5,0xa4,0xa3, 0xa2,0xa0,0x9f,0x9e,0x9d,0x9c,0x9b,0x9a, 0x99,0x98,0x97,0x96,0x95,0x94,0x93,0x92, 0x91,0x90,0x8f,0x8e,0x8d,0x8c,0x8c,0x8b, 0x8a,0x89,0x88,0x87,0x86,0x85,0x84,0x83, 0x83,0x82,0x81,0x80,0x7f,0x7e,0x7e,0x7d, 0x7c,0x7b,0x7a,0x79,0x79,0x78,0x77,0x76, 0x76,0x75,0x74,0x73,0x72,0x72,0x71,0x70, 0x6f,0x6f,0x6e,0x6d,0x6d,0x6c,0x6b,0x6a, 0x6a,0x69,0x68,0x68,0x67,0x66,0x66,0x65, 0x64,0x64,0x63,0x62,0x62,0x61,0x60,0x60, 0x5f,0x5e,0x5e,0x5d,0x5c,0x5c,0x5b,0x5a, 0x5a,0x59,0x59,0x58,0x57,0x57,0x56,0x56, 0x55,0x54,0x54,0x53,0x53,0x52,0x52,0x51, 0x50,0x50,0x4f,0x4f,0x4e,0x4e,0x4d,0x4d, 0x4c,0x4b,0x4b,0x4a,0x4a,0x49,0x49,0x48, 0x48,0x47,0x47,0x46,0x46,0x45,0x45,0x44, 0x44,0x43,0x43,0x42,0x42,0x41,0x41,0x40, 0x40,0x3f,0x3f,0x3e,0x3e,0x3d,0x3d,0x3c, 0x3c,0x3b,0x3b,0x3a,0x3a,0x39,0x39,0x39, 0x38,0x38,0x37,0x37,0x36,0x36,0x35,0x35, 0x35,0x34,0x34,0x33,0x33,0x32,0x32,0x32, 0x31,0x31,0x30,0x30,0x2f,0x2f,0x2f,0x2e, 0x2e,0x2d,0x2d,0x2d,0x2c,0x2c,0x2b,0x2b, 0x2b,0x2a,0x2a,0x29,0x29,0x29,0x28,0x28, 0x27,0x27,0x27,0x26,0x26,0x26,0x25,0x25, 0x24,0x24,0x24,0x23,0x23,0x23,0x22,0x22, 0x21,0x21,0x21,0x20,0x20,0x20,0x1f,0x1f, 0x1f,0x1e,0x1e,0x1e,0x1d,0x1d,0x1d,0x1c, 0x1c,0x1b,0x1b,0x1b,0x1a,0x1a,0x1a,0x19, 0x19,0x19,0x18,0x18,0x18,0x18,0x17,0x17, 0x17,0x16,0x16,0x16,0x15,0x15,0x15,0x14, 0x14,0x14,0x13,0x13,0x13,0x12,0x12,0x12, 0x12,0x11,0x11,0x11,0x10,0x10,0x10,0x0f, 0x0f,0x0f,0x0f,0x0e,0x0e,0x0e,0x0d,0x0d, 0x0d,0x0c,0x0c,0x0c,0x0c,0x0b,0x0b,0x0b, 0x0a,0x0a,0x0a,0x0a,0x09,0x09,0x09,0x09, 0x08,0x08,0x08,0x07,0x07,0x07,0x07,0x06, 0x06,0x06,0x06,0x05,0x05,0x05,0x04,0x04, 0x04,0x04,0x03,0x03,0x03,0x03,0x02,0x02, 0x02,0x02,0x01,0x01,0x01,0x01,0x00,0x00 }; // sqrtrem1: 1-limb 平方根 (除算不要, 乗算のみ) // 前提: a0 >= 2^62 (上位 2 ビットの少なくとも 1 つが 1) // 返値: floor(sqrt(a0)), *rp = a0 - result^2 inline uint64_t sqrtrem1(uint64_t* rp, uint64_t a0) { unsigned abits = static_cast(a0 >> 55); uint64_t x0 = 0x100ULL | invsqrttab[abits - 0x80]; // Newton 反復 1: 8-bit → ~16-bit (1/√a の近似) uint64_t a1 = a0 >> 31; int64_t t = static_cast( static_cast(0x2000000000000ULL) - 0x30000ULL - a1 * x0 * x0 ) >> 16; x0 = (x0 << 16) + (static_cast(x0 * static_cast(t)) >> 18); // Newton 反復 2: ~16-bit → ~32-bit (√a の近似) uint64_t t2 = x0 * (a0 >> 24); uint64_t t3 = t2 >> 25; t = static_cast((a0 << 14) - t3 * t3 - 0x10000000000ULL) >> 24; x0 = t2 + (static_cast(x0 * static_cast(t)) >> 15); x0 >>= 32; // 最終補正 (最大 1 回のインクリメント) uint64_t x2 = x0 * x0; if (x2 + 2 * x0 <= a0 - 1) { x2 += 2 * x0 + 1; x0++; } *rp = a0 - x2; return x0; } // sqrtrem2: 2-limb 平方根 (sqrtrem1 を拡張) // 前提: np[1] >= 2^62 // sp[0] = floor(sqrt(np[0..1])), rp[0] = 余り下位 // 返値: 余り上位キャリー (0 or 1) // rp は np と同じポインタでも可 (in-place) inline int sqrtrem2(uint64_t* sp, uint64_t* rp, const uint64_t* np) { constexpr unsigned Prec = 32; uint64_t np0 = np[0]; uint64_t sp0 = sqrtrem1(rp, np[1]); uint64_t rp0 = rp[0]; // 余りと np0 の上位部分を合成 → sp0 で除算 rp0 = (rp0 << (Prec - 1)) + (np0 >> (Prec + 1)); uint64_t q = rp0 / sp0; // q が 2^Prec に到達した場合の補正 q -= q >> Prec; uint64_t u = rp0 - q * sp0; sp0 = (sp0 << Prec) | q; int cc = static_cast(u >> (Prec - 1)); rp0 = ((u << (Prec + 1)) & UINT64_MAX) + (np0 & ((1ULL << (Prec + 1)) - 1)); // q^2 を減算 uint64_t q2 = q * q; cc -= (rp0 < q2) ? 1 : 0; rp0 -= q2; // 補正 (最大 1 回) if (cc < 0) { rp0 += sp0; cc += (rp0 < sp0) ? 1 : 0; --sp0; rp0 += sp0; cc += (rp0 < sp0) ? 1 : 0; } rp[0] = rp0; sp[0] = sp0; return cc; } // divmod_1: 多倍長 / 1-limb 除算 (preinv 最適化版) // q[0..an-1] = a[0..an-1] / d, 返値 = 余り // 正規化してから逆数乗算ベースで各 limb を処理 inline uint64_t divmod_1(uint64_t* q, const uint64_t* a, size_t an, uint64_t d) { if (an == 0) return 0; unsigned shift = std::countl_zero(d); uint64_t d_norm = d << shift; uint64_t dinv = invert_limb(d_norm); if (shift == 0) { // d は既に正規化済み (MSB set) uint64_t r = 0; for (size_t i = an; i-- > 0; ) { auto [qq, rr] = udiv_qrnnd_preinv(r, a[i], d_norm, dinv); q[i] = qq; r = rr; } return r; } else { // d を正規化 (左シフト): 被除数も同時にシフトして処理 // ループ内の条件分岐 (i > 0 ? ...) を除去するため、最終要素を分離 unsigned rshift = 64 - shift; uint64_t r = a[an - 1] >> rshift; for (size_t i = an; i-- > 1; ) { uint64_t n1 = (a[i] << shift) | (a[i - 1] >> rshift); auto [qq, rr] = udiv_qrnnd_preinv(r, n1, d_norm, dinv); q[i] = qq; r = rr; } { uint64_t n1 = a[0] << shift; auto [qq, rr] = udiv_qrnnd_preinv(r, n1, d_norm, dinv); q[0] = qq; r = rr; } return r >> shift; // 余りの逆正規化 } } // dc_sqrtrem に必要な scratch サイズ (再帰的) inline size_t dc_sqrtrem_scratch_size(size_t n) { if (n <= 1) return 0; size_t l = n / 2; size_t h = n - l; // q_buf(l+2) + op_scratch (divide or square) size_t div_sz = (h >= 2) ? divide_scratch_size(n, h) : (n + 4); size_t sqr_sz = square_scratch_size(l); size_t level_scratch = (l + 2) + std::max(div_sz, sqr_sz); size_t rec_scratch = dc_sqrtrem_scratch_size(h); return std::max(level_scratch, rec_scratch); } // dc_sqrtrem: Zimmermann 再帰的平方根 // 入力: np[0..2n-1] (正規化済み: np[2n-1] >= 2^62), in-place で書き換え // 出力: sp[0..n-1] = floor(sqrt(np)) の下位部分 // 返値: 余りのキャリー c (余りは np[0..n-1] に格納, 全余り = c*B^n + np[0..n-1]) // scratch: dc_sqrtrem_scratch_size(n) limbs inline int dc_sqrtrem(uint64_t* sp, uint64_t* np, size_t n, uint64_t* scratch) { // --- ベースケース: n = 1 (2-limb 入力) --- if (n == 1) { return sqrtrem2(sp, np, np); // in-place: rp = np } size_t l = n / 2; size_t h = n - l; // === Step 1: 上位 2h limbs の平方根を再帰的に計算 === int q; if (h == 1) { q = sqrtrem2(sp + l, np + 2 * l, np + 2 * l); } else { q = dc_sqrtrem(sp + l, np + 2 * l, h, scratch); } // sp[l..l+h-1] = sqrt, np[2l..2l+h-1] = 余り, q = 余りキャリー // === Step 2: 余りキャリーの調整 === if (q != 0) { // 余りから s' を引く (キャリーの吸収) sub(np + 2 * l, np + 2 * l, h, sp + l, h); } // === Step 3: np[l..l+n-1] を sp[l..l+h-1] で除算 === // 被除数: [a1, 調整済み余り] = np[l..l+n-1] (n limbs) // 除数: sp[l..l+h-1] (h limbs) // 商: q_buf[0..l] (最大 l+1 limbs) // 余り: np[l..l+h-1] に上書き uint64_t* q_buf = scratch; uint64_t* op_scratch = scratch + l + 2; if (h == 1) { // 1-limb 除数: divmod_1 を使用 uint64_t rem = divmod_1(q_buf, np + l, n, sp[l]); np[l] = rem; q_buf[n] = 0; } else { // h >= 2: divide を使用 // divide は a を const で受け取るので np+l を直接渡せる // 余りは np+l に戻す (divide は r に別バッファを要求) // → 一旦 op_scratch 領域に余りを受けてからコピーバック uint64_t* r_tmp = op_scratch; uint64_t* div_work = op_scratch + h; size_t qn = divide(q_buf, r_tmp, np + l, n, sp + l, h, div_work); // 余りを np[l..l+h-1] に書き戻し std::memcpy(np + l, r_tmp, h * sizeof(uint64_t)); // 商の上位を 0 埋め for (size_t i = qn; i <= l; i++) q_buf[i] = 0; } // === Step 4: 商の処理 === // q (キャリー) += 商の最上位 q += static_cast(q_buf[l]); // 商の偶奇を保存 int c_parity = static_cast(q_buf[0] & 1); // sp[0..l-1] = q_buf[0..l-1] >> 1 (商を 2 で割る) for (size_t i = 0; i < l - 1; i++) { sp[i] = (q_buf[i] >> 1) | (q_buf[i + 1] << 63); } sp[l - 1] = (q_buf[l - 1] >> 1) | (static_cast(q) << 63); q >>= 1; // === Step 5: 商が奇数だった場合の余り調整 === int c = 0; if (c_parity != 0) { c = static_cast(add(np + l, np + l, h, sp + l, h)); } // === Step 6: sp[0..l-1]^2 を計算して余りから引く === // np[n..n+2l-1] を自乗結果の一時バッファとして使用 uint64_t* sq_buf = np + n; std::memset(sq_buf, 0, 2 * l * sizeof(uint64_t)); square(sq_buf, sp, l, op_scratch); // np[0..2l-1] -= sq_buf[0..2l-1] uint64_t borrow = sub(np, np, 2 * l, sq_buf, 2 * l); int b = static_cast(q) + static_cast(borrow); if (l == h) { c -= b; } else { // l < h (h = l+1): np[2l] からボローを伝播 c -= static_cast(sub_1(np + 2 * l, h, static_cast(b))); } // === Step 7: 補正 (余りが負の場合) === if (c < 0) { // sp[l..l+h-1] にキャリーを加算 uint64_t q_carry = add_1(sp + l, h, static_cast(q)); // np += 2 * sp (addmul_1 で sp * 2 を加算) uint64_t am_carry = addmul_1(np, sp, n, 2) + 2 * q_carry; c += static_cast(am_carry); // np -= 1 c -= static_cast(sub_1(np, n, 1)); // sp -= 1 q = static_cast(q - sub_1(sp, n, 1)); } return c; } // sqrtrem に必要な scratch サイズ inline size_t sqrtrem_scratch_size(size_t an) { if (an <= 2) return 16; size_t sn = (an + 1) / 2; // root のサイズ // 作業バッファ: np_buf(2*sn+2) + dc_scratch + rem_compute(2*sn + mul_scratch) size_t dc_sz = dc_sqrtrem_scratch_size(sn); size_t sqr_sz = square_scratch_size(sn); return 2 * sn + 2 + dc_sz + 2 * sn + 2 + sqr_sz; } // floor(sqrt(a[0..an-1])), 余りを rp に格納 (rp=nullptr 可) // 返値: sqrt の normalized size // 前提: an >= 1, a[an-1] != 0 // scratch: sqrtrem_scratch_size(an) limbs inline size_t sqrtrem(uint64_t* sp, uint64_t* rp, const uint64_t* ap, size_t an, uint64_t* scratch) { // --- 1-limb ベースケース --- if (an == 1) { uint64_t val = ap[0]; // 正規化: 上位 2 ビットの少なくとも 1 つを 1 にする unsigned shift = std::countl_zero(val); shift &= ~1u; // 偶数に切り下げ uint64_t norm_val = val << shift; uint64_t rem; uint64_t s = sqrtrem1(&rem, norm_val); // 逆正規化 s >>= (shift / 2); sp[0] = s; if (rp) rp[0] = val - s * s; return (s > 0) ? 1 : 0; } // --- 2-limb ベースケース --- if (an == 2) { // 正規化: np[1] の上位 2 ビットの少なくとも 1 つを 1 にする unsigned shift = std::countl_zero(ap[1]); shift &= ~1u; uint64_t np_buf[2]; if (shift > 0) { np_buf[1] = (ap[1] << shift) | (ap[0] >> (64 - shift)); np_buf[0] = ap[0] << shift; } else { np_buf[0] = ap[0]; np_buf[1] = ap[1]; } sqrtrem2(sp, scratch, np_buf); // 逆正規化 sp[0] >>= (shift / 2); if (rp) { UInt128 sq_final = UInt128::multiply(sp[0], sp[0]); rp[0] = ap[0] - sq_final.low; uint64_t borrow = (ap[0] < sq_final.low) ? 1ULL : 0ULL; rp[1] = ap[1] - sq_final.high - borrow; } return (sp[0] > 0) ? 1 : 0; } // --- 一般ケース: an >= 3, Zimmermann 再帰 --- size_t sn = (an + 1) / 2; // root のサイズ (limbs) // scratch レイアウト: // np_buf[0..2*sn+1]: 入力コピー (パディング+正規化) // dc_scratch: dc_sqrtrem 用 // sq_buf[0..2*sn+1] + mul_work: 余り計算用 uint64_t* np_buf = scratch; uint64_t* dc_scratch = np_buf + 2 * sn + 2; uint64_t* sq_buf = dc_scratch + dc_sqrtrem_scratch_size(sn); uint64_t* mul_work = sq_buf + 2 * sn + 2; // 入力を np_buf にコピー (奇数 limb の場合はパディング) std::memset(np_buf, 0, (2 * sn + 2) * sizeof(uint64_t)); std::memcpy(np_buf, ap, an * sizeof(uint64_t)); // an が奇数なら np_buf[an] は 0 (パディング済み) // 正規化: np_buf[2*sn-1] の上位 2 ビットの少なくとも 1 つを 1 にする unsigned total_shift = 0; { uint64_t top = np_buf[2 * sn - 1]; if (top == 0) { // パディングで最上位が 0 の場合 (an が奇数) // np_buf[2*sn-2] が実質的な最上位 // 2*sn-1 番目を非ゼロにするために大きくシフトが必要 // ただし an が奇数 → 2*sn = an+1 → np_buf[an] = 0 は正常 // top = np_buf[2*sn-2] = ap[an-1] (which is != 0) // この場合は 64 ビット分のシフト + さらにbit単位のシフト unsigned clz2 = std::countl_zero(np_buf[2 * sn - 2]); total_shift = 64 + (clz2 & ~1u); } else { unsigned clz = std::countl_zero(top); total_shift = clz & ~1u; } } if (total_shift > 0) { // total_shift ビット左シフト size_t limb_shift = total_shift / 64; unsigned bit_shift = total_shift % 64; if (limb_shift > 0) { // limb 単位のシフト (上位方向に移動) for (size_t i = 2 * sn - 1; i >= limb_shift; i--) { np_buf[i] = np_buf[i - limb_shift]; } for (size_t i = 0; i < limb_shift; i++) { np_buf[i] = 0; } } if (bit_shift > 0) { lshift(np_buf, np_buf, 2 * sn, bit_shift); } } // dc_sqrtrem を呼び出し uint64_t* sp_work = sp; // 結果を直接 sp に書く std::memset(sp_work, 0, sn * sizeof(uint64_t)); dc_sqrtrem(sp_work, np_buf, sn, dc_scratch); // 逆正規化: root を total_shift/2 ビット右シフト unsigned root_shift = total_shift / 2; if (root_shift > 0) { rshift(sp_work, sp_work, sn, root_shift); } size_t sp_n = normalized_size(sp_work, sn); // 余り計算: dc_sqrtrem の余り (np_buf) を逆正規化して返す // ★ Codex 指摘 #1: sp² 再計算を排除 — dc_sqrtrem の余りを直接利用 if (rp) { // np_buf[0..2*sn-1] に正規化済み入力の余りが残っている // total_shift ビット右シフトで元の入力に対する余りを復元 if (total_shift > 0) { size_t limb_shift = total_shift / 64; unsigned bit_shift = total_shift % 64; if (bit_shift > 0) { rshift(np_buf, np_buf, 2 * sn, bit_shift); } if (limb_shift > 0) { for (size_t i = 0; i + limb_shift < 2 * sn; i++) { np_buf[i] = np_buf[i + limb_shift]; } for (size_t i = 2 * sn - limb_shift; i < 2 * sn; i++) { np_buf[i] = 0; } } } // rp にコピー (余りは sn limbs 以下) size_t rem_n = std::min((size_t)sn, an); std::memcpy(rp, np_buf, rem_n * sizeof(uint64_t)); if (rem_n < an) std::memset(rp + rem_n, 0, (an - rem_n) * sizeof(uint64_t)); } return sp_n; } // sqrtrem_check_exact: sqrt を計算し、余りがゼロかどうかを返す // M(n) の二乗計算を省略し、dc_sqrtrem の余りを直接検査 // 返値: {sqrt の normalized size, 余りがゼロなら true} // scratch: sqrtrem_scratch_size(an) limbs (二乗計算不要だが同じサイズで安全) inline std::pair sqrtrem_check_exact( uint64_t* sp, const uint64_t* ap, size_t an, uint64_t* scratch) { // --- 1-limb ベースケース --- if (an == 1) { uint64_t val = ap[0]; unsigned shift = std::countl_zero(val); shift &= ~1u; uint64_t norm_val = val << shift; uint64_t rem; uint64_t s = sqrtrem1(&rem, norm_val); s >>= (shift / 2); sp[0] = s; // 余りゼロ判定: s^2 == val return {(s > 0) ? 1u : 0u, s * s == val}; } // --- 2-limb ベースケース --- if (an == 2) { unsigned shift = std::countl_zero(ap[1]); shift &= ~1u; uint64_t np_buf[2]; if (shift > 0) { np_buf[1] = (ap[1] << shift) | (ap[0] >> (64 - shift)); np_buf[0] = ap[0] << shift; } else { np_buf[0] = ap[0]; np_buf[1] = ap[1]; } sqrtrem2(sp, scratch, np_buf); sp[0] >>= (shift / 2); UInt128 sq = UInt128::multiply(sp[0], sp[0]); bool exact = (sq.low == ap[0] && sq.high == ap[1]); return {(sp[0] > 0) ? 1u : 0u, exact}; } // --- 一般ケース: an >= 3 --- size_t sn = (an + 1) / 2; uint64_t* np_buf = scratch; uint64_t* dc_scratch = np_buf + 2 * sn + 2; std::memset(np_buf, 0, (2 * sn + 2) * sizeof(uint64_t)); std::memcpy(np_buf, ap, an * sizeof(uint64_t)); // 正規化 unsigned total_shift = 0; { uint64_t top = np_buf[2 * sn - 1]; if (top == 0) { unsigned clz2 = std::countl_zero(np_buf[2 * sn - 2]); total_shift = 64 + (clz2 & ~1u); } else { unsigned clz = std::countl_zero(top); total_shift = clz & ~1u; } } if (total_shift > 0) { size_t limb_shift = total_shift / 64; unsigned bit_shift = total_shift % 64; if (limb_shift > 0) { for (size_t i = 2 * sn - 1; i >= limb_shift; i--) np_buf[i] = np_buf[i - limb_shift]; for (size_t i = 0; i < limb_shift; i++) np_buf[i] = 0; } if (bit_shift > 0) { lshift(np_buf, np_buf, 2 * sn, bit_shift); } } uint64_t* sp_work = sp; std::memset(sp_work, 0, sn * sizeof(uint64_t)); int carry = dc_sqrtrem(sp_work, np_buf, sn, dc_scratch); // dc_sqrtrem の余り = carry * B^sn + np_buf[0..sn-1] // 正規化入力 a' = a << total_shift は完全平方数 ⟺ a が完全平方数 // (total_shift は偶数なので 2^{total_shift} は完全平方数) // よって: carry == 0 かつ np_buf[0..sn-1] が全ゼロ ⟺ 余りゼロ bool exact = (carry == 0); if (exact) { for (size_t i = 0; i < sn; i++) { if (np_buf[i] != 0) { exact = false; break; } } } // 逆正規化 unsigned root_shift = total_shift / 2; if (root_shift > 0) { rshift(sp_work, sp_work, sn, root_shift); } size_t sp_n = normalized_size(sp_work, sn); return {sp_n, exact}; } // ============================================================================ // Schönhage-Strassen FFT 乗算 // ============================================================================ // // NTT ベースの Schönhage-Strassen FFT。Z/(B^F+1)Z 上で畳み込みを行う。 // B = 2^64 (1 ワード)。 // // 入力を M=2^l 個の K ワードピースに分割し、 // ω = B^(2F/M) を M 次の単位根として NTT → 点ごと乗算 → 逆 NTT → キャリー伝搬。 // // 点ごと乗算のサイズは F ≈ 2N/M ≪ N なので、multiply() 経由で // 再帰的に FFT を呼んでも自然にサイズが減少し、再入ガード不要。 // // 参考文献: // - Schönhage, Strassen: "Schnelle Multiplikation großer Zahlen" (1971) // - Knuth TAOCP Vol.2 §4.3.3.C // ============================================================================ namespace fft_detail { // 2の補数否定: rp = -ap mod B^n = B^n - ap // 返値: 1 (ap != 0), 0 (ap == 0) inline uint64_t neg_n(uint64_t* rp, const uint64_t* ap, size_t n) { size_t i = 0; while (i < n && ap[i] == 0) { rp[i] = 0; i++; } if (i == n) return 0; // all zero rp[i] = (~ap[i]) + 1; // negate first non-zero limb for (++i; i < n; i++) { rp[i] = ~ap[i]; // complement remaining } return 1; } // ================================================================ // FFT パラメータ選択 // ================================================================ struct FftParams { size_t M; // ピース数 (= 2^l) size_t K; // 各ピースのワード数 size_t F; // 剰余環のワード数 (B^F+1) size_t l; // log2(M) }; // F×F 乗算の相対コスト推定。実際の乗算閾値を反映。 // basecase O(F²), Karatsuba O(F^1.585), TC3 O(F^1.465), TC4 O(F^1.404) inline size_t fft_pointwise_cost(size_t F) { if (F < KARATSUBA_THRESHOLD) return F * F; // basecase if (F < TOOMCOOK3_THRESHOLD) return F * F * 3 / 4; // Karatsuba ~25% 削減 if (F < TOOMCOOK4_THRESHOLD) return F * F * 9 / 16; // TC3 ~44% 削減 return F * F * 7 / 16; // TC4 ~56% 削減 } inline FftParams fft_choose_params(size_t an, size_t bn) { size_t N = an + bn; // 最適な l を探索: M=2^l 個のピース、各 K ワード、F ワードの剰余環 // ビットレベル twiddle: 2^(128F/M) が整数ビットシフトになるため // F % (M/128) == 0 が必要 (M ≤ 128 なら任意) // コスト = M * pointwise_multiply(F) + M * l * butterfly_cost(F) // l の初期推定: M ≈ √N size_t l0 = 1; while ((1ULL << (2 * l0)) < N) ++l0; // l0-2 ~ l0+4 の範囲で最良を選択 // (ビットレベル twiddle で F が小さくなると大きい l が最適になり得る) size_t best_cost = SIZE_MAX; FftParams best = {}; size_t lo = (l0 > 4) ? l0 - 2 : 3; for (size_t l = lo; l <= l0 + 4; ++l) { size_t M = 1ULL << l; size_t K = (N + M - 1) / M; if (K < 1) continue; size_t min_F = 2 * K + (l + 63) / 64 + 1; // F の量子化: M/128 の倍数 (M ≤ 128 なら制約なし) size_t quant = (M <= 128) ? 1 : (M >> 7); size_t F = ((min_F + quant - 1) / quant) * quant; // pointwise: M 回の F×F 乗算 // butterfly: M*l 回、各 O(F) ワード操作 (add_mod + sub_mod + bitshift_mod) // pw_cost >> 5: F が大きい場合 (Toom-4 区間) の乗算コストを正しく反映 size_t pw_cost = fft_pointwise_cost(F); size_t cost = M * (pw_cost >> 5) + M * l * (F >> 2); if (cost < best_cost) { best_cost = cost; best = {M, K, F, l}; } } return best; } // ================================================================ // mod (B^F + 1) 算術 // ================================================================ // 各要素は F+1 ワード。a[F] は 0 または 1。 // a[F] == 1 のとき a[0..F-1] は全て 0 (値 = B^F)。 // (a + b) mod (B^F+1) // r, a, b: F+1 ワード。r は a と同じでもよい。 inline void fft_add_mod(uint64_t* r, const uint64_t* a, const uint64_t* b, size_t F) { uint64_t cy = add(r, a, F, b, F); cy += a[F] + b[F]; r[F] = 0; // total = cy * B^F + r。B^F ≡ -1 mod (B^F+1) なので total ≡ r - cy。 // cy (0..3) を r から減算し、負なら B^F+1 を加算して正規化。 if (cy > 0) { uint64_t bw = sub_1(r, F, cy); if (bw) { // 負 → (B^F+1) を加算: r + 1 (sub_1 borrow 後の r = B^F + r_old - cy) r[F] = add_1(r, F, 1); } } } // (a - b) mod (B^F+1) // r, a, b: F+1 ワード。r は a と同じでもよい。 inline void fft_sub_mod(uint64_t* r, const uint64_t* a, const uint64_t* b, size_t F) { uint64_t borrow = sub(r, a, F, b, F); // hi = a[F] - b[F] - borrow (-2..1) int64_t hi = (int64_t)a[F] - (int64_t)b[F] - (int64_t)borrow; r[F] = 0; // total = hi * B^F + r。B^F ≡ -1 なので total ≡ r - hi mod (B^F+1)。 if (hi > 0) { // r から hi を減算 uint64_t bw = sub_1(r, F, (uint64_t)hi); if (bw) { r[F] = add_1(r, F, 1); } } else if (hi < 0) { // r に |hi| を加算 uint64_t cy = add_1(r, F, (uint64_t)(-hi)); if (cy) { // オーバーフロー → B^F + r' を削減: r' - 1, borrow なら B^F uint64_t bw = sub_1(r, F, 1); if (bw) { r[F] = add_1(r, F, 1); } } } } // -a mod (B^F+1) inline void fft_negate_mod(uint64_t* r, const uint64_t* a, size_t F) { if (a[F] != 0) { // a = B^F → -B^F mod (B^F+1) = 1 std::memset(r, 0, F * sizeof(uint64_t)); r[0] = 1; r[F] = 0; return; } // a[0..F-1] が全て 0 なら結果も 0 uint64_t nonzero = neg_n(r, a, F); if (nonzero) { // -a mod (B^F+1) = (B^F+1) - a = (B^F - a) + 1 = neg_n + 1 // carry=1 のとき a=1 → 結果は B^F (r[F]=1, r[0..F-1]=0) r[F] = add_1(r, F, 1); } else { r[F] = 0; } } // a * B^j mod (B^F+1), j はワード単位 (0 <= j < 2F) // r != a (非 in-place)。r, a: F+1 ワード。 inline void fft_shift_mod(uint64_t* r, const uint64_t* a, size_t F, size_t j) { j %= (2 * F); bool negate = (j >= F); if (negate) j -= F; if (j == 0) { if (negate) { fft_negate_mod(r, a, F); } else { std::memcpy(r, a, (F + 1) * sizeof(uint64_t)); } return; } // a[F] のハンドリング: a[F]*B^F ≡ -a[F] mod (B^F+1) // よって a[F]*B^(F+j) ≡ -a[F]*B^j // B^j * a[0..F-1] mod B^F+1: // 上位: a[0..F-j-1] → r[j..F-1] に配置 // 下位: a[F-j..F-1] → 否定して r[0..j-1] に配置 (B^F ≡ -1) // // r[j..F-1] = a[0..F-j-1] std::memcpy(r + j, a, (F - j) * sizeof(uint64_t)); // r[0..j-1] = -a[F-j..F-1] mod B^j uint64_t borrow_from_neg = neg_n(r, a + (F - j), j); r[F] = 0; // borrow_from_neg == 0 は a[F-j..F-1] が全て 0 だったことを意味 // borrow_from_neg == 1 は否定が行われた → r[j..F-1] から 1 を引く (borrow 伝搬) if (borrow_from_neg) { uint64_t bw = sub_1(r + j, F - j, 1); // bw が発生した場合 → r[j..F-1] は全て 0 だった → wrap around // B^F+1 を加算して補正: r[0..F-1] 全体に +1、carry → r[F] if (bw) { r[F] = add_1(r, F, 1); } } // a[F] の処理: a[F]*B^(F+j) ≡ -a[F]*B^j // a[F] == 1 のとき、a[0..F-1] は全て 0 なので上の処理で r は全て 0 // -B^j mod (B^F+1): r[j..F-1] から 1 を引く (B^j を引く) if (a[F] != 0) { uint64_t bw = sub_1(r + j, F - j, 1); if (bw) { r[F] += add_1(r, F, 1); } } // 正規化: r[F] >= 2 → r -= (B^F+1) while (r[F] >= 2) { sub_1(r, F, 1); r[F] -= 1; } // r[F] == 1 のとき r[0..F-1] は全て 0 であるべき (正規化確認) if (negate) { // in-place negate mod (B^F+1) if (r[F] != 0) { // r = B^F → -B^F mod (B^F+1) = 1 std::memset(r, 0, F * sizeof(uint64_t)); r[0] = 1; r[F] = 0; } else { uint64_t nonzero = neg_n(r, r, F); if (nonzero) { // (B^F+1) - r_old = neg_n(r_old) + 1 // carry=1 のとき r_old=1 → 結果は B^F r[F] = add_1(r, F, 1); } } } } // a * 2^total_bits mod (B^F+1), r != a (非 in-place) // ビットレベル twiddle: ワードシフト + 端数ビットシフトの 2 段階 // r, a: F+1 ワード inline void fft_bitshift_mod(uint64_t* r, const uint64_t* a, size_t F, size_t total_bits) { total_bits %= (128 * F); // 周期 = 2*64*F ビット size_t word_shift = total_bits / 64; size_t bit_shift = total_bits % 64; // Step 1: ワードシフト (既存関数) fft_shift_mod(r, a, F, word_shift); // Step 2: ビットシフト in-place if (bit_shift == 0) return; if (r[F] != 0) { // r = B^F → r * 2^s ≡ -(2^s) mod (B^F+1) // = (B^F + 1) - 2^s (s >= 1) // r[0] = 1 - 2^s (2の補数wrap), r[1..F-1] = UINT64_MAX, r[F] = 0 r[0] = (uint64_t)1 - ((uint64_t)1 << bit_shift); std::memset(r + 1, 0xFF, (F - 1) * sizeof(uint64_t)); r[F] = 0; return; } // r[F] == 0: lshift + overflow fold // overflow * B^F ≡ -overflow mod (B^F+1) → r -= overflow uint64_t overflow = lshift(r, r, F, (unsigned)bit_shift); if (overflow) { uint64_t bw = sub_1(r, F, overflow); if (bw) r[F] = add_1(r, F, 1); } } // a * b mod (B^F+1): フル F×F 乗算 → fold // r, a, b: F+1 ワード。scratch: 2*F + multiply_scratch_size(F,F) ワード inline void fft_mulpoint(uint64_t* r, const uint64_t* a, const uint64_t* b, size_t F, uint64_t* scratch) { // 特殊ケース: a[F]!=0 or b[F]!=0 if (a[F] != 0 && b[F] != 0) { // a = b = B^F → a*b = B^(2F) ≡ 1 mod (B^F+1) std::memset(r, 0, (F + 1) * sizeof(uint64_t)); r[0] = 1; return; } if (a[F] != 0) { // a = B^F → a*b = B^F * b ≡ -b mod (B^F+1) fft_negate_mod(r, b, F); return; } if (b[F] != 0) { fft_negate_mod(r, a, F); return; } // 一般: フル乗算 → fold uint64_t* prod = scratch; uint64_t* mul_scratch = scratch + 2 * F; multiply(prod, a, F, b, F, mul_scratch); // fold: r[0..F-1] = prod[0..F-1] - prod[F..2F-1] mod (B^F+1) // B^F ≡ -1 なので上位 F ワードを下位から減算 uint64_t borrow = sub(r, prod, F, prod + F, F); r[F] = 0; if (borrow) { // 負 → (B^F+1) を加算: sub の結果に +1 すると (B^F+1) 加算完了 // carry=1 のとき結果は B^F (r[0..F-1]=0, r[F]=1) r[F] = add_1(r, F, 1); } } // a² mod (B^F+1): フル F×F 自乗 → fold inline void fft_sqrpoint(uint64_t* r, const uint64_t* a, size_t F, uint64_t* scratch) { if (a[F] != 0) { // a = B^F → a² = B^(2F) ≡ 1 mod (B^F+1) std::memset(r, 0, (F + 1) * sizeof(uint64_t)); r[0] = 1; return; } uint64_t* prod = scratch; uint64_t* sqr_scratch = scratch + 2 * F; size_t sqr_scratch_sz = square_scratch_size(F); square(prod, a, F, sqr_scratch); uint64_t borrow = sub(r, prod, F, prod + F, F); r[F] = 0; if (borrow) { r[F] = add_1(r, F, 1); } } // M = 2^l で除算 mod (B^F+1) // B^F+1 は奇数 → 2^(-1) が存在。各半減: 奇数なら B^F+1 を加算して偶数に、1ビット右シフト。 inline void fft_div_by_power_of_2(uint64_t* a, size_t F, size_t l) { for (size_t i = 0; i < l; i++) { if (a[0] & 1) { // a += (B^F + 1) → a[0..F-1] += 1, a[F] += 1 uint64_t cy = add_1(a, F, 1); a[F] += 1 + cy; } // a[F] は最大 2 (∵ a[F]=0 のときのみ加算実行、cy ∈ {0,1})。 // 右シフトは F+1 ワード全体に適用されるため、a[F]=2 でも正しく処理される: // a[F]=2, a[0..F-1]=0 → 右シフト → a[F]=1, a[0..F-1]=0 = B^F // 以前の while ループ (sub_1 + a[F]-=1) は borrow 未処理で // 2*B^F を 2*B^F-1 (奇数) に壊していた。 #ifdef CALX_INT_HAS_ASM mpn_rshift_asm(a, a, F + 1, 1); #else for (size_t w = 0; w < F; w++) { a[w] = (a[w] >> 1) | (a[w + 1] << 63); } a[F] = a[F] >> 1; #endif } } // ================================================================ // FFT Forward / Inverse NTT // ================================================================ // Forward NTT (DIF — Decimation in Frequency) // data: M 個の (F+1) ワード要素、連続配置 // temp: F+1 ワードの一時バッファ inline void fft_forward(uint64_t* data, size_t M, size_t F, size_t l, uint64_t* temp) { size_t stride = F + 1; for (int s = (int)l - 1; s >= 0; --s) { size_t half = 1ULL << s; size_t block = half << 1; for (size_t k = 0; k < M; k += block) { for (size_t m = 0; m < half; ++m) { uint64_t* u = data + (k + m) * stride; uint64_t* v = data + (k + m + half) * stride; // twiddle: ω^(m * M/block) = 2^(m * 128F / block) (ビット単位) size_t tw_bits = m * (128 * F) / block; // DIF butterfly: temp = u - v; u = u + v; v = shift(temp, tw_bits) fft_sub_mod(temp, u, v, F); fft_add_mod(u, u, v, F); if (tw_bits != 0) fft_bitshift_mod(v, temp, F, tw_bits); else std::memcpy(v, temp, stride * sizeof(uint64_t)); } } } } // Inverse NTT (DIT — Decimation in Time) // data: M 個の (F+1) ワード要素 // temp: F+1 ワードの一時バッファ inline void fft_inverse(uint64_t* data, size_t M, size_t F, size_t l, uint64_t* temp) { size_t stride = F + 1; for (size_t s = 0; s < l; ++s) { size_t half = 1ULL << s; size_t block = half << 1; for (size_t k = 0; k < M; k += block) { for (size_t m = 0; m < half; ++m) { uint64_t* u = data + (k + m) * stride; uint64_t* v = data + (k + m + half) * stride; size_t tw_bits = m * (128 * F) / block; size_t inv_tw_bits = (tw_bits == 0) ? 0 : (128 * F - tw_bits); // DIT butterfly: temp = shift(v, inv_tw); v = u - temp; u = u + temp if (inv_tw_bits != 0) fft_bitshift_mod(temp, v, F, inv_tw_bits); else std::memcpy(temp, v, stride * sizeof(uint64_t)); fft_sub_mod(v, u, temp, F); fft_add_mod(u, u, temp, F); } } } // M = 2^l で除算 for (size_t i = 0; i < M; ++i) { fft_div_by_power_of_2(data + i * stride, F, l); } } // ================================================================ // 入力分割・結果組立 // ================================================================ // 入力 a[0..an-1] を M 個の K ワードピースに分割、各 F+1 ワードにゼロ埋め inline void fft_split(uint64_t* data, const uint64_t* a, size_t an, size_t M, size_t K, size_t F) { size_t stride = F + 1; for (size_t i = 0; i < M; ++i) { uint64_t* dst = data + i * stride; size_t src_off = i * K; size_t copy_len = 0; if (src_off < an) { copy_len = std::min(K, an - src_off); std::memcpy(dst, a + src_off, copy_len * sizeof(uint64_t)); } // 残りをゼロ埋め std::memset(dst + copy_len, 0, (stride - copy_len) * sizeof(uint64_t)); } } // 逆 NTT 後の係数から結果を組み立て // 各 c_i を rp[i*K..] に累積。c_i は最大 F+1 ワード。 // full_rn: 内部バッファサイズ (M*K + F 以上)。切り捨て防止。 // 呼び出し元が actual_rn ワードを rp からコピーする。 inline void fft_recompose(uint64_t* rp, size_t full_rn, const uint64_t* data, size_t M, size_t K, size_t F) { size_t stride = F + 1; std::memset(rp, 0, full_rn * sizeof(uint64_t)); for (size_t i = 0; i < M; ++i) { const uint64_t* coeff = data + i * stride; size_t base = i * K; if (base >= full_rn) break; // 係数の有効ワード数 size_t cn = F + 1; while (cn > 0 && coeff[cn - 1] == 0) --cn; if (cn == 0) continue; // rp[base..] に加算 (切り捨てなし: full_rn は十分大きい) size_t space = full_rn - base; size_t add_len = std::min(cn, space); add(rp + base, rp + base, space, coeff, add_len); } } // ================================================================ // mul_fft / sqr_fft // ================================================================ // FFT 乗算: rp[0..an+bn-1] = ap[0..an-1] * bp[0..bn-1] // thread_local バッファでヒープ確保を回避。F < FFT_THRESHOLD なので再帰的 FFT は発生しない。 // NTT 並列化閾値: 両オペランドがこれ以上のとき forward NTT を並列実行 static constexpr size_t FFT_PARALLEL_THRESHOLD = 8000; inline void mul_fft(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn) { if (an < bn) { std::swap(ap, bp); std::swap(an, bn); } auto p = fft_choose_params(an, bn); size_t stride = p.F + 1; size_t data_sz = p.M * stride; size_t mul_scratch_sz = 2 * p.F + multiply_scratch_size(p.F, p.F); bool parallel = (bn >= FFT_PARALLEL_THRESHOLD); // 並列時は temp を 2 つ確保 (各 forward NTT 用) size_t temp_count = parallel ? 2 : 1; size_t total = 2 * data_sz + mul_scratch_sz + stride * temp_count; thread_local std::vector work; if (work.size() < total) work.resize(total); uint64_t* data_a = work.data(); uint64_t* data_b = work.data() + data_sz; uint64_t* scratch = work.data() + 2 * data_sz; uint64_t* temp = scratch + mul_scratch_sz; uint64_t* temp2 = parallel ? temp + stride : nullptr; fft_split(data_a, ap, an, p.M, p.K, p.F); fft_split(data_b, bp, bn, p.M, p.K, p.F); if (parallel) { // 2 つの forward NTT を並列実行 (別 temp バッファ) auto future_b = calx::threadPool().submit([&]() { fft_forward(data_b, p.M, p.F, p.l, temp2); }); fft_forward(data_a, p.M, p.F, p.l, temp); future_b.get(); } else { fft_forward(data_a, p.M, p.F, p.l, temp); fft_forward(data_b, p.M, p.F, p.l, temp); } for (size_t i = 0; i < p.M; ++i) { fft_mulpoint(data_a + i * stride, data_a + i * stride, data_b + i * stride, p.F, scratch); } fft_inverse(data_a, p.M, p.F, p.l, temp); // data_b を recompose 用バッファとして再利用 (pointwise 後は不要) size_t rn = an + bn; size_t full_rn = p.M * p.K + p.F; uint64_t* recomp_buf = data_b; fft_recompose(recomp_buf, full_rn, data_a, p.M, p.K, p.F); std::memcpy(rp, recomp_buf, rn * sizeof(uint64_t)); } // FFT 自乗: rp[0..2n-1] = ap[0..n-1]² inline void sqr_fft(uint64_t* rp, const uint64_t* ap, size_t an) { auto p = fft_choose_params(an, an); size_t stride = p.F + 1; size_t data_sz = p.M * stride; size_t sqr_scratch_sz = 2 * p.F + square_scratch_size(p.F); size_t full_rn = p.M * p.K + p.F; size_t total = data_sz + sqr_scratch_sz + stride + full_rn; thread_local std::vector work; if (work.size() < total) work.resize(total); uint64_t* data_a = work.data(); uint64_t* scratch = work.data() + data_sz; uint64_t* temp = scratch + sqr_scratch_sz; uint64_t* recomp_buf = temp + stride; fft_split(data_a, ap, an, p.M, p.K, p.F); fft_forward(data_a, p.M, p.F, p.l, temp); for (size_t i = 0; i < p.M; ++i) { fft_sqrpoint(data_a + i * stride, data_a + i * stride, p.F, scratch); } fft_inverse(data_a, p.M, p.F, p.l, temp); size_t rn = 2 * an; fft_recompose(recomp_buf, full_rn, data_a, p.M, p.K, p.F); std::memcpy(rp, recomp_buf, rn * sizeof(uint64_t)); } } // namespace fft_detail // 公開インターフェース: fft_detail:: を呼ぶラッパー inline void mul_fft(uint64_t* rp, const uint64_t* ap, size_t an, const uint64_t* bp, size_t bn) { fft_detail::mul_fft(rp, ap, an, bp, bn); } inline void sqr_fft(uint64_t* rp, const uint64_t* ap, size_t an) { fft_detail::sqr_fft(rp, ap, an); } } // namespace mpn } // namespace calx