Skip to content

Commit

Permalink
Merge pull request #6 from JuliaDiffEq/neural_ode
Browse files Browse the repository at this point in the history
add neural ode and sde layers
  • Loading branch information
ChrisRackauckas authored Jan 22, 2019
2 parents 73aa7c8 + 3c42d3b commit a50aa3e
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module DiffEqFlux
using DiffEqBase, Flux, DiffResults, DiffEqSensitivity, ForwardDiff

include("Flux/layers.jl")
include("Flux/neural_de.jl")
include("Flux/utils.jl")

export diffeq_fd, diffeq_rd, diffeq_adjoint
export diffeq_fd, diffeq_rd, diffeq_adjoint, neural_ode, neural_msde
end
41 changes: 41 additions & 0 deletions src/Flux/neural_de.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
neural_ode_reduction(sol) = Array(sol)
neural_ode(x,model,tspan,args...;kwargs...) = neural_ode(x,model,tspan,
diffeq_adjoint,
args...;kwargs...)
function neural_ode(x,model,tspan,
ad_func::Function,
args...;kwargs...)
p = Flux.data(destructure(model))
dudt_(du,u::TrackedArray,p,t) = du .= restructure(model,p)(u)
dudt_(du,u::AbstractArray,p,t) = du .= Flux.data(restructure(model,p)(u))
prob = ODEProblem(dudt_,x,tspan,p)

if ad_func === diffeq_adjoint
return ad_func(p,prob,args...;kwargs...)
elseif ad_func === diffeq_fd
return ad_func(p,neural_ode_reduction,length(p),prob,args...;kwargs...)
else
return ad_func(p,neural_ode_reduction,prob,args...;kwargs...)
end
end

neural_msde(x,model,mp,tspan,args...;kwargs...) = neural_msde(x,model,mp,tspan,
diffeq_fd,
args...;kwargs...)
function neural_msde(x,model,mp,tspan,
ad_func::Function,
args...;kwargs...)
p = Flux.data(destructure(model))
dudt_(du,u::TrackedArray,p,t) = du .= restructure(model,p)(u)
dudt_(du,u::AbstractArray,p,t) = du .= Flux.data(restructure(model,p)(u))
g(du,u,p,t) = du .= mp.*u
prob = SDEProblem(dudt_,g,x,tspan,p)

if ad_func === diffeq_adjoint
return ad_func(p,prob,args...;kwargs...)
elseif ad_func === diffeq_fd
return ad_func(p,neural_ode_reduction,length(p),prob,args...;kwargs...)
else
return ad_func(p,neural_ode_reduction,prob,args...;kwargs...)
end
end
13 changes: 13 additions & 0 deletions test/neural_de.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using OrdinaryDiffEq, StochasticDiffEq, Flux, DiffEqFlux

x = Float32[2.; 0.]
tspan = (0.0f0,25.0f0)
dudt = Chain(Dense(2,50,tanh),Dense(50,2))

neural_ode(x,dudt,tspan,Tsit5(),save_everystep=false,save_start=false)
neural_ode(x,dudt,tspan,Tsit5(),saveat=0.1)
neural_ode(x,dudt,tspan,diffeq_adjoint,Tsit5(),saveat=0.1)
neural_ode(x,dudt,tspan,diffeq_fd,Tsit5(),saveat=0.1)
neural_ode(x,dudt,tspan,diffeq_rd,Tsit5(),saveat=0.1)

neural_msde(x,dudt,[0.1,0.1],tspan,SOSRI(),saveat=0.1)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ using DiffEqFlux, Test

include("layers.jl")
include("utils.jl")
include("neural_de.jl")

end

0 comments on commit a50aa3e

Please sign in to comment.