Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KKT] Add K2.5 formulation for augmented KKT system #352

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/IPM/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,34 @@ function solve!(kkt::AbstractReducedKKTSystem, w::AbstractKKTVector)
return w
end

function solve!(kkt::ScaledSparseKKTSystem, w::AbstractKKTVector)
r3 = kkt.buffer1
r4 = kkt.buffer2
fill!(r3, 0.0)
fill!(r4, 0.0)

wzl = dual_lb(w) # size nlb
wzu = dual_ub(w) # size nub

r3[kkt.ind_lb] .= wzl
r3[kkt.ind_ub] .*= sqrt.(kkt.u_diag)
r3[kkt.ind_lb] ./= sqrt.(kkt.l_diag)
r4[kkt.ind_ub] .= wzu
r4[kkt.ind_lb] .*= sqrt.(kkt.l_diag)
r4[kkt.ind_ub] ./= sqrt.(kkt.u_diag)
# Build RHS
w.xp .*= kkt.scaling_factor
w.xp .+= (r3 .+ r4)
# Backsolve
solve!(kkt.linear_solver, primal_dual(w))
# Unpack solution
w.xp .*= kkt.scaling_factor

wzl .= (wzl .- kkt.l_lower .* w.xp_lr) ./ kkt.l_diag
wzu .= (.-wzu .+ kkt.u_lower .* w.xp_ur) ./ kkt.u_diag
return w
end

function solve!(
kkt::SparseKKTSystem{T, VT, MT, QN},
w::AbstractKKTVector
Expand Down Expand Up @@ -167,6 +195,20 @@ function mul!(w::AbstractKKTVector{T}, kkt::Union{SparseKKTSystem{T,VT,MT,QN},Sp
return w
end

function mul!(w::AbstractKKTVector{T}, kkt::ScaledSparseKKTSystem{T,VT,MT,QN}, x::AbstractKKTVector, alpha = one(T), beta = zero(T)) where {T, VT, MT, QN<:ExactHessian}
mul!(primal(w), Symmetric(kkt.hess_com, :L), primal(x), alpha, beta)
mul!(primal(w), kkt.jac_com', dual(x), alpha, one(T))
mul!(dual(w), kkt.jac_com, primal(x), alpha, beta)
# Custom reduction
primal(w) .+= alpha .* kkt.reg .* primal(x)
dual(w) .+= alpha .* kkt.du_diag .* dual(x)
w.xp_lr .-= alpha .* dual_lb(x)
w.xp_ur .+= alpha .* dual_ub(x)
dual_lb(w) .= beta .* dual_lb(w) .+ alpha .* (x.xp_lr .* kkt.l_lower .+ dual_lb(x) .* kkt.l_diag)
dual_ub(w) .= beta .* dual_ub(w) .+ alpha .* (x.xp_ur .* kkt.u_lower .- dual_ub(x) .* kkt.u_diag)
return w
end

function mul!(w::AbstractKKTVector{T}, kkt::Union{SparseKKTSystem{T,VT,MT,QN},SparseUnreducedKKTSystem{T,VT,MT,QN}}, x::AbstractKKTVector, alpha = one(T), beta = zero(T)) where {T, VT, MT, QN<:CompactLBFGS}
qn = kkt.quasi_newton
n, p = size(qn)
Expand Down
55 changes: 53 additions & 2 deletions src/IPM/kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ function set_aug_diagonal!(kkt::AbstractKKTSystem{T}, solver::MadNLPSolver{T}) w

fill!(kkt.reg, zero(T))
fill!(kkt.du_diag, zero(T))
kkt.l_diag .= solver.xl_r .- solver.x_lr
kkt.u_diag .= solver.x_ur .- solver.xu_r
kkt.l_diag .= solver.xl_r .- solver.x_lr # (Xˡ - X)
kkt.u_diag .= solver.x_ur .- solver.xu_r # (X - Xᵘ)
copyto!(kkt.l_lower, solver.zl_r)
copyto!(kkt.u_lower, solver.zu_r)

Expand All @@ -33,6 +33,41 @@ function _set_aug_diagonal!(kkt::AbstractUnreducedKKTSystem)
return
end

function set_aug_diagonal!(kkt::ScaledSparseKKTSystem{T}, solver::MadNLPSolver{T}) where T
fill!(kkt.reg, zero(T))
fill!(kkt.du_diag, zero(T))
# Ensure l_diag and u_diag have only non negative entries
kkt.l_diag .= solver.x_lr .- solver.xl_r # (X - Xˡ)
kkt.u_diag .= solver.xu_r .- solver.x_ur # (Xᵘ - X)
copyto!(kkt.l_lower, solver.zl_r)
copyto!(kkt.u_lower, solver.zu_r)
_set_aug_diagonal!(kkt)
end

function _set_aug_diagonal!(kkt::ScaledSparseKKTSystem{T}) where T
xlzu = kkt.buffer1
xuzl = kkt.buffer2
fill!(xlzu, zero(T))
fill!(xuzl, zero(T))

xlzu[kkt.ind_ub] .= kkt.u_lower # zᵘ
xlzu[kkt.ind_lb] .*= kkt.l_diag # (X - Xˡ) zᵘ

xuzl[kkt.ind_lb] .= kkt.l_lower # zˡ
xuzl[kkt.ind_ub] .*= kkt.u_diag # (Xᵘ - X) zˡ

kkt.pr_diag .= xlzu .+ xuzl

fill!(kkt.scaling_factor, one(T))
kkt.scaling_factor[kkt.ind_lb] .*= sqrt.(kkt.l_diag)
kkt.scaling_factor[kkt.ind_ub] .*= sqrt.(kkt.u_diag)

# Scale regularization by scaling factor.
kkt.pr_diag .+= kkt.reg .* kkt.scaling_factor.^2
return
end


# Robust restoration
function set_aug_RR!(kkt::AbstractKKTSystem, solver::MadNLPSolver, RR::RobustRestorer)
x = full(solver.x)
Expand All @@ -48,7 +83,23 @@ function set_aug_RR!(kkt::AbstractKKTSystem, solver::MadNLPSolver, RR::RobustRes
kkt.u_diag .= solver.x_ur .- solver.xu_r

_set_aug_diagonal!(kkt)
return
end

function set_aug_RR!(kkt::ScaledSparseKKTSystem, solver::MadNLPSolver, RR::RobustRestorer)
x = full(solver.x)
xl = full(solver.xl)
xu = full(solver.xu)
zl = full(solver.zl)
zu = full(solver.zu)
kkt.reg .= RR.zeta .* RR.D_R .^ 2
kkt.du_diag .= .- RR.pp ./ RR.zp .- RR.nn ./ RR.zn
copyto!(kkt.l_lower, solver.zl_r)
copyto!(kkt.u_lower, solver.zu_r)
kkt.l_diag .= solver.x_lr .- solver.xl_r
kkt.u_diag .= solver.xu_r .- solver.x_ur

_set_aug_diagonal!(kkt)
return
end

Expand Down
1 change: 1 addition & 0 deletions src/KKT/KKTsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,6 @@ include("Dense/utils.jl")
include("Sparse/unreduced.jl")
include("Sparse/augmented.jl")
include("Sparse/condensed.jl")
include("Sparse/scaled_augmented.jl")
include("Sparse/utils.jl")

243 changes: 243 additions & 0 deletions src/KKT/Sparse/scaled_augmented.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
"""
ScaledSparseKKTSystem{T, VT, MT, QN} <: AbstractReducedKKTSystem{T, VT, MT, QN}

Scaled version of the [`AbstractReducedKKTSystem`](@ref) (using the K2.5 formulation introduced in [GOS]).

The K2.5 formulation of the augmented KKT system has a better conditioning
than the original (K2) formulation. It is recommend switching to a `ScaledSparseKKTSystem`
if you encounter numerical difficulties in MadNLP.

At a primal-dual iterate ``(x, s, y, z)``, the matrix writes
```
[√Ξₓ Wₓₓ √Ξₓ + Δₓ 0 √Ξₓ Aₑ' √Ξₓ Aᵢ'] [√Ξₓ⁻¹ Δx]
[0 Δₛ 0 -√Ξₛ ] [√Ξₛ⁻¹ Δs]
[Aₑ√Ξₓ 0 0 0 ] [Δy ]
[Aᵢ√Ξₓ -√Ξₛ 0 0 ] [Δz ]
```
with
* ``Wₓₓ``: Hessian of the Lagrangian.
* ``Aₑ``: Jacobian of the equality constraints
* ``Aᵢ``: Jacobian of the inequality constraints
* ``Δₓ = Xᵤ Zₗˣ + Xₗ Zᵤˣ``
* ``Δₛ = Sᵤ Zₗˢ + Sₗ Zᵤˢ``
* ``Ξₓ = Xₗ Xᵤ``
* ``Ξₛ = Sₗ Sᵤ``

# References
[GOS] Ghannad, Alexandre, Dominique Orban, and Michael A. Saunders.
"Linear systems arising in interior methods for convex optimization: a symmetric formulation with bounded condition number."
Optimization Methods and Software 37, no. 4 (2022): 1344-1369.

"""
struct ScaledSparseKKTSystem{T, VT, MT, QN, LS, VI, VI32} <: AbstractReducedKKTSystem{T, VT, MT, QN}
hess::VT
jac_callback::VT
jac::VT
quasi_newton::QN
reg::VT
pr_diag::VT
du_diag::VT
l_diag::VT
u_diag::VT
l_lower::VT
u_lower::VT
scaling_factor::VT
buffer1::VT
buffer2::VT
# Augmented system
aug_raw::SparseMatrixCOO{T,Int32,VT, VI32}
aug_com::MT
aug_csc_map::Union{Nothing, VI}
scaled_aug_raw::SparseMatrixCOO{T,Int32,VT, VI32}
# Hessian
hess_raw::SparseMatrixCOO{T,Int32,VT, VI32}
hess_com::MT
hess_csc_map::Union{Nothing, VI}
# Jacobian
jac_raw::SparseMatrixCOO{T,Int32,VT, VI32}
jac_com::MT
jac_csc_map::Union{Nothing, VI}
# LinearSolver
linear_solver::LS
# Info
ind_ineq::VI
ind_lb::VI
ind_ub::VI
end

# Build KKT system directly from SparseCallback
function create_kkt_system(
::Type{ScaledSparseKKTSystem},
cb::SparseCallback{T,VT},
ind_cons,
linear_solver::Type;
opt_linear_solver=default_options(linear_solver),
hessian_approximation=ExactHessian,
) where {T,VT}

n_slack = length(ind_cons.ind_ineq)
# Deduce KKT size.

n = cb.nvar
m = cb.ncon
# Evaluate sparsity pattern
jac_sparsity_I = create_array(cb, Int32, cb.nnzj)
jac_sparsity_J = create_array(cb, Int32, cb.nnzj)
_jac_sparsity_wrapper!(cb,jac_sparsity_I, jac_sparsity_J)

quasi_newton = create_quasi_newton(hessian_approximation, cb, n)
hess_sparsity_I, hess_sparsity_J = build_hessian_structure(cb, hessian_approximation)

nlb = length(ind_cons.ind_lb)
nub = length(ind_cons.ind_ub)

force_lower_triangular!(hess_sparsity_I,hess_sparsity_J)

ind_ineq = ind_cons.ind_ineq

n_slack = length(ind_ineq)
n_jac = length(jac_sparsity_I)
n_hess = length(hess_sparsity_I)
n_tot = n + n_slack


aug_vec_length = n_tot+m
aug_mat_length = n_tot+m+n_hess+n_jac+n_slack

I = create_array(cb, Int32, aug_mat_length)
J = create_array(cb, Int32, aug_mat_length)
V = VT(undef, aug_mat_length)
fill!(V, 0.0) # Need to initiate V to avoid NaN

offset = n_tot+n_jac+n_slack+n_hess+m

I[1:n_tot] .= 1:n_tot
I[n_tot+1:n_tot+n_hess] = hess_sparsity_I
I[n_tot+n_hess+1:n_tot+n_hess+n_jac] .= (jac_sparsity_I.+n_tot)
I[n_tot+n_hess+n_jac+1:n_tot+n_hess+n_jac+n_slack] .= ind_ineq .+ n_tot
I[n_tot+n_hess+n_jac+n_slack+1:offset] .= (n_tot+1:n_tot+m)

J[1:n_tot] .= 1:n_tot
J[n_tot+1:n_tot+n_hess] = hess_sparsity_J
J[n_tot+n_hess+1:n_tot+n_hess+n_jac] .= jac_sparsity_J
J[n_tot+n_hess+n_jac+1:n_tot+n_hess+n_jac+n_slack] .= (n+1:n+n_slack)
J[n_tot+n_hess+n_jac+n_slack+1:offset] .= (n_tot+1:n_tot+m)

pr_diag = _madnlp_unsafe_wrap(V, n_tot)
du_diag = _madnlp_unsafe_wrap(V, m, n_jac+n_slack+n_hess+n_tot+1)

reg = VT(undef, n_tot)
l_diag = VT(undef, nlb)
u_diag = VT(undef, nub)
l_lower = VT(undef, nlb)
u_lower = VT(undef, nub)

scaling_factor = VT(undef, n_tot)
buffer1 = VT(undef, n_tot)
buffer2 = VT(undef, n_tot)

hess = _madnlp_unsafe_wrap(V, n_hess, n_tot+1)
jac = _madnlp_unsafe_wrap(V, n_jac+n_slack, n_hess+n_tot+1)
jac_callback = _madnlp_unsafe_wrap(V, n_jac, n_hess+n_tot+1)

aug_raw = SparseMatrixCOO(aug_vec_length,aug_vec_length,I,J,V)
jac_raw = SparseMatrixCOO(
m, n_tot,
Int32[jac_sparsity_I; ind_ineq],
Int32[jac_sparsity_J; n+1:n+n_slack],
jac,
)
hess_raw = SparseMatrixCOO(
n_tot, n_tot,
hess_sparsity_I,
hess_sparsity_J,
hess,
)
scaled_aug_raw = SparseMatrixCOO(aug_vec_length,aug_vec_length,I,J,copy(V))

aug_com, aug_csc_map = coo_to_csc(aug_raw)
jac_com, jac_csc_map = coo_to_csc(jac_raw)
hess_com, hess_csc_map = coo_to_csc(hess_raw)

_linear_solver = linear_solver(
aug_com; opt = opt_linear_solver
)

return ScaledSparseKKTSystem(
hess, jac_callback, jac, quasi_newton, reg, pr_diag, du_diag,
l_diag, u_diag, l_lower, u_lower,
scaling_factor, buffer1, buffer2,
aug_raw, aug_com, aug_csc_map, scaled_aug_raw,
hess_raw, hess_com, hess_csc_map,
jac_raw, jac_com, jac_csc_map,
_linear_solver,
ind_ineq, ind_cons.ind_lb, ind_cons.ind_ub,
)
end

num_variables(kkt::ScaledSparseKKTSystem) = length(kkt.pr_diag)
get_jacobian(kkt::ScaledSparseKKTSystem) = kkt.jac_callback

function initialize!(kkt::ScaledSparseKKTSystem{T}) where T
fill!(kkt.reg, one(T))
fill!(kkt.pr_diag, one(T))
fill!(kkt.du_diag, zero(T))
fill!(kkt.hess, zero(T))
fill!(kkt.l_lower, zero(T))
fill!(kkt.u_lower, zero(T))
fill!(kkt.l_diag, one(T))
fill!(kkt.u_diag, one(T))
fill!(kkt.scaling_factor, one(T))
fill!(nonzeros(kkt.hess_com), zero(T)) # so that mul! in the initial primal-dual solve has no effect
end

function jtprod!(y::AbstractVector, kkt::ScaledSparseKKTSystem, x::AbstractVector)
mul!(y, kkt.jac_com', x)
end

function compress_jacobian!(kkt::ScaledSparseKKTSystem)
ns = length(kkt.ind_ineq)
kkt.jac[end-ns+1:end] .= -1.0
transfer!(kkt.jac_com, kkt.jac_raw, kkt.jac_csc_map)
end

function compress_hessian!(kkt::ScaledSparseKKTSystem)
transfer!(kkt.hess_com, kkt.hess_raw, kkt.hess_csc_map)
end

# N.B. Matrices are assumed to have an augmented KKT structure and be lower-triangular.
function _build_scale_augmented_system_coo!(dest, src, scaling, n, m)
for (k, i, j) in zip(1:nnz(src), src.I, src.J)
# Primal regularization pr_diag
if k <= n
dest.V[k] = src.V[k]
# Hessian block
elseif i <= n && j <= n
dest.V[k] = src.V[k] * scaling[i] * scaling[j]
# Jacobian block
elseif n + 1 <= i <= n + m && j <= n
dest.V[k] = src.V[k] * scaling[j]
# Dual regularization du_diag
elseif n + 1 <= i <= n + m && n + 1 <= j <= n + m
dest.V[k] = src.V[k]
end
end
end

function build_kkt!(kkt::ScaledSparseKKTSystem)
m, n = size(kkt.jac_raw)
_build_scale_augmented_system_coo!(
kkt.scaled_aug_raw,
kkt.aug_raw,
kkt.scaling_factor,
n, m,
)
transfer!(kkt.aug_com, kkt.scaled_aug_raw, kkt.aug_csc_map)
end

function regularize_diagonal!(kkt::ScaledSparseKKTSystem, primal, dual)
kkt.reg .+= primal
kkt.pr_diag .+= primal .* kkt.scaling_factor.^2
kkt.du_diag .= .-dual
end

1 change: 1 addition & 0 deletions test/kkt_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ end
(MadNLP.SparseKKTSystem, MadNLP.SparseCallback),
(MadNLP.SparseUnreducedKKTSystem, MadNLP.SparseCallback),
(MadNLP.SparseCondensedKKTSystem, MadNLP.SparseCallback),
(MadNLP.ScaledSparseKKTSystem, MadNLP.SparseCallback),
(MadNLP.DenseKKTSystem, MadNLP.DenseCallback),
(MadNLP.DenseCondensedKKTSystem, MadNLP.DenseCallback),
]
Expand Down
Loading