// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later /** * @file sparse_lu.hpp * @brief 疎行列用 LU 分解 (SparseLU) * * 左向き列 LU 分解 + 部分ピボット。 * L, U を CSC 形式で格納し、疎構造を保持。 * * API は Eigen の SparseLU に準拠: * SparseLU lu; * lu.compute(A); * Vector x = lu.solve(b); */ #ifndef CALX_SPARSE_LU_HPP #define CALX_SPARSE_LU_HPP #include #include #include #include #include #include #include "../core/sparse_matrix.hpp" #include "../core/vector.hpp" namespace calx { /** * @brief 疎行列直接 LU 分解 (CSC 格納) * * P * A = L * U (P は行置換) * L: 単位下三角 (CSC), U: 上三角 (CSC) * 左向き列分解で fill-in を管理。 */ template class SparseLU { public: SparseLU() = default; /// 疎行列を LU 分解する void compute(const SparseMatrix& A) { const auto n = static_cast(A.rows()); if (static_cast(A.cols()) != n) throw std::invalid_argument("SparseLU: matrix must be square"); n_ = n; computed_ = false; // CSC データ取得 const auto& a_val = A.csc_values(); const auto& a_row = A.csc_row_indices(); const auto& a_col = A.csc_col_ptr(); // 置換配列 perm_.resize(n); // perm_[k] = ピボット行 k の元の行番号 inv_perm_.resize(n); // inv_perm_[元の行] = 現在の行番号 std::iota(perm_.begin(), perm_.end(), std::size_t(0)); std::iota(inv_perm_.begin(), inv_perm_.end(), std::size_t(0)); // L, U を列ごとに構築 (一時的に vector of vector) std::vector> L_rows(n), U_rows(n); std::vector> L_vals(n), U_vals(n); // 密ワークベクトル (列ごとに再利用) std::vector x(n, T(0)); std::vector x_nz(n, 0); // 非ゼロフラグ (vector は proxy で swap 不可) std::vector nz_list; // 非ゼロ位置リスト nz_list.reserve(n); for (std::size_t k = 0; k < n; ++k) { // 1. A の列 k を散布 (元の行番号 → 現在の行番号) nz_list.clear(); auto col_start = static_cast(a_col[k]); auto col_end = static_cast(a_col[k + 1]); for (auto p = col_start; p < col_end; ++p) { std::size_t orig_row = static_cast(a_row[p]); std::size_t cur_row = inv_perm_[orig_row]; x[cur_row] = a_val[p]; if (!x_nz[cur_row]) { x_nz[cur_row] = 1; nz_list.push_back(cur_row); } } // 2. 左向き解法: j = 0..k-1 で x[j] ≠ 0 なら L(:,j) で更新 // 行番号順に処理するため nz_list をソート std::sort(nz_list.begin(), nz_list.end()); for (std::size_t idx = 0; idx < nz_list.size(); ++idx) { std::size_t j = nz_list[idx]; if (j >= k) break; if (x[j] == T(0)) continue; // L(:,j) の各エントリで更新 for (std::size_t q = 0; q < L_rows[j].size(); ++q) { std::size_t i = L_rows[j][q]; T lij = L_vals[j][q]; T old = x[i]; x[i] -= lij * x[j]; if (!x_nz[i] && x[i] != T(0)) { x_nz[i] = 1; nz_list.push_back(i); // 挿入後にソート位置を維持 (末尾追加→後で処理) } } } // fill-in で追加された要素があるので再ソート std::sort(nz_list.begin(), nz_list.end()); // 重複除去 nz_list.erase(std::unique(nz_list.begin(), nz_list.end()), nz_list.end()); // 3. 部分ピボット: x[k:n-1] で最大値を探す std::size_t pivot = k; T pivot_val = std::abs(x[k]); for (auto r : nz_list) { if (r < k) continue; T v = std::abs(x[r]); if (v > pivot_val) { pivot_val = v; pivot = r; } } if (pivot_val < std::numeric_limits::epsilon() * T(100)) throw std::runtime_error("SparseLU: singular matrix"); // 4. 行交換 (pivot ↔ k) if (pivot != k) { std::swap(x[k], x[pivot]); std::swap(x_nz[k], x_nz[pivot]); // 置換更新 std::size_t ok = perm_[k], op = perm_[pivot]; std::swap(perm_[k], perm_[pivot]); inv_perm_[ok] = pivot; inv_perm_[op] = k; // L の既存列 (0..k-1) で行 k と pivot を交換 for (std::size_t j = 0; j < k; ++j) { T* vk = nullptr; T* vp = nullptr; for (std::size_t q = 0; q < L_rows[j].size(); ++q) { if (L_rows[j][q] == k) vk = &L_vals[j][q]; if (L_rows[j][q] == pivot) vp = &L_vals[j][q]; } if (vk && vp) { std::swap(*vk, *vp); } else if (vk) { // k にはあるが pivot にはない // k の位置を pivot に変更 for (std::size_t q = 0; q < L_rows[j].size(); ++q) { if (L_rows[j][q] == k) { L_rows[j][q] = pivot; break; } } } else if (vp) { for (std::size_t q = 0; q < L_rows[j].size(); ++q) { if (L_rows[j][q] == pivot) { L_rows[j][q] = k; break; } } } } } // 5. U 列 k: x[0..k] の非ゼロ部分 for (auto r : nz_list) { if (r > k) break; if (x[r] != T(0)) { U_rows[k].push_back(r); U_vals[k].push_back(x[r]); } } // 対角が含まれていなければ追加 (数値的にゼロ近辺) T diag = x[k]; // 6. L 列 k: x[k+1..n-1] / diag の非ゼロ部分 T inv_diag = T(1) / diag; for (auto r : nz_list) { if (r <= k) continue; if (x[r] != T(0)) { L_rows[k].push_back(r); L_vals[k].push_back(x[r] * inv_diag); } } // 7. ワークベクトルのクリア (非ゼロ位置のみ) for (auto r : nz_list) { x[r] = T(0); x_nz[r] = 0; } } // CSC 形式に変換 build_csc(L_rows, L_vals, L_col_ptr_, L_row_ind_, L_val_); build_csc(U_rows, U_vals, U_col_ptr_, U_row_ind_, U_val_); computed_ = true; } /// 連立方程式 A*x = b を解く [[nodiscard]] Vector solve(const Vector& b) const { if (!computed_) throw std::runtime_error("SparseLU: compute() not called"); if (b.size() != n_) throw std::invalid_argument("SparseLU::solve: dimension mismatch"); // P * b (行置換適用) Vector y(n_); for (std::size_t i = 0; i < n_; ++i) y[i] = b[perm_[i]]; // 前方代入: L * z = y (L は単位下三角, CSC) // 列 k を処理: z[i] -= L(i,k) * z[k] (i > k) for (std::size_t k = 0; k < n_; ++k) { if (y[k] == T(0)) continue; auto start = L_col_ptr_[k]; auto end = L_col_ptr_[k + 1]; for (auto p = start; p < end; ++p) y[L_row_ind_[p]] -= L_val_[p] * y[k]; } // 後退代入: U * x = z (U は上三角, CSC) // 列 k を右から処理: x[k] = (z[k] - sum U(i,k)*x[i]) / U(k,k) for (std::size_t kk = 0; kk < n_; ++kk) { std::size_t k = n_ - 1 - kk; auto start = U_col_ptr_[k]; auto end = U_col_ptr_[k + 1]; // U 列 k の対角は最後の要素 (行番号 k) T diag = T(0); for (auto p = start; p < end; ++p) { if (U_row_ind_[p] == k) { diag = U_val_[p]; } else if (U_row_ind_[p] < k) { // nothing — will handle in column processing from right } } y[k] /= diag; // 列 k の対角以外のエントリで更新 for (auto p = start; p < end; ++p) { if (U_row_ind_[p] < k) y[U_row_ind_[p]] -= U_val_[p] * y[k]; } } return y; } [[nodiscard]] bool computed() const { return computed_; } [[nodiscard]] std::size_t rows() const { return n_; } [[nodiscard]] std::size_t cols() const { return n_; } private: std::size_t n_ = 0; bool computed_ = false; std::vector perm_; std::vector inv_perm_; // L (単位下三角, 対角は暗黙的に 1) を CSC で格納 std::vector L_col_ptr_; std::vector L_row_ind_; std::vector L_val_; // U (上三角) を CSC で格納 std::vector U_col_ptr_; std::vector U_row_ind_; std::vector U_val_; /// vector of vector → CSC 変換 static void build_csc(const std::vector>& rows, const std::vector>& vals, std::vector& col_ptr, std::vector& row_ind, std::vector& val) { const auto n = rows.size(); col_ptr.resize(n + 1); col_ptr[0] = 0; std::size_t total = 0; for (std::size_t j = 0; j < n; ++j) { total += rows[j].size(); col_ptr[j + 1] = total; } row_ind.resize(total); val.resize(total); std::size_t pos = 0; for (std::size_t j = 0; j < n; ++j) { for (std::size_t q = 0; q < rows[j].size(); ++q) { row_ind[pos] = rows[j][q]; val[pos] = vals[j][q]; ++pos; } } } }; } // namespace calx #endif // CALX_SPARSE_LU_HPP