Skip to content

Commit

Permalink
have with_concrete_ad_flips support LogPr
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Dec 24, 2023
1 parent ccd8903 commit d80e01a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/autodiff/adnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ function backward(n::Transpose, vals, derivs)
add_deriv(derivs, n.x, transpose(derivs[n]))
end

# Give override for add_logprobs so logprob in wmc.jl is differentiable
# Give override for add_logprobs so logprob in wmc.jl is differentiable.
# computes log(exp(x) + exp(y))
mutable struct NodeLogPr <: ADNode
pr::ADNode
Expand Down
3 changes: 2 additions & 1 deletion src/autodiff_pr/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ end
function with_concrete_ad_flips(f, var_vals, dist)
flip_to_original_prob = Dict()
a = ADComputer(var_vals)
l = LogPrExpander(WMC(BDDCompiler()))
for x in collect_flips(tobits(dist))
if x.prob isa ADNode
flip_to_original_prob[x] = x.prob
x.prob = compute(a, x.prob)
x.prob = compute(a, expand_logprs(l, x.prob))
end
end
res = f()
Expand Down
4 changes: 3 additions & 1 deletion test/autodiff_pr/train_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ end
x = Var("x")
prob = sigmoid(x)
prob2 = exp(LogPr(flip(prob) & flip(prob)))
loss = mle_loss([flip(prob2) & flip(prob2) & !flip(prob2)])
bool = flip(prob2) & flip(prob2) & !flip(prob2)
loss = mle_loss([bool])
var_vals = Valuation(x => 0)
train!(var_vals, loss, epochs=2000, learning_rate=0.1)

# loss is minimized if prob2 is 2/3
# therefore, prob should be sqrt(2/3)
@test compute(var_vals, prob) sqrt(2/3)
@test compute_mixed(var_vals, loss) -log(2/3*2/3*1/3)
pr_mixed(var_vals)(bool)
end

0 comments on commit d80e01a

Please sign in to comment.