Skip to content

Commit

Permalink
Merge pull request #150 from Juice-jl/qc6
Browse files Browse the repository at this point in the history
Autodiff + Dice.jl!
  • Loading branch information
guyvdbroeck authored Sep 1, 2023
2 parents d7f8f60 + 0d2a64a commit 64ecfe6
Show file tree
Hide file tree
Showing 45 changed files with 1,610 additions and 849 deletions.
34 changes: 34 additions & 0 deletions examples/darts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#==
"Dart target painting"
- We can paint a target using red, green, and blue
- Fixed numbers of red, green, and blue darts will randomly hit the target
- What proportion of paint colors maximizes the probability that at least one
dart of each color lands in it own color?
==#

using Dice
import Base: all, any

all(itr) = reduce(&, itr)
any(itr) = reduce(|, itr)

DARTS_PER_COLOR = [1, 2, 10] # number of red, green, and blue darts
weights = [var!("r", 1), var!("g", 1), var!("b", 1)]

all_colors_receive_own_dart = all(
any(flip(weight / sum(weights)) for _ in 1:num_own_darts)
for (num_own_darts, weight) in zip(DARTS_PER_COLOR, weights)
)

pr(all_colors_receive_own_dart) # 0.182
train_vars!([all_colors_receive_own_dart]; epochs=1000, learning_rate=0.3)

# We've increased the chance of success!
pr(all_colors_receive_own_dart) # 0.234

# Compute what ratio we actually need to paint the target:
[compute(weight/sum(weights)) for weight in weights]
# 3-element Vector{Float64}:
# 0.46536681058883267
# 0.3623861813855715
# 0.17224700802559573
2 changes: 2 additions & 0 deletions examples/qc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ The following related programs are included. The expected output of each is in a
- Generator for closed untyped lambda calculus expressions ([`demo_utlc.jl`](demo_utlc.jl))
- Given a generator for UTLC exprs with a hole dependent on size, chooses probabilities such that the AST has near uniform depth
- 50 example generated expressions are visible at [`samples/utlc.txt`](samples/utlc.txt).
- Generator for well-typed, simply-typed lambda calculus expressions ([`stlc`](stlc))
- Configure and run [`stlc/main.jl`](stlc/main.jl)

Beware that the programs expected to run on this version of Dice.jl are the examples listed above and the tests. Other examples are known to be broken.
16 changes: 6 additions & 10 deletions examples/qc/demo_bst.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@

using Dice
include("lib/dist_tree.jl") # DistLeaf, DistBranch, depth
include("lib/sample.jl") # sample

# Return tree
function gen_bst(size, lo, hi)
size == 0 && return DistLeaf(DistUInt32)

# Try changing the parameter to flip_for to a constant, which would force
# all sizes to use the same probability.
@dice_ite if flip_for(size)
@dice_ite if size == 0 || flip_for(size)
DistLeaf(DistUInt32)
else
# The flips used in the uniform aren't tracked via flip_for, so we
Expand All @@ -28,33 +25,32 @@ INIT_SIZE = 3
DATASET = [DistUInt32(x) for x in 0:INIT_SIZE]

# Use Dice to build computation graph
gen() = gen_bst(
tree = gen_bst(
INIT_SIZE,
DistUInt32(1),
DistUInt32(2 * INIT_SIZE),
)
tree_depth = depth(gen())
tree_depth = depth(tree)

println("Distribution before training:")
display(pr(tree_depth))
println()

bools_to_maximize = [prob_equals(tree_depth, x) for x in DATASET]
train_group_probs!(bools_to_maximize)
train_group_probs!(bools_to_maximize, 1000, 0.3) # epochs and lr

# Done!
println("Learned flip probability for each size:")
display(get_group_probs())
println()

println("Distribution over depths after training:")
tree = gen()
display(pr(depth(tree)))
display(pr(tree_depth))
println()

println("A few sampled trees:")
for _ in 1:3
print_tree(sample((tree, true)))
print_tree(sample(tree))
println()
end

Expand Down
118 changes: 0 additions & 118 deletions examples/qc/demo_bst_obs.jl

This file was deleted.

11 changes: 6 additions & 5 deletions examples/qc/demo_natlist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,23 @@ INIT_SIZE = 5
DATASET = [DistUInt32(x) for x in 0:INIT_SIZE]

# Use Dice to build computation graph
generated = length(gen_list(INIT_SIZE))
list = gen_list(INIT_SIZE)
list_len = length(list)

println("Distribution before training:")
display(pr(generated))
display(pr(list_len))
println()

bools_to_maximize = AnyBool[prob_equals(generated, x) for x in DATASET]
train_group_probs!(bools_to_maximize)
bools_to_maximize = AnyBool[prob_equals(list_len, x) for x in DATASET]
train_group_probs!(bools_to_maximize, 1000, 0.3) # epochs and lr

# Done!
println("Learned flip probability for each size:")
display(get_group_probs())
println()

println("Distribution over lengths after training:")
display(pr(length(gen_list(INIT_SIZE))))
display(pr(list_len))

#==
Distribution before training:
Expand Down
11 changes: 4 additions & 7 deletions examples/qc/demo_sortednatlist.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Demo of using BDD MLE to learn flip probs for a sorted nat list of uniform length.

using Dice
include("lib/sample.jl") # sample

# Return a List
function gen_sorted_list(size, lo, hi)
Expand All @@ -28,34 +27,32 @@ INIT_SIZE = 5
DATASET = [DistUInt32(x) for x in 0:INIT_SIZE]

# Use Dice to build computation graph
gen() = gen_sorted_list(
list = gen_sorted_list(
INIT_SIZE,
DistUInt32(1),
DistUInt32(INIT_SIZE),
)
list_len = length(gen())
list_len = length(list)

println("Distribution before training:")
display(pr(list_len))
println()

bools_to_maximize = [prob_equals(list_len, x) for x in DATASET]
train_group_probs!(bools_to_maximize)
train_group_probs!(bools_to_maximize, 1000, 0.3) # epochs and lr

# Done!
println("Learned flip probability for each size:")
display(get_group_probs())
println()

println("Distribution over lengths after training:")
list_len = length(gen())
display(pr(list_len))
println()

println("A few sampled lists:")
l = gen()
for _ in 1:3
print_tree(sample((l, true)))
print_tree(sample(list))
println()
end

Expand Down
Loading

0 comments on commit 64ecfe6

Please sign in to comment.