Skip to content

Commit

Permalink
Merge pull request #1135 from SciML/dg/nnrev2
Browse files Browse the repository at this point in the history
Avoid needing adjoints of SciMLStructures' constructor
  • Loading branch information
ChrisRackauckas authored Nov 8, 2024
2 parents 1458111 + 20d64fd commit 9143ea7
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 10 deletions.
39 changes: 31 additions & 8 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -639,15 +639,26 @@ function DiffEqBase._concrete_solve_adjoint(

du0 = reshape(du0, size(u0))

dp = p === nothing || p === SciMLBase.NullParameters() ? nothing :
dp isa AbstractArray ? reshape(dp', size(p)) : dp
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :
dp isa AbstractArray ? reshape(dp', size(tunables)) : dp

_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
!isscimlstructure(p)
nothing, x -> (x,)
else
Zygote.pullback(p) do p
t, _, _ = canonicalize(Tunable(), p)
t
end
end

if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
(NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
(NoTangent(), NoTangent(), NoTangent(),
du0, repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
Expand Down Expand Up @@ -835,7 +846,7 @@ function DiffEqBase._concrete_solve_adjoint(
pparts = typeof(tunables[1:1])[]
for j in 0:(num_chunks - 1)
local chunk
if ((j + 1) * chunk_size) <= length(p)
if ((j + 1) * chunk_size) <= length(tunables)
chunk = ((j * chunk_size + 1):((j + 1) * chunk_size))
pchunk = vec(tunables)[chunk]
pdualpart = seed_duals(pchunk, prob.f,
Expand Down Expand Up @@ -957,7 +968,7 @@ function DiffEqBase._concrete_solve_adjoint(
end
push!(pparts, vec(_dp))
end
SciMLStructures.replace(Tunable(), p, reduce(vcat, pparts))
reduce(vcat, pparts)
end
else
dp = nothing
Expand Down Expand Up @@ -1134,12 +1145,24 @@ function DiffEqBase._concrete_solve_adjoint(
end
end

_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
!isscimlstructure(p)
nothing, x -> (x,)
else
Zygote.pullback(p) do p
t, _, _ = canonicalize(Tunable(), p)
t
end
end

if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), unthunk(du0), unthunk(dp), NoTangent(),
(NoTangent(), NoTangent(), unthunk(du0),
repack_adjoint(unthunk(dp))[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
(NoTangent(), NoTangent(), NoTangent(),
du0, repack_adjoint(unthunk(dp))[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,8 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
ReverseDiff.reverse_pass!(tape)
copyto!(vec(out), ReverseDiff.deriv(tp))
elseif sensealg.autojacvec isa ZygoteVJP
_dy, back = Zygote.pullback(p) do p
vec(f(y, p, t))
_dy, back = Zygote.pullback(tunables) do tunables
vec(f(y, tunables, t))
end
tmp = back(λ)
if tmp[1] === nothing
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ end
@time @safetestset "Prob Kwargs" include("prob_kwargs.jl")
@time @safetestset "DiscreteProblem Adjoints" include("discrete.jl")
@time @safetestset "Time Type Mixing Adjoints" include("time_type_mixing.jl")
@time @safetestset "SciMLStructures Interface" include("scimlstructures_interface.jl")
end
end

Expand Down
80 changes: 80 additions & 0 deletions test/scimlstructures_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# taken from https://github.com/SciML/SciMLStructures.jl/pull/28
using OrdinaryDiffEq, SciMLSensitivity, Zygote
using LinearAlgebra
import SciMLStructures as SS

mutable struct SubproblemParameters{P, Q, R}
p::P # tunable
q::Q
r::R
end
mutable struct Parameters{P, C}
subparams::P
coeffs::C # tunable matrix
end
# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
# and `du[length(subparams)+1:end] .= coeffs * u`
function rhs!(du, u, p::Parameters, t)
for (i, subpars) in enumerate(p.subparams)
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
end
N = length(p.subparams)
mul!(view(du, (N + 1):(length(du))), p.coeffs, u)
return nothing
end
u = sin.(0.1:0.1:1.0)
subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5]
p = Parameters(subparams, cos.([0.1i + 0.33j for i in 1:5, j in 1:10]))
tspan = (0.0, 1.0)
prob = ODEProblem(rhs!, u, tspan, p)
solve(prob, Tsit5())

# Mark the struct as a SciMLStructure
SS.isscimlstructure(::Parameters) = true
# It is mutable
SS.ismutablescimlstructure(::Parameters) = true
# Only contains `Tunable` portion
# We could also add a `Constants` portion to contain the values that are
# not tunable. The implementation would be similar to this one.
SS.hasportion(::SS.Tunable, ::Parameters) = true
function SS.canonicalize(::SS.Tunable, p::Parameters)
# concatenate all tunable values into a single vector
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
# repack takes a new vector of the same length as `buffer`, and constructs
# a new `Parameters` object using the values from the new vector for tunables
# and retaining old values for other parameters. This is exactly what replace does,
# so we can use that instead.
repack = let p = p
function repack(newbuffer)
SS.replace(SS.Tunable(), p, newbuffer)
end
end
# the canonicalized vector, the repack function, and a boolean indicating
# whether the buffer aliases values in the parameter object (here, it doesn't)
return buffer, repack, false
end
function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
N = length(p.subparams) + length(p.coeffs)
@assert length(newbuffer) == N
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r)
for (i, subpar) in enumerate(p.subparams)]
coeffs = reshape(
view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs))
return Parameters(subparams, coeffs)
end
function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
N = length(p.subparams) + length(p.coeffs)
@assert length(newbuffer) == N
for (subpar, val) in zip(p.subparams, newbuffer)
subpar.p = val
end
copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer)))
return p
end

Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunables
newp = SS.replace(SS.Tunable(), p, tunables)
newprob = remake(prob; p = newp)
sol = solve(newprob, Tsit5())
return sum(sol.u[end])
end

0 comments on commit 9143ea7

Please sign in to comment.