Skip to content

Commit

Permalink
Don't allow picking GaussAdjoint with TrackerVJP in default alg
Browse files Browse the repository at this point in the history
Fixes #1116 . I don't have a case that actually hits this though, no MWE provided, so I can't add a test but it's at least a thing that needs to be done.
  • Loading branch information
ChrisRackauckas authored Nov 9, 2024
1 parent 6d81222 commit 8ec9647
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ function automatic_sensealg_choice(
# QuadratureAdjoint skips all p calculations until the end
# So it's the fastest when there are no parameters
QuadratureAdjoint(autodiff = false, autojacvec = vjp)
elseif prob isa ODEProblem
elseif prob isa ODEProblem && !(vjp isa TrackerVJP)
GaussAdjoint(autodiff = false, autojacvec = vjp)
else
InterpolatingAdjoint(autodiff = false, autojacvec = vjp)
Expand All @@ -194,7 +194,7 @@ function automatic_sensealg_choice(
# QuadratureAdjoint skips all p calculations until the end
# So it's the fastest when there are no parameters
QuadratureAdjoint(autojacvec = vjp)
elseif prob isa ODEProblem
elseif prob isa ODEProblem && !(vjp isa TrackerVJP)
GaussAdjoint(autojacvec = vjp)
else
InterpolatingAdjoint(autojacvec = vjp)
Expand All @@ -209,15 +209,15 @@ function automatic_sensealg_choice(
# If reverse-mode isn't working, just fallback to numerical vjps
if p === nothing || p === SciMLBase.NullParameters()
QuadratureAdjoint(autodiff = false, autojacvec = vjp)
elseif prob isa ODEProblem
elseif prob isa ODEProblem && !(vjp isa TrackerVJP)
GaussAdjoint(autodiff = false, autojacvec = vjp)
else
InterpolatingAdjoint(autodiff = false, autojacvec = vjp)
end
else
if p === nothing || p === SciMLBase.NullParameters()
QuadratureAdjoint(autojacvec = vjp)
elseif prob isa ODEProblem
elseif prob isa ODEProblem && !(vjp isa TrackerVJP)
GaussAdjoint(autojacvec = vjp)
else
InterpolatingAdjoint(autojacvec = vjp)
Expand Down

0 comments on commit 8ec9647

Please sign in to comment.