Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make time evolution solvers compatible with automatic differentiation #311

Merged
merged 21 commits into from
Nov 18, 2024

Conversation

albertomercurio
Copy link
Member

@albertomercurio albertomercurio commented Nov 14, 2024

Checklist

Thank you for contributing to QuantumToolbox.jl! Please make sure you have finished the following tasks before opening the PR.

  • Please read Contributing to QuantumToolbox.jl.
  • Any code changes were done in a way that does not break public API.
  • Appropriate tests were added and tested locally by running: make test.
  • Any code changes should be julia formatted by running: make format.
  • All documents (in docs/ folder) related to code changes were updated and able to build locally by running: make docs.
  • (If necessary) the CHANGELOG.md should be updated (regarding to the code changes) and built by running: make changelog.

Request for a review after you have completed all the tasks. If you have not finished them all, you can also open a Draft Pull Request to let the others know this on-going work.

Description

With this PR I change the structure of the time evolution solver in order to support automatic differentiation.

Thanks to the SciMLSensitivity.jl package, it is possible to compute the gradient of a differential equation. It is almost straightforward to do with ODE parameters as a Vector type, but it is not easy to implement when we have a complicated structure of the parameters as in the current case of the package, where we have many variables, progress bar, etc inside the params.

The main change here is to introduce a new struct for the parameters, instead of using the current NamedTuple. In this way, thanks to SciMLStructures.jl, we can say which part of the structure is differentiable and which not.

As a first step, I'm trying to simplify the structure of the params struct. This involves the creation of a custom struct to handle the ODEProblem generated by functions like sesolveProblem. In this way, many variables can be removed from params.

Currently, there are some limitations on the type of the differentiable part of the params struct, and the only supported type is the Vector one. For example, the params kwarg in the mesolve has to be a Vector and not a NamedTuple. See this issue for more information. Nonetheless, the NamedTuple type is still supported in standard simulations, where the gradient is not needed.

EDIT:

The custom struct TimeEvolutionParameters is no longer needed. We can pass all the cache and temporary variables to their respective callbacks. In this way, the params variable is only composed by the true parameters (usually a Vector{Number}.

To Do:

  • Implement a custom struct for the ODE Parameters
  • Implement a custom struct for the ODEProblems generated by the solvers
  • Implement the SciMLStructures.jl rules for the custom params struct
  • Make sesolve differentiable
  • Make mesolve differentiable
  • Make mcsolve differentiable (maybe in another PR?)

@ytdHuang
Copy link
Member

ytdHuang commented Nov 15, 2024

I think for some of the parameters which must be used in all solvers (e.g., tlist and progress_bar, dims ...) can still be a independent field and put it in QuantumTimeEvoParameters.

In this case, we don't need to define QuantumTimeEvoProblem right ? Just also put them in QuantumTimeEvoParameters.

BTW, maybe call it TimeEvoParameters is also a good choice.

@albertomercurio
Copy link
Member Author

Actually I'm trying to simplify it as much as possible. The gradient calculation of sesolve still complaints about the ProgressBar. So I think that I will try to remove it in the next commits to see if this fixes the problem. I'm basically trying to leave inside the params struct only AbstractArrays.

@albertomercurio
Copy link
Member Author

With the last commit, I finally succeeded to support automatic differentiation on sesolve. The following code works

using QuantumToolbox
using OrdinaryDiffEq
using SciMLSensitivity
using Zygote

##

const N = 20
const F = 1
const γ = 1

const a = destroy(N)

const ψ0 = fock(N, 0)

# coef1(p, t) = p.Δ
coef1(p, t) = p[1]

QobjEvo(a' * a, coef1) + F * (a' + a)

function ss_population(Δ)
  # H = Δ * a' * a + F * (a' + a)
  H = QobjEvo(a' * a, coef1) + F * (a' + a)
  c_ops = [sqrt(γ) * a]
  
  tlist = range(0, 1.5, 100)

  ρ_ss = sesolve(H, ψ0, tlist, progress_bar=Val(false), inplace=Val(false), params = [Δ], saveat = [tlist[end]]).states[end].data

  return real(sum(ρ_ss))
end

Δ = 1.0
ss_population(Δ)

##

Zygote.gradient(ss_population, Δ)

@albertomercurio albertomercurio marked this pull request as ready for review November 18, 2024 01:26
Copy link

codecov bot commented Nov 18, 2024

Codecov Report

Attention: Patch coverage is 90.25788% with 34 lines in your changes missing coverage. Please review.

Project coverage is 93.21%. Comparing base (afded9a) to head (a426df6).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
src/time_evolution/mcsolve.jl 70.76% 19 Missing ⚠️
...ution/callback_helpers/mcsolve_callback_helpers.jl 93.79% 9 Missing ⚠️
...ime_evolution/callback_helpers/callback_helpers.jl 95.45% 2 Missing ⚠️
...ution/callback_helpers/mesolve_callback_helpers.jl 86.66% 2 Missing ⚠️
...ution/callback_helpers/sesolve_callback_helpers.jl 88.88% 1 Missing ⚠️
src/time_evolution/time_evolution.jl 90.90% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #311      +/-   ##
==========================================
- Coverage   93.69%   93.21%   -0.48%     
==========================================
  Files          32       36       +4     
  Lines        2490     2581      +91     
==========================================
+ Hits         2333     2406      +73     
- Misses        157      175      +18     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@albertomercurio albertomercurio merged commit 2836696 into main Nov 18, 2024
13 of 16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants