Skip to content

Commit

Permalink
add pr_mixed, support_mixed, and dir for autodiff_pr
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Dec 24, 2023
1 parent 9808a41 commit ccd8903
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 76 deletions.
6 changes: 2 additions & 4 deletions examples/darts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ all_colors_receive_own_dart = all(
for (num_own_darts, weight) in zip(DARTS_PER_COLOR, weights)
)

pr_all_colors_receive_own_dart = exp(LogPr(all_colors_receive_own_dart))

compute_mixed(var_vals, pr_all_colors_receive_own_dart) # 0.182
pr_mixed(var_vals)(all_colors_receive_own_dart) # 0.182
train!(var_vals, -LogPr(all_colors_receive_own_dart); epochs=1000, learning_rate=0.3)

# We've increased the chance of success!
compute_mixed(var_vals, pr_all_colors_receive_own_dart) # 0.234
pr_mixed(var_vals)(pr_all_colors_receive_own_dart) # 0.234

# Compute what ratio we actually need to paint the target:
[compute(var_vals, weight/sum(weights)) for weight in weights]
Expand Down
32 changes: 0 additions & 32 deletions examples/qc/stlc/lib/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,35 +151,3 @@ function println_flush(io, args...)
println(io, args...)
flush(io)
end

function collect_flips(bools)
flips = Vector{Dice.Flip}()
Dice.foreach_down(bools) do x
x isa Dice.Flip && push!(flips, x)
end
flips
end

function with_concrete_flips(f, var_vals, dist)
flips = collect_flips(Dice.tobits(dist))
flip_to_original_prob = Dict()
a = ADComputer(var_vals)
for x in flips
if x.prob isa ADNode
flip_to_original_prob[x] = x.prob
x.prob = compute(a, x.prob)
end
end
res = f()
# restore
for (x, prob) in flip_to_original_prob
x.prob = prob
end
res
end

function pr_with_concrete_flips(var_vals, dist)
with_concrete_flips(var_vals, dist) do
pr(dist)
end
end
8 changes: 4 additions & 4 deletions examples/qc/stlc/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ show(io, Dict(s => vals[adnode] for (s, adnode) in adnodes_of_interest))
println(io)

println_flush(io, "Inferring initial distribution...")
time_infer_init = @elapsed metric_dist = pr_with_concrete_flips(var_vals, metric)
time_infer_init = @elapsed metric_dist = pr_mixed(var_vals)(metric)
println(io, " $(time_infer_init) seconds")
save_metric_dist(joinpath(OUT_DIR, "dist_before.csv"), METRIC, metric_dist; io=io)
println(io)

println_flush(io, "Saving samples...")
time_sample_init = @elapsed with_concrete_flips(var_vals, e) do
time_sample_init = @elapsed with_concrete_ad_flips(var_vals, e) do
save_samples(joinpath(OUT_DIR, "terms_before.txt"), e; io=io)
end
println(io, " $(time_sample_init) seconds")
Expand Down Expand Up @@ -158,12 +158,12 @@ show(io, Dict(s => vals[adnode] for (s, adnode) in adnodes_of_interest))
println(io)

println(io, "Inferring trained distribution...")
time_infer_final = @elapsed metric_dist_after = pr_with_concrete_flips(var_vals, metric)
time_infer_final = @elapsed metric_dist_after = pr_mixed(var_vals)(metric)
save_metric_dist(joinpath(OUT_DIR, "dist_trained_" * OUT_FILE_TAG * ".csv"), METRIC, metric_dist_after; io=io)
println(io)

println(io, "Saving samples...")
time_sample_final = @elapsed with_concrete_flips(var_vals, e) do
time_sample_final = @elapsed with_concrete_ad_flips(var_vals, e) do
save_samples(joinpath(OUT_DIR, "terms_trained_" * OUT_FILE_TAG * ".txt"), e; io=io)
end
println(io, " $(time_sample_final) seconds")
Expand Down
2 changes: 2 additions & 0 deletions src/Dice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ include("autodiff/adnode.jl")
include("autodiff/core.jl")
include("dist/dist.jl")
include("inference/inference.jl")
include("autodiff_pr/train.jl")
include("autodiff_pr/losses.jl")
include("analysis/analysis.jl")
include("dsl.jl")
include("plot.jl")
Expand Down
File renamed without changes.
48 changes: 46 additions & 2 deletions src/inference/train_pr.jl → src/autodiff_pr/train.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# The bridge between autodiff and cudd
export LogPr, compute_mixed, train!
using DataStructures: Queue
export LogPr, compute_mixed, train!, pr_mixed, support_mixed, with_concrete_ad_flips

mutable struct LogPr <: ADNode
bool::Dist{Bool}
Expand Down Expand Up @@ -93,3 +92,48 @@ function train!(
push!(losses, compute_mixed(var_vals, loss))
losses
end

function collect_flips(bools)
flips = Vector{Flip}()
foreach_down(bools) do x
x isa Flip && push!(flips, x)
end
flips
end

function with_concrete_ad_flips(f, var_vals, dist)
flip_to_original_prob = Dict()
a = ADComputer(var_vals)
for x in collect_flips(tobits(dist))
if x.prob isa ADNode
flip_to_original_prob[x] = x.prob
x.prob = compute(a, x.prob)
end
end
res = f()
for (x, prob) in flip_to_original_prob
x.prob = prob
end
res
end

function pr_mixed(var_vals)
(args...; kwargs...) -> with_concrete_ad_flips(var_vals, args...) do
pr(args...; kwargs...)
end
end

function support_mixed(dist)
flip_to_original_prob = Dict()
for x in collect_flips(tobits(dist))
if x.prob isa ADNode
flip_to_original_prob[x] = x.prob
x.prob = 0.5
end
end
res = keys(pr(dist))
for (x, prob) in flip_to_original_prob
x.prob = prob
end
res
end
30 changes: 1 addition & 29 deletions src/dist/number/uint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,31 +420,6 @@ function unif_obs(lo::DistUInt{W}, hi::DistUInt{W}) where W
x, (x >= lo) & (x <= hi)
end

function collect_flips(bools)
flips = Vector{Flip}()
Dice.foreach_down(bools) do x
x isa Flip && push!(flips, x)
end
flips
end

function with_arb_ad_flips(f, dist)
flips = collect_flips(tobits(dist))
flip_to_original_prob = Dict()
for x in flips
if x.prob isa ADNode
flip_to_original_prob[x] = x.prob
x.prob = 0.5
end
end
res = f()
# restore
for (x, prob) in flip_to_original_prob
x.prob = prob
end
res
end

# Uniform from 0 to hi, exclusive
function unif_half(hi::DistUInt{W})::DistUInt{W} where W
# max_hi = maxvalue(hi)
Expand All @@ -455,10 +430,7 @@ function unif_half(hi::DistUInt{W})::DistUInt{W} where W
# end

# note: # could use path cond too
support = with_arb_ad_flips(hi) do
keys(pr(hi))
end
prod = lcm([BigInt(x) for x in support if x != 0])
prod = lcm([BigInt(x) for x in support_mixed(hi) if x != 0])
u = uniform(DistUInt{ndigits(prod, base=2)}, 0, prod)
rem_trunc(u, hi)
end
Expand Down
3 changes: 0 additions & 3 deletions src/inference/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,4 @@ include("cudd/wmc.jl")
# - pr(::Dist, evidence=..., errors=...)
include("pr.jl")

include("train_pr.jl")
include("train_pr_losses.jl")

include("sample.jl")
File renamed without changes.
18 changes: 16 additions & 2 deletions test/inference/train_test.jl → test/autodiff_pr/train_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,28 @@ end
var_vals = Valuation(psp => 0)
b = @dice_ite if flip(prob) true else flip(prob) end
train!(var_vals, mle_loss([prob_equals(b, x) for x in dataset]); epochs=200, learning_rate=0.003)
p1 = compute_mixed(var_vals, LogPr(b))
p1 = pr_mixed(var_vals)(b)[true]

# Train for 100 epochs, twice
b = @dice_ite if flip(prob) true else flip(prob) end
var_vals = Valuation(psp => 0)
train!(var_vals, mle_loss([prob_equals(b, x) for x in dataset]); epochs=100, learning_rate=0.003)
train!(var_vals, mle_loss([prob_equals(b, x) for x in dataset]); epochs=100, learning_rate=0.003)
p2 = compute_mixed(var_vals, LogPr(b))
p2 = pr_mixed(var_vals)(b)[true]

@test p1 p2
end

@testset "interleaving" begin
x = Var("x")
prob = sigmoid(x)
prob2 = exp(LogPr(flip(prob) & flip(prob)))
loss = mle_loss([flip(prob2) & flip(prob2) & !flip(prob2)])
var_vals = Valuation(x => 0)
train!(var_vals, loss, epochs=2000, learning_rate=0.1)

# loss is minimized if prob2 is 2/3
# therefore, prob should be sqrt(2/3)
@test compute(var_vals, prob) sqrt(2/3)
@test compute_mixed(var_vals, loss) -log(2/3*2/3*1/3)
end

0 comments on commit ccd8903

Please sign in to comment.