From 01184f007342ff6fbb8c3702c0052a55841f2d7e Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 9 Apr 2019 18:14:37 -0400 Subject: [PATCH 1/3] fix and test partial neural ODEs --- src/Flux/layers.jl | 4 ++-- test/partial_neural.jl | 38 ++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 test/partial_neural.jl diff --git a/src/Flux/layers.jl b/src/Flux/layers.jl index 42a7452539..2884756455 100644 --- a/src/Flux/layers.jl +++ b/src/Flux/layers.jl @@ -4,7 +4,7 @@ using DiffEqSensitivity: adjoint_sensitivities_u0 ## Reverse-Mode via Flux.jl function diffeq_rd(p,prob,args...;u0=prob.u0,kwargs...) - if typeof(u0) <: AbstractArray + if typeof(u0) <: AbstractArray && !(typeof(u0) <: TrackedArray) if DiffEqBase.isinplace(prob) # use Array{TrackedReal} for mutation to work # Recurse to all Array{TrackedArray} @@ -14,7 +14,7 @@ function diffeq_rd(p,prob,args...;u0=prob.u0,kwargs...) _prob = remake(prob,u0=convert(recursive_bottom_eltype(p),u0),p=p) end else # u0 is functional, ignore the change - _prob = remake(prob,p=p) + _prob = remake(prob,u0=u0,p=p) end solve(_prob,args...;kwargs...) end diff --git a/test/partial_neural.jl b/test/partial_neural.jl new file mode 100644 index 0000000000..5c907ed521 --- /dev/null +++ b/test/partial_neural.jl @@ -0,0 +1,38 @@ +using DiffEqFlux, Flux, OrdinaryDiffEq, Plots + +x = Float32[0.8; 0.8] +tspan = (0.0f0,25.0f0) + +ann = Chain(Dense(2,10,tanh), Dense(10,1)) +p = param(Float32[-2.0,1.1]) + +function dudt_(u::TrackedArray,p,t) + x, y = u + Flux.Tracker.collect([ann(u)[1],p[1]*y + p[2]*x]) +end +function dudt_(u::AbstractArray,p,t) + x, y = u + [Flux.data(ann(u)[1]),p[1]*y + p[2]*x*y] +end + +prob = ODEProblem(dudt_,x,tspan,p) +diffeq_rd(p,prob,Tsit5()) +_x = param(x) + +function predict_rd() + Flux.Tracker.collect(diffeq_rd(p,prob,Tsit5(),u0=_x)) +end +loss_rd() = sum(abs2,x-1 for x in predict_rd()) +loss_rd() + +data = Iterators.repeated((), 10) +opt = ADAM(0.1) +cb = function () + display(loss_rd()) + display(plot(solve(remake(prob,u0=Flux.data(_x),p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6))) +end + +# Display the ODE with the current parameter values. +cb() + +Flux.train!(loss_rd, params(ann,p,_x), data, opt, cb = cb) diff --git a/test/runtests.jl b/test/runtests.jl index afee9ff236..6c008ace81 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,6 @@ using DiffEqFlux, Test include("layers.jl") include("utils.jl") include("neural_de.jl") +include("partial_neural.jl") end From ecd240c496d9e9007b3225d6209e2f23656c7dd9 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 9 Apr 2019 18:52:52 -0400 Subject: [PATCH 2/3] add partial neural adjoint --- test/partial_neural.jl | 42 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/test/partial_neural.jl b/test/partial_neural.jl index 5c907ed521..fb2be51ddb 100644 --- a/test/partial_neural.jl +++ b/test/partial_neural.jl @@ -1,4 +1,4 @@ -using DiffEqFlux, Flux, OrdinaryDiffEq, Plots +using DiffEqFlux, Flux, OrdinaryDiffEq x = Float32[0.8; 0.8] tspan = (0.0f0,25.0f0) @@ -29,10 +29,48 @@ data = Iterators.repeated((), 10) opt = ADAM(0.1) cb = function () display(loss_rd()) - display(plot(solve(remake(prob,u0=Flux.data(_x),p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6))) + #display(plot(solve(remake(prob,u0=Flux.data(_x),p=Flux.data(p)),Tsit5(),saveat=0.1),ylim=(0,6))) end # Display the ODE with the current parameter values. cb() Flux.train!(loss_rd, params(ann,p,_x), data, opt, cb = cb) + +## Partial Neural Adjoint + +u0 = param(Float32[0.8; 0.8]) +tspan = (0.0f0,25.0f0) + +ann = Chain(Dense(2,10,tanh), Dense(10,1)) + +p1 = Flux.data(DiffEqFlux.destructure(ann)) +p2 = Float32[-2.0,1.1] +p3 = param([p1;p2]) +ps = Flux.params(p3,u0) + +function dudt_(du,u,p,t) + x, y = u + du[1] = DiffEqFlux.restructure(ann,p[1:41])(u)[1] + du[2] = p[end-1]*y + p[end]*x +end +prob = ODEProblem(dudt_,u0,tspan,p3) +diffeq_adjoint(p3,prob,Tsit5(),u0=u0,abstol=1e-8,reltol=1e-6) + +function predict_adjoint() + diffeq_adjoint(p3,prob,Tsit5(),u0=u0,saveat=0.0:0.1:25.0) +end +loss_adjoint() = sum(abs2,x-1 for x in predict_adjoint()) +loss_adjoint() + +data = Iterators.repeated((), 10) +opt = ADAM(0.1) +cb = function () + display(loss_adjoint()) + #display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6))) +end + +# Display the ODE with the current parameter values. +cb() + +Flux.train!(loss_adjoint, ps, data, opt, cb = cb) From b65d3426aef0684602d40194a5f45124cc5ccce6 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 10 Apr 2019 10:37:43 -0400 Subject: [PATCH 3/3] fix out of place type conversions --- src/Flux/layers.jl | 2 +- test/REQUIRE | 1 + test/runtests.jl | 14 +++++--------- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/Flux/layers.jl b/src/Flux/layers.jl index 2884756455..5c89dccb69 100644 --- a/src/Flux/layers.jl +++ b/src/Flux/layers.jl @@ -11,7 +11,7 @@ function diffeq_rd(p,prob,args...;u0=prob.u0,kwargs...) _prob = remake(prob,u0=convert.(recursive_bottom_eltype(p),u0),p=p) else # use TrackedArray for efficiency of the tape - _prob = remake(prob,u0=convert(recursive_bottom_eltype(p),u0),p=p) + _prob = remake(prob,u0=convert(typeof(p),u0),p=p) end else # u0 is functional, ignore the change _prob = remake(prob,u0=u0,p=p) diff --git a/test/REQUIRE b/test/REQUIRE index fba0dbd5bd..da448ac078 100644 --- a/test/REQUIRE +++ b/test/REQUIRE @@ -1,3 +1,4 @@ OrdinaryDiffEq StochasticDiffEq DelayDiffEq +SafeTestsets diff --git a/test/runtests.jl b/test/runtests.jl index 6c008ace81..8d7d779e1c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,10 +1,6 @@ -using DiffEqFlux, Test +using DiffEqFlux, Test, SafeTestsets -@testset "DiffEqFlux" begin - -include("layers.jl") -include("utils.jl") -include("neural_de.jl") -include("partial_neural.jl") - -end +@safetestset "Utils Tests" begin include("utils.jl") end +@safetestset "Layers Tests" begin include("layers.jl") end +@safetestset "Neural DE Tests" begin include("neural_de.jl") end +@safetestset "Partial Neural Tests" begin include("partial_neural.jl") end