Skip to content

Commit

Permalink
remove Variable
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Dec 23, 2023
1 parent 6ef8ea5 commit 86ad4b9
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 20 deletions.
15 changes: 5 additions & 10 deletions src/autodiff/adnode.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
export ADNode, ADMatrix, Variable, Var, ad_map, sigmoid, deriv_sigmoid, inverse_sigmoid
export ADNode, ADMatrix, Var, ad_map, sigmoid, deriv_sigmoid, inverse_sigmoid

import DirectedAcyclicGraphs: NodeType, DAG, children

abstract type ADNode <: DAG end

# We always differentiate with respect to Variables
abstract type Variable <: ADNode end

NodeType(::Type{<:Variable}) = Leaf()
compute_leaf(x::Variable) = error("The value of $(x) should've been provided in `vals`!")
backward(::Variable, _, _) = nothing

ADNodeCompatible = Union{Real, AbstractMatrix{<:Real}}

function add_deriv(derivs, n::ADNode, amount::ADNodeCompatible)
Expand All @@ -21,10 +14,12 @@ function add_deriv(derivs, n::ADNode, amount::ADNodeCompatible)
end
end


struct Var <: Variable
struct Var <: ADNode
id::Any
end
NodeType(::Type{Var}) = Leaf()
compute_leaf(x::Var) = error("The value of $(x) should've been provided in `vals`!")
backward(::Var, _, _) = nothing
function Base.show(io::IO, x::Var)
print(io, "Var(")
show(io, x.id)
Expand Down
8 changes: 2 additions & 6 deletions src/autodiff/core.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
export value, compute, differentiate, value, Valuation, Derivs, compute_one, variables
export value, compute, differentiate, value, Valuation, Derivs, compute_one

using DirectedAcyclicGraphs
using DataStructures: DefaultDict

Valuation = Dict{Variable, ADNodeCompatible}
Valuation = Dict{Var, ADNodeCompatible}
Derivs = Dict{ADNode, ADNodeCompatible}

function compute_one(root, vals::Dict{ADNode, <:ADNodeCompatible})
Expand Down Expand Up @@ -51,7 +51,3 @@ function foreach_down(f::Function, roots)
f(n)
end
end

function variables(x::ADNode)
filter(node -> node isa Variable, x)
end
4 changes: 0 additions & 4 deletions src/inference/train_pr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ mutable struct LogPrExpander
end
end



function expand_logprs(l::LogPrExpander, root::ADNode)::ADNode
fl(x::LogPr) = expand_logprs(l, logprob(l.w, x.bool))
fl(x::Var) = x
Expand Down Expand Up @@ -54,7 +52,6 @@ function bool_roots(root::ADNode)
setdiff(keys(seen_bools), non_roots)
end


# Find the log-probabilities and the log-probability gradient of a BDD
function add_scaled_dict!(
x::AbstractDict{<:Any, <:Real},
Expand All @@ -66,7 +63,6 @@ function add_scaled_dict!(
end
end


function step_pr!(
var_vals::Valuation,
loss::ADNode,
Expand Down

0 comments on commit 86ad4b9

Please sign in to comment.