Skip to content

Commit

Permalink
Start to optimize the block Krylov processes
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Sep 14, 2023
1 parent 114949d commit f7fc22b
Showing 1 changed file with 54 additions and 25 deletions.
79 changes: 54 additions & 25 deletions src/block_krylov_processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@ function gs(A::AbstractMatrix{FC}) where FC <: FloatOrComplex
n, k = size(A)
Q = copy(A)
R = zeros(FC, k, k)
v = zeros(FC, n)
gs!(Q, R, v)
end

function gs!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}, v::AbstractVector{FC}) where FC <: FloatOrComplex
n, k = size(Q)
aⱼ = v
for j = 1:k
aⱼ = view(A,:,j)
qⱼ = view(Q,:,j)
aⱼ .= qⱼ
for i = 1:j-1
qᵢ = view(Q,:,i)
R[i,j] = @kdot(n, qᵢ, aⱼ) # rᵢⱼ = ⟨qᵢ , aⱼ⟩
Expand All @@ -41,6 +48,11 @@ function mgs(A::AbstractMatrix{FC}) where FC <: FloatOrComplex
n, k = size(A)
Q = copy(A)
R = zeros(FC, k, k)
mgs!(Q, R)
end

function mgs!(Q::AbstractMatrix{FC}, R::AbstractMatrix{FC}) where FC <: FloatOrComplex
n, k = size(Q)
for i = 1:k
qᵢ = view(Q,:,i)
R[i,i] = @knrm2(n, qᵢ) # rᵢᵢ = ‖qᵢ‖
Expand Down Expand Up @@ -80,13 +92,12 @@ function householder(A::AbstractMatrix{FC}) where FC <: FloatOrComplex
end

function reduced_qr(V::AbstractMatrix{FC}, algo::String) where FC <: FloatOrComplex
algo == "gs" || algo == "mgs" || algo == "givens" || error("$algo is not a supported method to perform a reduced QR.")
if algo == "gs"
Q, R = gs(V)
elseif algo == "mgs"
Q, R = mgs(V)
else
Q, R = givens(V)
error("$algo is not a supported method to perform a reduced QR.")

Check warning on line 100 in src/block_krylov_processes.jl

View check run for this annotation

Codecov / codecov/patch

src/block_krylov_processes.jl#L100

Added line #L100 was not covered by tests
end
return Q, R
end
Expand All @@ -113,6 +124,9 @@ function hermitian_lanczos(A, B::AbstractMatrix{FC}, k::Int; algo::String="mgs")
V = zeros(FC, n, (k+1)*p)
T = spzeros(FC, (k+1)*p, k*p)

q = zeros(FC, n, p)
Ωᵢ = zeros(FC, p, p)

for i = 1:k
pos1 = (i-1)*p + 1
pos2 = i*p
Expand All @@ -122,17 +136,17 @@ function hermitian_lanczos(A, B::AbstractMatrix{FC}, k::Int; algo::String="mgs")
Q, Ψᵢ = reduced_qr(B, algo)
vᵢ .= Q
end
aux = A * vᵢ
mul!(q, A, vᵢ)
if i 2
vᵢ₋₁ = view(V,:,pos1-p:pos2-p)
Ψᵢ = T[pos1:pos2,pos1-p:pos2-p]
T[pos1-p:pos2-p,pos1:pos2] .= Ψᵢ'
aux = aux - vᵢ₋₁ * Ψᵢ'
q = q - vᵢ₋₁ * Ψᵢ'
end
Ωᵢ = vᵢ' * aux
mul!(Ωᵢ, vᵢ', q) # Ωᵢ = vᵢᵀq
T[pos1:pos2,pos1:pos2] .= Ωᵢ
aux = aux - vᵢ * Ωᵢ
Q, Ψᵢ₊₁ = reduced_qr(aux, algo)
q = q - vᵢ * Ωᵢ
Q, Ψᵢ₊₁ = reduced_qr(q, algo)
vᵢ₊₁ .= Q
T[pos1+p:pos2+p,pos1:pos2] .= Ψᵢ₊₁
end
Expand Down Expand Up @@ -167,6 +181,10 @@ function nonhermitian_lanczos(A, B::AbstractMatrix{FC}, C::AbstractMatrix{FC}, k
T = zeros(FC, (k+1)*p, k*p)
Tᴴ = zeros(FC, (k+1)*p, k*p)

qv = zeros(FC, n, p)
qu = zeros(FC, n, p)
Ωᵢ = zeros(FC, p, p)

for i = 1:k
pos1 = (i-1)*p + 1
pos2 = i*p
Expand All @@ -184,8 +202,8 @@ function nonhermitian_lanczos(A, B::AbstractMatrix{FC}, C::AbstractMatrix{FC}, k
vᵢ .= (Ψᵢ' \ B')'
uᵢ .= (Φᵢ \ C')'
end
qv = A * vᵢ
qu = Aᴴ * uᵢ
mul!(qv, A, vᵢ)
mul!(qu, Aᴴ, uᵢ)
if i 2
pos5 = pos1 - p
pos6 = pos2 - p
Expand All @@ -198,7 +216,7 @@ function nonhermitian_lanczos(A, B::AbstractMatrix{FC}, C::AbstractMatrix{FC}, k
Tᴴ[pos5:pos6,pos1:pos2] = TΨᵢ
qu = qu - uᵢ₋₁ * TΨᵢ
end
Ωᵢ = uᵢ' * qv
mul!(Ωᵢ, uᵢ', qv)
T[pos1:pos2,pos1:pos2] .= Ωᵢ
Tᴴ[pos1:pos2,pos1:pos2] .= Ωᵢ'
qv = qv - vᵢ * Ωᵢ
Expand Down Expand Up @@ -239,6 +257,7 @@ function arnoldi(A, B::AbstractMatrix{FC}, k::Int; algo::String="mgs", reorthogo

V = zeros(FC, n, (k+1)*p)
H = zeros(FC, (k+1)*p, k*p)
q = zeros(FC, n, p)

for j = 1:k
pos1 = (j-1)*p + 1
Expand All @@ -249,7 +268,7 @@ function arnoldi(A, B::AbstractMatrix{FC}, k::Int; algo::String="mgs", reorthogo
Q, Γ = reduced_qr(B, algo)
vⱼ .= Q
end
q = A * vⱼ
mul!(q, A, vⱼ)
for i = 1:j
pos3 = (i-1)*p + 1
pos4 = i*p
Expand Down Expand Up @@ -301,6 +320,9 @@ function golub_kahan(A, B::AbstractMatrix{FC}, k::Int; algo::String="mgs") where
U = zeros(FC, m, (k+1)*p)
L = spzeros(FC, (k+1)*p, (k+1)*p)

qv = zeros(FC, n, p)
qu = zeros(FC, m, p)

for i = 1:k
pos1 = (i-1)*p + 1
pos2 = i*p
Expand All @@ -313,20 +335,20 @@ function golub_kahan(A, B::AbstractMatrix{FC}, k::Int; algo::String="mgs") where
if i == 1
Qu, Ψᵢ = reduced_qr(B, algo)
uᵢ .= Qu
q = Aᴴ * uᵢ
Qv, TΩᵢ = reduced_qr(q, algo)
mul!(qv, Aᴴ, uᵢ)
Qv, TΩᵢ = reduced_qr(qv, algo)
vᵢ .= Qv
L[pos1:pos2,pos1:pos2] .= TΩᵢ'
end
aux1 = A * vᵢ
mul!(qu, A, vᵢ)
Ωᵢ = L[pos1:pos2,pos1:pos2]
aux1 = aux1 - uᵢ * Ωᵢ
Qu, Ψᵢ₊₁ = reduced_qr(aux1, algo)
qu = qu - uᵢ * Ωᵢ
Qu, Ψᵢ₊₁ = reduced_qr(qu, algo)
uᵢ₊₁ .= Qu
L[pos3:pos4,pos1:pos2] .= Ψᵢ₊₁
aux2 = Aᴴ * uᵢ₊₁
aux2 = aux2 - vᵢ * Ψᵢ₊₁'
Qv, TΩᵢ₊₁ = reduced_qr(aux2, algo)
mul!(qv, Aᴴ, uᵢ₊₁)
qv = qv - vᵢ * Ψᵢ₊₁'
Qv, TΩᵢ₊₁ = reduced_qr(qv, algo)
vᵢ₊₁ .= Qv
L[pos3:pos4,pos3:pos4] .= TΩᵢ₊₁'
end
Expand Down Expand Up @@ -360,6 +382,10 @@ function saunders_simon_yip(A, B::AbstractMatrix{FC}, C::AbstractMatrix{FC}, k::
T = zeros(FC, (k+1)*p, k*p)
Tᴴ = zeros(FC, (k+1)*p, k*p)

qv = zeros(FC, m, p)
qu = zeros(FC, n, p)
Ωᵢ = zeros(FC, p, p)

for i = 1:k
pos1 = (i-1)*p + 1
pos2 = i*p
Expand All @@ -375,8 +401,8 @@ function saunders_simon_yip(A, B::AbstractMatrix{FC}, C::AbstractMatrix{FC}, k::
Qu, TΦᵢ = reduced_qr(C, algo)
uᵢ .= Qu
end
qv = A * uᵢ
qu = Aᴴ * vᵢ
mul!(qv, A, uᵢ)
mul!(qu, Aᴴ, vᵢ)
if i 2
pos5 = pos1 - p
pos6 = pos2 - p
Expand All @@ -389,7 +415,7 @@ function saunders_simon_yip(A, B::AbstractMatrix{FC}, C::AbstractMatrix{FC}, k::
Tᴴ[pos5:pos6,pos1:pos2] = TΨᵢ
qu = qu - uᵢ₋₁ * TΨᵢ
end
Ωᵢ = vᵢ' * qv
mul!(Ωᵢ, vᵢ', qv)
T[pos1:pos2,pos1:pos2] .= Ωᵢ
Tᴴ[pos1:pos2,pos1:pos2] .= Ωᵢ'
qv = qv - vᵢ * Ωᵢ
Expand Down Expand Up @@ -435,6 +461,9 @@ function montoison_orban(A, B, D::AbstractMatrix{FC}, C::AbstractMatrix{FC}, k::
H = zeros(FC, (k+1)*p, k*p)
F = zeros(FC, (k+1)*p, k*p)

qv = zeros(FC, m, p)
qu = zeros(FC, n, p)

for j = 1:k
pos1 = (j-1)*p + 1
pos2 = j*p
Expand All @@ -448,8 +477,8 @@ function montoison_orban(A, B, D::AbstractMatrix{FC}, C::AbstractMatrix{FC}, k::
Qu, Λ = reduced_qr(C, algo)
uⱼ .= Qu
end
qv = A * uⱼ
qu = B * vⱼ
mul!(qv, A, uⱼ)
mul!(qu, B, vⱼ)
for i = 1:j
pos3 = (i-1)*p + 1
pos4 = i*p
Expand Down

0 comments on commit f7fc22b

Please sign in to comment.