From 8ec96472d7f5a31dce05f98b9da77673dcac7bd2 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 9 Nov 2024 14:35:28 -0100 Subject: [PATCH] Don't allow picking GaussAdjoint with TrackerVJP in default alg Fixes https://github.com/SciML/SciMLSensitivity.jl/issues/1116 . I don't have a case that actually hits this though, no MWE provided, so I can't add a test but it's at least a thing that needs to be done. --- src/concrete_solve.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index d646ddbf6..aeb85ad6f 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -184,7 +184,7 @@ function automatic_sensealg_choice( # QuadratureAdjoint skips all p calculations until the end # So it's the fastest when there are no parameters QuadratureAdjoint(autodiff = false, autojacvec = vjp) - elseif prob isa ODEProblem + elseif prob isa ODEProblem && !(vjp isa TrackerVJP) GaussAdjoint(autodiff = false, autojacvec = vjp) else InterpolatingAdjoint(autodiff = false, autojacvec = vjp) @@ -194,7 +194,7 @@ function automatic_sensealg_choice( # QuadratureAdjoint skips all p calculations until the end # So it's the fastest when there are no parameters QuadratureAdjoint(autojacvec = vjp) - elseif prob isa ODEProblem + elseif prob isa ODEProblem && !(vjp isa TrackerVJP) GaussAdjoint(autojacvec = vjp) else InterpolatingAdjoint(autojacvec = vjp) @@ -209,7 +209,7 @@ function automatic_sensealg_choice( # If reverse-mode isn't working, just fallback to numerical vjps if p === nothing || p === SciMLBase.NullParameters() QuadratureAdjoint(autodiff = false, autojacvec = vjp) - elseif prob isa ODEProblem + elseif prob isa ODEProblem && !(vjp isa TrackerVJP) GaussAdjoint(autodiff = false, autojacvec = vjp) else InterpolatingAdjoint(autodiff = false, autojacvec = vjp) @@ -217,7 +217,7 @@ function automatic_sensealg_choice( else if p === nothing || p === SciMLBase.NullParameters() QuadratureAdjoint(autojacvec = vjp) - elseif prob isa ODEProblem + elseif prob isa ODEProblem && !(vjp isa TrackerVJP) GaussAdjoint(autojacvec = vjp) else InterpolatingAdjoint(autojacvec = vjp)