Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReverseDiffAdjoint does not support non-Array types #489

Closed
JTaets opened this issue Sep 10, 2021 · 9 comments
Closed

ReverseDiffAdjoint does not support non-Array types #489

JTaets opened this issue Sep 10, 2021 · 9 comments

Comments

@JTaets
Copy link

JTaets commented Sep 10, 2021

On the latest version of DiffEqSensitivity, using ComponentArrays with ReverseDiffSensitivity fails.

using Pkg, DiffEqSensitivity, OrdinaryDiffEq, Zygote, ComponentArrays, ReverseDiff
Pkg.status("DiffEqSensitivity")
# DiffEqSensitivity v6.58.0

## Proof that ReverseDiff works with ComponentArrays
ReverseDiff.gradient(x->sum(x),ComponentArray(a=1.,b=7.,c=8.))
# ComponentVector{Float64}(a = 1.0, b = 1.0, c = 1.0)


## Proof that ReverseDiff for ODEs doesn't work with ComponentArrays
u0 = [1.0;1.0]
p = ComponentArray(a=1.5,b=1.0,c=3.0,d=1.0)
function fiip(du,u,p,t)
  du[1] = dx = p.a*u[1] - p.b*u[1]*u[2]
  du[2] = dy = -p.c*u[2] + p.d*u[1]*u[2]
end
prob = ODEProblem(fiip,u0,(0.0,10.0),p)

# ForwardDiff works
loss(p) = sum(solve(prob,Tsit5(),u0=u0,p=p,saveat=0.1,sensealg=ForwardDiffSensitivity()))
dp = Zygote.gradient(loss,p)[1]
# ComponentVector{Float64}(a = 7.349039781610272, b = -159.31079871982794, c = 74.93924771425637, d = -339.3272371527868)

# ReverseDiff fails
loss2(p) = sum(solve(prob,Tsit5(),u0=u0,p=p,saveat=0.1,sensealg=ReverseDiffAdjoint()))
dp2 = Zygote.gradient(loss2,p)[1]
# ComponentVector{Float64}(a = 0.0, b = 0.0, c = 0.0, d = 0.0)
@ChrisRackauckas ChrisRackauckas changed the title ReverseDiffAdjoint not working with ComponentArray ReverseDiffAdjoint does not support non-Array types Sep 10, 2021
@ChrisRackauckas
Copy link
Member

This is an upstream issue. The real problem is that ReverseDiff doesn't work with non-Array types, so this probably isn't fixable.

@JTaets
Copy link
Author

JTaets commented Sep 10, 2021

But something like this works. Am I missing something?

ReverseDiff.gradient(x->sum(x),ComponentArray(a=1.,b=7.,c=8.))
# ComponentVector{Float64}(a = 1.0, b = 1.0, c = 1.0)

(If it is due to ReverseDiff, than you can close this issue if you want.)

@ChrisRackauckas
Copy link
Member

do something like sum((x+x).a + (x+x).b)? I think it's that the output of no operations can result in an AbstractArray type IIRC. The pushforward of the operations will change it internally from a ComponentVector to a Vector, but only in the context of the AD pushforward.

(If it is due to ReverseDiff, than you can close this issue if you want.)

This is just a known feature of ReverseDiff, and not something that can be fixed due to its design IIRC.

@JTaets JTaets closed this as completed Sep 10, 2021
@acertain
Copy link

@ChrisRackauckas are you sure? SciML/ComponentArrays.jl#37 , also your ReverseDiff.gradient(x->sum((x+x).a+(x+x).b),ComponentArray(a=1.,b=7.,c=8.)) example works:

julia> ReverseDiff.gradient(x->sum((x+x).a+(x+x).b),ComponentArray(a=1.,b=7.,c=8.))
ComponentVector{Float64}(a = 2.0, b = 2.0, c = 0.0)

@ChrisRackauckas
Copy link
Member

This response was written in 2021. It means nothing for today.

@acertain
Copy link

acertain commented Aug 10, 2024

Then maybe this should be reopened? afaict all of the reverse mode adjoints are either broken (Zygote) or don't work with ComponentArrays (ReverseDiff, Tracker).

@ChrisRackauckas
Copy link
Member

But those are all upstream issues. Nothing in this library can fix the limitations of ReverseDiff.jl, Tracker.jl, or Zygote.jl. Those are issues that would need to be addressed on those respective libraries.

If anything is going to fix this issue, it's to make EnzymeAdjoint work out, which is SciML/OrdinaryDiffEq.jl#2282 and a few other threads.

@acertain
Copy link

What's the current problem with ReverseDiff?

@ChrisRackauckas
Copy link
Member

ReverseDiff.jl only supports defining @grad overloads on arrays and automatically converts all AbstractArray objects to an Array on gradient application. That's why it has issues with alternative array types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants