Skip to content

Commit

Permalink
add LogPr
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Dec 22, 2023
1 parent 895cf0f commit 99eb642
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/autodiff/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ export value, compute, differentiate, value, Valuation, Derivs, compute_one, var
using DirectedAcyclicGraphs
using DataStructures: DefaultDict

Valuation = Dict{Var, ADNodeCompatible}
Valuation = Dict{Variable, ADNodeCompatible}
Derivs = Dict{ADNode, ADNodeCompatible}

function compute_one(root, vals::Dict{ADNode, <:ADNodeCompatible})
Expand Down
1 change: 1 addition & 0 deletions src/inference/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,6 @@ include("pr.jl")
# - train_group_probs!(::Vector{<:AnyBool}))
# - train_group_probs!(::Vector{<:Tuple{<:AnyBool, <:AnyBool}})
include("train_pr.jl")
include("train_pr_losses.jl")

include("sample.jl")
72 changes: 18 additions & 54 deletions src/inference/train_pr.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
# The bridge between autodiff and cudd
export step_vars!, train_pr!, total_logprob, valuation_to_flip_pr_resolver
using DataStructures: Queue

export step_vars!, train_pr!, BoolToMax, total_logprob, valuation_to_flip_pr_resolver, mle_loss, kl_divergence

struct BoolToMax
bool::AnyBool
evid::AnyBool
weight::Real
BoolToMax(bool, evid, weight) = new(bool & evid, evid, weight)
struct LogPr <: Variable
bool::Dist{Bool}
end

function BoolToMax(bool; evidence=true, weight=1)
BoolToMax(bool, evidence, weight)
end

# Find the log-probabilities and the log-probability gradient of a BDD
function add_scaled_dict!(
Expand All @@ -24,18 +18,22 @@ function add_scaled_dict!(
end
end


function step_pr!(
var_vals::Valuation,
loss::ADNode, # var(distbool) represents logpr of that distbool
loss::ADNode,
learning_rate::Real
)
# loss refers to logprs of bools
# error to do with var(true)? just make it a vector of anybool and don't filter
bools = Vector{Dist{Bool}}([n.id for n in variables(loss) if !(n.id isa Bool)])
# error to do with LogPr(true)? just make it a vector of anybool and don't filter
bools = Vector{Dist{Bool}}([
n.bool for n in variables(loss)
if !(n isa Var) && !(n.bool isa Bool)
])

# so, calculate these logprs
w = WMC(BDDCompiler(bools), valuation_to_flip_pr_resolver(var_vals))
bool_logprs = Valuation(Var(bool) => logprob(w, bool) for bool in bools)
bool_logprs = Valuation(LogPr(bool) => logprob(w, bool) for bool in bools)
# TODO: have differentiate return vals as well to avoid this compute
# or have it take vals
loss_val = compute(bool_logprs, [loss])[loss] # for return value only
Expand All @@ -46,7 +44,7 @@ function step_pr!(
# find grad of loss w.r.t. each flip's probability
grad = DefaultDict{Flip, Float64}(0.)
for bool in bools
add_scaled_dict!(grad, grad_logprob(w, bool), derivs[Var(bool)])
add_scaled_dict!(grad, grad_logprob(w, bool), derivs[LogPr(bool)])
end

# move blame from flips probabilities to their adnode params
Expand All @@ -73,43 +71,6 @@ function step_pr!(
loss_val
end

function mle_loss(bools_to_max::Vector{BoolToMax})
loss = 0
for b in bools_to_max
if b.evid === true
loss -= b.weight * Var(b.bool)
else
loss -= b.weight * (Var(b.bool) - Var(b.evid))
end
end
loss
end

function mle_loss(bools_to_max::Vector{<:AnyBool})
mle_loss([BoolToMax(b, true, 1) for b in bools_to_max])
end

# This is valid but not what we usually want: when training a dist, the reference
# distribution should be constant, and the other should be symbolic.
# reference distribution to be constant.
# function kl_divergence(p::Dist, q::Dict{<:Any, <:Real}, domain::Set{<:Pair{<:Any, <:Dist}})
# res = 0
# for (x, x_dist) in domain
# logpx = Var(prob_equals(p, x_dist)) # Var(b) represents the logpr of b
# res += exp(logpx) * (logpx - log(q[x]))
# end
# res
# end

function kl_divergence(p::Dict{<:Any, <:Real}, q::Dist, domain::Set{<:Pair{<:Any, <:Dist}})
res = 0
for (x, x_dist) in domain
logqx = Var(prob_equals(q, x_dist)) # Var(b) represents the logpr of b
res += p[x] * (log(p[x]) - logqx)
end
res
end

# Train group_to_psp to such that generate() approximates dataset's distribution
function train_pr!(
var_vals::Valuation,
Expand Down Expand Up @@ -137,8 +98,11 @@ function compute_loss(
var_vals::Valuation,
loss::ADNode
)
bools = Vector{Dist{Bool}}([n.id for n in variables(loss) if !(n.id isa Bool)])
bools = Vector{Dist{Bool}}([
n.bool for n in variables(loss)
if !(n isa Var) && !(n.bool isa Bool)
])
w = WMC(BDDCompiler(bools), valuation_to_flip_pr_resolver(var_vals))
bool_logprs = Valuation(Var(bool) => logprob(w, bool) for bool in bools)
bool_logprs = Valuation(LogPr(bool) => logprob(w, bool) for bool in bools)
compute(bool_logprs, [loss])[loss] # for return value only
end
50 changes: 50 additions & 0 deletions src/inference/train_pr_losses.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

export BoolToMax, mle_loss, kl_divergence

struct BoolToMax
bool::AnyBool
evid::AnyBool
weight::Real
BoolToMax(bool, evid, weight) = new(bool & evid, evid, weight)
end

function BoolToMax(bool; evidence=true, weight=1)
BoolToMax(bool, evidence, weight)
end

function mle_loss(bools_to_max::Vector{BoolToMax})
loss = 0
for b in bools_to_max
if b.evid === true
loss -= b.weight * LogPr(b.bool)
else
loss -= b.weight * (LogPr(b.bool) - LogPr(b.evid))
end
end
loss
end

function mle_loss(bools_to_max::Vector{<:AnyBool})
mle_loss([BoolToMax(b, true, 1) for b in bools_to_max])
end

# This is valid but not what we usually want: when training a dist, the reference
# distribution should be constant, and the other should be symbolic.
# reference distribution to be constant.
# function kl_divergence(p::Dist, q::Dict{<:Any, <:Real}, domain::Set{<:Pair{<:Any, <:Dist}})
# res = 0
# for (x, x_dist) in domain
# logpx = Var(prob_equals(p, x_dist)) # Var(b) represents the logpr of b
# res += exp(logpx) * (logpx - log(q[x]))
# end
# res
# end

function kl_divergence(p::Dict{<:Any, <:Real}, q::Dist, domain::Set{<:Pair{<:Any, <:Dist}})
res = 0
for (x, x_dist) in domain
logqx = LogPr(prob_equals(q, x_dist))
res += p[x] * (log(p[x]) - logqx)
end
res
end

0 comments on commit 99eb642

Please sign in to comment.