// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // IntModular.cpp // モジュラ演算の実装 #include "math/core/mp/Int/IntModular.hpp" #include "math/core/mp/Int/IntOps.hpp" #include "math/core/mp/Int/IntGCD.hpp" #include "math/core/mp/Int/IntSpecialStates.hpp" #include "math/core/mp/Int/MpnOps.hpp" #include #include #if defined(_MSC_VER) && defined(_M_X64) #include #include #endif // ASM Montgomery 関数のプロトタイプ (グローバルスコープで宣言) #ifdef CALX_INT_HAS_ASM extern "C" void mpn_mont_mul_mulx(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, size_t n, const uint64_t* mp, uint64_t m_inv, uint64_t* scratch); extern "C" void mpn_mont_redc_mulx(uint64_t* rp, uint64_t* tp, size_t tn, const uint64_t* mp, size_t n, uint64_t m_inv); extern "C" void mpn_mont_mul_8(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, const uint64_t* mp, uint64_t m_inv); extern "C" void mpn_mont_sqr_8(uint64_t* rp, const uint64_t* ap, const uint64_t* mp, uint64_t m_inv); extern "C" void mpn_mont_redc_16(uint64_t* rp, uint64_t* tp, const uint64_t* mp, uint64_t m_inv); extern "C" void mpn_mont_redc_32(uint64_t* rp, uint64_t* tp, const uint64_t* mp, uint64_t m_inv); extern "C" void mpn_mont_sqr_16(uint64_t* rp, const uint64_t* ap, const uint64_t* mp, uint64_t m_inv); extern "C" void mpn_mont_mul_4(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, const uint64_t* mp, uint64_t m_inv); extern "C" void mpn_mont_sqr_4(uint64_t* rp, const uint64_t* ap, const uint64_t* mp, uint64_t m_inv); extern "C" void mpn_mont_mul_2(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, const uint64_t* mp, uint64_t m_inv); extern "C" void mpn_mont_sqr_2(uint64_t* rp, const uint64_t* ap, const uint64_t* mp, uint64_t m_inv); #endif namespace calx { // =========================================================================== // Montgomery 乗算 (内部実装) // =========================================================================== namespace { // -m^{-1} mod 2^64 を計算 (m は奇数) // Newton 反復: x ← x·(2 - m·x) mod 2^64 inline uint64_t mont_neg_inv(uint64_t m0) { uint64_t x = 1; for (int i = 0; i < 6; ++i) // 6回で 64-bit 精度に収束 x *= 2 - m0 * x; // mod 2^64 は自動 (uint64_t) return static_cast(0) - x; // -x mod 2^64 } // Montgomery リダクション: T * R^{-1} mod m // T: 2n ワード, m: n ワード, m_inv = -m^{-1} mod 2^64 // 結果: r (n ワード), return: 0 or 1 (r >= m) inline void mont_redc(uint64_t* r, uint64_t* t, size_t tn, // T (破壊される, 2n+1 ワード必要) const uint64_t* m, size_t n, uint64_t m_inv) { // CIOS (Coarsely Integrated Operand Scanning) 方式 for (size_t i = 0; i < n; ++i) { uint64_t q = t[i] * m_inv; // q = T[i] · (-m^{-1}) mod 2^64 uint64_t carry = mpn::addmul_1(t + i, m, n, q); // T += q · m · B^i // carry を上位に伝播 for (size_t j = i + n; carry && j < tn; ++j) { uint64_t sum = t[j] + carry; carry = (sum < t[j]) ? 1 : 0; t[j] = sum; } } // 結果は T[n..2n-1] std::memcpy(r, t + n, n * sizeof(uint64_t)); // r >= m なら r -= m if (mpn::cmp(r, n, m, n) >= 0) { mpn::sub(r, r, n, m, n); } } // 2-word 逆数の計算: mip * m ≡ -1 mod B^2 // mip[0] = m_inv (既存), mip[1] を計算 // GMP sec_powm.c と同等のアルゴリズム inline uint64_t mont_neg_inv2(uint64_t m0, uint64_t m1, uint64_t m_inv) { // m_inv * m[0] ≡ -1 mod B なので lo(m_inv * m[0]) = 0xFFFF...FFFF // h = hi(m_inv * m[0]) uint64_t h; #if defined(_MSC_VER) && defined(_M_X64) _umul128(m_inv, m0, &h); #else unsigned __int128 p = static_cast(m_inv) * m0; h = static_cast(p >> 64); #endif // t = h + m_inv * m[1] (word 1 of m_inv * m, mod B) uint64_t t = h + m_inv * m1; // mip[1] = (t + 1) * m_inv (mod B) return (t + 1) * m_inv; } // Montgomery リダクション k=2: GMP mpn_redc_2 アルゴリズム // 2 ワードずつ REDC 反復を処理し、ループ回数を半減 // tp[0..2n] は破壊される (2n+1 ワード必要) // n は偶数でなければならない (奇数の場合は呼び出し元で先頭 1 反復を処理) inline void mont_redc_2(uint64_t* rp, uint64_t* tp, const uint64_t* mp, size_t n, uint64_t mip0, uint64_t mip1) { // GMP redc_2 アルゴリズム: // for each pair (i, i+1): // q = umul2low(mip, tp[i..i+1]) — 2-word quotient // save tp[n], addmul_2(tp, mp, n, q), bookkeeping // final: add_n(rp, tp+n, tp, n) uint64_t* up = tp; // n が奇数の場合: 先頭 1 反復を通常処理 if (n & 1) { uint64_t q = up[0] * mip0; uint64_t cy = mpn::addmul_1(up, mp, n, q); up[n] += cy; up++; } for (size_t j = (n & 1) ? (n - 1) : n; j >= 2; j -= 2) { // 2-word quotient: (q1, q0) = umul2low(mip, up[0..1]) uint64_t q0, q1; q0 = up[0] * mip0; // q1 = hi(mip0 * up[0]) + mip0 * up[1] + mip1 * up[0] uint64_t h; #if defined(_MSC_VER) && defined(_M_X64) _umul128(mip0, up[0], &h); #else unsigned __int128 p = static_cast(mip0) * up[0]; h = static_cast(p >> 64); #endif q1 = h + mip0 * up[1] + mip1 * up[0]; // addmul_2: up[0..n] += mp[0..n-1] * (q0 + q1*B) uint64_t upn = up[n]; // save (will be overwritten) up[n] = mpn::addmul_1(up, mp, n, q0); // carry → up[n] uint64_t cy = mpn::addmul_1(up + 1, mp, n, q1); // carry returned // GMP bookkeeping: up[1] = cy; // addmul_2 return value up[0] = up[n]; // accumulated carry → position 0 up[n] = upn; // restore original up[n] up += 2; } // 上位半分 + 下位半分 (キャリー情報) を加算 uint64_t cy = mpn::add(rp, up, n, up - n, n); // cy > 0 or rp >= mp なら rp -= mp if (cy != 0 || mpn::cmp(rp, n, mp, n) >= 0) { mpn::sub(rp, rp, n, mp, n); } } // redc_n 用閾値: これ以上なら redc_n (O(M(n))) を使う // mul_low が Karatsuba 化されたため、閾値を下げられる static constexpr size_t REDC_N_THRESHOLD = 48; // Hensel 逆元の計算: result * m ≡ -1 (mod B^n) // Newton 反復: x ← x * (2 + m * x) mod B^{2k} // (m_inv = -m^{-1} mod B は事前計算済み) // scratch: 3*n ワード必要 inline void hensel_inverse(uint64_t* result, const uint64_t* m, size_t n, uint64_t m_inv, uint64_t* scratch) { // m_inv は -m[0]^{-1} mod B なので、result[0] = m_inv // これは m_inv * m[0] ≡ -1 (mod B) を満たす std::memset(result, 0, n * sizeof(uint64_t)); result[0] = m_inv; // scratch layout: t[0..2k-1], p[0..2k-1], x_new[0..k-1] を共有 // 各段で必要なのは最大 3*n ワード size_t cur = 1; // 現在の有効桁数 while (cur < n) { size_t next = std::min(2 * cur, n); // t = m[0..next-1] * result[0..cur-1] の下位 next ワード // = m * x mod B^next uint64_t* t = scratch; // next ワード使用 uint64_t* p = scratch + next; // next ワード使用 // basecase 乗算で下位 next ワードを計算 (上位は不要) // t[0..next-1] = m[0..next-1] * result[0..cur-1] の下位 next ワード std::memset(t, 0, next * sizeof(uint64_t)); for (size_t i = 0; i < cur; ++i) { if (result[i] == 0) continue; size_t jmax = std::min(next - i, next); uint64_t carry = mpn::addmul_1(t + i, m, jmax, result[i]); // carry は next ワード目以降に伝播するが、mod B^next なので無視 (void)carry; } // t = m * x mod B^next // 2 + t の下位 next ワードを計算: t[0] は -1 mod B (m_inv*m[0] = -1 mod B) // なので 2 + t: t[0] += 2 → t[0] = 1 (mod B) (キャリーなし) // 実際は t = m*x ≡ -1 (mod B^cur) なので全桁が 0xFFF...FFF // 2 + t mod B^next = 1 (下位 cur 桁) + (2+t)[cur..next-1]*B^cur // Newton: x_new = x * (2 + m*x) mod B^next // ここで 2 + t を計算 t[0] += 2; // 2 + m*x (mod B^next); 下位 cur 桁は 0...01 になる // t[0..cur-1] = 1, 0, 0, ..., 0 (キャリー伝播済み) // ただし実際には t[0] が FFFFFFFFFFFFFFFFh + 2 = 1 (carry=1) // t[1] が FFFFFFFFFFFFFFFFh + 1 = 0 (carry=1) ... t[cur-1]+1=0, carry=1 // t[cur] += 1 (carry from below) // これは面倒なので、別のアプローチ: t = 2 + m*x mod B^next を直接計算 // 下位 cur ワードは必ず [1, 0, 0, ..., 0] になる (Newton の性質) // → x_new = x * t mod B^next の上位 (next-cur) ワードのみが新しい // → result[cur..next-1] = (x * t[cur..next-1]) の下位 (next-cur) ワード // + x * [1,0,...,0] の上位分 (= result[0..cur-1] そのまま) // 簡潔に: result[0..next-1] = result[0..cur-1] * t[0..next-1] mod B^next // p = result[0..cur-1] * t[0..next-1] の下位 next ワード std::memset(p, 0, next * sizeof(uint64_t)); for (size_t i = 0; i < cur; ++i) { if (result[i] == 0) continue; size_t jmax = std::min(next - i, next); mpn::addmul_1(p + i, t, jmax, result[i]); } // result[0..next-1] = p[0..next-1] std::memcpy(result, p, next * sizeof(uint64_t)); cur = next; } } // mul_low: (a * b) mod B^n — 下位 n ワードのみ計算 // Karatsuba 版: a1*b1 を省略し、a0*b1 + a1*b0 は mul_low(h) で再帰 // basecase 閾値以下は addmul_1 ループ static constexpr size_t MUL_LOW_THRESHOLD = 24; // Karatsuba 切り替え閾値 static void mul_low_n(uint64_t* rp, const uint64_t* ap, const uint64_t* bp, size_t n, uint64_t* scratch) { if (n < MUL_LOW_THRESHOLD) { // basecase: addmul_1 ループ (上位を捨てる) std::memset(rp, 0, n * sizeof(uint64_t)); for (size_t i = 0; i < n; ++i) { if (ap[i] == 0) continue; size_t jmax = n - i; mpn::addmul_1(rp + i, bp, jmax, ap[i]); } return; } // Karatsuba mul_low (交差項を 1 回の乗算に統合) size_t h = n / 2; size_t l = n - h; // l >= h // a = a1·B^h + a0, b = b1·B^h + b0 // (a*b) mod B^n = a0*b0 + (a0*b1 + a1*b0)·B^h mod B^n // // Karatsuba トリック: // a0*b1 + a1*b0 = (a0+a1)*(b0+b1) - a0*b0 - a1*b1 // a1*b1 は B^{2h} 以上にしか寄与しないので mod B^n で不要 // → cross = (a0+a1)*(b0+b1) - a0*b0 の下位 l ワードのみ必要 // → 2 回の mul_low → 1 回のフル乗算に削減 // scratch layout: s[h+1] + t[h+1] + p[2(h+1)] + mul_scratch uint64_t* s = scratch; // h+1 ワード uint64_t* t = scratch + h + 1; // h+1 ワード uint64_t* p = scratch + 2 * (h + 1); // 2*(h+1) ワード uint64_t* rec_scratch = scratch + 4 * (h + 1); // 残り // Step 1: rp[0..2h-1] = a0 * b0 (フル乗算) size_t mul_s = mpn::multiply_scratch_size(h, h); if (mul_s > 0) mpn::multiply(rp, ap, h, bp, h, rec_scratch); else mpn::multiply(rp, ap, h, bp, h, nullptr); // Step 2: s = a0 + a1, t = b0 + b1 (各 h+1 ワード) s[h] = mpn::add(s, ap, h, ap + h, h); t[h] = mpn::add(t, bp, h, bp + h, h); size_t sn = h + (s[h] ? 1 : 0); size_t tn = h + (t[h] ? 1 : 0); // Step 3: p = s * t (フル乗算) size_t mul_st = mpn::multiply_scratch_size(sn, tn); if (sn >= tn) { if (mul_st > 0) mpn::multiply(p, s, sn, t, tn, rec_scratch); else mpn::multiply(p, s, sn, t, tn, nullptr); } else { if (mul_st > 0) mpn::multiply(p, t, tn, s, sn, rec_scratch); else mpn::multiply(p, t, tn, s, sn, nullptr); } size_t pn = mpn::normalized_size(p, sn + tn); // Step 4: cross = p - a0*b0 の下位 l ワード // p -= rp[0..2h-1] (a0*b0) size_t sub_n = std::min(pn, 2 * h); mpn::sub(p, p, pn, rp, sub_n); // p の下位 l ワードを rp[h..n-1] に加算 size_t add_n = std::min(pn, l); mpn::add(rp + h, rp + h, l, p, add_n); } // redc_n: O(M(n)) Montgomery リダクション // GMP mpn_redc_n と同等のアルゴリズム: // Q = T[0..n-1] * m_inv_n[0..n-1] mod B^n (下位 n ワードのみ) // P = Q * m (フル乗算) // result = T[n..2n-1] + P[n..2n-1] // 条件付き減算 // // scratch: multiply_scratch_size(n,n) ワード (multiply 用) // tp は破壊される // mul_low に必要な scratch サイズ (再帰考慮) static size_t mul_low_scratch_size(size_t n) { if (n < MUL_LOW_THRESHOLD) return 0; size_t h = n / 2; size_t l = n - h; // t[l] + recursive (multiply scratch or mul_low scratch) size_t mul_s = mpn::multiply_scratch_size(h, h); size_t rec = mul_low_scratch_size(l); return l + std::max(mul_s, rec); } inline void mont_redc_n(uint64_t* rp, uint64_t* tp, const uint64_t* mp, size_t n, const uint64_t* m_inv_n, uint64_t* scratch) { // Step 1: Q = T_low * m_inv_n mod B^n (mul_low で下位 n ワードのみ) uint64_t* Q = static_cast(_alloca(n * sizeof(uint64_t))); // scratch を mul_low の作業領域として使用 mul_low_n(Q, tp, m_inv_n, n, scratch); // Step 2: P = Q * m (フル乗算, 2n ワード) uint64_t* P = static_cast(_alloca(2 * n * sizeof(uint64_t))); size_t ms = mpn::multiply_scratch_size(n, n); if (ms > 0) { mpn::multiply(P, Q, n, mp, n, scratch); } else { mpn::multiply(P, Q, n, mp, n, nullptr); } // Step 3: rp = tp[n..2n-1] + P[n..2n-1] uint64_t cy = mpn::add(rp, tp + n, n, P + n, n); // Step 4: 条件付き減算 if (cy != 0 || mpn::cmp(rp, n, mp, n) >= 0) { mpn::sub(rp, rp, n, mp, n); } } // FIOS (Finely Integrated Operand Scanning) Montgomery 乗算 // 乗算とリダクションを 1 パスで融合。メモリアクセスを半減。 // n≥16 の大サイズで CIOS より高速になる可能性がある。 // scratch 不要 (内部バッファ n+2 ワードのみ)。 #if defined(_MSC_VER) && defined(_M_X64) inline void mont_mul_fios(uint64_t* r, const uint64_t* a, const uint64_t* b, const uint64_t* m, size_t n, uint64_t m_inv) { // t[0..n+1]: シフトアキュムレータ uint64_t* t = static_cast(_alloca((n + 2) * sizeof(uint64_t))); std::memset(t, 0, (n + 2) * sizeof(uint64_t)); for (size_t i = 0; i < n; ++i) { uint64_t bi = b[i]; uint64_t hi_a, lo_a, hi_m, lo_m; // === Word 0: S = t[0] + a[0]*bi, q = S*m_inv === lo_a = _umul128(a[0], bi, &hi_a); uint64_t S = t[0] + lo_a; uint64_t Ca = hi_a + (S < t[0] ? 1ULL : 0); unsigned Ca_hi = 0; // Ca の 65 bit 目 (0 or 1) uint64_t q = S * m_inv; lo_m = _umul128(m[0], q, &hi_m); uint64_t tmp = S + lo_m; // ≡ 0 mod 2^64 uint64_t Cm = hi_m + (tmp < S ? 1ULL : 0); unsigned Cm_hi = 0; // === Words 1..n-1: 融合 addmul === for (size_t j = 1; j < n; ++j) { // Phase A: S = t[j] + a[j]*bi + Ca (+ Ca_hi<<64) lo_a = _umul128(a[j], bi, &hi_a); uint64_t s1 = t[j] + lo_a; unsigned c1 = (s1 < t[j]) ? 1u : 0u; S = s1 + Ca; c1 += (S < s1) ? 1u : 0u; // new Ca = hi_a + c1 + Ca_hi (65-bit) uint64_t new_Ca = hi_a + c1; unsigned new_Ca_hi = (new_Ca < c1) ? 1u : 0u; new_Ca += Ca_hi; new_Ca_hi += (new_Ca < Ca_hi) ? 1u : 0u; // Phase M: t[j-1] = S + m[j]*q + Cm (+ Cm_hi<<64) lo_m = _umul128(m[j], q, &hi_m); uint64_t s2 = S + lo_m; unsigned c2 = (s2 < S) ? 1u : 0u; uint64_t res = s2 + Cm; c2 += (res < s2) ? 1u : 0u; uint64_t new_Cm = hi_m + c2; unsigned new_Cm_hi = (new_Cm < c2) ? 1u : 0u; new_Cm += Cm_hi; new_Cm_hi += (new_Cm < Cm_hi) ? 1u : 0u; t[j - 1] = res; Ca = new_Ca; Ca_hi = new_Ca_hi; Cm = new_Cm; Cm_hi = new_Cm_hi; } // === Final word: t[n-1] = t[n] + Ca + Cm + overflow === uint64_t f = t[n] + Ca; unsigned fc = (f < t[n]) ? 1u : 0u; f += Cm; fc += (f < Cm) ? 1u : 0u; t[n - 1] = f; t[n] = static_cast(fc + Ca_hi + Cm_hi) + t[n + 1]; t[n + 1] = 0; } // 条件付き減算 if (t[n] != 0 || mpn::cmp(t, n, m, n) >= 0) { mpn::sub(r, t, n, m, n); } else { std::memcpy(r, t, n * sizeof(uint64_t)); } } #endif // Montgomery 乗算: (a * b * R^{-1}) mod m // a, b: n ワード (Montgomery 形式), m: n ワード // 結果: r (n ワード) inline void mont_mul(uint64_t* r, const uint64_t* a, const uint64_t* b, const uint64_t* m, size_t n, uint64_t m_inv, uint64_t mip1, uint64_t* scratch, // 2n+1 ワード (積バッファ) uint64_t* mul_scratch = nullptr, // multiply_scratch_size(n,n) ワード const uint64_t* m_inv_n = nullptr) { // redc_n 用 n-word 逆元 (nullable) #ifdef CALX_INT_HAS_ASM if (mpn::detail::has_bmi2_adx()) { if (n == 2) { mpn_mont_mul_2(r, a, b, m, m_inv); return; } if (n == 4) { mpn_mont_mul_4(r, a, b, m, m_inv); return; } if (n == 8) { mpn_mont_mul_8(r, a, b, m, m_inv); return; } if (n == 16) { // SOS 分離方式: basecase 乗算 + 特化 REDC // n=16 < KARATSUBA_THRESHOLD(24) なので basecase, 追加 scratch 不要 mpn::mul_basecase(scratch, a, n, b, n); scratch[2 * n] = 0; mpn_mont_redc_16(r, scratch, m, m_inv); return; } if (n == 32) { // SOS 分離方式: basecase 乗算 + 特化 REDC // basecase は Karatsuba より MULX 数は多いが追加 scratch 不要 // REDC のアンロール最適化 (ADDMUL_BLOCK_4) が CIOS より高速 mpn::mul_basecase(scratch, a, n, b, n); scratch[2 * n] = 0; mpn_mont_redc_32(r, scratch, m, m_inv); return; } // 汎用パス: SOS 分離方式 + REDC // n >= 64 で Karatsuba のアルゴリズム優位が発揮される if (mul_scratch && n >= 64) { mpn::multiply(scratch, a, n, b, n, mul_scratch); } else { mpn::mul_basecase(scratch, a, n, b, n); } scratch[2 * n] = 0; // redc_n: O(M(n)) — n >= REDC_N_THRESHOLD かつ逆元が事前計算済みの場合 if (n >= REDC_N_THRESHOLD && m_inv_n && mul_scratch) { mont_redc_n(r, scratch, m, n, m_inv_n, mul_scratch); return; } mont_redc_2(r, scratch, m, n, m_inv, mip1); return; } #endif #if defined(_MSC_VER) && defined(_M_X64) // C++ フォールバック: FIOS (n≥16) / CIOS (n<16) if (n >= 16) { mont_mul_fios(r, a, b, m, n, m_inv); return; } #endif // C++ フォールバック: CIOS (Coarsely Integrated Operand Scanning) std::memset(scratch, 0, (2 * n + 1) * sizeof(uint64_t)); for (size_t i = 0; i < n; ++i) { uint64_t carry = mpn::addmul_1(scratch + i, a, n, b[i]); scratch[i + n] += carry; if (scratch[i + n] < carry) { for (size_t j = i + n + 1; j <= 2 * n; ++j) { ++scratch[j]; if (scratch[j] != 0) break; } } uint64_t q = scratch[i] * m_inv; carry = mpn::addmul_1(scratch + i, m, n, q); scratch[i + n] += carry; if (scratch[i + n] < carry) { for (size_t j = i + n + 1; j <= 2 * n; ++j) { ++scratch[j]; if (scratch[j] != 0) break; } } } if (scratch[2 * n] != 0 || mpn::cmp(scratch + n, n, m, n) >= 0) { mpn::sub(r, scratch + n, n, m, n); } else { std::memcpy(r, scratch + n, n * sizeof(uint64_t)); } } // Montgomery 二乗: (a^2 * R^{-1}) mod m // mpn::square (対称性: 1.5n² MULX) + mont_redc で高速化 // // 検証済み: n=8 CIOS アンロール版 / CIOS mont_mul(a,a) を試したが、 // Zen 3 では MULX/ADCX/ADOX のポート競合 (全て port 0/3) が律速であり // ループ除去・方式変更の効果なし。Intel Skylake+ では ADCX/ADOX が // 別ポートに分散するため、異なる結果が予想される。 inline void mont_sqr(uint64_t* r, const uint64_t* a, const uint64_t* m, size_t n, uint64_t m_inv, uint64_t mip1, uint64_t* scratch, // 2n+1 ワード uint64_t* sq_scratch, // square_scratch_size(n) ワード const uint64_t* m_inv_n = nullptr, // redc_n 用 n-word 逆元 (nullable) uint64_t* mul_scratch = nullptr) { // redc_n 用 multiply scratch #ifdef CALX_INT_HAS_ASM if (mpn::detail::has_bmi2_adx()) { if (n == 2) { mpn_mont_sqr_2(r, a, m, m_inv); return; } if (n == 4) { mpn_mont_sqr_4(r, a, m, m_inv); return; } if (n == 8) { mpn_mont_sqr_8(r, a, m, m_inv); return; } if (n == 16) { mpn_mont_sqr_16(r, a, m, m_inv); return; } mpn::square(scratch, a, n, sq_scratch); scratch[2 * n] = 0; // redc_n: O(M(n)) — n >= REDC_N_THRESHOLD かつ逆元が事前計算済みの場合 if (n >= REDC_N_THRESHOLD && m_inv_n && mul_scratch) { mont_redc_n(r, scratch, m, n, m_inv_n, mul_scratch); return; } // 汎用パス: redc_2 (k=2 REDC) でループ回数半減 mont_redc_2(r, scratch, m, n, m_inv, mip1); return; } #endif #if defined(_MSC_VER) && defined(_M_X64) if (n >= 16) { mont_mul_fios(r, a, a, m, n, m_inv); return; } #endif mpn::square(scratch, a, n, sq_scratch); scratch[2 * n] = 0; mont_redc(r, scratch, 2 * n + 1, m, n, m_inv); } // Montgomery 形式への変換: aR mod m // R = 2^(64*n), aR mod m を計算 // 方法: a を n ワード左シフトして m で割った余り inline void mont_encode(uint64_t* r, const uint64_t* a, size_t an, const uint64_t* m, size_t n) { // aR = a << (64*n) を m で除算 // 2n ワードの被除数を構築して除算 Int shifted; // a を Int に変換 std::vector awords(a, a + an); while (!awords.empty() && awords.back() == 0) awords.pop_back(); if (awords.empty()) { std::memset(r, 0, n * sizeof(uint64_t)); return; } Int aInt = Int::fromRawWords(awords, 1); // aR = a << (64*n) IntOps::leftShift(aInt, static_cast(64 * n)); // aR mod m Int mInt = Int::fromRawWords(std::vector(m, m + n), 1); Int result = aInt % mInt; // 結果を n ワードにパディング std::memset(r, 0, n * sizeof(uint64_t)); size_t rn = result.size(); if (rn > 0) { std::memcpy(r, result.words().data(), std::min(rn, n) * sizeof(uint64_t)); } } // =========================================================================== // AVX2 radix-2^29 Montgomery (SOS: 乗算 + REDC 分離) // =========================================================================== // VPMULUDQ (4×32→64 bit 並列乗算) をキャリーフリー蓄積で使用。 // 29-bit リムの積は 58 bit → 64-bit アキュムレータに ~64 回蓄積可能。 // 512-bit (n29=18) ではキャリー伝搬不要、4096-bit (n29=142) でも // ~30 反復ごとのキャリー伝搬のみ。 #if defined(_MSC_VER) && defined(_M_X64) namespace { namespace r29 { constexpr int BITS = 29; constexpr uint64_t MASK = (1ULL << BITS) - 1; // 安全に蓄積できる回数 (64-bit overflow 前): // 各蓄積は最大 (2^29-1)^2 ≈ 2^58, floor(2^64 / 2^58) = 64 // CIOS は 1 反復で 2 回 addmul → 安全反復数 = 30 constexpr size_t SAFE_ITERS = 30; inline bool has_avx2() { static const bool result = []() { int info[4]; __cpuidex(info, 7, 0); return (info[1] >> 5) & 1; // EBX bit 5 = AVX2 }(); return result; } inline bool has_avx512_ifma() { static const bool result = []() { int info[4]; __cpuidex(info, 7, 0); bool avx512f = (info[1] >> 16) & 1; // EBX bit 16 bool ifma = (info[1] >> 21) & 1; // EBX bit 21 return avx512f && ifma; }(); return result; } // 64-bit リム → radix-2^29 リム変換 // dst は (n64 * 64 + 28) / 29 + 1 要素以上確保されていること inline size_t to_r29(uint64_t* dst, const uint64_t* src, size_t n64) { size_t total_bits = n64 * 64; size_t n29 = (total_bits + BITS - 1) / BITS; const uint8_t* bytes = reinterpret_cast(src); size_t total_bytes = n64 * 8; for (size_t i = 0; i < n29; ++i) { size_t bit_start = i * BITS; size_t byte_start = bit_start / 8; size_t bit_off = bit_start % 8; // 5 バイト読み (最大 40 bit カバー) uint64_t val = 0; for (size_t b = 0; b < 5 && byte_start + b < total_bytes; ++b) val |= static_cast(bytes[byte_start + b]) << (b * 8); dst[i] = (val >> bit_off) & MASK; } return n29; } // radix-2^29 → 64-bit リム変換 inline void from_r29(uint64_t* dst, const uint64_t* src, size_t n29, size_t n64) { std::memset(dst, 0, n64 * 8); uint8_t* bytes = reinterpret_cast(dst); size_t total_bytes = n64 * 8; for (size_t i = 0; i < n29; ++i) { size_t bit_start = i * BITS; size_t byte_start = bit_start / 8; size_t bit_off = bit_start % 8; uint64_t val = (src[i] & MASK) << bit_off; for (size_t b = 0; b < 5 && byte_start + b < total_bytes; ++b) bytes[byte_start + b] |= static_cast(val >> (b * 8)); } } // キャリー伝搬 (radix-2^29) inline void carry_prop(uint64_t* t, size_t n) { uint64_t carry = 0; for (size_t i = 0; i < n; ++i) { t[i] += carry; carry = t[i] >> BITS; t[i] &= MASK; } } // -m29[0]^{-1} mod 2^29 inline uint64_t neg_inv_r29(uint64_t m0) { uint64_t x = 1; for (int i = 0; i < 5; ++i) x = (x * (2 - ((m0 * x) & MASK))) & MASK; return (MASK + 1 - x) & MASK; } // AVX2 vectorized addmul_1: t[0..n-1] += a[0..n-1] * b // 事前条件: a[i], b は 29-bit 以下, t[i] は 64-bit に収まること inline void addmul_1(uint64_t* t, const uint64_t* a, size_t n, uint64_t b) { __m256i vb = _mm256_set1_epi64x(static_cast(b)); size_t i = 0; for (; i + 4 <= n; i += 4) { __m256i va = _mm256_loadu_si256(reinterpret_cast(a + i)); __m256i vt = _mm256_loadu_si256(reinterpret_cast(t + i)); vt = _mm256_add_epi64(vt, _mm256_mul_epu32(va, vb)); _mm256_storeu_si256(reinterpret_cast<__m256i*>(t + i), vt); } for (; i < n; ++i) t[i] += static_cast(a[i]) * static_cast(b); } // r29 domain REDC: T[0..2n29+1] → r29_out[0..n29-1] // T は破壊される inline void redc_r29(uint64_t* r29_out, uint64_t* T, const uint64_t* m29, size_t n29, uint64_t m_inv29) { for (size_t i = 0; i < n29; ++i) { uint64_t q = (T[i] * m_inv29) & MASK; addmul_1(T + i, m29, n29, q); uint64_t carry = T[i] >> BITS; T[i] = 0; T[i + 1] += carry; if (i > 0 && (i % SAFE_ITERS == 0)) carry_prop(T + i + 1, n29 - 1); } carry_prop(T + n29, n29 + 1); // 条件付き減算 int cmp = 0; for (int j = static_cast(n29) - 1; j >= 0; --j) { if (T[n29 + j] > m29[j]) { cmp = 1; break; } if (T[n29 + j] < m29[j]) { cmp = -1; break; } } if (T[2 * n29] != 0 || cmp >= 0) { uint64_t borrow = 0; for (size_t j = 0; j < n29; ++j) { int64_t diff = static_cast(T[n29 + j]) - static_cast(m29[j]) - static_cast(borrow); borrow = (diff < 0) ? 1 : 0; T[n29 + j] = static_cast(diff) & MASK; } } std::memcpy(r29_out, T + n29, n29 * sizeof(uint64_t)); } // r29 domain Montgomery 乗算: r29_out = r29_a * r29_b * R_r29^{-1} mod m // 全入出力 radix-2^29 (R_r29 = 2^(29*n29)) // T: 2*n29+2 ワード以上の作業バッファ void mont_mul_r29(uint64_t* r29_out, const uint64_t* r29_a, const uint64_t* r29_b, const uint64_t* m29, size_t n29, uint64_t m_inv29, uint64_t* T) { size_t T_len = 2 * n29 + 2; std::memset(T, 0, T_len * sizeof(uint64_t)); for (size_t i = 0; i < n29; ++i) addmul_1(T + i, r29_a, n29, r29_b[i]); carry_prop(T, 2 * n29 + 1); redc_r29(r29_out, T, m29, n29, m_inv29); } // r29 domain Montgomery 自乗: r29_out = r29_a² * R_r29^{-1} mod m // 対称性利用: off-diagonal を 1 回だけ計算して 2 倍 // T: 2*n29+2 ワード以上の作業バッファ void mont_sqr_r29(uint64_t* r29_out, const uint64_t* r29_a, const uint64_t* m29, size_t n29, uint64_t m_inv29, uint64_t* T) { size_t T_len = 2 * n29 + 2; std::memset(T, 0, T_len * sizeof(uint64_t)); for (size_t i = 0; i < n29; ++i) addmul_1(T + 2 * i + 1, r29_a + i + 1, n29 - i - 1, r29_a[i]); for (size_t i = 0; i < 2 * n29; ++i) T[i] <<= 1; for (size_t i = 0; i < n29; ++i) T[2 * i] += static_cast(r29_a[i]) * static_cast(r29_a[i]); carry_prop(T, 2 * n29 + 1); redc_r29(r29_out, T, m29, n29, m_inv29); } // Montgomery encode (r29 domain): a * R_r29 mod m → r29 form // R_r29 = 2^(29 * n29), a は正規化済み (0 <= a < m) void mont_encode_r29(uint64_t* r29_out, size_t n29, const uint64_t* a, size_t an, const uint64_t* m64, size_t n64) { std::vector awords(a, a + an); while (!awords.empty() && awords.back() == 0) awords.pop_back(); if (awords.empty()) { std::memset(r29_out, 0, n29 * sizeof(uint64_t)); return; } Int aInt = Int::fromRawWords(awords, 1); IntOps::leftShift(aInt, static_cast(29 * n29)); Int mInt = Int::fromRawWords(std::vector(m64, m64 + n64), 1); Int result = aInt % mInt; // 64-bit → r29 変換 std::vector tmp64(n64, 0); size_t rn = result.size(); if (rn > 0) std::memcpy(tmp64.data(), result.words().data(), std::min(rn, n64) * sizeof(uint64_t)); to_r29(r29_out, tmp64.data(), n64); } // Final REDC: r29 Montgomery form → 64-bit plain value // work: 2*n29+2 + n29 ワード (T + r29_tmp) void mont_redc_final_r29(uint64_t* r64_out, size_t n64, const uint64_t* r29_in, const uint64_t* m29, size_t n29, uint64_t m_inv29, uint64_t* work) { uint64_t* T = work; uint64_t* r29_tmp = work + 2 * n29 + 2; std::memcpy(T, r29_in, n29 * sizeof(uint64_t)); std::memset(T + n29, 0, (n29 + 2) * sizeof(uint64_t)); redc_r29(r29_tmp, T, m29, n29, m_inv29); from_r29(r64_out, r29_tmp, n29, n64); } } // namespace r29 } // anonymous namespace #endif // _MSC_VER && _M_X64 // ウィンドウ幅の選択 (指数のビット長に応じて動的) // GMP mpn_sec_powm と同等の閾値。事前計算テーブルのコスト (2^(w-1) 回の乗算) // と指数ループの乗算削減 (expBits/w 回) のバランスで最適幅を決定。 inline int choose_window_width(size_t expBits) { if (expBits <= 24) return 1; if (expBits <= 64) return 3; if (expBits <= 256) return 4; if (expBits <= 1024) return 5; if (expBits <= 4096) return 6; return 7; } // Montgomery 冪剰余: base^exp mod m (m は奇数) // Left-to-right sliding window + Montgomery 乗算 // 専用 squaring (mpn::square 対称性利用) を組み合わせ // // Sliding window の利点 (vs 固定 k-ary): // - 奇数べきのみ事前計算 (テーブル半分) // - 0 ビット列はスキップ (自乗のみ) // - 窓を奇数値に限定し、乗算回数を ~10% 削減 // // AVX2 radix-2^29 パス: // - R_r29 = 2^(29*n29) で統一 (R_64 との不一致を回避) // - encode/precompute/loop/final REDC 全て r29 domain で実行 // - VPMULUDQ (4×32→64 並列乗算) によるスループット向上 // n <= 8 専用高速パス: ヒープ割り当てゼロ、ASM カーネル直接呼び出し // スタック上の固定サイズバッファで全処理を完結 #ifdef CALX_INT_HAS_ASM template // N = ワード数 (2, 4, 8) static Int mont_power_mod_fixed(const Int& base, const Int& exp, const Int& m) { static_assert(N == 2 || N == 4 || N == 8); auto _prof_lap = [](const char*) {}; const uint64_t* mdata = m.data(); uint64_t m_inv = mont_neg_inv(mdata[0]); size_t expBits = exp.bitLength(); int w = choose_window_width(expBits); int oddTableSize = 1 << (w - 1); // 全バッファをスタック上に配置 (ヒープ割り当てゼロ) alignas(64) uint64_t R2[N] = {}; alignas(64) uint64_t baseR[N] = {}; alignas(64) uint64_t oneR[N] = {}; alignas(64) uint64_t result[N], temp_buf[N]; alignas(64) uint64_t pad[N] = {}; // g_buf: 最大 oddTableSize (w=7 → 64) エントリ。 // w は expBits に依存するが N<=8 なら w<=5 (oddTableSize<=16) alignas(64) uint64_t g_buf[16 * N] = {}; // 最大 16 エントリ alignas(64) uint64_t base2R[N]; // Montgomery 乗算/二乗の直接呼び出し (ディスパッチ不要) auto do_mul = [&](uint64_t* dst, const uint64_t* a, const uint64_t* b) { if constexpr (N == 2) mpn_mont_mul_2(dst, a, b, mdata, m_inv); else if constexpr (N == 4) mpn_mont_mul_4(dst, a, b, mdata, m_inv); else mpn_mont_mul_8(dst, a, b, mdata, m_inv); }; auto do_sqr = [&](uint64_t* dst, const uint64_t* src) { if constexpr (N == 2) mpn_mont_sqr_2(dst, src, mdata, m_inv); else if constexpr (N == 4) mpn_mont_sqr_4(dst, src, mdata, m_inv); else mpn_mont_sqr_8(dst, src, mdata, m_inv); }; // R² mod m: 繰り返し二倍法 (ヒープ割り当て・alloca 不要) // d = 1 → d = 2*d mod m を 2*64*N 回繰り返す → d = 2^(128*N) mod m = R² mod m // N=4 なら 512 回のシフト+条件減算。各回は N ワード操作のみ。 _prof_lap("setup"); R2[0] = 1; for (size_t iter = 0; iter < 2 * 64 * N; ++iter) { // d <<= 1 uint64_t carry = 0; for (size_t j = 0; j < N; ++j) { uint64_t old = R2[j]; R2[j] = (old << 1) | carry; carry = old >> 63; } // d >= m なら d -= m if (carry || mpn::cmp(R2, N, mdata, N) >= 0) { mpn::sub(R2, R2, N, mdata, N); } } _prof_lap("R2_compute"); // base を Montgomery 形式に変換 { Int bmod = IntModular::mod(base, m); _prof_lap("base_mod"); size_t bn = bmod.size(); std::memset(pad, 0, N * sizeof(uint64_t)); if (bn > 0) std::memcpy(pad, bmod.data(), std::min(bn, N) * sizeof(uint64_t)); do_mul(baseR, pad, R2); } // 1R { std::memset(pad, 0, N * sizeof(uint64_t)); pad[0] = 1; do_mul(oneR, pad, R2); } _prof_lap("encode"); // ウィンドウテーブル uint64_t* g[16]; for (int ii = 0; ii < oddTableSize && ii < 16; ++ii) g[ii] = g_buf + static_cast(ii) * N; std::memcpy(g[0], baseR, N * sizeof(uint64_t)); if (oddTableSize > 1) { do_sqr(base2R, baseR); for (int ii = 1; ii < oddTableSize; ++ii) do_mul(g[ii], g[ii - 1], base2R); } _prof_lap("precomp"); // Sliding window exponentiation uint64_t* rp = result; uint64_t* tp = temp_buf; std::memcpy(rp, oneR, N * sizeof(uint64_t)); int i = static_cast(expBits) - 1; while (i >= 0) { if (!exp.getBit(i)) { do_sqr(tp, rp); std::swap(rp, tp); --i; } else { int j = (i - w + 1 > 0) ? (i - w + 1) : 0; while (!exp.getBit(j)) ++j; int wval = 0; for (int k = i; k >= j; --k) wval = (wval << 1) | (exp.getBit(k) ? 1 : 0); for (int s = 0; s < i - j + 1; ++s) { do_sqr(tp, rp); std::swap(rp, tp); } do_mul(tp, rp, g[(wval - 1) / 2]); std::swap(rp, tp); i = j - 1; } } _prof_lap("loop"); // Montgomery → 通常形式 (REDC with input = result * 1) { std::memset(pad, 0, N * sizeof(uint64_t)); pad[0] = 1; do_mul(tp, rp, pad); } size_t rn = N; while (rn > 0 && tp[rn - 1] == 0) --rn; if (rn == 0) return Int::Zero(); std::vector rwords(tp, tp + rn); _prof_lap("final_redc"); return Int::fromRawWords(rwords, 1); } #endif // CALX_INT_HAS_ASM Int mont_power_mod(const Int& base, const Int& exp, const Int& m) { size_t n = m.size(); // m のワード数 const uint64_t* mdata = m.data(); // n <= 8 専用高速パス (ヒープ割り当てゼロ、ASM カーネル直接呼び出し) #ifdef CALX_INT_HAS_ASM if (mpn::detail::has_bmi2_adx()) { if (n == 2) return mont_power_mod_fixed<2>(base, exp, m); if (n == 4) return mont_power_mod_fixed<4>(base, exp, m); if (n == 8) return mont_power_mod_fixed<8>(base, exp, m); } #endif uint64_t m_inv = mont_neg_inv(mdata[0]); uint64_t mip1 = (n >= 2) ? mont_neg_inv2(mdata[0], mdata[1], m_inv) : 0; size_t expBits = exp.bitLength(); int w = choose_window_width(expBits); int oddTableSize = 1 << (w - 1); // 奇数べきのみ: 2^(w-1) エントリ // --- AVX2 radix-2^29 domain 判定 --- #if defined(_MSC_VER) && defined(_M_X64) bool use_avx2 = false; // schoolbook r29 は scalar MULX+Karatsuba より遅い (1024-bit: 2倍, 2048-bit: 1.9倍) size_t n29 = 0; uint64_t m_inv29 = 0; std::vector m29_buf; // work: 2*n29+2 (T) + n29 (r29_tmp for final REDC) = 3*n29+2 std::vector avx2_work; if (use_avx2) { n29 = (n * 64 + r29::BITS - 1) / r29::BITS; m29_buf.resize(n29 + 1, 0); r29::to_r29(m29_buf.data(), mdata, n); m_inv29 = r29::neg_inv_r29(m29_buf[0]); avx2_work.resize(3 * n29 + 4, 0); } #else constexpr bool use_avx2 = false; constexpr size_t n29 = 0; #endif // リム数 (AVX2: r29 リム, scalar: 64-bit リム) size_t limb_n = use_avx2 ? n29 : n; // ======================================================================== // 単一バッファ割り当て: 全 scratch/作業領域を 1 回のヒープ割り当てに統合 // 従来は 15-18 回の std::vector 割り当てが発生していたのを 1 回に削減 // ======================================================================== size_t sq_scratch_sz = 0, mul_scratch_sz = 0, extra_sz = 0; if (!use_avx2) { sq_scratch_sz = mpn::square_scratch_size(n); mul_scratch_sz = mpn::multiply_scratch_size(n, n); extra_sz = std::max(sq_scratch_sz, mul_scratch_sz); } // R² 計算用 scratch サイズ (n >= 2 のとき) size_t r2_scratch_sz = 0; if (!use_avx2 && n >= 2) { size_t ds1 = mpn::divide_scratch_size(n + 1, n); size_t ds2 = mpn::divide_scratch_size(2 * n, n); size_t msz = mpn::multiply_scratch_size(n, n); r2_scratch_sz = std::max({ds1, ds2, msz}); } // redc_n 用: n >= REDC_N_THRESHOLD のとき Hensel 逆元を事前計算 bool use_redc_n = (!use_avx2 && n >= REDC_N_THRESHOLD); size_t inv_n_sz = use_redc_n ? n : 0; // m_inv_n バッファ size_t inv_scratch_sz = use_redc_n ? 3 * n : 0; // hensel_inverse scratch // バッファレイアウト (scalar パス): // [0..2n] scratch (mont_mul/sqr の product + REDC) // [2n+1..2n+extra] sq_scratch / mul_scratch // [+n] R2 // [+n] baseR // [+n] oneR // [+oddTableSize*n] g_buf (ウィンドウテーブル) // [+n] base2R // [+n] result // [+n] temp // [+n] base_padded (encode 用) // [+inv_n_sz] m_inv_n (redc_n 用 Hensel 逆元) // [+n+1+2+n+2n+r2_scr] R² 計算用一時領域 // [+inv_scratch_sz] hensel_inverse scratch (一時使用、R²計算後に不要) size_t mont_area = 2 * n + 1 + extra_sz; size_t buf_total = mont_area + n // R2 + n // baseR + n // oneR + static_cast(oddTableSize) * n // g_buf + n // base2R + n // result + n // temp + n // base_padded / one_padded + inv_n_sz // m_inv_n + (n >= 2 ? (n + 1) + 2 + n + 2 * n + r2_scratch_sz : 0) + inv_scratch_sz; // hensel scratch (一時) std::vector arena(buf_total, 0); uint64_t* ap = arena.data(); // 各領域へのポインタ割り当て uint64_t* scratch_ptr = ap; ap += mont_area; uint64_t* sq_scratch_p = scratch_ptr + 2 * n + 1; uint64_t* mul_scratch_p = scratch_ptr + 2 * n + 1; uint64_t* R2_p = ap; ap += n; uint64_t* baseR_p = ap; ap += n; uint64_t* oneR_p = ap; ap += n; uint64_t* g_buf_p = ap; ap += static_cast(oddTableSize) * n; uint64_t* base2R_p = ap; ap += n; uint64_t* result_p = ap; ap += n; uint64_t* temp_p = ap; ap += n; uint64_t* pad_p = ap; ap += n; // redc_n 用 Hensel 逆元 uint64_t* m_inv_n_p = use_redc_n ? ap : nullptr; if (use_redc_n) ap += inv_n_sz; // R² 計算用一時領域 (n >= 2) uint64_t* r2_R_buf = ap; // n+1 words uint64_t* r2_q = r2_R_buf + n + 1; // 2 words uint64_t* r2_Rmod = r2_q + 2; // n words uint64_t* r2_full = r2_Rmod + n; // 2n words uint64_t* r2_work = r2_full + 2 * n; // r2_scratch_sz words // R² mod m を事前計算 if (!use_avx2) { if (n == 1) { auto [q_hi, Rmod1] = UInt128::divmod_fast(1ULL, 0ULL, mdata[0]); uint64_t hi, lo; #if defined(_MSC_VER) lo = _umul128(Rmod1, Rmod1, &hi); #else __uint128_t prod = static_cast<__uint128_t>(Rmod1) * Rmod1; hi = static_cast(prod >> 64); lo = static_cast(prod); #endif auto [q2, R2mod1] = UInt128::divmod_fast(hi, lo, mdata[0]); R2_p[0] = R2mod1; } else { // R = 2^(64n) r2_R_buf[n] = 1; // 他は arena のゼロ初期化済み // R mod m mpn::divide(r2_q, r2_Rmod, r2_R_buf, n + 1, mdata, n, r2_work); // (R mod m)² mod m size_t rmod_n = n; while (rmod_n > 0 && r2_Rmod[rmod_n - 1] == 0) --rmod_n; if (rmod_n == 0) rmod_n = 1; mpn::multiply(r2_full, r2_Rmod, rmod_n, r2_Rmod, rmod_n, r2_work); size_t r2fn = 2 * rmod_n; while (r2fn > 0 && r2_full[r2fn - 1] == 0) --r2fn; if (r2fn == 0) r2fn = 1; if (r2fn <= n) { std::memcpy(R2_p, r2_full, r2fn * sizeof(uint64_t)); } else { mpn::divide(r2_q, R2_p, r2_full, r2fn, mdata, n, r2_work); } } } // redc_n 用 Hensel 逆元の事前計算 if (use_redc_n) { // arena 末尾の inv_scratch 領域を一時バッファとして使用 uint64_t* inv_scratch = arena.data() + buf_total - inv_scratch_sz; hensel_inverse(m_inv_n_p, mdata, n, m_inv, inv_scratch); } // base を Montgomery 形式に変換 { Int bmod = IntModular::mod(base, m); size_t bn = bmod.size(); const uint64_t* bdata = bmod.data(); #if defined(_MSC_VER) && defined(_M_X64) if (use_avx2) { r29::mont_encode_r29(baseR_p, n29, bdata, bn, mdata, n); } else #endif { std::memset(pad_p, 0, n * sizeof(uint64_t)); if (bn > 0) std::memcpy(pad_p, bdata, std::min(bn, n) * sizeof(uint64_t)); mont_mul(baseR_p, pad_p, R2_p, mdata, n, m_inv, mip1, scratch_ptr, mul_scratch_p); } } // 1R = Mont(1, R²) = R mod m { #if defined(_MSC_VER) && defined(_M_X64) if (use_avx2) { uint64_t one = 1; r29::mont_encode_r29(oneR_p, n29, &one, 1, mdata, n); } else #endif { std::memset(pad_p, 0, n * sizeof(uint64_t)); pad_p[0] = 1; mont_mul(oneR_p, pad_p, R2_p, mdata, n, m_inv, mip1, scratch_ptr, mul_scratch_p); } } // === Sliding window: 奇数べきのみ事前計算 === uint64_t** g = static_cast(alloca(oddTableSize * sizeof(uint64_t*))); for (int ii = 0; ii < oddTableSize; ++ii) g[ii] = g_buf_p + static_cast(ii) * limb_n; // ラムダ: mont_sqr / mont_mul のディスパッチ auto do_sqr = [&](uint64_t* dst, const uint64_t* src) { #if defined(_MSC_VER) && defined(_M_X64) if (use_avx2) { r29::mont_sqr_r29(dst, src, m29_buf.data(), n29, m_inv29, avx2_work.data()); return; } #endif mont_sqr(dst, src, mdata, n, m_inv, mip1, scratch_ptr, sq_scratch_p, m_inv_n_p, mul_scratch_p); }; auto do_mul = [&](uint64_t* dst, const uint64_t* a, const uint64_t* b) { #if defined(_MSC_VER) && defined(_M_X64) if (use_avx2) { r29::mont_mul_r29(dst, a, b, m29_buf.data(), n29, m_inv29, avx2_work.data()); return; } #endif mont_mul(dst, a, b, mdata, n, m_inv, mip1, scratch_ptr, mul_scratch_p, m_inv_n_p); }; // g[0] = baseR std::memcpy(g[0], baseR_p, limb_n * sizeof(uint64_t)); if (oddTableSize > 1) { do_sqr(base2R_p, baseR_p); for (int ii = 1; ii < oddTableSize; ++ii) do_mul(g[ii], g[ii - 1], base2R_p); } // === Left-to-right sliding window exponentiation === std::memcpy(result_p, oneR_p, limb_n * sizeof(uint64_t)); int i = static_cast(expBits) - 1; while (i >= 0) { if (!exp.getBit(i)) { do_sqr(temp_p, result_p); std::swap(result_p, temp_p); --i; } else { int j = (i - w + 1 > 0) ? (i - w + 1) : 0; while (!exp.getBit(j)) ++j; int wval = 0; for (int k = i; k >= j; --k) wval = (wval << 1) | (exp.getBit(k) ? 1 : 0); int sqr_count = i - j + 1; for (int s = 0; s < sqr_count; ++s) { do_sqr(temp_p, result_p); std::swap(result_p, temp_p); } do_mul(temp_p, result_p, g[(wval - 1) / 2]); std::swap(result_p, temp_p); i = j - 1; } } // === Montgomery 形式から通常形式に戻す === #if defined(_MSC_VER) && defined(_M_X64) if (use_avx2) { std::vector r64(n, 0); r29::mont_redc_final_r29(r64.data(), n, result_p, m29_buf.data(), n29, m_inv29, avx2_work.data()); while (n > 0 && r64[n - 1] == 0) --n; if (n == 0) return Int::Zero(); return Int::fromRawWords(r64, 1); } #endif { std::memset(scratch_ptr, 0, (2 * n + 1) * sizeof(uint64_t)); std::memcpy(scratch_ptr, result_p, n * sizeof(uint64_t)); mont_redc_2(temp_p, scratch_ptr, mdata, n, m_inv, mip1); size_t rn = n; while (rn > 0 && temp_p[rn - 1] == 0) --rn; if (rn == 0) return Int::Zero(); std::vector rwords(temp_p, temp_p + rn); return Int::fromRawWords(rwords, 1); } } // 定数時間冪剰余: base^exp mod m (GMP mpz_powm_sec 相当) // - 固定ウィンドウ幅 (スライディングウィンドウではなく) // - 全テーブルエントリを毎回読み、条件付き選択 (cmov 相当) // - ビット 0/1 に関わらず sqr+mul を実行し、条件付きコピーで結果を選択 Int mont_power_mod_sec(const Int& base, const Int& exp, const Int& m) { size_t n = m.size(); const uint64_t* mdata = m.data(); uint64_t m_inv = mont_neg_inv(mdata[0]); uint64_t mip1 = (n >= 2) ? mont_neg_inv2(mdata[0], mdata[1], m_inv) : 0; size_t expBits = exp.bitLength(); int w = choose_window_width(expBits); int tableSize = 1 << w; // 全エントリ (0..2^w-1) // scratch バッファ size_t sq_scratch_sz = mpn::square_scratch_size(n); size_t mul_scratch_sz = mpn::multiply_scratch_size(n, n); size_t extra_sz = std::max(sq_scratch_sz, mul_scratch_sz); std::vector scratch(2 * n + 1 + extra_sz, 0); uint64_t* sq_scratch = scratch.data() + 2 * n + 1; uint64_t* mul_scratch_ptr = scratch.data() + 2 * n + 1; // R² mod m 事前計算 std::vector R2(n, 0); { Int mInt = Int::fromRawWords(std::vector(mdata, mdata + n), 1); Int R_int = Int::One(); IntOps::leftShift(R_int, static_cast(64 * n)); Int R_mod = R_int % mInt; Int R2_mod = (R_mod * R_mod) % mInt; size_t r2n = R2_mod.size(); if (r2n > 0) std::memcpy(R2.data(), R2_mod.words().data(), std::min(r2n, n) * sizeof(uint64_t)); } // base → Montgomery 形式 std::vector baseR(n, 0); { Int bmod = IntModular::mod(base, m); std::vector base_padded(n, 0); size_t bn = bmod.size(); if (bn > 0) { auto bw = bmod.words(); std::memcpy(base_padded.data(), bw.data(), std::min(bn, n) * sizeof(uint64_t)); } mont_mul(baseR.data(), base_padded.data(), R2.data(), mdata, n, m_inv, mip1, scratch.data(), mul_scratch_ptr); } // 1R = Montgomery 形式の 1 std::vector oneR(n, 0); { std::vector one_padded(n, 0); one_padded[0] = 1; mont_mul(oneR.data(), one_padded.data(), R2.data(), mdata, n, m_inv, mip1, scratch.data(), mul_scratch_ptr); } // === 固定ウィンドウ: 全べき事前計算 === // g[i] = base^i * R mod m (i = 0..2^w-1) std::vector g_buf(static_cast(tableSize) * n); std::vector g(tableSize); for (int i = 0; i < tableSize; ++i) g[i] = g_buf.data() + static_cast(i) * n; // g[0] = 1R std::memcpy(g[0], oneR.data(), n * sizeof(uint64_t)); // g[1] = baseR std::memcpy(g[1], baseR.data(), n * sizeof(uint64_t)); // g[i] = g[i-1] * baseR for (int i = 2; i < tableSize; ++i) mont_mul(g[i], g[i - 1], baseR.data(), mdata, n, m_inv, mip1, scratch.data(), mul_scratch_ptr); // === 定数時間テーブルルックアップ === // 全エントリを読み、idx == target のものだけ OR で選択 auto ct_select = [&](uint64_t* dst, int idx) { std::memset(dst, 0, n * sizeof(uint64_t)); for (int i = 0; i < tableSize; ++i) { uint64_t mask = static_cast(-(static_cast(i == idx))); for (size_t j = 0; j < n; ++j) dst[j] |= g[i][j] & mask; } }; // === 定数時間条件付きコピー === auto ct_cond_copy = [&](uint64_t* dst, const uint64_t* src, bool cond) { uint64_t mask = static_cast(-(static_cast(cond))); uint64_t nmask = ~mask; for (size_t j = 0; j < n; ++j) dst[j] = (dst[j] & nmask) | (src[j] & mask); }; // === 固定ウィンドウ冪剰余 (left-to-right) === // 指数を w ビットずつ処理。パディングして w の倍数にする。 size_t padded_bits = ((expBits + w - 1) / w) * w; std::vector result(n); std::vector temp(n); std::vector sel(n); std::memcpy(result.data(), oneR.data(), n * sizeof(uint64_t)); for (size_t pos = padded_bits; pos >= static_cast(w); pos -= w) { // w 回の二乗 for (int s = 0; s < w; ++s) { mont_sqr(temp.data(), result.data(), mdata, n, m_inv, mip1, scratch.data(), sq_scratch); std::memcpy(result.data(), temp.data(), n * sizeof(uint64_t)); } // ウィンドウ値の抽出 int wval = 0; for (int b = 0; b < w; ++b) { int bit_idx = static_cast(pos) - w + b; int bit = (bit_idx >= 0 && bit_idx < static_cast(expBits)) ? (exp.getBit(bit_idx) ? 1 : 0) : 0; wval |= bit << b; } // 定数時間テーブル選択 ct_select(sel.data(), wval); // 常に乗算を実行 mont_mul(temp.data(), result.data(), sel.data(), mdata, n, m_inv, mip1, scratch.data(), mul_scratch_ptr); std::memcpy(result.data(), temp.data(), n * sizeof(uint64_t)); } // === Montgomery → 通常形式 === std::memset(scratch.data(), 0, (2 * n + 1) * sizeof(uint64_t)); std::memcpy(scratch.data(), result.data(), n * sizeof(uint64_t)); mont_redc_2(temp.data(), scratch.data(), mdata, n, m_inv, mip1); std::vector rwords(temp.data(), temp.data() + n); while (!rwords.empty() && rwords.back() == 0) rwords.pop_back(); if (rwords.empty()) return Int::Zero(); return Int::fromRawWords(rwords, 1); } } // anonymous namespace // 正規化剰余(Mathematica の Mod[] と同じ動作) Int IntModular::mod(const Int& x, const Int& m) { // 1. 特殊状態の処理 if (x.isNaN() || m.isNaN()) { return Int::NaN(); } if (x.isInfinite() || m.isInfinite()) { return Int::NaN(); } if (m.isZero()) { // ゼロ除算 return Int::NaN(); } // 2. m が負の場合の処理 // mod(x, -m) = -mod(-x, m) if (m.getSign() < 0) { return -mod(-x, -m); } // 3. x が負の場合の処理 // mod(-a, m) = m - mod(a, m) (ただし mod(a,m)=0 なら 0) if (x.getSign() < 0) { Int t = mod(-x, m); if (t.isZero()) { return t; } else { return m - t; } } // 4. 正の値同士の通常の剰余 // x % m を使う(C++のoperator%) return x % m; } // 冪剰余: base^exp mod m (Binary exponentiation) Int IntModular::powerMod(const Int& base, const Int& exp, const Int& m) { // 1. 特殊状態の処理 if (base.isNaN() || exp.isNaN() || m.isNaN()) { return Int::NaN(); } if (base.isInfinite() || exp.isInfinite() || m.isInfinite()) { return Int::NaN(); } if (m.isZero()) { // ゼロ除算 return Int::NaN(); } // 2. m = 1 の特殊ケース(任意の数 mod 1 = 0) if (m.isOne()) { return Int::Zero(); } // 3. 指数が 0 の場合: base^0 = 1 if (exp.isZero()) { return Int::One(); } // 4. 底が 0 の場合: 0^n = 0 (n > 0) if (base.isZero()) { return Int::Zero(); } // 5. 負の指数の処理: base^(-n) = (base^n)^(-1) mod m if (exp.getSign() < 0) { // まず base^|exp| mod m を計算 Int pos_result = powerMod(base, -exp, m); // その逆元を返す return inverseMod(pos_result, m); } // 6. m が奇数なら Montgomery 冪剰余を使用 (除算不要で高速) if (m.getBit(0)) { return mont_power_mod(base, exp, m); } // 7. m が偶数: 従来の Binary exponentiation (right-to-left method) // 入口で特殊状態チェック済み → Unchecked 版を使用 Int x = mod(base, m); Int result = Int::One(); Int e = exp; Int temp; while (!e.isZero()) { if (e.getBit(0)) { IntOps::mulUnchecked(result, x, temp); result = mod(temp, m); } IntOps::square(x, temp); x = mod(temp, m); IntOps::rightShift(e, 1, e); } return result; } // 乗法的逆元: a^(-1) mod m Int IntModular::inverseMod(const Int& a, const Int& m, bool is_prime) { // 1. 特殊状態の処理 if (a.isNaN() || m.isNaN()) { return Int::NaN(); } if (a.isInfinite() || m.isInfinite()) { return Int::NaN(); } if (m.isZero()) { // ゼロ除算 return Int::NaN(); } // 2. m = 1 の特殊ケース(任意の数 mod 1 = 0、逆元なし) if (m.isOne()) { return Int::Zero(); } // 3. a = 0 の場合、逆元は存在しない if (a.isZero()) { return Int::Zero(); } // 4. a を正規化(0 <= a < m の範囲に) Int a_norm = mod(a, m); if (a_norm.isZero()) { return Int::Zero(); // a ≡ 0 mod m なので逆元なし } // 5. a = 1 の場合、逆元は 1 if (a_norm.isOne()) { return Int::One(); } // 6. m が素数の場合の最適化: a^(-1) ≡ a^(m-2) mod m (フェルマーの小定理) if (is_prime) { return powerMod(a_norm, m - 2, m); } // 7. 一般の場合: 拡張ユークリッド互除法 // a * x + m * y = gcd(a, m) を解く Int x, y; Int g = IntGCD::extendedGcd(a_norm, m, x, y); // 8. gcd(a, m) != 1 なら逆元は存在しない if (!g.isOne()) { return Int::Zero(); // 逆元なし } // 9. x を正規化して返す(0 <= x < m) Int result = mod(x, m); return result; } // 定数時間冪剰余: base^exp mod m (GMP mpz_powm_sec 相当) Int IntModular::powerModSec(const Int& base, const Int& exp, const Int& m) { // 1. 特殊状態 if (base.isNaN() || exp.isNaN() || m.isNaN()) return Int::NaN(); if (base.isInfinite() || exp.isInfinite() || m.isInfinite()) return Int::NaN(); if (m.isZero()) return Int::NaN(); // 2. m = 1 → 0 if (m.isOne()) return Int::Zero(); // 3. exp = 0 → 1 if (exp.isZero()) return Int::One(); // 4. 負の指数は不可 (定数時間の制約) if (exp.getSign() < 0) return Int::NaN(); // 5. m は奇数正整数でなければならない if (!m.getBit(0)) { // 偶数 m → powerMod にフォールバック return powerMod(base, exp, m); } // 6. Montgomery 定数時間冪剰余 return mont_power_mod_sec(base, exp, m); } } // namespace calx