// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // IntSqrt.cpp // 整数平方根と平方数判定の実装 #include #include #include #include #include #include #include #include namespace calx { // ======================================================================== // メイン sqrt 関数 // ======================================================================== Int IntSqrt::sqrt(const Int& value) { // 特殊状態の処理 if (value.isNaN()) [[unlikely]] { return Int::NaN(); } if (value.isNegative()) [[unlikely]] { // 負の数の平方根は NaN Int result = Int::NaN(); result.setState(NumericState::NaN, NumericError::NegativeSqrt); return result; } if (value.getState() == NumericState::PositiveInfinity) [[unlikely]] { return Int::PositiveInfinity(); } if (value.isZero()) [[unlikely]] { return Int::Zero(); } if (value.isOne()) [[unlikely]] { return Int::One(); } // 1 limb fast path: isqrt(uint64_t) を直接計算 size_t an = value.size(); if (an == 1) { uint64_t v = value.word(0); uint64_t s = static_cast(std::sqrt(static_cast(v))); // Newton 補正 (double の丸め誤差) if (s > 0 && s * s > v) s--; if ((s + 1) * (s + 1) <= v) s++; return Int(s); } // 2+ limbs: mpn レベル Newton sqrt { // value.data() を直接参照 (コピー不要 — sqrtrem は ap を書き換えないため) const uint64_t* ap = value.data(); size_t sn = (an + 1) / 2; // スタックバッファで小サイズの arena アクセスを回避 constexpr size_t STACK_LIMIT = 64; uint64_t sp_stack[STACK_LIMIT]; size_t scratch_sz = mpn::sqrtrem_scratch_size(an); uint64_t scratch_stack[STACK_LIMIT * 4]; uint64_t* sp; uint64_t* scratch; ScratchScope scope; if (sn <= STACK_LIMIT && scratch_sz <= STACK_LIMIT * 4) { sp = sp_stack; scratch = scratch_stack; } else { sp = getThreadArena().alloc_limbs(sn); scratch = getThreadArena().alloc_limbs(scratch_sz); } size_t result_n = mpn::sqrtrem(sp, nullptr, ap, an, scratch); if (result_n == 0) return Int::Zero(); return Int::fromRawWords(std::span(sp, result_n), 1); } } // ======================================================================== // sqrt_small: double 精度での平方根計算 // ======================================================================== Int IntSqrt::sqrt_small(const Int& value) { // MSB < 52 ビットの場合、double で正確に計算できる double x = value.toDouble(); double sqrt_x = std::sqrt(x); uint64_t result = static_cast(sqrt_x); return Int(result); } // ======================================================================== // sqrtRem: 平方根と余りを返す // ======================================================================== Int IntSqrt::sqrtRem(const Int& value, Int& remainder) { // 特殊状態の処理 if (value.isNaN()) { remainder = Int::NaN(); return Int::NaN(); } if (value.isNegative()) { Int result = Int::NaN(); result.setState(NumericState::NaN, NumericError::NegativeSqrt); remainder = Int::NaN(); return result; } if (value.getState() == NumericState::PositiveInfinity) { remainder = Int::NaN(); // 余りは定義できない return Int::PositiveInfinity(); } if (value.isZero()) { remainder = Int::Zero(); return Int::Zero(); } if (value.isOne()) { remainder = Int::Zero(); return Int::One(); } // mpn レベル sqrtrem で平方根と余りを同時に計算 size_t bit_length = value.bitLength(); if (bit_length < 52) { Int s = sqrt_small(value); Int s_sq; IntOps::square(s, s_sq); remainder = value - s_sq; return s; } { size_t an = value.size(); ScratchScope scope; uint64_t* ap = getThreadArena().alloc_limbs(an); for (size_t i = 0; i < an; i++) ap[i] = value.word(i); size_t sn = (an + 1) / 2; uint64_t* sp = getThreadArena().alloc_limbs(sn); uint64_t* rem_p = getThreadArena().alloc_limbs(an); std::memset(rem_p, 0, an * sizeof(uint64_t)); size_t scratch_sz = mpn::sqrtrem_scratch_size(an); uint64_t* scratch = getThreadArena().alloc_limbs(scratch_sz); size_t result_n = mpn::sqrtrem(sp, rem_p, ap, an, scratch); size_t rem_n = mpn::normalized_size(rem_p, an); if (rem_n == 0) { remainder = Int::Zero(); } else { remainder = Int::fromRawWords(std::span(rem_p, rem_n), 1); } if (result_n == 0) return Int::Zero(); return Int::fromRawWords(std::span(sp, result_n), 1); } } // ======================================================================== // isSquare: 完全平方数判定 // GMP perfsqr.c に学んだ 3 層フィルタ: // Filter 1: up[0] % 256 ビットテーブル (O(1), 82.81% 棄却) // Filter 2: mod_34lsub1 + 小素数の平方剰余チェック (O(n), ~97.81% 追加棄却) // Filter 3: sqrtRem で最終検証 // 合計: 99.62% の非平方数を sqrtrem 前に棄却 // ======================================================================== // ================================================================ // consteval によるコンパイル時テーブル生成 // ================================================================ // mod p の平方剰余ビットマスクをコンパイル時に生成 (p < 64) consteval uint64_t make_qr_mask(unsigned p) { uint64_t mask = 0; for (unsigned i = 0; i < p; i++) mask |= 1ULL << ((i * i) % p); return mask; } // mod p の平方剰余ビットマスクをコンパイル時に生成 (p >= 64, 2 limb) struct QRMask128 { uint64_t lo; uint64_t hi; }; consteval QRMask128 make_qr_mask_wide(unsigned p) { uint64_t lo = 0, hi = 0; for (unsigned i = 0; i < p; i++) { unsigned r = (i * i) % p; if (r < 64) lo |= 1ULL << r; else hi |= 1ULL << (r - 64); } return {lo, hi}; } // up[0] % 256 の平方剰余ビットテーブル (GMP perfsqr.h の sq_res_0x100 相当) // ビットが 1 なら平方剰余の可能性あり、0 なら確実に非平方数 consteval std::array make_sq_res_256() { std::array t{}; for (unsigned i = 0; i < 256; i++) { unsigned r = (i * i) & 0xFF; t[r / 64] |= 1ULL << (r % 64); } return t; } static constexpr auto SQ_RES_256 = make_sq_res_256(); // 各素数の平方剰余マスク (コンパイル時検証可能) static constexpr uint64_t QR_9 = make_qr_mask(9); // 棄却率 55.56% static constexpr uint64_t QR_5 = make_qr_mask(5); // 棄却率 40% static constexpr uint64_t QR_7 = make_qr_mask(7); // 棄却率 42.86% static constexpr uint64_t QR_13 = make_qr_mask(13); // 棄却率 46.15% static constexpr uint64_t QR_17 = make_qr_mask(17); // 棄却率 47.06% static constexpr auto QR_97 = make_qr_mask_wide(97); // 棄却率 49.48% // mod_34lsub1 の結果 r を小素数で平方剰余チェック // 2^48-1 の素因数 (3,5,7,13,17,97) で検査 static bool check_sq_residue_mod34(uint64_t r) { constexpr uint64_t M48 = (1ULL << 48) - 1; r = (r & M48) + (r >> 48); if (((QR_9 >> (r % 9)) & 1) == 0) return false; if (((QR_5 >> (r % 5)) & 1) == 0) return false; if (((QR_7 >> (r % 7)) & 1) == 0) return false; if (((QR_13 >> (r % 13)) & 1) == 0) return false; if (((QR_17 >> (r % 17)) & 1) == 0) return false; { uint64_t r97 = r % 97; if (r97 < 64) { if (((QR_97.lo >> r97) & 1) == 0) return false; } else { if (((QR_97.hi >> (r97 - 64)) & 1) == 0) return false; } } return true; } bool IntSqrt::isSquare(const Int& value, Int* pSqrt) { // 特殊状態の処理 if (value.isNaN()) [[unlikely]] { if (pSqrt) *pSqrt = Int::NaN(); return false; } if (value.isInfinite()) [[unlikely]] { if (pSqrt) *pSqrt = Int::NaN(); return false; } // 負の数は平方数ではない if (value.isNegative()) [[unlikely]] { return false; } // 0 と 1 は平方数 if (value.isZero()) [[unlikely]] { if (pSqrt) *pSqrt = Int::Zero(); return true; } if (value.isOne()) [[unlikely]] { if (pSqrt) *pSqrt = Int::One(); return true; } // Filter 1: up[0] % 256 ビットテーブル (O(1), 82.81% 棄却) { unsigned idx = static_cast(value.word(0) & 0xFF); if (((SQ_RES_256[idx / 64] >> (idx % 64)) & 1) == 0) return false; } // Filter 2: mod_34lsub1 ベースの平方剰余フィルタ (O(n), ~97.81% 追加棄却) { std::vector words = value.words(); uint64_t r = mpn::mod_34lsub1(words.data(), words.size()); if (!check_sq_residue_mod34(r)) return false; } // Filter 3: sqrtRem で最終検証 (O(M(n))) Int rem; Int s = sqrtRem(value, rem); if (pSqrt) *pSqrt = s; return rem.isZero(); } Int IntSqrt::nthRoot(const Int& value, uint32_t n) { // n = 0 または n = 1 の特殊ケース if (n == 0) { // n = 0 は無効な引数 return Int::NaN(); } if (n == 1) { return value; } // n = 2 の場合は sqrt を使用 if (n == 2) { return sqrt(value); } // 特殊状態の処理 if (value.isNaN()) { return Int::NaN(); } // 負の値の処理 if (value.isNegative()) { // n が奇数の場合は負の根を返す: -root(|value|, n) if (n & 1) { Int result = nthRoot(-value, n); return -result; } else { // n が偶数の場合は NaN(負の偶数乗根は実数で存在しない) Int result = Int::NaN(); result.setState(NumericState::NaN, NumericError::NegativeSqrt); return result; } } if (value.getState() == NumericState::PositiveInfinity) { return Int::PositiveInfinity(); } if (value.isZero()) { return Int::Zero(); } if (value.isOne()) { return Int::One(); } // Precision Doubling Newton + floor 検証 // // アルゴリズム概要 (Brent & Zimmermann, Modern Computer Arithmetic §1.5.2): // 1. double 近似で ~53/n ビットの初期値を取得 // 2. 精度倍増 Newton: 各ステップで正確ビット数を約2倍に拡大 // 精度スケジュール: c_prev = ceil((c + L) / 2), L = 2*ceil(log2(n)) + 2 // 3. 最終ステップで remainder 追跡による Newton 補正 + floor 保証 // // 計算量: // PD ループ: pow(x, n-1) を各ステップで計算。等比級数の性質により // 全ステップの合計コストは最終ステップの約2倍(≈ 2·M(P)·log(n))。 // 後処理: 最終ステップの pow を Newton 補正と floor 検証に再利用。 // Newton 補正 Q > 0 の場合のみ検証用 pow を1回追加。 int B = static_cast(value.bitLength()); int ni = static_cast(n); // value < 2^n の場合、root は 1 if (B <= ni) { return Int::One(); } // --- 初期近似 (double ベース) --- Int x; { int shift = (B > 53) ? (B - 53) : 0; double m_d = (shift > 0) ? (value >> shift).toDouble() : value.toDouble(); int q = shift / ni; int r = shift % ni; double root_d = std::pow(m_d, 1.0 / n) * std::exp2(static_cast(r) / n); int64_t ri = std::max(static_cast(1), static_cast(root_d)); x = (q > 0) ? (Int(ri) << q) : Int(ri); } if (x.isZero()) x = Int::One(); const Int n_minus_1(static_cast(n - 1)); const Int n_int(static_cast(n)); int P = (B + ni - 1) / ni; // root の目標ビット精度 int x_bits = static_cast(x.bitLength()); // double 近似の正確なビット数 int P0 = std::min(x_bits, std::max(53 / ni - 1, 4)); if (P > P0 + 10) { // === Precision Doubling Newton (Brent-Zimmermann 式) === // ceil(log2(n)) int logk = 0; { unsigned tmp = n - 1; while (tmp > 0) { logk++; tmp >>= 1; } } int L = 2 * logk + 2; // 精度スケジュールを P から逆向きに構築 int sizes[64]; int ns = 0; sizes[0] = P; while (sizes[ns] > P0 && ns < 60) { int next = (sizes[ns] + L + 1) / 2; if (next >= sizes[ns]) break; sizes[ns + 1] = next; ns++; } // 正順に反転: sizes[0]=最小, sizes[ns]=P for (int i = 0, j = ns; i < j; i++, j--) { std::swap(sizes[i], sizes[j]); } // 初期精度に切り詰め if (x_bits > sizes[0]) { x = x >> (x_bits - sizes[0]); } int cur_prec = static_cast(x.bitLength()); // --- PD Newton ループ: step 1 .. ns-1 --- for (int step = 1; step < ns; step++) { int new_prec = sizes[step]; int extend = new_prec - cur_prec; if (extend > 0) x = x << extend; int new_scale = P - new_prec; int a_shift = ni * new_scale; Int a = (a_shift > 0) ? (value >> a_shift) : value; Int xpm1 = pow(x, n - 1); Int quot = a / xpm1; x = (n_minus_1 * x + quot) / n_int; cur_prec = static_cast(x.bitLength()); } // --- 最終ステップ: remainder 追跡 + Newton 補正 + floor 検証 --- // // PD ループの最終ステップで pow(x, n-1) を計算し、 // remainder R = value - x^n を得る。R >= 0 を保証した上で // Newton 補正 Q = R / (n·x^(n-1)) を適用。 // Q == 0 なら追加 pow 不要で確定、Q > 0 なら検証用 pow を1回実行。 { int extend = sizes[ns] - cur_prec; if (extend > 0) x = x << extend; Int xpm1 = pow(x, n - 1); Int R = value - xpm1 * x; // オーバーシュート補正(高々 1-2 回) for (int adj = 0; adj < 3 && R.isNegative(); adj++) { x = x - Int::One(); xpm1 = pow(x, n - 1); R = value - xpm1 * x; } // Newton 補正 Int D = n_int * xpm1; Int Q = R / D; if (Q.isZero()) { // R >= 0 かつ R < D = n·x^(n-1) < (x+1)^n - x^n // → x^n <= value < (x+1)^n → x = floor(value^(1/n)) return x; } x = x + Q; // Q > 0: 検証用 pow(Newton 補正で x が変化したため) xpm1 = pow(x, n - 1); R = value - xpm1 * x; if (R.isNegative()) { return x - Int::One(); } if (R >= n_int * xpm1) { return x + Int::One(); } return x; } } else { // 小さい root: 標準 Newton(PD 不要) constexpr int MAX_ITER = 200; for (int iter = 0; iter < MAX_ITER; iter++) { Int xpm1 = pow(x, n - 1); Int quot = value / xpm1; Int x_new = (n_minus_1 * x + quot) / n_int; if (x_new == x) break; Int diff = (x_new > x) ? (x_new - x) : (x - x_new); if (diff <= Int::One()) { if (x_new < x) x = x_new; break; } x = x_new; } } // === 小さい root 用の最終検証 === { Int xpm1 = pow(x, n - 1); Int R = value - xpm1 * x; while (R.isNegative()) { x = x - Int::One(); if (x.isZero()) return x; xpm1 = pow(x, n - 1); R = value - xpm1 * x; } if (R >= n_int * xpm1) { Int xp1_pow_n = pow(x + Int::One(), n); if (xp1_pow_n <= value) { x = x + Int::One(); } } } return x; } // ======================================================================== // nthRootRem: n乗根と余りを同時に計算 // ======================================================================== Int IntSqrt::nthRootRem(const Int& value, uint32_t n, Int& remainder) { // root を計算 Int root = nthRoot(value, n); // 特殊状態の処理 if (root.isNaN()) { remainder = Int::NaN(); return root; } if (root.isInfinite()) { remainder = Int::NaN(); // 余りは定義できない return root; } // remainder = value - root^n Int root_power = pow(root, n); remainder = value - root_power; return root; } // ============================================================================= // isPerfectPower — value = b^k (k >= 2) となる k が存在するか判定 // ============================================================================= bool IntSqrt::isPerfectPower(const Int& value, Int* pBase, uint32_t* pExp) { if (value.isNaN() || value.isInfinite()) return false; if (value.isNegative()) return false; // 負の完全冪は未対応 if (value.isZero()) { if (pBase) *pBase = Int(0); if (pExp) *pExp = 2; return true; // 0 = 0^2 } if (value == Int::One()) { if (pBase) *pBase = Int(1); if (pExp) *pExp = 2; return true; // 1 = 1^2 } size_t bits = value.bitLength(); // k の上限: 2^k <= value なので k <= bitLength uint32_t max_k = static_cast(bits); // 素数の指数のみ試行すれば十分 (e.g., a^6 = (a^2)^3 = (a^3)^2) // 小さい素数から順に試行 static const uint32_t primes[] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61 }; for (uint32_t p : primes) { if (p > max_k) break; Int root = nthRoot(value, p); if (pow(root, p) == value) { if (pBase) *pBase = root; if (pExp) *pExp = p; return true; } } return false; } } // namespace calx