// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // blas.hpp // BLAS レベル 1/2/3 インターフェース // // BLAS (Basic Linear Algebra Subprograms) 風の API を提供する。 // 将来の MKL/OpenBLAS バックエンド切り替えの基盤。 // 現時点ではネイティブ実装 (calx の Vector/Matrix を直接操作)。 #ifndef CALX_BLAS_HPP #define CALX_BLAS_HPP #include #include #include #include namespace calx { namespace blas { // ==================================================================== // Level 1: ベクトル-ベクトル演算 // ==================================================================== /** * @brief 内積: result = x^T * y */ template T dot(const Vector& x, const Vector& y) { if (x.size() != y.size()) { throw DimensionError("blas::dot: size mismatch"); } if constexpr (std::is_same_v || std::is_same_v) { return computation::simd::dot_product_simd(x.data(), y.data(), x.size()); } else { T result = T{0}; for (std::size_t i = 0; i < x.size(); ++i) result += x[i] * y[i]; return result; } } /** * @brief 2-ノルム: result = ‖x‖₂ */ template T nrm2(const Vector& x) { if constexpr (std::is_same_v || std::is_same_v) { T sum = computation::simd::dot_product_simd(x.data(), x.data(), x.size()); return std::sqrt(sum); } else { T sum = T{0}; for (std::size_t i = 0; i < x.size(); ++i) sum += x[i] * x[i]; return static_cast(std::sqrt(static_cast(sum))); } } /** * @brief 1-ノルム (絶対値の和): result = ‖x‖₁ = Σ|x_i| * * BLAS の dasum に相当。 */ template T asum(const Vector& x) { T sum = T{0}; for (std::size_t i = 0; i < x.size(); ++i) sum += static_cast(std::abs(static_cast(x[i]))); return sum; } /** * @brief 絶対値最大要素のインデックス: result = argmax_i |x_i| * * BLAS の idamax に相当。空ベクトルの場合は 0 を返す。 */ template std::size_t iamax(const Vector& x) { if (x.size() == 0) return 0; std::size_t idx = 0; double max_val = std::abs(static_cast(x[0])); for (std::size_t i = 1; i < x.size(); ++i) { double ai = std::abs(static_cast(x[i])); if (ai > max_val) { max_val = ai; idx = i; } } return idx; } /** * @brief スケーリング: x ← alpha * x */ template void scal(T alpha, Vector& x) { if constexpr (std::is_same_v || std::is_same_v) { computation::simd::scale_simd(x.data(), alpha, x.size()); } else { for (std::size_t i = 0; i < x.size(); ++i) x[i] *= alpha; } } /** * @brief AXPY: y ← alpha * x + y */ template void axpy(T alpha, const Vector& x, Vector& y) { if (x.size() != y.size()) { throw DimensionError("blas::axpy: size mismatch"); } if constexpr (std::is_same_v || std::is_same_v) { computation::simd::axpy_simd(y.data(), alpha, x.data(), x.size()); } else { for (std::size_t i = 0; i < x.size(); ++i) y[i] += alpha * x[i]; } } /** * @brief コピー: y ← x */ template void copy(const Vector& x, Vector& y) { if (x.size() != y.size()) { throw DimensionError("blas::copy: size mismatch"); } for (std::size_t i = 0; i < x.size(); ++i) y[i] = x[i]; } /** * @brief スワップ: x ↔ y */ template void swap(Vector& x, Vector& y) { if (x.size() != y.size()) { throw DimensionError("blas::swap: size mismatch"); } for (std::size_t i = 0; i < x.size(); ++i) std::swap(x[i], y[i]); } // ==================================================================== // Level 2: 行列-ベクトル演算 // ==================================================================== /** * @brief 一般行列-ベクトル積: y ← alpha * op(A) * x + beta * y * * @param trans false: op(A) = A, true: op(A) = A^T * @param alpha スカラー係数 * @param A m×n 行列 * @param x 入力ベクトル * @param beta スカラー係数 * @param y 出力ベクトル (in-place 更新) */ template void gemv(bool trans, T alpha, const Matrix& A, const Vector& x, T beta, Vector& y) { const auto m = A.rows(); const auto n = A.cols(); if (!trans) { // y ← alpha * A * x + beta * y if (x.size() != n || y.size() != m) { throw DimensionError("blas::gemv: dimension mismatch"); } if constexpr (std::is_same_v || std::is_same_v) { for (std::size_t i = 0; i < m; ++i) { T sum = computation::simd::dot_product_simd(&A(i, 0), x.data(), n); y[i] = alpha * sum + beta * y[i]; } } else { for (std::size_t i = 0; i < m; ++i) { T sum = T{0}; for (std::size_t j = 0; j < n; ++j) sum += A(i, j) * x[j]; y[i] = alpha * sum + beta * y[i]; } } } else { // y ← alpha * A^T * x + beta * y if (x.size() != m || y.size() != n) { throw DimensionError("blas::gemv: dimension mismatch (transpose)"); } if constexpr (std::is_same_v || std::is_same_v) { computation::simd::scale_simd(y.data(), beta, n); for (std::size_t i = 0; i < m; ++i) { T ax = alpha * x[i]; computation::simd::axpy_simd(y.data(), ax, &A(i, 0), n); } } else { for (std::size_t j = 0; j < n; ++j) y[j] *= beta; for (std::size_t i = 0; i < m; ++i) { T ax = alpha * x[i]; for (std::size_t j = 0; j < n; ++j) y[j] += ax * A(i, j); } } } } /** * @brief 三角行列の連立方程式を解く: x ← op(A)⁻¹ * x * * @param upper true: 上三角, false: 下三角 * @param trans true: A^T, false: A * @param unit_diag true: 対角要素を 1 とみなす * @param A 三角行列 * @param x 右辺ベクトル (in-place 更新で解を返す) */ template void trsv(bool upper, bool trans, bool unit_diag, const Matrix& A, Vector& x) { const auto n = A.rows(); if (A.cols() != n || x.size() != n) { throw DimensionError("blas::trsv: dimension mismatch"); } if (!trans) { if (upper) { // 上三角・後退代入 for (std::size_t ii = 0; ii < n; ++ii) { std::size_t i = n - 1 - ii; for (std::size_t j = i + 1; j < n; ++j) x[i] -= A(i, j) * x[j]; if (!unit_diag) x[i] /= A(i, i); } } else { // 下三角・前方代入 for (std::size_t i = 0; i < n; ++i) { for (std::size_t j = 0; j < i; ++j) x[i] -= A(i, j) * x[j]; if (!unit_diag) x[i] /= A(i, i); } } } else { if (upper) { // 上三角転置 → 下三角として前方代入 for (std::size_t i = 0; i < n; ++i) { for (std::size_t j = 0; j < i; ++j) x[i] -= A(j, i) * x[j]; if (!unit_diag) x[i] /= A(i, i); } } else { // 下三角転置 → 上三角として後退代入 for (std::size_t ii = 0; ii < n; ++ii) { std::size_t i = n - 1 - ii; for (std::size_t j = i + 1; j < n; ++j) x[i] -= A(j, i) * x[j]; if (!unit_diag) x[i] /= A(i, i); } } } } /** * @brief ランク 1 更新: A ← alpha * x * y^T + A */ template void ger(T alpha, const Vector& x, const Vector& y, Matrix& A) { if (x.size() != A.rows() || y.size() != A.cols()) { throw DimensionError("blas::ger: dimension mismatch"); } for (std::size_t i = 0; i < A.rows(); ++i) { T ax = alpha * x[i]; for (std::size_t j = 0; j < A.cols(); ++j) A(i, j) += ax * y[j]; } } /** * @brief 対称ランク 1 更新: A ← alpha * x * x^T + A * * A は対称行列 (上三角のみ更新し、下三角にコピー)。 */ template void syr(bool upper, T alpha, const Vector& x, Matrix& A) { const auto n = x.size(); if (A.rows() != n || A.cols() != n) { throw DimensionError("blas::syr: dimension mismatch"); } if (upper) { for (std::size_t i = 0; i < n; ++i) { T ax = alpha * x[i]; for (std::size_t j = i; j < n; ++j) { A(i, j) += ax * x[j]; if (i != j) A(j, i) = A(i, j); } } } else { for (std::size_t j = 0; j < n; ++j) { T ax = alpha * x[j]; for (std::size_t i = j; i < n; ++i) { A(i, j) += ax * x[i]; if (i != j) A(j, i) = A(i, j); } } } } // ==================================================================== // Level 3: 行列-行列演算 // ==================================================================== /** * @brief 一般行列積: C ← alpha * op(A) * op(B) + beta * C * * @param transA false: op(A) = A, true: op(A) = A^T * @param transB false: op(B) = B, true: op(B) = B^T */ template void gemm(bool transA, bool transB, T alpha, const Matrix& A, const Matrix& B, T beta, Matrix& C) { const auto opA_rows = transA ? A.cols() : A.rows(); const auto opA_cols = transA ? A.rows() : A.cols(); const auto opB_rows = transB ? B.cols() : B.rows(); const auto opB_cols = transB ? B.rows() : B.cols(); if (opA_cols != opB_rows || C.rows() != opA_rows || C.cols() != opB_cols) { throw DimensionError("blas::gemm: dimension mismatch"); } const auto m = opA_rows; const auto n = opB_cols; const auto k = opA_cols; // C ← beta * C if (beta == T{0}) { for (std::size_t i = 0; i < m; ++i) for (std::size_t j = 0; j < n; ++j) C(i, j) = T{0}; } else if (beta != T{1}) { for (std::size_t i = 0; i < m; ++i) for (std::size_t j = 0; j < n; ++j) C(i, j) *= beta; } // C += alpha * op(A) * op(B) // 4 つの転置組み合わせ if (!transA && !transB) { if constexpr (std::is_same_v || std::is_same_v) { // ブロッキング + SIMD (ipj ループ順) constexpr std::size_t BLK = 64; for (std::size_t i0 = 0; i0 < m; i0 += BLK) { const std::size_t i1 = (std::min)(i0 + BLK, m); for (std::size_t p0 = 0; p0 < k; p0 += BLK) { const std::size_t p1 = (std::min)(p0 + BLK, k); for (std::size_t j0 = 0; j0 < n; j0 += BLK) { const std::size_t j1 = (std::min)(j0 + BLK, n); const std::size_t jlen = j1 - j0; // マイクロカーネル for (std::size_t i = i0; i < i1; ++i) { for (std::size_t p = p0; p < p1; ++p) { T aip = alpha * A(i, p); computation::simd::axpy_simd( &C(i, j0), aip, &B(p, j0), jlen); } } } } } } else { for (std::size_t i = 0; i < m; ++i) for (std::size_t p = 0; p < k; ++p) { T aip = alpha * A(i, p); for (std::size_t j = 0; j < n; ++j) C(i, j) += aip * B(p, j); } } } else if (transA && !transB) { for (std::size_t i = 0; i < m; ++i) for (std::size_t p = 0; p < k; ++p) { T api = alpha * A(p, i); for (std::size_t j = 0; j < n; ++j) C(i, j) += api * B(p, j); } } else if (!transA && transB) { for (std::size_t i = 0; i < m; ++i) for (std::size_t j = 0; j < n; ++j) { T sum = T{0}; for (std::size_t p = 0; p < k; ++p) sum += A(i, p) * B(j, p); C(i, j) += alpha * sum; } } else { // transA && transB for (std::size_t i = 0; i < m; ++i) for (std::size_t j = 0; j < n; ++j) { T sum = T{0}; for (std::size_t p = 0; p < k; ++p) sum += A(p, i) * B(j, p); C(i, j) += alpha * sum; } } } /** * @brief 三角行列の連立方程式を解く (行列版): B ← op(A)⁻¹ * B or B ← B * op(A)⁻¹ * * @param side true: left (A⁻¹B), false: right (BA⁻¹) * @param upper true: 上三角, false: 下三角 * @param trans true: A^T, false: A * @param unit_diag true: 対角要素を 1 とみなす * @param alpha スカラー係数 (B ← alpha * ...) * @param A 三角行列 (n×n) * @param B 入出力行列 */ template void trsm(bool side_left, bool upper, bool trans, bool unit_diag, T alpha, const Matrix& A, Matrix& B) { const auto n = A.rows(); if (A.cols() != n) { throw DimensionError("blas::trsm: A must be square"); } // alpha スケーリング if (alpha != T{1}) { for (std::size_t i = 0; i < B.rows(); ++i) for (std::size_t j = 0; j < B.cols(); ++j) B(i, j) *= alpha; } if (side_left) { // B ← A⁻¹ * B: 各列を独立に trsv if (B.rows() != n) { throw DimensionError("blas::trsm: dimension mismatch (left)"); } for (std::size_t col = 0; col < B.cols(); ++col) { Vector b_col(n); for (std::size_t i = 0; i < n; ++i) b_col[i] = B(i, col); trsv(upper, trans, unit_diag, A, b_col); for (std::size_t i = 0; i < n; ++i) B(i, col) = b_col[i]; } } else { // B ← B * A⁻¹: 各行を独立に trsv (A^T の trsv と等価) if (B.cols() != n) { throw DimensionError("blas::trsm: dimension mismatch (right)"); } for (std::size_t row = 0; row < B.rows(); ++row) { Vector b_row(n); for (std::size_t j = 0; j < n; ++j) b_row[j] = B(row, j); trsv(upper, !trans, unit_diag, A, b_row); for (std::size_t j = 0; j < n; ++j) B(row, j) = b_row[j]; } } } /** * @brief 対称ランク k 更新: C ← alpha * A * A^T + beta * C (trans=false) * C ← alpha * A^T * A + beta * C (trans=true) * * C は対称行列 (n×n)。 * trans=false: A は n×k, C ← alpha*A*A^T + beta*C * trans=true: A は k×n, C ← alpha*A^T*A + beta*C */ template void syrk(bool upper, bool trans, T alpha, const Matrix& A, T beta, Matrix& C) { const auto n = C.rows(); if (C.cols() != n) { throw DimensionError("blas::syrk: C must be square"); } std::size_t k; if (!trans) { if (A.rows() != n) throw DimensionError("blas::syrk: dimension mismatch"); k = A.cols(); } else { if (A.cols() != n) throw DimensionError("blas::syrk: dimension mismatch"); k = A.rows(); } // C ← beta * C (上三角 or 下三角のみ) for (std::size_t i = 0; i < n; ++i) { std::size_t j_start = upper ? i : 0; std::size_t j_end = upper ? n : i + 1; for (std::size_t j = j_start; j < j_end; ++j) C(i, j) *= beta; } // C += alpha * op(A) * op(A)^T if (!trans) { // C += alpha * A * A^T for (std::size_t i = 0; i < n; ++i) { std::size_t j_start = upper ? i : 0; std::size_t j_end = upper ? n : i + 1; for (std::size_t j = j_start; j < j_end; ++j) { T sum = T{0}; for (std::size_t p = 0; p < k; ++p) sum += A(i, p) * A(j, p); C(i, j) += alpha * sum; } } } else { // C += alpha * A^T * A for (std::size_t i = 0; i < n; ++i) { std::size_t j_start = upper ? i : 0; std::size_t j_end = upper ? n : i + 1; for (std::size_t j = j_start; j < j_end; ++j) { T sum = T{0}; for (std::size_t p = 0; p < k; ++p) sum += A(p, i) * A(p, j); C(i, j) += alpha * sum; } } } // 対称部分をコピー for (std::size_t i = 0; i < n; ++i) for (std::size_t j = i + 1; j < n; ++j) { if (upper) C(j, i) = C(i, j); else C(i, j) = C(j, i); } } } // namespace blas } // namespace calx #endif // CALX_BLAS_HPP