> {
static constexpr ModularInt add(const ModularInt
& a, const ModularInt
& b) noexcept {
return a + b;
}
static constexpr ModularInt
subtract(const ModularInt
& a, const ModularInt
& b) noexcept {
return a - b;
}
static constexpr ModularInt
multiply(const ModularInt
& a, const ModularInt
& b) noexcept {
return a * b;
}
static ModularInt
divide(const ModularInt
& a, const ModularInt
& b) {
return a / b;
}
static constexpr ModularInt
negate(const ModularInt
& a) noexcept {
return -a;
}
static constexpr ModularInt
zero() noexcept {
return ModularInt
(0);
}
static constexpr ModularInt
one() noexcept {
return ModularInt
(1);
}
static constexpr bool equals(const ModularInt
& a, const ModularInt
& b) noexcept {
return a == b;
}
static constexpr bool less_than(const ModularInt
& a, const ModularInt
& b) noexcept {
return a < b;
}
};
} // namespace concepts
// MKL対応
#if CALX_HAS_MKL
namespace detail {
namespace matrix_mkl {
// ModularInt用のMKL演算インターフェース(MKLは対応していないのでカスタム実装)
template
class ModularIntMklOperations {
public:
// ベクトル加算: y = a*x + y
static void axpy(int n, ModularInt alpha, const ModularInt
* x, int incx,
ModularInt
* y, int incy) {
for (int i = 0; i < n; ++i) {
y[i * incy] += alpha * x[i * incx];
}
}
// ベクトルのドット積: result = x・y
static ModularInt
dot(int n, const ModularInt
* x, int incx,
const ModularInt
* y, int incy) {
ModularInt
result(0);
for (int i = 0; i < n; ++i) {
result += x[i * incx] * y[i * incy];
}
return result;
}
// 行列乗算: C = alpha*A*B + beta*C
static void gemm(int layout, int transA, int transB, int m, int n, int k,
ModularInt
alpha, const ModularInt
* A, int lda,
const ModularInt
* B, int ldb, ModularInt
beta,
ModularInt
* C, int ldc) {
// Cに対するbetaの適用
if (beta != ModularInt
(1)) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
C[i*ldc + j] = beta * C[i*ldc + j];
}
}
}
// 行優先形式のみサポート(実装を簡略化)
if (layout != CblasRowMajor || transA != CblasNoTrans || transB != CblasNoTrans) {
throw MathError("Unsupported matrix layout or transposition for ModularInt");
}
// A*Bの計算とCへの加算
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
ModularInt
sum(0);
for (int l = 0; l < k; ++l) {
sum += A[i*lda + l] * B[l*ldb + j];
}
C[i*ldc + j] += alpha * sum;
}
}
}
// 追加のMKL互換関数を必要に応じて実装
};
} // namespace matrix_mkl
} // namespace detail
#endif // CALX_HAS_MKL
// ModularInt特有の数論関数
// extended_gcd, chinese_remainder_theorem, CRT クラスは CRT.hpp に定義
/**
* @brief 高速モジュラー累乗計算
* @tparam P モジュラス
* @param base 底
* @param exponent 指数
* @return 計算結果
*/
template
constexpr ModularInt mod_pow(const ModularInt
& base, int64_t exponent) {
return base.pow(exponent);
}
/**
* @brief 原始根の計算
* @tparam P 素数モジュラス
* @return Pの原始根またはnullopt(見つからない場合)
*/
template
std::optional> find_primitive_root() {
if (P <= 1) {
return std::nullopt;
}
// P-1の素因数分解
int64_t phi = P - 1;
std::vector prime_factors;
// 素因数分解(単純アルゴリズム)
int64_t n = phi;
for (int64_t i = 2; i * i <= n; ++i) {
if (n % i == 0) {
prime_factors.push_back(i);
while (n % i == 0) {
n /= i;
}
}
}
if (n > 1) {
prime_factors.push_back(n);
}
// 原始根の候補を検索
for (int g = 2; g < P; ++g) {
bool is_primitive = true;
ModularInt candidate(g);
for (int64_t prime : prime_factors) {
if (candidate.pow(phi / prime) == ModularInt
(1)) {
is_primitive = false;
break;
}
}
if (is_primitive) {
return ModularInt
(g);
}
}
return std::nullopt;
}
/**
* @brief 高速モジュラーべき乗計算(非メンバー関数)
* @tparam P モジュラス
* @param base 底
* @param exponent 指数
* @return 計算結果
*/
template
constexpr ModularInt pow_mod(int64_t base, int64_t exponent) {
return ModularInt
(base).pow(exponent);
}
/**
* @brief モジュラー逆元計算(フェルマーの小定理使用)
* @tparam P 素数モジュラス
* @param a 逆元を求める数
* @return aのモジュラー逆元
*/
template
constexpr ModularInt mod_inverse(int64_t a) {
return ModularInt
(a).inverse();
}
/**
* @brief モジュラー合同式ソルバー
* @tparam P モジュラス
* @param a 係数
* @param b 定数項
* @return ax ≡ b (mod P)の解xの配列
*/
template
std::vector> solve_congruence(int64_t a, int64_t b) {
a = (a % P + P) % P;
b = (b % P + P) % P;
std::vector> solutions;
// aとPの最大公約数を計算
int64_t g = std::gcd(a, static_cast(P));
if (b % g != 0) {
// 解なし
return solutions;
}
// a'x ≡ b' (mod P')、ここでa'=a/g, b'=b/g, P'=P/g
a /= g;
b /= g;
int64_t mod_prime = P / g;
// a'の逆元を計算
ModularInt a_inv = ModularInt
(a).inverse();
// 基本解x_0 = a^(-1) * b (mod P/g)
ModularInt
x0 = a_inv * ModularInt
(b);
// 全ての解を生成: x_k = x_0 + k*(P/g) for k = 0,1,...,g-1
for (int64_t k = 0; k < g; ++k) {
solutions.push_back(x0 + ModularInt
(k * mod_prime));
}
return solutions;
}
} // namespace calx
#endif // CALX_MODULAR_INT_TRAITS_HPP