diff --git a/src/autodiff/core.jl b/src/autodiff/core.jl index 14b56a8b..f8c1239a 100644 --- a/src/autodiff/core.jl +++ b/src/autodiff/core.jl @@ -3,7 +3,7 @@ export value, compute, differentiate, value, Valuation, Derivs, compute_one, var using DirectedAcyclicGraphs using DataStructures: DefaultDict -Valuation = Dict{Var, ADNodeCompatible} +Valuation = Dict{Variable, ADNodeCompatible} Derivs = Dict{ADNode, ADNodeCompatible} function compute_one(root, vals::Dict{ADNode, <:ADNodeCompatible}) diff --git a/src/inference/inference.jl b/src/inference/inference.jl index b44ee5ff..fbd8c0a5 100644 --- a/src/inference/inference.jl +++ b/src/inference/inference.jl @@ -146,5 +146,6 @@ include("pr.jl") # - train_group_probs!(::Vector{<:AnyBool})) # - train_group_probs!(::Vector{<:Tuple{<:AnyBool, <:AnyBool}}) include("train_pr.jl") +include("train_pr_losses.jl") include("sample.jl") \ No newline at end of file diff --git a/src/inference/train_pr.jl b/src/inference/train_pr.jl index 8d8c16d0..b0dd6deb 100644 --- a/src/inference/train_pr.jl +++ b/src/inference/train_pr.jl @@ -1,17 +1,11 @@ # The bridge between autodiff and cudd +export step_vars!, train_pr!, total_logprob, valuation_to_flip_pr_resolver +using DataStructures: Queue -export step_vars!, train_pr!, BoolToMax, total_logprob, valuation_to_flip_pr_resolver, mle_loss, kl_divergence - -struct BoolToMax - bool::AnyBool - evid::AnyBool - weight::Real - BoolToMax(bool, evid, weight) = new(bool & evid, evid, weight) +struct LogPr <: Variable + bool::Dist{Bool} end -function BoolToMax(bool; evidence=true, weight=1) - BoolToMax(bool, evidence, weight) -end # Find the log-probabilities and the log-probability gradient of a BDD function add_scaled_dict!( @@ -24,18 +18,22 @@ function add_scaled_dict!( end end + function step_pr!( var_vals::Valuation, - loss::ADNode, # var(distbool) represents logpr of that distbool + loss::ADNode, learning_rate::Real ) # loss refers to logprs of bools - # error to do with var(true)? just make it a vector of anybool and don't filter - bools = Vector{Dist{Bool}}([n.id for n in variables(loss) if !(n.id isa Bool)]) + # error to do with LogPr(true)? just make it a vector of anybool and don't filter + bools = Vector{Dist{Bool}}([ + n.bool for n in variables(loss) + if !(n isa Var) && !(n.bool isa Bool) + ]) # so, calculate these logprs w = WMC(BDDCompiler(bools), valuation_to_flip_pr_resolver(var_vals)) - bool_logprs = Valuation(Var(bool) => logprob(w, bool) for bool in bools) + bool_logprs = Valuation(LogPr(bool) => logprob(w, bool) for bool in bools) # TODO: have differentiate return vals as well to avoid this compute # or have it take vals loss_val = compute(bool_logprs, [loss])[loss] # for return value only @@ -46,7 +44,7 @@ function step_pr!( # find grad of loss w.r.t. each flip's probability grad = DefaultDict{Flip, Float64}(0.) for bool in bools - add_scaled_dict!(grad, grad_logprob(w, bool), derivs[Var(bool)]) + add_scaled_dict!(grad, grad_logprob(w, bool), derivs[LogPr(bool)]) end # move blame from flips probabilities to their adnode params @@ -73,43 +71,6 @@ function step_pr!( loss_val end -function mle_loss(bools_to_max::Vector{BoolToMax}) - loss = 0 - for b in bools_to_max - if b.evid === true - loss -= b.weight * Var(b.bool) - else - loss -= b.weight * (Var(b.bool) - Var(b.evid)) - end - end - loss -end - -function mle_loss(bools_to_max::Vector{<:AnyBool}) - mle_loss([BoolToMax(b, true, 1) for b in bools_to_max]) -end - -# This is valid but not what we usually want: when training a dist, the reference -# distribution should be constant, and the other should be symbolic. -# reference distribution to be constant. -# function kl_divergence(p::Dist, q::Dict{<:Any, <:Real}, domain::Set{<:Pair{<:Any, <:Dist}}) -# res = 0 -# for (x, x_dist) in domain -# logpx = Var(prob_equals(p, x_dist)) # Var(b) represents the logpr of b -# res += exp(logpx) * (logpx - log(q[x])) -# end -# res -# end - -function kl_divergence(p::Dict{<:Any, <:Real}, q::Dist, domain::Set{<:Pair{<:Any, <:Dist}}) - res = 0 - for (x, x_dist) in domain - logqx = Var(prob_equals(q, x_dist)) # Var(b) represents the logpr of b - res += p[x] * (log(p[x]) - logqx) - end - res -end - # Train group_to_psp to such that generate() approximates dataset's distribution function train_pr!( var_vals::Valuation, @@ -137,8 +98,11 @@ function compute_loss( var_vals::Valuation, loss::ADNode ) - bools = Vector{Dist{Bool}}([n.id for n in variables(loss) if !(n.id isa Bool)]) + bools = Vector{Dist{Bool}}([ + n.bool for n in variables(loss) + if !(n isa Var) && !(n.bool isa Bool) + ]) w = WMC(BDDCompiler(bools), valuation_to_flip_pr_resolver(var_vals)) - bool_logprs = Valuation(Var(bool) => logprob(w, bool) for bool in bools) + bool_logprs = Valuation(LogPr(bool) => logprob(w, bool) for bool in bools) compute(bool_logprs, [loss])[loss] # for return value only end diff --git a/src/inference/train_pr_losses.jl b/src/inference/train_pr_losses.jl new file mode 100644 index 00000000..5c4377fc --- /dev/null +++ b/src/inference/train_pr_losses.jl @@ -0,0 +1,50 @@ + +export BoolToMax, mle_loss, kl_divergence + +struct BoolToMax + bool::AnyBool + evid::AnyBool + weight::Real + BoolToMax(bool, evid, weight) = new(bool & evid, evid, weight) +end + +function BoolToMax(bool; evidence=true, weight=1) + BoolToMax(bool, evidence, weight) +end + +function mle_loss(bools_to_max::Vector{BoolToMax}) + loss = 0 + for b in bools_to_max + if b.evid === true + loss -= b.weight * LogPr(b.bool) + else + loss -= b.weight * (LogPr(b.bool) - LogPr(b.evid)) + end + end + loss +end + +function mle_loss(bools_to_max::Vector{<:AnyBool}) + mle_loss([BoolToMax(b, true, 1) for b in bools_to_max]) +end + +# This is valid but not what we usually want: when training a dist, the reference +# distribution should be constant, and the other should be symbolic. +# reference distribution to be constant. +# function kl_divergence(p::Dist, q::Dict{<:Any, <:Real}, domain::Set{<:Pair{<:Any, <:Dist}}) +# res = 0 +# for (x, x_dist) in domain +# logpx = Var(prob_equals(p, x_dist)) # Var(b) represents the logpr of b +# res += exp(logpx) * (logpx - log(q[x])) +# end +# res +# end + +function kl_divergence(p::Dict{<:Any, <:Real}, q::Dist, domain::Set{<:Pair{<:Any, <:Dist}}) + res = 0 + for (x, x_dist) in domain + logqx = LogPr(prob_equals(q, x_dist)) + res += p[x] * (log(p[x]) - logqx) + end + res +end \ No newline at end of file