// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // IntMultiplication.hpp // 高速乗算アルゴリズム (Karatsuba, Toom-Cook, FFT) #pragma once #include #include #include #include #include #include #include #include namespace calx { // 乗算アルゴリズムの閾値(ワード数) // ベンチマーク結果 (2026-02) に基づいて調整済み: // - Basecase → Karatsuba: 128 ワードで 44% 高速化 // - Karatsuba → Toom-Cook: 256 ワードで 15% 高速化 // - FFT (double ベース): 全サイズで Toom-Cook より遅い (1.5x-2.8x) // → NTT 実装後に再評価予定。現在は operator* で未使用。 namespace MultiplicationThresholds { constexpr size_t KARATSUBA_THRESHOLD = 128; // 128 ワード (8K-bit) 以上で Karatsuba constexpr size_t TOOMCOOK_THRESHOLD = 256; // 256 ワード (16K-bit) 以上で Toom-Cook-3 constexpr size_t TOOMCOOK4_THRESHOLD = 640; // 640 ワード (40K-bit) 以上で Toom-Cook-4 constexpr size_t FFT_THRESHOLD = 256; // 256 ワード — 現在未使用 (FFT は Toom-Cook より遅い) } class IntMultiplication { public: // Karatsuba 乗算 (mpn ベース) // 計算量: O(n^1.585) where n = max(X.words, Y.words) // 内部は生 limb ポインタで演算し、Int 生成は入口と出口の 1 回のみ static Int karatsubaMultiply(const Int& X, const Int& Y) { int result_sign = X.getSign() * Y.getSign(); if (result_sign == 0) return Int(0); if (X.isSpecialState() || Y.isSpecialState()) { return basecaseMultiply(X, Y); } const uint64_t* a = X.m_words.data(); size_t an = X.m_words.size(); const uint64_t* b = Y.m_words.data(); size_t bn = Y.m_words.size(); // 小さい場合は basecase(mpn ベース) if (an < MultiplicationThresholds::KARATSUBA_THRESHOLD || bn < MultiplicationThresholds::KARATSUBA_THRESHOLD) { return basecaseMultiply(X, Y); } // Arena から結果バッファと scratch を確保 ScratchScope scope; size_t rn = an + bn; uint64_t* r = getThreadArena().alloc_limbs(rn); size_t scratch_size = mpn::mul_karatsuba_scratch_size(std::max(an, bn)); uint64_t* scratch = getThreadArena().alloc_limbs(scratch_size); // mpn Karatsuba 実行 mpn::mul_karatsuba(r, a, an, b, bn, scratch); // 結果を Int に変換 (ここだけ 1 回コピー) size_t actual_rn = mpn::normalized_size(r, rn); if (actual_rn == 0) return Int(0); return Int::fromRawWords(std::span(r, actual_rn), result_sign); } // Toom-Cook-3 乗算 (mpn ベース) // 計算量: O(n^1.465) where n = max(X.words, Y.words) // 評価点 {0, 1, -1, 2, ∞} (GMP 方式) // 内部は生 limb ポインタで演算し、Int 生成は入口と出口の 1 回のみ static Int toomCook3Multiply(const Int& X, const Int& Y) { int result_sign = X.getSign() * Y.getSign(); if (result_sign == 0) return Int(0); if (X.isSpecialState() || Y.isSpecialState()) { return basecaseMultiply(X, Y); } const uint64_t* a = X.m_words.data(); size_t an = X.m_words.size(); const uint64_t* b = Y.m_words.data(); size_t bn = Y.m_words.size(); if (an < MultiplicationThresholds::TOOMCOOK_THRESHOLD || bn < MultiplicationThresholds::TOOMCOOK_THRESHOLD) { return karatsubaMultiply(X, Y); } // Arena から結果バッファと scratch を確保 ScratchScope scope; size_t rn = an + bn; uint64_t* r = getThreadArena().alloc_limbs(rn); size_t scratch_size = mpn::mul_toomcook3_scratch_size(std::max(an, bn)); uint64_t* scratch = getThreadArena().alloc_limbs(scratch_size); // mpn Toom-Cook-3 実行 mpn::mul_toomcook3(r, a, an, b, bn, scratch); // 結果を Int に変換 size_t actual_rn = mpn::normalized_size(r, rn); if (actual_rn == 0) return Int(0); return Int::fromRawWords(std::span(r, actual_rn), result_sign); } // Toom-Cook-4 乗算 (mpn ベース) // 計算量: O(n^1.404) where n = max(X.words, Y.words) // 評価点 {0, 1, -1, 2, -2, 3, ∞} // 内部は生 limb ポインタで演算し、Int 生成は入口と出口の 1 回のみ static Int toomCook4Multiply(const Int& X, const Int& Y) { int result_sign = X.getSign() * Y.getSign(); if (result_sign == 0) return Int(0); if (X.isSpecialState() || Y.isSpecialState()) { return basecaseMultiply(X, Y); } const uint64_t* a = X.m_words.data(); size_t an = X.m_words.size(); const uint64_t* b = Y.m_words.data(); size_t bn = Y.m_words.size(); if (an < MultiplicationThresholds::TOOMCOOK4_THRESHOLD || bn < MultiplicationThresholds::TOOMCOOK4_THRESHOLD) { return toomCook3Multiply(X, Y); } // Arena から結果バッファと scratch を確保 ScratchScope scope; size_t rn = an + bn; uint64_t* r = getThreadArena().alloc_limbs(rn); size_t scratch_size = mpn::mul_toomcook4_scratch_size(std::max(an, bn)); uint64_t* scratch = getThreadArena().alloc_limbs(scratch_size); // mpn Toom-Cook-4 実行 mpn::mul_toomcook4(r, a, an, b, bn, scratch); // 結果を Int に変換 size_t actual_rn = mpn::normalized_size(r, rn); if (actual_rn == 0) return Int(0); return Int::fromRawWords(std::span(r, actual_rn), result_sign); } // 汎用乗算 (アンバランス乗算対応) // サイズに応じてアルゴリズムを自動選択し、 // サイズ比が大きい場合はチャンク方式で最適化 static Int multiply(const Int& X, const Int& Y) { int result_sign = X.getSign() * Y.getSign(); if (result_sign == 0) return Int(0); if (X.isSpecialState() || Y.isSpecialState()) { return basecaseMultiply(X, Y); } const uint64_t* a = X.m_words.data(); size_t an = X.m_words.size(); const uint64_t* b = Y.m_words.data(); size_t bn = Y.m_words.size(); ScratchScope scope; size_t rn = an + bn; uint64_t* r = getThreadArena().alloc_limbs(rn); size_t scratch_size = mpn::multiply_scratch_size(an, bn); uint64_t* scratch = (scratch_size > 0) ? getThreadArena().alloc_limbs(scratch_size) : nullptr; mpn::multiply(r, a, an, b, bn, scratch); size_t actual_rn = mpn::normalized_size(r, rn); if (actual_rn == 0) return Int(0); return Int::fromRawWords(std::span(r, actual_rn), result_sign); } // 積の上位 rn word のみ計算 (short multiplication) // basecase サイズ以下の場合のみ mulhigh_basecase を使用。 // それ以上は通常の全積 → 上位抽出にフォールバック。 static Int multiplyHigh(const Int& X, const Int& Y, size_t rn) { int result_sign = X.getSign() * Y.getSign(); if (result_sign == 0) return Int(0); const uint64_t* a = X.m_words.data(); size_t an = X.m_words.size(); const uint64_t* b = Y.m_words.data(); size_t bn = Y.m_words.size(); size_t total = an + bn; if (rn >= total) { return multiply(X, Y); // 全積で十分 } ScratchScope scope; uint64_t* rp = getThreadArena().alloc_limbs(rn); if (std::min(an, bn) < mpn::KARATSUBA_THRESHOLD) { // Basecase: mulhigh で下位スキップ mpn::mulhigh_basecase(rp, a, an, b, bn, rn); } else { // Karatsuba 以上: 全積を計算して上位を抽出 size_t scratch_size = mpn::multiply_scratch_size(an, bn); uint64_t* full = getThreadArena().alloc_limbs(total); uint64_t* scratch = (scratch_size > 0) ? getThreadArena().alloc_limbs(scratch_size) : nullptr; mpn::multiply(full, a, an, b, bn, scratch); std::memcpy(rp, full + (total - rn), rn * sizeof(uint64_t)); } size_t actual_rn = mpn::normalized_size(rp, rn); if (actual_rn == 0) return Int(0); return Int::fromRawWords(std::span(rp, actual_rn), result_sign); } // FFT 乗算 // 計算量: O(n log n) where n = FFT size // 精度限界: N < 2^17 点 (約63万桁) - double (53-bit) の制約 static Int fftMultiply(const Int& X, const Int& Y) { // 符号を保存して、絶対値で計算 int result_sign = X.getSign() * Y.getSign(); // 絶対値を取得 Int absX = (X.getSign() < 0) ? -X : X; Int absY = (Y.getSign() < 0) ? -Y : Y; size_t Nx = absX.words().size(); size_t Ny = absY.words().size(); // 小さい数は Toom-Cook if (Nx < MultiplicationThresholds::FFT_THRESHOLD || Ny < MultiplicationThresholds::FFT_THRESHOLD) { Int result = toomCook3Multiply(absX, absY); if (result_sign < 0) result = -result; return result; } // Int の 64-bit ワードを 16-bit ワードに分解 // (FFT の精度を保つため、小さいワードサイズを使用) constexpr size_t WORD_BITS = 16; constexpr size_t WORDS_PER_U64 = 64 / WORD_BITS; // = 4 size_t x_u16_words = Nx * WORDS_PER_U64; size_t y_u16_words = Ny * WORDS_PER_U64; // FFT サイズ決定 (次の2のべき乗) size_t result_u16_words = x_u16_words + y_u16_words; int nfft = 1; while (static_cast(nfft) < result_u16_words) { nfft *= 2; } // 精度限界チェック (N ≤ 2^16) if (nfft > (1 << 16)) { // FFT の精度限界を超える場合は Toom-Cook にフォールバック Int result = toomCook3Multiply(absX, absY); if (result_sign < 0) result = -result; return result; } // FFT 初期化 FFTEngine fft(nfft, FFTDivMode::Inverse); // X を 16-bit ワード配列に変換 std::vector x_data(nfft, 0.0); auto x_words = absX.words(); for (size_t i = 0; i < Nx; ++i) { uint64_t word = x_words[i]; for (size_t j = 0; j < WORDS_PER_U64; ++j) { x_data[i * WORDS_PER_U64 + j] = static_cast(word & 0xFFFF); word >>= WORD_BITS; } } // Y を 16-bit ワード配列に変換 std::vector y_data(nfft, 0.0); auto y_words = absY.words(); for (size_t i = 0; i < Ny; ++i) { uint64_t word = y_words[i]; for (size_t j = 0; j < WORDS_PER_U64; ++j) { y_data[i * WORDS_PER_U64 + j] = static_cast(word & 0xFFFF); word >>= WORD_BITS; } } // 実 FFT 変換 auto X_freq = fft.real_transform(x_data); auto Y_freq = fft.real_transform(y_data); // 周波数領域で点ごとの乗算 std::vector> Z_freq(nfft / 2 + 1); for (int i = 0; i <= nfft / 2; ++i) { Z_freq[i] = X_freq[i] * Y_freq[i]; } // 実 FFT 逆変換 auto z_data = fft.real_inverse(Z_freq); // 16-bit ワードから 64-bit ワードへ変換 (キャリー処理) std::vector result_words; result_words.reserve((nfft / WORDS_PER_U64) + 1); uint64_t carry = 0; uint64_t current_word = 0; size_t shift = 0; for (int i = 0; i < nfft; ++i) { // 四捨五入して整数化 uint64_t val = static_cast(z_data[i] + 0.5) + carry; // 16-bit ワードとして処理 uint64_t low16 = val & 0xFFFF; carry = val >> WORD_BITS; // 64-bit ワードに詰め込む current_word |= (low16 << shift); shift += WORD_BITS; if (shift == 64) { result_words.push_back(current_word); current_word = 0; shift = 0; } } // 最後のキャリー処理 if (shift > 0) { current_word |= (carry << shift); result_words.push_back(current_word); // carry の上位ビットが shift 分溢れる場合を処理 if (shift < 64) { uint64_t remaining = carry >> (64 - shift); if (remaining > 0) { result_words.push_back(remaining); } } } else if (carry > 0) { result_words.push_back(carry); } // 上位の 0 を削除 while (!result_words.empty() && result_words.back() == 0) { result_words.pop_back(); } if (result_words.empty()) { return Int(0); } Int result = Int::fromRawWords(result_words, result_sign); return result; } // 基本乗算 (O(n²) long multiplication, mpn ベース) // 小さい数や Karatsuba の閾値以下で使用 static Int basecaseMultiply(const Int& X, const Int& Y) { int result_sign = X.getSign() * Y.getSign(); if (result_sign == 0) return Int(0); const uint64_t* a = X.m_words.data(); size_t an = X.m_words.size(); const uint64_t* b = Y.m_words.data(); size_t bn = Y.m_words.size(); if (an == 0 || bn == 0) return Int(0); std::vector product(an + bn); mpn::mul_basecase(product.data(), a, an, b, bn); size_t actual = mpn::normalized_size(product.data(), product.size()); if (actual == 0) return Int(0); product.resize(actual); return Int::fromRawWords(product, result_sign); } }; } // namespace calx