// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // IntOps.hpp // 多倍長整数の操作に関するユーティリティ #ifndef CALX_INT_OPS_HPP #define CALX_INT_OPS_HPP #include #include namespace calx { /** * @brief 多倍長整数の操作に関するユーティリティクラス * * このクラスは、Intクラスの内部値を操作するための静的メソッドを提供します。 * 主に算術演算やビット演算の実装を含みます。 */ class IntOps { public: // 比較操作 static bool compareAbsLess(const Int& lhs, const Int& rhs); static bool compareAbsGreater(const Int& lhs, const Int& rhs); static bool compareAbsEqual(const Int& lhs, const Int& rhs); // 加算・減算操作 static void addAbsolute(Int& result, const Int& other); static void addAbsolute(const Int& lhs, const Int& rhs, Int& result); static void subtractAbsolute(const Int& lhs, const Int& rhs, Int& result); static void subtractAbsoluteInPlace(Int& result, const Int& other); // インクリメント/デクリメント: Int(1) 構築 + operator+/- を回避 // delta = +1 or -1 static void addDelta(Int& value, int delta); // 3引数版 add/sub/mul/div: result のバッファを再利用 (GMP mpz_add/sub/mul/tdiv_q 相当) // result は lhs, rhs と同一オブジェクトでも可 static void add(const Int& lhs, const Int& rhs, Int& result); static void sub(const Int& lhs, const Int& rhs, Int& result); static void mul(const Int& lhs, const Int& rhs, Int& result); static void div(const Int& dividend, const Int& divisor, Int& result); // チェック無し版: isSpecialState() を省略 (呼び出し元が保証) // 内部アルゴリズムのホットパスで使用 static void addUnchecked(const Int& lhs, const Int& rhs, Int& result); static void subUnchecked(const Int& lhs, const Int& rhs, Int& result); static void mulUnchecked(const Int& lhs, const Int& rhs, Int& result); static void divUnchecked(const Int& dividend, const Int& divisor, Int& result); // 単一ワード演算 (GMP _ui 相当: Int 一時オブジェクト構築を回避) // addWord: result += word (符号付き加算) static void addWord(Int& result, uint64_t word); // subWord: result -= word (符号付き減算) static void subWord(Int& result, uint64_t word); // divWord: result /= word, 余りを返す static uint64_t divWord(Int& result, uint64_t word); // 乗算操作 static void multiplyAbsolute(Int& result, const Int& other); static void multiplyAbsolute(const Int& lhs, const Int& rhs, Int& result); static void multiplyWord(Int& result, uint64_t word); static void square(const Int& value, Int& result); // NTT キャッシュ付き絶対値乗算: rhs の forward NTT をキャッシュして再利用 // 同じ rhs と複数の lhs を乗算する場面 (BS merge の QR 再利用等) で有効 static void mulAbsCached(const Int& lhs, const Int& rhs, Int& result, prime_ntt::NttCache& cache); // YC-2: Fused multiply-add: result = a*b + c*d (NTT 融合で高速化) // BS merge の T = TL*QR + PL*TR パターンを 1 回の inverse NTT + CRT で計算 // 符号処理込み (正+正, 正-負 等すべて対応) static void mulAdd(const Int& a, const Int& b, const Int& c, const Int& d, Int& result); // 除算・剰余操作 static void divideAbsolute(Int& result, const Int& divisor); static void moduloAbsolute(Int& result, const Int& divisor); static Int divmod(const Int& dividend, const Int& divisor, Int& remainder); // ビット操作 static void bitwiseAnd(Int& result, const Int& other); static void bitwiseOr(Int& result, const Int& other); static void bitwiseXor(Int& result, const Int& other); static void bitwiseNot(Int& result); static void leftShift(Int& value, int shift); static void rightShift(Int& value, int shift); // 3引数版: result = value << shift / value >> shift (バッファ再利用) static void leftShift(const Int& value, int shift, Int& result); static void rightShift(const Int& value, int shift, Int& result); // 2の累乗係数 static Int pow2(uint64_t exponent); // GCD (最大公約数) static Int gcd(const Int& a, const Int& b); // 因子除去 (removeFactor) // value から factor を全て除去し、除去した回数を返す // value = result × factor^count となる // GMP 互換: mpz_remove(rop, op, f) // 戻り値: 除去した回数 (指数) // result: 除去後の値 (value / factor^count) static uint64_t removeFactor(const Int& value, const Int& factor, Int& result); // 融合乗加算/乗減算: rop += a*b / rop -= a*b (GMP mpz_addmul/mpz_submul 相当) // 中間 Int の一時構築を最小化 static void addmul(Int& rop, const Int& a, const Int& b); static void submul(Int& rop, const Int& a, const Int& b); // 単一ワード版: rop += a*word / rop -= a*word static void addmul(Int& rop, const Int& a, uint64_t word); static void submul(Int& rop, const Int& a, uint64_t word); // 平方根 static Int sqrt(const Int& value); // 床除算(floor division) // 商を負の無限大方向に丸める除算 // 例: floorDiv(-7, 3) = -3 (通常の除算では -2) static Int floorDiv(const Int& dividend, const Int& divisor); // 床剰余(floor modulo) // 結果の符号は常に除数と同じ (Python の % と同一) // 例: floorMod(-7, 3) = 2, floorMod(7, -3) = -2 static Int floorMod(const Int& dividend, const Int& divisor); // 床除算と剰余を同時に返す // quotient = floorDiv(a, b), remainder = floorMod(a, b) static Int floorDivMod(const Int& dividend, const Int& divisor, Int& remainder); // 天井除算(ceiling division) // 商を正の無限大方向に丸める除算 // 例: ceilDiv(7, 3) = 3, ceilDiv(-7, 3) = -2 static Int ceilDiv(const Int& dividend, const Int& divisor); // 正確な除算(divisor が dividend を割り切ることが保証されている場合) // 割り切れない場合の動作は未定義(高速化のため余りチェックを省略) static Int divExact(const Int& dividend, const Int& divisor); private: // カラツバ法による乗算 static void karatsubaMultiply(const Int& a, const Int& b, Int& result); // Toom-Cook-3 による乗算 static void toomCookMultiply(const Int& a, const Int& b, Int& result); // Toom-Cook-4 による乗算 static void toomCook4Multiply(const Int& a, const Int& b, Int& result); // 汎用乗算 (アンバランス対応、サイズに応じた自動選択) static void generalMultiply(const Int& a, const Int& b, Int& result); // FFTを使用した乗算 static void fftMultiply(const Int& a, const Int& b, Int& result); // ニュートン法による除算 static void newtonDivision(const Int& dividend, const Int& divisor, Int& result); }; // ---------------------------------------------------------------- // インライン実装: 加算・減算 (3引数版) // inline にすることで、operator+/- から mpn::add/sub まで // 関数呼び出しなしで直結し、小サイズ演算のオーバーヘッドを除去。 // ---------------------------------------------------------------- inline void IntOps::addAbsolute(const Int& lhs, const Int& rhs, Int& result) { size_t lhsSize = lhs.m_words.size(); size_t rhsSize = rhs.m_words.size(); // mpn::add は an >= bn を要求 // result が lhs/rhs と同一オブジェクトの場合、resize 後のバッファを // そのまま使いインプレース加算 (不要なコピーを回避) size_t bigSize = std::max(lhsSize, rhsSize); size_t smallSize = std::min(lhsSize, rhsSize); if (&result == &lhs || &result == &rhs) { // result は lhs か rhs の一方と同一 const Int& other = (&result == &lhs) ? rhs : lhs; size_t otherSize = other.m_words.size(); size_t resultSize = result.m_words.size(); // result 側を bigSize+1 に拡張 (既存データは保持される) result.m_words.resize_uninitialized(bigSize + 1); uint64_t* rp = result.m_words.data(); const uint64_t* op = other.m_words.data(); uint64_t carry; if (resultSize >= otherSize) { // result が大きい側: r == a エイリアシング carry = mpn::add(rp, rp, resultSize, op, otherSize); } else { // result が小さい側: r == b エイリアシング // mpn::add は i < bn で b[i] を読んでから r[i] に書くので安全 carry = mpn::add(rp, op, otherSize, rp, resultSize); } if (carry) { rp[bigSize] = 1; } else { result.m_words.resize_uninitialized(bigSize); } } else { // result は lhs, rhs とは別オブジェクト result.m_words.resize_uninitialized(bigSize + 1); const uint64_t* bigData = (lhsSize >= rhsSize) ? lhs.m_words.data() : rhs.m_words.data(); const uint64_t* smallData = (lhsSize >= rhsSize) ? rhs.m_words.data() : lhs.m_words.data(); uint64_t carry = mpn::add(result.m_words.data(), bigData, bigSize, smallData, smallSize); if (carry) { result.m_words.data()[bigSize] = 1; } else { result.m_words.resize_uninitialized(bigSize); } } } inline void IntOps::subtractAbsolute(const Int& lhs, const Int& rhs, Int& result) { // 前提: |lhs| >= |rhs| (呼び出し元で保証) size_t lhsSize = lhs.m_words.size(); size_t rhsSize = rhs.m_words.size(); if (&result == &lhs) { // result == lhs: インプレース減算 (r == a エイリアシング) mpn::sub(result.m_words.data(), result.m_words.data(), lhsSize, rhs.m_words.data(), rhsSize); } else { result.m_words.resize_uninitialized(lhsSize); mpn::sub(result.m_words.data(), lhs.m_words.data(), lhsSize, rhs.m_words.data(), rhsSize); } // normalize: 先頭のゼロを除去 while (result.m_words.size() > 0 && result.m_words.back() == 0) { result.m_words.pop_back(); } if (result.m_words.empty()) { result.setSign(0); } } // ---------------------------------------------------------------- // 3引数版 add/sub: result のバッファを再利用 // GMP の mpz_add(r, a, b) / mpz_sub(r, a, b) 相当 // ---------------------------------------------------------------- // --- Unchecked 版 (isSpecialState チェック省略) --- inline void IntOps::addUnchecked(const Int& lhs, const Int& rhs, Int& result) { if (lhs.m_sign == 0) { if (&result != &rhs) result = rhs; return; } if (rhs.m_sign == 0) { if (&result != &lhs) result = lhs; return; } if (lhs.m_sign == rhs.m_sign) { addAbsolute(lhs, rhs, result); result.m_sign = lhs.m_sign; } else { int cmp = mpn::cmp(lhs.m_words.data(), lhs.m_words.size(), rhs.m_words.data(), rhs.m_words.size()); if (cmp < 0) { subtractAbsolute(rhs, lhs, result); result.m_sign = rhs.m_sign; } else if (cmp > 0) { subtractAbsolute(lhs, rhs, result); result.m_sign = lhs.m_sign; } else { result.m_words.clear(); result.m_sign = 0; } } result.m_state = NumericState::Normal; } inline void IntOps::subUnchecked(const Int& lhs, const Int& rhs, Int& result) { if (rhs.m_sign == 0) { if (&result != &lhs) result = lhs; return; } if (lhs.m_sign == 0) { if (&result != &rhs) result = rhs; result.m_sign = -result.m_sign; return; } if (lhs.m_sign != rhs.m_sign) { addAbsolute(lhs, rhs, result); result.m_sign = lhs.m_sign; } else { int cmp = mpn::cmp(lhs.m_words.data(), lhs.m_words.size(), rhs.m_words.data(), rhs.m_words.size()); if (cmp >= 0) { subtractAbsolute(lhs, rhs, result); result.m_sign = (cmp == 0) ? 0 : lhs.m_sign; } else { subtractAbsolute(rhs, lhs, result); result.m_sign = -rhs.m_sign; } } result.m_state = NumericState::Normal; } // --- チェック付き版 (薄いラッパー) --- inline void IntOps::add(const Int& lhs, const Int& rhs, Int& result) { if (lhs.isSpecialState() || rhs.isSpecialState()) [[unlikely]] { result = lhs + rhs; return; } addUnchecked(lhs, rhs, result); } inline void IntOps::sub(const Int& lhs, const Int& rhs, Int& result) { if (lhs.isSpecialState() || rhs.isSpecialState()) [[unlikely]] { result = lhs - rhs; return; } subUnchecked(lhs, rhs, result); } } // namespace calx #endif // CALX_INT_OPS_HPP