From 8b26c43245e3099dfb3e5ca0f043183f75d9a87c Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 11 Nov 2024 17:41:50 +0530 Subject: [PATCH 1/2] chore: reinstate repack --- src/gauss_adjoint.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index baa8df50b..238d3f2ba 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -483,7 +483,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) copyto!(vec(out), ReverseDiff.deriv(tp)) elseif sensealg.autojacvec isa ZygoteVJP _dy, back = Zygote.pullback(tunables) do tunables - vec(f(y, tunables, t)) + vec(f(y, repack(tunables), t)) end tmp = back(λ) if tmp[1] === nothing From 074bf22382ad66d2700dae076520399710419ee8 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 18 Nov 2024 21:01:27 +0530 Subject: [PATCH 2/2] test: set lower atol --- test/alternative_ad_frontend.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/alternative_ad_frontend.jl b/test/alternative_ad_frontend.jl index 51363fee9..79b45009c 100644 --- a/test/alternative_ad_frontend.jl +++ b/test/alternative_ad_frontend.jl @@ -247,5 +247,5 @@ grad_fd = ForwardDiff.gradient(loss2, p) grad_zg = Zygote.gradient(loss2, p)[1] grad_rd = ReverseDiff.gradient(loss2, p) @test grad_fd≈grad_fi atol=1e-2 -@test grad_fd ≈ grad_zg -@test grad_fd ≈ grad_rd +@test grad_fd ≈ grad_zg atol=1e-4 +@test grad_fd ≈ grad_rd atol=1e-4