Skip to content

Commit

Permalink
Make time evolution solvers compatible with automatic differentiation (
Browse files Browse the repository at this point in the history
…#311)

* Working sesolve

* add `inplace` keywork argument

* add SciMLStructures and relax params type

* Working mcsolve (no type-stability)

* Fix type-instabilities for mcsolve

* Add SciMLStructures.jl methods

* Add callbacks helpers

* Fix dsf_mcsolve

* Remove ProgressBar from ODE parameters

* Fix abstol and reltol extraction

* Use Base allequal function

* Remove expvals from TimeEvolutionParameters

* Make NullParameters as default for params

* Remove custom PresetTimeCallback

* Update description of `inplace` argument

* Working mesolve

* Fix dfd_mesolve and dsf_mesolve

* Remove TimeEvolutionParameters (type-unstable)

* Fix type instabilities

* Fix type instabilities on Julia v1.10

* Format document
  • Loading branch information
albertomercurio authored Nov 18, 2024
1 parent 9567c45 commit 2836696
Show file tree
Hide file tree
Showing 17 changed files with 876 additions and 561 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- *__We will start to write changelog once we have the first standard release.__*
- Change the parameters structure of `sesolve`, `mesolve` and `mcsolve` functions to possibly support automatic differentiation. ([#311])


## [v0.21.5] (2024-11-15)

Expand All @@ -21,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
<!-- Links generated by Changelog.jl -->

[v0.21.4]: https://github.com/qutip/QuantumToolbox.jl/releases/tag/v0.21.4
[v0.21.5]: https://github.com/qutip/QuantumToolbox.jl/releases/tag/v0.21.5
[#139]: https://github.com/qutip/QuantumToolbox.jl/issues/139
[#306]: https://github.com/qutip/QuantumToolbox.jl/issues/306
[#309]: https://github.com/qutip/QuantumToolbox.jl/issues/309
[#311]: https://github.com/qutip/QuantumToolbox.jl/issues/311
1 change: 1 addition & 0 deletions docs/src/resources/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ qeye
## [Time evolution](@id doc-API:Time-evolution)

```@docs
TimeEvolutionProblem
TimeEvolutionSol
TimeEvolutionMCSol
TimeEvolutionSSESol
Expand Down
11 changes: 10 additions & 1 deletion src/QuantumToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,22 @@ import SciMLBase:
reinit!,
remake,
u_modified!,
NullParameters,
ODEFunction,
ODEProblem,
SDEProblem,
EnsembleProblem,
EnsembleSerial,
EnsembleThreads,
EnsembleSplitThreads,
EnsembleDistributed,
FullSpecialize,
CallbackSet,
ContinuousCallback,
DiscreteCallback
DiscreteCallback,
AbstractSciMLProblem,
AbstractODEIntegrator,
AbstractODESolution
import StochasticDiffEq: StochasticDiffEqAlgorithm, SRA1
import SciMLOperators:
SciMLOperators,
Expand Down Expand Up @@ -88,6 +93,10 @@ include("qobj/synonyms.jl")

# time evolution
include("time_evolution/time_evolution.jl")
include("time_evolution/callback_helpers/sesolve_callback_helpers.jl")
include("time_evolution/callback_helpers/mesolve_callback_helpers.jl")
include("time_evolution/callback_helpers/mcsolve_callback_helpers.jl")
include("time_evolution/callback_helpers/callback_helpers.jl")
include("time_evolution/mesolve.jl")
include("time_evolution/lr_mesolve.jl")
include("time_evolution/sesolve.jl")
Expand Down
3 changes: 1 addition & 2 deletions src/correlations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ function correlation_3op_2t(
(H.dims == ψ0.dims && H.dims == A.dims && H.dims == B.dims && H.dims == C.dims) ||
throw(DimensionMismatch("The quantum objects are not of the same Hilbert dimension."))

kwargs2 = (; kwargs...)
kwargs2 = merge(kwargs2, (saveat = collect(t_l),))
kwargs2 = merge((saveat = collect(t_l),), (; kwargs...))
ρt = mesolve(H, ψ0, t_l, c_ops; kwargs2...).states

corr = map((t, ρ) -> mesolve(H, C * ρ * A, τ_l .+ t, c_ops, e_ops = [B]; kwargs...).expect[1, :], t_l, ρt)
Expand Down
17 changes: 9 additions & 8 deletions src/qobj/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,14 +391,15 @@ function eigsolve_al(
kwargs...,
) where {DT1,HOpType<:Union{OperatorQuantumObject,SuperOperatorQuantumObject}}
L_evo = _mesolve_make_L_QobjEvo(H, c_ops)
prob = mesolveProblem(
L_evo,
QuantumObject(ρ0, type = Operator, dims = H.dims),
[zero(T), T];
params = params,
progress_bar = Val(false),
kwargs...,
)
prob =
mesolveProblem(
L_evo,
QuantumObject(ρ0, type = Operator, dims = H.dims),
[zero(T), T];
params = params,
progress_bar = Val(false),
kwargs...,
).prob
integrator = init(prob, alg)

# prog = ProgressUnknown(desc="Applications:", showspeed = true, enabled=progress)
Expand Down
2 changes: 1 addition & 1 deletion src/qobj/quantum_object_evo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ Parse the `op_func_list` and generate the data for the `QuantumObjectEvolution`
quote
dims = tuple($(dims_expr...))

length(unique(dims)) == 1 || throw(ArgumentError("The dimensions of the operators must be the same."))
allequal(dims) || throw(ArgumentError("The dimensions of the operators must be the same."))

data_expr_const = $qobj_expr_const isa Integer ? $qobj_expr_const : _make_SciMLOperator($qobj_expr_const, α)

Expand Down
89 changes: 89 additions & 0 deletions src/time_evolution/callback_helpers/callback_helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#=
This file contains helper functions for callbacks. The affect! function are defined taking advantage of the Julia struct, which allows to store some cache exclusively for the callback.
=#

##

# Multiple dispatch depending on the progress_bar and e_ops types
function _generate_se_me_kwargs(e_ops, progress_bar, tlist, kwargs, method)
cb = _generate_save_callback(e_ops, tlist, progress_bar, method)
return _merge_kwargs_with_callback(kwargs, cb)
end
_generate_se_me_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs, method) = kwargs

function _merge_kwargs_with_callback(kwargs, cb)
kwargs2 =
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb, kwargs.callback),)) :
merge(kwargs, (callback = cb,))

return kwargs2
end

function _generate_save_callback(e_ops, tlist, progress_bar, method)
e_ops_data = e_ops isa Nothing ? nothing : _get_e_ops_data(e_ops, method)

progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing

expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))

_save_affect! = method(e_ops_data, progr, Ref(1), expvals)
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
end

_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops)
_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10

_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op)))

##

# When e_ops is Nothing. Common for both mesolve and sesolve
function _save_func(integrator, progr)
next!(progr)
u_modified!(integrator, false)
return nothing
end

# When progr is Nothing. Common for both mesolve and sesolve
function _save_func(integrator, progr::Nothing)
u_modified!(integrator, false)
return nothing
end

##

# Get the e_ops from a given AbstractODESolution. Valid for `sesolve`, `mesolve` and `ssesolve`.
function _se_me_sse_get_expvals(sol::AbstractODESolution)
cb = _se_me_sse_get_save_callback(sol)
if cb isa Nothing
return nothing
else
return cb.affect!.expvals
end
end

function _se_me_sse_get_save_callback(sol::AbstractODESolution)
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
if hasproperty(kwargs, :callback)
return _se_me_sse_get_save_callback(kwargs.callback)
else
return nothing
end
end
_se_me_sse_get_save_callback(integrator::AbstractODEIntegrator) = _se_me_sse_get_save_callback(integrator.opts.callback)
function _se_me_sse_get_save_callback(cb::CallbackSet)
cbs_discrete = cb.discrete_callbacks
if length(cbs_discrete) > 0
_cb = cb.discrete_callbacks[1]
return _se_me_sse_get_save_callback(_cb)
else
return nothing
end
end
_se_me_sse_get_save_callback(cb::DiscreteCallback) =
if (cb.affect! isa SaveFuncSESolve) || (cb.affect! isa SaveFuncMESolve)
return cb
else
return nothing
end
_se_me_sse_get_save_callback(cb::ContinuousCallback) = nothing
Loading

0 comments on commit 2836696

Please sign in to comment.