From 9a6aa8ac903101d2c4e407c46cd8dfad7f11f469 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Sat, 26 Oct 2024 15:14:10 -0700 Subject: [PATCH] Fix block_gmres on GPUs --- src/block_gmres.jl | 18 +++++++++++++++--- src/block_krylov_solvers.jl | 4 +++- src/block_krylov_utils.jl | 10 ++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/block_gmres.jl b/src/block_gmres.jl index 28851cc08..bc323494a 100644 --- a/src/block_gmres.jl +++ b/src/block_gmres.jl @@ -196,7 +196,11 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto # Initial Γ and V₁ copyto!(V[1], R₀) - householder!(V[1], Z[1], τ[1]) + if C isa Matrix + householder!(V[1], Z[1], τ[1]) + else + householder!(V[1], Z[1], τ[1], solver.tmp) + end npass = npass + 1 inner_iter = 0 @@ -236,7 +240,11 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto end # Vₖ₊₁ and Ψₖ₊₁.ₖ are stored in Q and C. - householder!(Q, C, τ[inner_iter]) + if C isa Matrix + householder!(Q, C, τ[inner_iter]) + else + householder!(Q, C, τ[inner_iter], solver.tmp) + end # Update the QR factorization of Hₖ₊₁.ₖ. # Apply previous Householder reflections Ωᵢ. @@ -251,7 +259,11 @@ kwargs_block_gmres = (:M, :N, :ldiv, :restart, :reorthogonalization, :atol, :rto # Compute and apply current Householder reflection Ωₖ. H[inner_iter][1:p,:] .= R[nr+inner_iter] H[inner_iter][p+1:2p,:] .= C - householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], compact=true) + if C isa Matrix + householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], compact=true) + else + householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], solver.tmp, compact=true) + end # Update Zₖ = (Qₖ)ᴴΓE₁ = (Λ₁, ..., Λₖ, Λbarₖ₊₁) D1 .= Z[inner_iter] diff --git a/src/block_krylov_solvers.jl b/src/block_krylov_solvers.jl index 67cdace91..a1e8db13a 100644 --- a/src/block_krylov_solvers.jl +++ b/src/block_krylov_solvers.jl @@ -35,6 +35,7 @@ mutable struct BlockGmresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM} R :: Vector{SM} H :: Vector{SM} τ :: Vector{SV} + tmp :: SM warm_start :: Bool stats :: SimpleStats{T} end @@ -55,8 +56,9 @@ function BlockGmresSolver(m, n, p, memory, SV, SM) R = SM[SM(undef, p, p) for i = 1 : div(memory * (memory+1), 2)] H = SM[SM(undef, 2p, p) for i = 1 : memory] τ = SV[SV(undef, p) for i = 1 : memory] + tmp = C isa Matrix ? SM(undef, 0, 0) : SM(undef, p, p) stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown") - solver = BlockGmresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, V, Z, R, H, τ, false, stats) + solver = BlockGmresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, V, Z, R, H, τ, tmp, false, stats) return solver end diff --git a/src/block_krylov_utils.jl b/src/block_krylov_utils.jl index 0429a82bc..dae7be166 100644 --- a/src/block_krylov_utils.jl +++ b/src/block_krylov_utils.jl @@ -192,6 +192,16 @@ function householder(A::AbstractMatrix{FC}; compact::Bool=false) where FC <: Flo householder!(Q, R, τ; compact) end +function householder!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, τ::AbstractVector{FC}, tmp::AbstractMatrix{FC}; compact::Bool=false) where FC <: FloatOrComplex + n, k = size(Q) + kfill!(R, zero(FC)) + kgeqrf!(Q, τ) + copyto!(tmp, view(Q, 1:k, 1:k)) + copy_triangle(tmp, R, k) + !compact && korgqr!(Q, τ) + return Q, R +end + function householder!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, τ::AbstractVector{FC}; compact::Bool=false) where FC <: FloatOrComplex n, k = size(Q) kfill!(R, zero(FC))