-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make time evolution solvers compatible with automatic differentiation #311
Conversation
I think for some of the parameters which must be used in all solvers (e.g., In this case, we don't need to define BTW, maybe call it |
Actually I'm trying to simplify it as much as possible. The gradient calculation of |
bf3450d
to
5579426
Compare
With the last commit, I finally succeeded to support automatic differentiation on using QuantumToolbox
using OrdinaryDiffEq
using SciMLSensitivity
using Zygote
##
const N = 20
const F = 1
const γ = 1
const a = destroy(N)
const ψ0 = fock(N, 0)
# coef1(p, t) = p.Δ
coef1(p, t) = p[1]
QobjEvo(a' * a, coef1) + F * (a' + a)
function ss_population(Δ)
# H = Δ * a' * a + F * (a' + a)
H = QobjEvo(a' * a, coef1) + F * (a' + a)
c_ops = [sqrt(γ) * a]
tlist = range(0, 1.5, 100)
ρ_ss = sesolve(H, ψ0, tlist, progress_bar=Val(false), inplace=Val(false), params = [Δ], saveat = [tlist[end]]).states[end].data
return real(sum(ρ_ss))
end
Δ = 1.0
ss_population(Δ)
##
Zygote.gradient(ss_population, Δ) |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #311 +/- ##
==========================================
- Coverage 93.69% 93.21% -0.48%
==========================================
Files 32 36 +4
Lines 2490 2581 +91
==========================================
+ Hits 2333 2406 +73
- Misses 157 175 +18 ☔ View full report in Codecov by Sentry. |
Checklist
Thank you for contributing to
QuantumToolbox.jl
! Please make sure you have finished the following tasks before opening the PR.make test
.julia
formatted by running:make format
.docs/
folder) related to code changes were updated and able to build locally by running:make docs
.CHANGELOG.md
should be updated (regarding to the code changes) and built by running:make changelog
.Request for a review after you have completed all the tasks. If you have not finished them all, you can also open a Draft Pull Request to let the others know this on-going work.
Description
With this PR I change the structure of the time evolution solver in order to support automatic differentiation.
Thanks to the SciMLSensitivity.jl package, it is possible to compute the gradient of a differential equation. It is almost straightforward to do with ODE parameters as a
Vector
type, but it is not easy to implement when we have a complicated structure of the parameters as in the current case of the package, where we have many variables, progress bar, etc inside the params.The main change here is to introduce a new struct for the parameters, instead of using the current
NamedTuple
. In this way, thanks to SciMLStructures.jl, we can say which part of the structure is differentiable and which not.As a first step, I'm trying to simplify the structure of the
params
struct. This involves the creation of a custom struct to handle theODEProblem
generated by functions likesesolveProblem
. In this way, many variables can be removed fromparams
.Currently, there are some limitations on the type of the differentiable part of the
params
struct, and the only supported type is theVector
one. For example, theparams
kwarg in themesolve
has to be aVector
and not aNamedTuple
. See this issue for more information. Nonetheless, theNamedTuple
type is still supported in standard simulations, where the gradient is not needed.EDIT:
The custom struct
TimeEvolutionParameters
is no longer needed. We can pass all the cache and temporary variables to their respective callbacks. In this way, theparams
variable is only composed by the true parameters (usually aVector{Number}
.To Do:
ODEProblem
s generated by the solversSciMLStructures.jl
rules for the customparams
structsesolve
differentiablemesolve
differentiablemcsolve
differentiable (maybe in another PR?)