From d80e01a7dd7019eb8daa543199b6276e7682dbc8 Mon Sep 17 00:00:00 2001 From: Ryan Tjoa Date: Sun, 24 Dec 2023 15:17:15 -0800 Subject: [PATCH] have with_concrete_ad_flips support LogPr --- src/autodiff/adnode.jl | 2 +- src/autodiff_pr/train.jl | 3 ++- test/autodiff_pr/train_test.jl | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/autodiff/adnode.jl b/src/autodiff/adnode.jl index 71da97c5..8445777e 100644 --- a/src/autodiff/adnode.jl +++ b/src/autodiff/adnode.jl @@ -198,7 +198,7 @@ function backward(n::Transpose, vals, derivs) add_deriv(derivs, n.x, transpose(derivs[n])) end -# Give override for add_logprobs so logprob in wmc.jl is differentiable +# Give override for add_logprobs so logprob in wmc.jl is differentiable. # computes log(exp(x) + exp(y)) mutable struct NodeLogPr <: ADNode pr::ADNode diff --git a/src/autodiff_pr/train.jl b/src/autodiff_pr/train.jl index f76db487..c5e4183e 100644 --- a/src/autodiff_pr/train.jl +++ b/src/autodiff_pr/train.jl @@ -104,10 +104,11 @@ end function with_concrete_ad_flips(f, var_vals, dist) flip_to_original_prob = Dict() a = ADComputer(var_vals) + l = LogPrExpander(WMC(BDDCompiler())) for x in collect_flips(tobits(dist)) if x.prob isa ADNode flip_to_original_prob[x] = x.prob - x.prob = compute(a, x.prob) + x.prob = compute(a, expand_logprs(l, x.prob)) end end res = f() diff --git a/test/autodiff_pr/train_test.jl b/test/autodiff_pr/train_test.jl index 9b2ac3bc..2107a9c9 100644 --- a/test/autodiff_pr/train_test.jl +++ b/test/autodiff_pr/train_test.jl @@ -62,7 +62,8 @@ end x = Var("x") prob = sigmoid(x) prob2 = exp(LogPr(flip(prob) & flip(prob))) - loss = mle_loss([flip(prob2) & flip(prob2) & !flip(prob2)]) + bool = flip(prob2) & flip(prob2) & !flip(prob2) + loss = mle_loss([bool]) var_vals = Valuation(x => 0) train!(var_vals, loss, epochs=2000, learning_rate=0.1) @@ -70,4 +71,5 @@ end # therefore, prob should be sqrt(2/3) @test compute(var_vals, prob) ≈ sqrt(2/3) @test compute_mixed(var_vals, loss) ≈ -log(2/3*2/3*1/3) + pr_mixed(var_vals)(bool) end