Skip to content

Commit

Permalink
Implement map_variables + Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pedromxavier committed Sep 12, 2024
1 parent 5575b9e commit 3828e84
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 13 deletions.
10 changes: 10 additions & 0 deletions src/library/model/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,13 @@ function Model{V,T,U}(f::F; kws...) where {V,T,U,F<:PBO.AbstractFunction{V,T}}

return Model{V,T,U}(L, Q; offset = β, sense = :min, domain = :bool, kws...)
end

function map_variables(::Type{V}, vm::Function, model::AbstractModel{_,T,U}) where {_,V,T,U}
new_model = copy(model)::AbstractModel{V,T,U}
new_model.variable_map = VariableMap{V}(Dict{Int,V}(i => vm(i)::V for i in indices(model)))

return new_model
end

map_variables(vm::Dict{Int,V}, model::AbstractModel) where {V} = map_variables(V, i -> vm[i], model)
map_variables(vm::AbstractVector{V}, model::AbstractModel) where {V} = map_variables(V, i -> vm[i], model)
18 changes: 18 additions & 0 deletions src/library/model/variable_map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ struct VariableMap{V}
end
end

function VariableMap{V}(vm::Dict{Int,V}) where {V}
map = sizehint!(Dict{V,Int}(), length(vm))
inv = Vector{V}(undef, length(vm))

for i = 1:length(vm)
if !haskey(vm, i)
error("Invalid variable mapping: Mappings should contain values for all indices")
else
let v = vm[i]
map[v] = i
inv[i] = v
end
end
end

return VariableMap{V}(map, inv)
end

function VariableMap{V}(
variables::X,
) where {V,X<:Union{AbstractVector{V},AbstractSet{V}}}
Expand Down
44 changes: 31 additions & 13 deletions test/unit/library/formats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,19 @@ function test_bqpjson_format()
@testset "⋅ BQPJSON" begin
@testset "bool" begin
for i = 0:2
file_path = joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.json")
file_path =
joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.json")
temp_path = "$(tempname()).bool.json"

src_model = QUBOTools.read_model(file_path)
variables = QUBOTools.variables(src_model)

@test src_model isa QUBOTools.Model

QUBOTools.write_model(temp_path, src_model)

dst_model = QUBOTools.read_model(temp_path)
dst_model =
QUBOTools.map_variables(variables, QUBOTools.read_model(temp_path))

@test dst_model isa QUBOTools.Model

Expand All @@ -49,16 +52,19 @@ function test_bqpjson_format()

@testset "spin" begin
for i = 0:2
file_path = joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "spin.json")
file_path =
joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "spin.json")
temp_path = "$(tempname()).spin.json"

src_model = QUBOTools.read_model(file_path)
variables = QUBOTools.variables(src_model)

@test src_model isa QUBOTools.Model

QUBOTools.write_model(temp_path, src_model)

dst_model = QUBOTools.read_model(temp_path)
dst_model =
QUBOTools.map_variables(variables, QUBOTools.read_model(temp_path))

@test dst_model isa QUBOTools.Model

Expand All @@ -75,17 +81,22 @@ function test_qubo_format()
src_fmt = QUBOTools.QUBO(:dwave)

for i = 0:2
file_path = joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.qubo")
file_path =
joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.qubo")
temp_path = "$(tempname()).bool.qubo"

src_model = QUBOTools.read_model(file_path, src_fmt)
variables = QUBOTools.variables(src_model)

@test src_model isa QUBOTools.Model

for dst_fmt in QUBOTools.QUBO.([:dwave, :mqlib])
QUBOTools.write_model(temp_path, src_model, dst_fmt)

dst_model = QUBOTools.read_model(temp_path, dst_fmt)
dst_model = QUBOTools.map_variables(
variables,
QUBOTools.read_model(temp_path, dst_fmt),
)

@test dst_model isa QUBOTools.Model

Expand All @@ -100,16 +111,18 @@ end
function test_qubist_format()
@testset "⋅ Qubist" begin
for i = 0:2
file_path = joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "spin.qh")
file_path =
joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "spin.qh")
temp_path = "$(tempname()).spin.qh"

src_model = QUBOTools.read_model(file_path)
variables = QUBOTools.variables(src_model)

@test src_model isa QUBOTools.Model

QUBOTools.write_model(temp_path, src_model)

dst_model = QUBOTools.read_model(temp_path)
dst_model = QUBOTools.map_variables(variables, QUBOTools.read_model(temp_path))

@test dst_model isa QUBOTools.Model

Expand All @@ -124,16 +137,19 @@ function test_qubin_format()
@testset "⋅ QUBin" begin
@testset "bool" begin
for i = 0:2
file_path = joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.qb")
file_path =
joinpath(__TEST_PATH__, "data", Printf.@sprintf("%02d", i), "bool.qb")
temp_path = "$(tempname()).bool.qb"

src_model = QUBOTools.read_model(file_path)
variables = QUBOTools.variables(src_model)

@test src_model isa QUBOTools.Model

QUBOTools.write_model(temp_path, src_model)

dst_model = QUBOTools.read_model(temp_path)
dst_model =
QUBOTools.map_variables(variables, QUBOTools.read_model(temp_path))

@test dst_model isa QUBOTools.Model

Expand All @@ -148,12 +164,14 @@ function test_qubin_format()
temp_path = "$(tempname()).spin.qb"

src_model = QUBOTools.read_model(file_path)
variables = QUBOTools.variables(src_model)

@test src_model isa QUBOTools.Model

QUBOTools.write_model(temp_path, src_model)

dst_model = QUBOTools.read_model(temp_path)
dst_model =
QUBOTools.map_variables(variables, QUBOTools.read_model(temp_path))

@test dst_model isa QUBOTools.Model

Expand All @@ -175,7 +193,7 @@ function test_minizinc_format()
(1, 3) => -13.0,
(2, 3) => -23.0,
);
scale = 2.0,
scale = 2.0,
offset = -1.0,
sense = :min,
domain = :bool,
Expand Down Expand Up @@ -206,7 +224,7 @@ function test_minizinc_format()
(1, 3) => -13.0,
(2, 3) => -23.0,
);
scale = 2.0,
scale = 2.0,
offset = -1.0,
sense = :max,
domain = :spin,
Expand Down

0 comments on commit 3828e84

Please sign in to comment.