Skip to content

Commit

Permalink
Merge pull request #1147 from SciML/dg/nnrev3
Browse files Browse the repository at this point in the history
chore: reinstate repack
  • Loading branch information
ChrisRackauckas authored Nov 18, 2024
2 parents 849d093 + 074bf22 commit aef3a2c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
copyto!(vec(out), ReverseDiff.deriv(tp))
elseif sensealg.autojacvec isa ZygoteVJP
_dy, back = Zygote.pullback(tunables) do tunables
vec(f(y, tunables, t))
vec(f(y, repack(tunables), t))
end
tmp = back(λ)
if tmp[1] === nothing
Expand Down
4 changes: 2 additions & 2 deletions test/alternative_ad_frontend.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,5 +247,5 @@ grad_fd = ForwardDiff.gradient(loss2, p)
grad_zg = Zygote.gradient(loss2, p)[1]
grad_rd = ReverseDiff.gradient(loss2, p)
@test grad_fdgrad_fi atol=1e-2
@test grad_fd grad_zg
@test grad_fd grad_rd
@test grad_fd grad_zg atol=1e-4
@test grad_fd grad_rd atol=1e-4

0 comments on commit aef3a2c

Please sign in to comment.