From 3828e84c53c864a159ea99f514033fe1dacdec6e Mon Sep 17 00:00:00 2001 From: Pedro Maciel Xavier Date: Wed, 11 Sep 2024 21:25:33 -0400 Subject: [PATCH] Implement `map_variables` + Add tests --- src/library/model/model.jl | 10 +++++++ src/library/model/variable_map.jl | 18 +++++++++++++ test/unit/library/formats.jl | 44 ++++++++++++++++++++++--------- 3 files changed, 59 insertions(+), 13 deletions(-) diff --git a/src/library/model/model.jl b/src/library/model/model.jl index 5d403a9..9fe1526 100644 --- a/src/library/model/model.jl +++ b/src/library/model/model.jl @@ -344,3 +344,13 @@ function Model{V,T,U}(f::F; kws...) where {V,T,U,F<:PBO.AbstractFunction{V,T}} return Model{V,T,U}(L, Q; offset = β, sense = :min, domain = :bool, kws...) end + +function map_variables(::Type{V}, vm::Function, model::AbstractModel{_,T,U}) where {_,V,T,U} + new_model = copy(model)::AbstractModel{V,T,U} + new_model.variable_map = VariableMap{V}(Dict{Int,V}(i => vm(i)::V for i in indices(model))) + + return new_model +end + +map_variables(vm::Dict{Int,V}, model::AbstractModel) where {V} = map_variables(V, i -> vm[i], model) +map_variables(vm::AbstractVector{V}, model::AbstractModel) where {V} = map_variables(V, i -> vm[i], model) diff --git a/src/library/model/variable_map.jl b/src/library/model/variable_map.jl index 14af997..0441cb6 100644 --- a/src/library/model/variable_map.jl +++ b/src/library/model/variable_map.jl @@ -17,6 +17,24 @@ struct VariableMap{V} end end +function VariableMap{V}(vm::Dict{Int,V}) where {V} + map = sizehint!(Dict{V,Int}(), length(vm)) + inv = Vector{V}(undef, length(vm)) + + for i = 1:length(vm) + if !haskey(vm, i) + error("Invalid variable mapping: Mappings should contain values for all indices") + else + let v = vm[i] + map[v] = i + inv[i] = v + end + end + end + + return VariableMap{V}(map, inv) +end + function VariableMap{V}( variables::X, ) where {V,X<:Union{AbstractVector{V},AbstractSet{V}}} diff --git a/test/unit/library/formats.jl b/test/unit/library/formats.jl index 5ea8322..7f63397 100644 --- a/test/unit/library/formats.jl +++ b/test/unit/library/formats.jl @@ -30,16 +30,19 @@ function test_bqpjson_format() @testset "⋅ BQPJSON" begin @testset "bool" begin for i = 0:2 - file_path = joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.json") + file_path = + joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.json") temp_path = "$(tempname()).bool.json" src_model = QUBOTools.read_model(file_path) + variables = QUBOTools.variables(src_model) @test src_model isa QUBOTools.Model QUBOTools.write_model(temp_path, src_model) - dst_model = QUBOTools.read_model(temp_path) + dst_model = + QUBOTools.map_variables(variables, QUBOTools.read_model(temp_path)) @test dst_model isa QUBOTools.Model @@ -49,16 +52,19 @@ function test_bqpjson_format() @testset "spin" begin for i = 0:2 - file_path = joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "spin.json") + file_path = + joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "spin.json") temp_path = "$(tempname()).spin.json" src_model = QUBOTools.read_model(file_path) + variables = QUBOTools.variables(src_model) @test src_model isa QUBOTools.Model QUBOTools.write_model(temp_path, src_model) - dst_model = QUBOTools.read_model(temp_path) + dst_model = + QUBOTools.map_variables(variables, QUBOTools.read_model(temp_path)) @test dst_model isa QUBOTools.Model @@ -75,17 +81,22 @@ function test_qubo_format() src_fmt = QUBOTools.QUBO(:dwave) for i = 0:2 - file_path = joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.qubo") + file_path = + joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.qubo") temp_path = "$(tempname()).bool.qubo" src_model = QUBOTools.read_model(file_path, src_fmt) + variables = QUBOTools.variables(src_model) @test src_model isa QUBOTools.Model for dst_fmt in QUBOTools.QUBO.([:dwave, :mqlib]) QUBOTools.write_model(temp_path, src_model, dst_fmt) - dst_model = QUBOTools.read_model(temp_path, dst_fmt) + dst_model = QUBOTools.map_variables( + variables, + QUBOTools.read_model(temp_path, dst_fmt), + ) @test dst_model isa QUBOTools.Model @@ -100,16 +111,18 @@ end function test_qubist_format() @testset "⋅ Qubist" begin for i = 0:2 - file_path = joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "spin.qh") + file_path = + joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "spin.qh") temp_path = "$(tempname()).spin.qh" src_model = QUBOTools.read_model(file_path) + variables = QUBOTools.variables(src_model) @test src_model isa QUBOTools.Model QUBOTools.write_model(temp_path, src_model) - dst_model = QUBOTools.read_model(temp_path) + dst_model = QUBOTools.map_variables(variables, QUBOTools.read_model(temp_path)) @test dst_model isa QUBOTools.Model @@ -124,16 +137,19 @@ function test_qubin_format() @testset "⋅ QUBin" begin @testset "bool" begin for i = 0:2 - file_path = joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.qb") + file_path = + joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.qb") temp_path = "$(tempname()).bool.qb" src_model = QUBOTools.read_model(file_path) + variables = QUBOTools.variables(src_model) @test src_model isa QUBOTools.Model QUBOTools.write_model(temp_path, src_model) - dst_model = QUBOTools.read_model(temp_path) + dst_model = + QUBOTools.map_variables(variables, QUBOTools.read_model(temp_path)) @test dst_model isa QUBOTools.Model @@ -148,12 +164,14 @@ function test_qubin_format() temp_path = "$(tempname()).spin.qb" src_model = QUBOTools.read_model(file_path) + variables = QUBOTools.variables(src_model) @test src_model isa QUBOTools.Model QUBOTools.write_model(temp_path, src_model) - dst_model = QUBOTools.read_model(temp_path) + dst_model = + QUBOTools.map_variables(variables, QUBOTools.read_model(temp_path)) @test dst_model isa QUBOTools.Model @@ -175,7 +193,7 @@ function test_minizinc_format() (1, 3) => -13.0, (2, 3) => -23.0, ); - scale = 2.0, + scale = 2.0, offset = -1.0, sense = :min, domain = :bool, @@ -206,7 +224,7 @@ function test_minizinc_format() (1, 3) => -13.0, (2, 3) => -23.0, ); - scale = 2.0, + scale = 2.0, offset = -1.0, sense = :max, domain = :spin,