Skip to content

Commit

Permalink
Continue block_minres
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 31, 2024
1 parent 0eb30ec commit 2b56bdc
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 45 deletions.
4 changes: 3 additions & 1 deletion src/block_krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ mutable struct BlockMinresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM}
C :: SM
D :: SM
τ :: SV
tmp :: SM
warm_start :: Bool
stats :: SimpleStats{T}
end
Expand All @@ -46,8 +47,9 @@ function BlockMinresSolver(m, n, p, SV, SM)
C = SM(undef, p, p)
D = SM(undef, 2p, p)
τ = SV(undef, p)
tmp = C isa Matrix ? SM(undef, 0, 0) : SM(undef, p, p)
stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown")
solver = BlockMinresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, τ, false, stats)
solver = BlockMinresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, τ, tmp, false, stats)
return solver
end

Expand Down
113 changes: 70 additions & 43 deletions src/block_minres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ Solve the Hermitian linear system AX = B of size n with p right-hand sides using
#### Input arguments
* `A`: a linear operator that models a Hermitian matrix of dimension n;
* `B`: a matrix of size n × p.
* `A`: a linear operator that models a Hermitian matrix of dimension `n`;
* `B`: a matrix of size `n × p`.
#### Optional argument
* `X0`: a matrix of size n × p that represents an initial guess of the solution X.
* `X0`: a matrix of size `n × p` that represents an initial guess of the solution `X`.
#### Keyword arguments
Expand All @@ -45,7 +45,7 @@ Solve the Hermitian linear system AX = B of size n with p right-hand sides using
#### Output arguments
* `X`: a dense matrix of size n × p;
* `X`: a dense matrix of size `n × p`;
* `stats`: statistics collected on the run in a [`SimpleStats`](@ref) structure.
"""
function block_minres end
Expand Down Expand Up @@ -83,27 +83,6 @@ optargs_block_minres = (:X0,)
kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :history, :callback, :iostream)

@eval begin
function block_minres($(def_args_block_minres...), $(def_optargs_block_minres...); $(def_kwargs_block_minres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = BlockMinresSolver(A, B)
warm_start!(solver, $(optargs_block_minres...))
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
block_minres!(solver, $(args_block_minres...); $(kwargs_block_minres...))
solver.stats.timer += elapsed_time
return solver.X, solver.stats
end

function block_minres($(def_args_block_minres...); $(def_kwargs_block_minres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}}
start_time = time_ns()
solver = BlockMinresSolver(A, B)
elapsed_time = ktimer(start_time)
timemax -= elapsed_time
block_minres!(solver, $(args_block_minres...); $(kwargs_block_minres...))
solver.stats.timer += elapsed_time
return solver.X, solver.stats
end

function block_minres!(solver :: BlockMinresSolver{T,FC,SV,SM}, $(def_args_block_minres...); $(def_kwargs_block_minres...)) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, SV <: AbstractVector{FC}, SM <: AbstractMatrix{FC}}

# Timer
Expand Down Expand Up @@ -174,39 +153,87 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his

# Initial Ψ₁ and V₁
copyto!(V, R₀)
householder!(V, Z, τ)
if C isa Matrix
householder!(V, Z, τ)
else
householder!(V, Z, τ, solver.tmp)
end

# Continue the block-Lanczos process.
mul!(W, A, V) # Q ← AVₖ
for i = 1 : inner_iter
mul!(Ω, V', W) # Ωₖ = Vₖᴴ * Q
# (iter ≥ 2) && mul!(Q, ...) # Q ← Q - βₖ * Vₖ₋₁ * Ψₖᴴ
mul!(Q, V, R, α, β) # Q = Q - Vₖ * Ωₖ
mul!(Ω, V', W) # Ωₖ = Vₖᴴ * Q
(iter 2) && mul!(Q, Vold, Ψ') # Q ← Q - Vₖ₋₁ * Ψₖᴴ
mul!(Q, V, R, α, β) # Q = Q - Vₖ * Ωₖ
end

# Vₖ₊₁ and Ψₖ₊₁ are stored in Q and C.
householder!(Q, C, τ)
if C isa Matrix
householder!(Q, C, τ)
else
householder!(Q, C, τ, solver.tmp)
end

# Update the QR factorization of Tₖ₊₁.ₖ.
# Apply previous Householder reflections Ωᵢ.
for i = 1 : inner_iter-1
D1 .= R[nr+i]
D2 .= R[nr+i+1]
@kormqr!('L', trans, H[i], τ[i], D)
R[nr+i] .= D1
R[nr+i+1] .= D2
# Apply previous Householder reflections Θₖ₋₂.
if k 3
D1 .= Rₖ₋₂.
D2 .= Rₖ₋₁.
kormqr!('L', trans, H[i-2], τ[i-2], D)
Rₖ₋₂. .= D1
Rₖ₋₁. .= D2
end

# Compute and apply current Householder reflection Ωₖ.
H[inner_iter][1:p,:] .= R[nr+inner_iter]
# Apply previous Householder reflections Θₖ₋₁.
if k 2
D1 .= Rₖ₋₁.
D2 .= Rₖ.
kormqr!('L', trans, H[i-1], τ[i-1], D)
Rₖ₋₁.ₖ .= D1
Rₖ.ₖ .= D2
end

# Compute and apply current Householder reflection θₖ.
H[inner_iter][1:p,:] .= Rₖ.
H[inner_iter][p+1:2p,:] .= C
householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], compact=true)
if C isa Matrix
householder!(H[i], Rₖ.ₖ, τ[i], compact=true)
else
householder!(H[i], Rₖ.ₖ, τ[i], solver.tmp, compact=true)
end

# Update Zₖ = (Qₖ)ᴴΓE₁ = (Λ₁, ..., Λₖ, Λbarₖ₊₁)
D1 .= Z[inner_iter]
D1 .= Λbarₖ
D2 .= zero(FC)
@kormqr!('L', trans, H[inner_iter], τ[inner_iter], D)
Z[inner_iter] .= D1
kormqr!('L', trans, H[i], τ[i], D)
Λₖ .= D1

# Compute the directions Wₖ, the last columns of Wₖ = Vₖ(Rₖ)⁻¹ ⟷ (Rₖ)ᵀ(Wₖ)ᵀ = (Vₖ)ᵀ
# R₁₁w₁ = v₁
if iter == 1
# wₖ = wₖ₋₁
# kaxpy!(n, one(FC), uₖ, wₖ)
# wₖ .= wₖ ./ δₖ
end
# R₂₂w₂ = (v₂ - R₂₁w₁)
if iter == 2
# wₖ = wₖ₋₂
# kaxpy!(n, -λₖ₋₁, wₖ₋₁, wₖ)
# kaxpy!(n, one(FC), uₖ, wₖ)
# wₖ .= wₖ ./ δₖ
end
# Rₖₖwₖ = (vₖ - Rₖₖ₋₁wₖ₋₁ - Rₖₖ₋₂wₖ₋₂)
if iter 3
# kscal!(n, -ϵₖ₋₂, wₖ₋₂)
# wₖ = wₖ₋₂
# kaxpy!(n, -λₖ₋₁, wₖ₋₁, wₖ)
# kaxpy!(n, one(FC), uₖ, wₖ)
# wₖ .= wₖ ./ δₖ
end

# Update Xₖ = VₖYₖ = WₖZₖ
# Xₖ = Xₖ₋₁ + Λₖ * wₖ
mul!(X, Λₖ, W[i], γ, β)

# Update residual norm estimate.
# ‖ M(B - AXₖ) ‖_F = ‖Λbarₖ₊₁‖_F
Expand Down
3 changes: 2 additions & 1 deletion src/krylov_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
(:TrilqrSolver , :trilqr , args_trilqr , def_args_trilqr , optargs_trilqr , def_optargs_trilqr , kwargs_trilqr , def_kwargs_trilqr )
(:CrmrSolver , :crmr , args_crmr , def_args_crmr , () , () , kwargs_crmr , def_kwargs_crmr )
(:CgSolver , :cg , args_cg , def_args_cg , optargs_cg , def_optargs_cg , kwargs_cg , def_kwargs_cg )
(:CgLanczosShiftSolver, :cg_lanczos_shift, args_cg_lanczos_shift, def_args_cg_lanczos_shift, () , () , kwargs_cg_lanczos_shift, def_kwargs_cg_lanczos_shift)
(:CglsSolver , :cgls , args_cgls , def_args_cgls , () , () , kwargs_cgls , def_kwargs_cgls )
(:CgLanczosSolver , :cg_lanczos , args_cg_lanczos , def_args_cg_lanczos , optargs_cg_lanczos, def_optargs_cg_lanczos, kwargs_cg_lanczos , def_kwargs_cg_lanczos )
(:BilqSolver , :bilq , args_bilq , def_args_bilq , optargs_bilq , def_optargs_bilq , kwargs_bilq , def_kwargs_bilq )
Expand All @@ -58,6 +57,8 @@ for (workspace, krylov, args, def_args, optargs, def_optargs, kwargs, def_kwargs
(:FgmresSolver , :fgmres , args_fgmres , def_args_fgmres , optargs_fgmres , def_optargs_fgmres , kwargs_fgmres , def_kwargs_fgmres )
(:FomSolver , :fom , args_fom , def_args_fom , optargs_fom , def_optargs_fom , kwargs_fom , def_kwargs_fom )
(:GpmrSolver , :gpmr , args_gpmr , def_args_gpmr , optargs_gpmr , def_optargs_gpmr , kwargs_gpmr , def_kwargs_gpmr )
(:CgLanczosShiftSolver , :cg_lanczos_shift , args_cg_lanczos_shift , def_args_cg_lanczos_shift , (), (), kwargs_cg_lanczos_shift , def_kwargs_cg_lanczos_shift )
(:CglsLanczosShiftSolver, :cgls_lanczos_shift, args_cgls_lanczos_shift, def_args_cgls_lanczos_shift, (), (), kwargs_cgls_lanczos_shift, def_kwargs_cgls_lanczos_shift)
]
# Create the symbol for the in-place method
krylov! = Symbol(krylov, :!)
Expand Down

0 comments on commit 2b56bdc

Please sign in to comment.