From 283669648d286d20f24568fc1ed7143dd3a7d60d Mon Sep 17 00:00:00 2001 From: Alberto Mercurio <61953577+albertomercurio@users.noreply.github.com> Date: Mon, 18 Nov 2024 10:07:28 +0100 Subject: [PATCH] Make time evolution solvers compatible with automatic differentiation (#311) * Working sesolve * add `inplace` keywork argument * add SciMLStructures and relax params type * Working mcsolve (no type-stability) * Fix type-instabilities for mcsolve * Add SciMLStructures.jl methods * Add callbacks helpers * Fix dsf_mcsolve * Remove ProgressBar from ODE parameters * Fix abstol and reltol extraction * Use Base allequal function * Remove expvals from TimeEvolutionParameters * Make NullParameters as default for params * Remove custom PresetTimeCallback * Update description of `inplace` argument * Working mesolve * Fix dfd_mesolve and dsf_mesolve * Remove TimeEvolutionParameters (type-unstable) * Fix type instabilities * Fix type instabilities on Julia v1.10 * Format document --- CHANGELOG.md | 5 +- docs/src/resources/api.md | 1 + src/QuantumToolbox.jl | 11 +- src/correlations.jl | 3 +- src/qobj/eigsolve.jl | 17 +- src/qobj/quantum_object_evo.jl | 2 +- .../callback_helpers/callback_helpers.jl | 89 ++++ .../mcsolve_callback_helpers.jl | 340 +++++++++++++ .../mesolve_callback_helpers.jl | 39 ++ .../sesolve_callback_helpers.jl | 25 + src/time_evolution/mcsolve.jl | 445 ++++++------------ src/time_evolution/mesolve.jl | 105 ++--- src/time_evolution/sesolve.jl | 112 ++--- src/time_evolution/ssesolve.jl | 54 +-- src/time_evolution/time_evolution.jl | 60 ++- .../time_evolution_dynamical.jl | 94 ++-- test/core-test/time_evolution.jl | 35 +- 17 files changed, 876 insertions(+), 561 deletions(-) create mode 100644 src/time_evolution/callback_helpers/callback_helpers.jl create mode 100644 src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl create mode 100644 src/time_evolution/callback_helpers/mesolve_callback_helpers.jl create mode 100644 src/time_evolution/callback_helpers/sesolve_callback_helpers.jl diff --git a/CHANGELOG.md b/CHANGELOG.md index ef074162..df5f4204 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased -- *__We will start to write changelog once we have the first standard release.__* +- Change the parameters structure of `sesolve`, `mesolve` and `mcsolve` functions to possibly support automatic differentiation. ([#311]) + ## [v0.21.5] (2024-11-15) @@ -21,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [v0.21.4]: https://github.com/qutip/QuantumToolbox.jl/releases/tag/v0.21.4 +[v0.21.5]: https://github.com/qutip/QuantumToolbox.jl/releases/tag/v0.21.5 [#139]: https://github.com/qutip/QuantumToolbox.jl/issues/139 [#306]: https://github.com/qutip/QuantumToolbox.jl/issues/306 [#309]: https://github.com/qutip/QuantumToolbox.jl/issues/309 +[#311]: https://github.com/qutip/QuantumToolbox.jl/issues/311 diff --git a/docs/src/resources/api.md b/docs/src/resources/api.md index c12753b0..0b0c37a5 100644 --- a/docs/src/resources/api.md +++ b/docs/src/resources/api.md @@ -181,6 +181,7 @@ qeye ## [Time evolution](@id doc-API:Time-evolution) ```@docs +TimeEvolutionProblem TimeEvolutionSol TimeEvolutionMCSol TimeEvolutionSSESol diff --git a/src/QuantumToolbox.jl b/src/QuantumToolbox.jl index 4077ed05..676c61e3 100644 --- a/src/QuantumToolbox.jl +++ b/src/QuantumToolbox.jl @@ -23,17 +23,22 @@ import SciMLBase: reinit!, remake, u_modified!, + NullParameters, ODEFunction, ODEProblem, SDEProblem, EnsembleProblem, EnsembleSerial, EnsembleThreads, + EnsembleSplitThreads, EnsembleDistributed, FullSpecialize, CallbackSet, ContinuousCallback, - DiscreteCallback + DiscreteCallback, + AbstractSciMLProblem, + AbstractODEIntegrator, + AbstractODESolution import StochasticDiffEq: StochasticDiffEqAlgorithm, SRA1 import SciMLOperators: SciMLOperators, @@ -88,6 +93,10 @@ include("qobj/synonyms.jl") # time evolution include("time_evolution/time_evolution.jl") +include("time_evolution/callback_helpers/sesolve_callback_helpers.jl") +include("time_evolution/callback_helpers/mesolve_callback_helpers.jl") +include("time_evolution/callback_helpers/mcsolve_callback_helpers.jl") +include("time_evolution/callback_helpers/callback_helpers.jl") include("time_evolution/mesolve.jl") include("time_evolution/lr_mesolve.jl") include("time_evolution/sesolve.jl") diff --git a/src/correlations.jl b/src/correlations.jl index 38d7cb99..cbec3da9 100644 --- a/src/correlations.jl +++ b/src/correlations.jl @@ -49,8 +49,7 @@ function correlation_3op_2t( (H.dims == ψ0.dims && H.dims == A.dims && H.dims == B.dims && H.dims == C.dims) || throw(DimensionMismatch("The quantum objects are not of the same Hilbert dimension.")) - kwargs2 = (; kwargs...) - kwargs2 = merge(kwargs2, (saveat = collect(t_l),)) + kwargs2 = merge((saveat = collect(t_l),), (; kwargs...)) ρt = mesolve(H, ψ0, t_l, c_ops; kwargs2...).states corr = map((t, ρ) -> mesolve(H, C * ρ * A, τ_l .+ t, c_ops, e_ops = [B]; kwargs...).expect[1, :], t_l, ρt) diff --git a/src/qobj/eigsolve.jl b/src/qobj/eigsolve.jl index 63b4be99..47bbeecc 100644 --- a/src/qobj/eigsolve.jl +++ b/src/qobj/eigsolve.jl @@ -391,14 +391,15 @@ function eigsolve_al( kwargs..., ) where {DT1,HOpType<:Union{OperatorQuantumObject,SuperOperatorQuantumObject}} L_evo = _mesolve_make_L_QobjEvo(H, c_ops) - prob = mesolveProblem( - L_evo, - QuantumObject(ρ0, type = Operator, dims = H.dims), - [zero(T), T]; - params = params, - progress_bar = Val(false), - kwargs..., - ) + prob = + mesolveProblem( + L_evo, + QuantumObject(ρ0, type = Operator, dims = H.dims), + [zero(T), T]; + params = params, + progress_bar = Val(false), + kwargs..., + ).prob integrator = init(prob, alg) # prog = ProgressUnknown(desc="Applications:", showspeed = true, enabled=progress) diff --git a/src/qobj/quantum_object_evo.jl b/src/qobj/quantum_object_evo.jl index 8abadb3b..cec0010e 100644 --- a/src/qobj/quantum_object_evo.jl +++ b/src/qobj/quantum_object_evo.jl @@ -269,7 +269,7 @@ Parse the `op_func_list` and generate the data for the `QuantumObjectEvolution` quote dims = tuple($(dims_expr...)) - length(unique(dims)) == 1 || throw(ArgumentError("The dimensions of the operators must be the same.")) + allequal(dims) || throw(ArgumentError("The dimensions of the operators must be the same.")) data_expr_const = $qobj_expr_const isa Integer ? $qobj_expr_const : _make_SciMLOperator($qobj_expr_const, α) diff --git a/src/time_evolution/callback_helpers/callback_helpers.jl b/src/time_evolution/callback_helpers/callback_helpers.jl new file mode 100644 index 00000000..d1180afe --- /dev/null +++ b/src/time_evolution/callback_helpers/callback_helpers.jl @@ -0,0 +1,89 @@ +#= +This file contains helper functions for callbacks. The affect! function are defined taking advantage of the Julia struct, which allows to store some cache exclusively for the callback. +=# + +## + +# Multiple dispatch depending on the progress_bar and e_ops types +function _generate_se_me_kwargs(e_ops, progress_bar, tlist, kwargs, method) + cb = _generate_save_callback(e_ops, tlist, progress_bar, method) + return _merge_kwargs_with_callback(kwargs, cb) +end +_generate_se_me_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs, method) = kwargs + +function _merge_kwargs_with_callback(kwargs, cb) + kwargs2 = + haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb, kwargs.callback),)) : + merge(kwargs, (callback = cb,)) + + return kwargs2 +end + +function _generate_save_callback(e_ops, tlist, progress_bar, method) + e_ops_data = e_ops isa Nothing ? nothing : _get_e_ops_data(e_ops, method) + + progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing + + expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist)) + + _save_affect! = method(e_ops_data, progr, Ref(1), expvals) + return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)) +end + +_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops) +_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10 + +_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op))) + +## + +# When e_ops is Nothing. Common for both mesolve and sesolve +function _save_func(integrator, progr) + next!(progr) + u_modified!(integrator, false) + return nothing +end + +# When progr is Nothing. Common for both mesolve and sesolve +function _save_func(integrator, progr::Nothing) + u_modified!(integrator, false) + return nothing +end + +## + +# Get the e_ops from a given AbstractODESolution. Valid for `sesolve`, `mesolve` and `ssesolve`. +function _se_me_sse_get_expvals(sol::AbstractODESolution) + cb = _se_me_sse_get_save_callback(sol) + if cb isa Nothing + return nothing + else + return cb.affect!.expvals + end +end + +function _se_me_sse_get_save_callback(sol::AbstractODESolution) + kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl + if hasproperty(kwargs, :callback) + return _se_me_sse_get_save_callback(kwargs.callback) + else + return nothing + end +end +_se_me_sse_get_save_callback(integrator::AbstractODEIntegrator) = _se_me_sse_get_save_callback(integrator.opts.callback) +function _se_me_sse_get_save_callback(cb::CallbackSet) + cbs_discrete = cb.discrete_callbacks + if length(cbs_discrete) > 0 + _cb = cb.discrete_callbacks[1] + return _se_me_sse_get_save_callback(_cb) + else + return nothing + end +end +_se_me_sse_get_save_callback(cb::DiscreteCallback) = + if (cb.affect! isa SaveFuncSESolve) || (cb.affect! isa SaveFuncMESolve) + return cb + else + return nothing + end +_se_me_sse_get_save_callback(cb::ContinuousCallback) = nothing diff --git a/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl b/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl new file mode 100644 index 00000000..c27b8cbc --- /dev/null +++ b/src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl @@ -0,0 +1,340 @@ +#= +Helper functions for the mcsolve callbacks. +=# + +struct SaveFuncMCSolve{TE,IT,TEXPV} + e_ops::TE + iter::IT + expvals::TEXPV +end + +(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter, f.expvals) + +struct LindbladJump{ + T1, + T2, + RNGType<:AbstractRNG, + RandT, + CT<:AbstractVector, + WT<:AbstractVector, + JTT<:AbstractVector, + JWT<:AbstractVector, + JTWIT, +} + c_ops::T1 + c_ops_herm::T2 + traj_rng::RNGType + random_n::RandT + cache_mc::CT + weights_mc::WT + cumsum_weights_mc::WT + jump_times::JTT + jump_which::JWT + jump_times_which_idx::JTWIT +end + +(f::LindbladJump)(integrator) = _lindblad_jump_affect!( + integrator, + f.c_ops, + f.c_ops_herm, + f.traj_rng, + f.random_n, + f.cache_mc, + f.weights_mc, + f.cumsum_weights_mc, + f.jump_times, + f.jump_which, + f.jump_times_which_idx, +) + +## + +function _save_func_mcsolve(integrator, e_ops, iter, expvals) + cache_mc = _mc_get_jump_callback(integrator).affect!.cache_mc + + copyto!(cache_mc, integrator.u) + normalize!(cache_mc) + ψ = cache_mc + _expect = op -> dot(ψ, op, ψ) + @. expvals[:, iter[]] = _expect(e_ops) + iter[] += 1 + + u_modified!(integrator, false) + return nothing +end + +function _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rng, kwargs) + c_ops_data = get_data.(c_ops) + c_ops_herm_data = map(op -> op' * op, c_ops_data) + + cache_mc = similar(ψ0.data, T) + weights_mc = Vector{Float64}(undef, length(c_ops)) + cumsum_weights_mc = similar(weights_mc) + + jump_times = Vector{Float64}(undef, JUMP_TIMES_WHICH_INIT_SIZE) + jump_which = Vector{Int}(undef, JUMP_TIMES_WHICH_INIT_SIZE) + jump_times_which_idx = Ref(1) + + random_n = Ref(rand(rng)) + + _affect! = LindbladJump( + c_ops_data, + c_ops_herm_data, + rng, + random_n, + cache_mc, + weights_mc, + cumsum_weights_mc, + jump_times, + jump_which, + jump_times_which_idx, + ) + + if jump_callback isa DiscreteLindbladJumpCallback + cb1 = DiscreteCallback(_mcsolve_discrete_condition, _affect!, save_positions = (false, false)) + else + cb1 = ContinuousCallback( + _mcsolve_continuous_condition, + _affect!, + nothing, + interp_points = jump_callback.interp_points, + save_positions = (false, false), + ) + end + + if e_ops isa Nothing + # We are implicitly saying that we don't have a `ProgressBar` + kwargs2 = + haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, kwargs.callback),)) : + merge(kwargs, (callback = cb1,)) + return kwargs2 + else + expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist)) + + _save_affect! = SaveFuncMCSolve(get_data.(e_ops), Ref(1), expvals) + cb2 = PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)) + kwargs2 = + haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb1, cb2, kwargs.callback),)) : + merge(kwargs, (callback = CallbackSet(cb1, cb2),)) + return kwargs2 + end +end + +function _lindblad_jump_affect!( + integrator, + c_ops, + c_ops_herm, + traj_rng, + random_n, + cache_mc, + weights_mc, + cumsum_weights_mc, + jump_times, + jump_which, + jump_times_which_idx, +) + ψ = integrator.u + + @inbounds for i in eachindex(weights_mc) + weights_mc[i] = real(dot(ψ, c_ops_herm[i], ψ)) + end + cumsum!(cumsum_weights_mc, weights_mc) + r = rand(traj_rng) * sum(weights_mc) + collapse_idx = getindex(1:length(weights_mc), findfirst(>(r), cumsum_weights_mc)) + mul!(cache_mc, c_ops[collapse_idx], ψ) + normalize!(cache_mc) + copyto!(integrator.u, cache_mc) + + random_n[] = rand(traj_rng) + + idx = jump_times_which_idx[] + @inbounds jump_times[idx] = integrator.t + @inbounds jump_which[idx] = collapse_idx + jump_times_which_idx[] += 1 + if jump_times_which_idx[] > length(jump_times) + resize!(jump_times, length(jump_times) + JUMP_TIMES_WHICH_INIT_SIZE) + resize!(jump_which, length(jump_which) + JUMP_TIMES_WHICH_INIT_SIZE) + end + u_modified!(integrator, true) + return nothing +end + +_mcsolve_continuous_condition(u, t, integrator) = + @inbounds _mc_get_jump_callback(integrator).affect!.random_n[] - real(dot(u, u)) + +_mcsolve_discrete_condition(u, t, integrator) = + @inbounds real(dot(u, u)) < _mc_get_jump_callback(integrator).affect!.random_n[] + +## + +#= + _mc_get_save_callback + +Return the Callback that is responsible for saving the expectation values of the system. +=# +function _mc_get_save_callback(sol::AbstractODESolution) + kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl + return _mc_get_save_callback(kwargs.callback) # There is always the Jump callback +end +_mc_get_save_callback(integrator::AbstractODEIntegrator) = _mc_get_save_callback(integrator.opts.callback) +function _mc_get_save_callback(cb::CallbackSet) + cbs_discrete = cb.discrete_callbacks + + if length(cbs_discrete) > 0 + idx = _mcsolve_has_continuous_jump(cb) ? 1 : 2 + _cb = cb.discrete_callbacks[idx] + return _mc_get_save_callback(_cb) + else + return nothing + end +end +_mc_get_save_callback(cb::DiscreteCallback) = + if cb.affect! isa SaveFuncMCSolve + return cb + else + return nothing + end +_mc_get_save_callback(cb::ContinuousCallback) = nothing + +## + +function _mc_get_jump_callback(sol::AbstractODESolution) + kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl + return _mc_get_jump_callback(kwargs.callback) # There is always the Jump callback +end +_mc_get_jump_callback(integrator::AbstractODEIntegrator) = _mc_get_jump_callback(integrator.opts.callback) +_mc_get_jump_callback(cb::CallbackSet) = + if _mcsolve_has_continuous_jump(cb) + return cb.continuous_callbacks[1] + else + return cb.discrete_callbacks[1] + end +_mc_get_jump_callback(cb::ContinuousCallback) = cb +_mc_get_jump_callback(cb::DiscreteCallback) = cb + +## + +#= +With this function we extract the c_ops and c_ops_herm from the LindbladJump `affect!` function of the callback of the integrator. +This callback can be a DiscreteLindbladJumpCallback or a ContinuousLindbladJumpCallback. +=# +function _mcsolve_get_c_ops(integrator::AbstractODEIntegrator) + cb = _mc_get_jump_callback(integrator) + if cb isa Nothing + return nothing + else + return cb.affect!.c_ops, cb.affect!.c_ops_herm + end +end + +#= +With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator. +This callback can only be a PresetTimeCallback (DiscreteCallback). +=# +function _mcsolve_get_e_ops(integrator::AbstractODEIntegrator) + cb = _mc_get_save_callback(integrator) + if cb isa Nothing + return nothing + else + return cb.affect!.e_ops + end +end + +function _mcsolve_get_expvals(sol::AbstractODESolution) + cb = _mc_get_save_callback(sol) + if cb isa Nothing + return nothing + else + return cb.affect!.expvals + end +end + +#= + _mcsolve_initialize_callbacks(prob, tlist) + +Return the same callbacks of the `prob`, but with the `iter` variable reinitialized to 1 and the `expvals` variable reinitialized to a new matrix. +=# +function _mcsolve_initialize_callbacks(prob, tlist, traj_rng) + cb = prob.kwargs[:callback] + return _mcsolve_initialize_callbacks(cb, tlist, traj_rng) +end +function _mcsolve_initialize_callbacks(cb::CallbackSet, tlist, traj_rng) + cb_continuous = cb.continuous_callbacks + cb_discrete = cb.discrete_callbacks + + if _mcsolve_has_continuous_jump(cb) + idx = 1 + if cb_discrete[idx].affect! isa SaveFuncMCSolve + e_ops = cb_discrete[idx].affect!.e_ops + expvals = similar(cb_discrete[idx].affect!.expvals) + _save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals) + cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),) + else + cb_save = () + end + + _jump_affect! = _similar_affect!(cb_continuous[1].affect!, traj_rng) + cb_jump = _modify_field(cb_continuous[1], :affect!, _jump_affect!) + + return CallbackSet((cb_jump, cb_continuous[2:end]...), (cb_save..., cb_discrete[2:end]...)) + else + idx = 2 + if cb_discrete[idx].affect! isa SaveFuncMCSolve + e_ops = cb_discrete[idx].affect!.e_ops + expvals = similar(cb_discrete[idx].affect!.expvals) + _save_affect! = SaveFuncMCSolve(e_ops, Ref(1), expvals) + cb_save = (PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false)),) + else + cb_save = () + end + + _jump_affect! = _similar_affect!(cb_discrete[1].affect!, traj_rng) + cb_jump = _modify_field(cb_discrete[1], :affect!, _jump_affect!) + + return CallbackSet(cb_continuous, (cb_jump, cb_save..., cb_discrete[3:end]...)) + end +end +function _mcsolve_initialize_callbacks(cb::CBT, tlist, traj_rng) where {CBT<:Union{ContinuousCallback,DiscreteCallback}} + _jump_affect! = _similar_affect!(cb.affect!, traj_rng) + return _modify_field(cb, :affect!, _jump_affect!) +end + +#= + _similar_affect! + +Return a new LindbladJump with the same fields as the input LindbladJump but with new memory. +=# +function _similar_affect!(affect::LindbladJump, traj_rng) + random_n = Ref(rand(traj_rng)) + cache_mc = similar(affect.cache_mc) + weights_mc = similar(affect.weights_mc) + cumsum_weights_mc = similar(affect.cumsum_weights_mc) + jump_times = similar(affect.jump_times) + jump_which = similar(affect.jump_which) + jump_times_which_idx = Ref(1) + + return LindbladJump( + affect.c_ops, + affect.c_ops_herm, + traj_rng, + random_n, + cache_mc, + weights_mc, + cumsum_weights_mc, + jump_times, + jump_which, + jump_times_which_idx, + ) +end + +Base.@constprop :aggressive function _modify_field(obj::T, field_name::Symbol, field_val) where {T} + # Create a NamedTuple of fields, deepcopying only the selected ones + fields = (name != field_name ? (getfield(obj, name)) : field_val for name in fieldnames(T)) + # Reconstruct the struct with the updated fields + return Base.typename(T).wrapper(fields...) +end + +_mcsolve_has_continuous_jump(cb::CallbackSet) = + (length(cb.continuous_callbacks) > 0) && (cb.continuous_callbacks[1].affect! isa LindbladJump) +_mcsolve_has_continuous_jump(cb::ContinuousCallback) = true +_mcsolve_has_continuous_jump(cb::DiscreteCallback) = false diff --git a/src/time_evolution/callback_helpers/mesolve_callback_helpers.jl b/src/time_evolution/callback_helpers/mesolve_callback_helpers.jl new file mode 100644 index 00000000..449e2645 --- /dev/null +++ b/src/time_evolution/callback_helpers/mesolve_callback_helpers.jl @@ -0,0 +1,39 @@ +#= +Helper functions for the mesolve callbacks. +=# + +struct SaveFuncMESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}} + e_ops::TE + progr::PT + iter::IT + expvals::TEXPV +end + +(f::SaveFuncMESolve)(integrator) = _save_func_mesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals) +(f::SaveFuncMESolve{Nothing})(integrator) = _save_func(integrator, f.progr) + +## + +# When e_ops is a list of operators +function _save_func_mesolve(integrator, e_ops, progr, iter, expvals) + # This is equivalent to tr(op * ρ), when both are matrices. + # The advantage of using this convention is that We don't need + # to reshape u to make it a matrix, but we reshape the e_ops once. + + ρ = integrator.u + _expect = op -> dot(op, ρ) + @. expvals[:, iter[]] = _expect(e_ops) + iter[] += 1 + + return _save_func(integrator, progr) +end + +function _mesolve_callbacks_new_e_ops!(integrator::AbstractODEIntegrator, e_ops) + cb = _se_me_sse_get_save_callback(integrator) + if cb isa Nothing + return nothing + else + cb.affect!.e_ops .= e_ops # Only works if e_ops is a Vector of operators + return nothing + end +end diff --git a/src/time_evolution/callback_helpers/sesolve_callback_helpers.jl b/src/time_evolution/callback_helpers/sesolve_callback_helpers.jl new file mode 100644 index 00000000..54f0945f --- /dev/null +++ b/src/time_evolution/callback_helpers/sesolve_callback_helpers.jl @@ -0,0 +1,25 @@ +#= +Helper functions for the sesolve callbacks. +=# + +struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}} + e_ops::TE + progr::PT + iter::IT + expvals::TEXPV +end + +(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals) +(f::SaveFuncSESolve{Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve + +## + +# When e_ops is a list of operators +function _save_func_sesolve(integrator, e_ops, progr, iter, expvals) + ψ = integrator.u + _expect = op -> dot(ψ, op, ψ) + @. expvals[:, iter[]] = _expect(e_ops) + iter[] += 1 + + return _save_func(integrator, progr) +end diff --git a/src/time_evolution/mcsolve.jl b/src/time_evolution/mcsolve.jl index 0c58b042..db8ab031 100644 --- a/src/time_evolution/mcsolve.jl +++ b/src/time_evolution/mcsolve.jl @@ -1,128 +1,70 @@ export mcsolveProblem, mcsolveEnsembleProblem, mcsolve export ContinuousLindbladJumpCallback, DiscreteLindbladJumpCallback -function _save_func_mcsolve(integrator) - internal_params = integrator.p - progr = internal_params.progr_mc - - if !internal_params.is_empty_e_ops_mc - e_ops = internal_params.e_ops_mc - expvals = internal_params.expvals - cache_mc = internal_params.cache_mc - - copyto!(cache_mc, integrator.u) - normalize!(cache_mc) - ψ = cache_mc - _expect = op -> dot(ψ, op, ψ) - @. expvals[:, progr.counter[]+1] = _expect(e_ops) - end - next!(progr) - return u_modified!(integrator, false) -end - -function LindbladJumpAffect!(integrator) - internal_params = integrator.p - c_ops = internal_params.c_ops - c_ops_herm = internal_params.c_ops_herm - cache_mc = internal_params.cache_mc - weights_mc = internal_params.weights_mc - cumsum_weights_mc = internal_params.cumsum_weights_mc - random_n = internal_params.random_n - jump_times = internal_params.jump_times - jump_which = internal_params.jump_which - traj_rng = internal_params.traj_rng - ψ = integrator.u - - @inbounds for i in eachindex(weights_mc) - weights_mc[i] = real(dot(ψ, c_ops_herm[i], ψ)) - end - cumsum!(cumsum_weights_mc, weights_mc) - r = rand(traj_rng) * sum(weights_mc) - collapse_idx = getindex(1:length(weights_mc), findfirst(>(r), cumsum_weights_mc)) - mul!(cache_mc, c_ops[collapse_idx], ψ) - normalize!(cache_mc) - copyto!(integrator.u, cache_mc) - - random_n[] = rand(traj_rng) - jump_times[internal_params.jump_times_which_idx[]] = integrator.t - jump_which[internal_params.jump_times_which_idx[]] = collapse_idx - internal_params.jump_times_which_idx[] += 1 - if internal_params.jump_times_which_idx[] > length(jump_times) - resize!(jump_times, length(jump_times) + internal_params.jump_times_which_init_size) - resize!(jump_which, length(jump_which) + internal_params.jump_times_which_init_size) - end -end - -LindbladJumpContinuousCondition(u, t, integrator) = integrator.p.random_n[] - real(dot(u, u)) - -LindbladJumpDiscreteCondition(u, t, integrator) = real(dot(u, u)) < integrator.p.random_n[] - -function _mcsolve_prob_func(prob, i, repeat) - internal_params = prob.p - - global_rng = internal_params.global_rng - seed = internal_params.seeds[i] +function _mcsolve_prob_func(prob, i, repeat, global_rng, seeds, tlist) + seed = seeds[i] traj_rng = typeof(global_rng)() seed!(traj_rng, seed) - prm = merge( - internal_params, - ( - expvals = similar(internal_params.expvals), - cache_mc = similar(internal_params.cache_mc), - weights_mc = similar(internal_params.weights_mc), - cumsum_weights_mc = similar(internal_params.weights_mc), - traj_rng = traj_rng, - random_n = Ref(rand(traj_rng)), - progr_mc = ProgressBar(size(internal_params.expvals, 2), enable = false), - jump_times_which_idx = Ref(1), - jump_times = similar(internal_params.jump_times), - jump_which = similar(internal_params.jump_which), - ), - ) - f = deepcopy(prob.f.f) + cb = _mcsolve_initialize_callbacks(prob, tlist, traj_rng) + + return remake(prob, f = f, callback = cb) +end - return remake(prob, f = f, p = prm) +function _mcsolve_dispatch_prob_func(rng, ntraj, tlist) + seeds = map(i -> rand(rng, UInt64), 1:ntraj) + return (prob, i, repeat) -> _mcsolve_prob_func(prob, i, repeat, rng, seeds, tlist) end # Standard output function function _mcsolve_output_func(sol, i) - resize!(sol.prob.p.jump_times, sol.prob.p.jump_times_which_idx[] - 1) - resize!(sol.prob.p.jump_which, sol.prob.p.jump_times_which_idx[] - 1) + idx = _mc_get_jump_callback(sol).affect!.jump_times_which_idx[] + resize!(_mc_get_jump_callback(sol).affect!.jump_times, idx - 1) + resize!(_mc_get_jump_callback(sol).affect!.jump_which, idx - 1) return (sol, false) end # Output function with progress bar update -function _mcsolve_output_func_progress(sol, i) - next!(sol.prob.p.progr_trajectories) +function _mcsolve_output_func_progress(sol, i, progr) + next!(progr) return _mcsolve_output_func(sol, i) end # Output function with distributed channel update for progress bar -function _mcsolve_output_func_distributed(sol, i) - put!(sol.prob.p.progr_channel, true) +function _mcsolve_output_func_distributed(sol, i, channel) + put!(channel, true) return _mcsolve_output_func(sol, i) end -_mcsolve_dispatch_output_func() = _mcsolve_output_func -_mcsolve_dispatch_output_func(::ET) where {ET<:Union{EnsembleSerial,EnsembleThreads}} = _mcsolve_output_func_progress -_mcsolve_dispatch_output_func(::EnsembleDistributed) = _mcsolve_output_func_distributed - -function _normalize_state!(u, dims, normalize_states) - getVal(normalize_states) && normalize!(u) - return QuantumObject(u, dims = dims) +function _mcsolve_dispatch_output_func(::ET, progress_bar, ntraj) where {ET<:Union{EnsembleSerial,EnsembleThreads}} + if getVal(progress_bar) + progr = ProgressBar(ntraj, enable = getVal(progress_bar)) + f = (sol, i) -> _mcsolve_output_func_progress(sol, i, progr) + return (f, progr, nothing) + else + return (_mcsolve_output_func, nothing, nothing) + end end +function _mcsolve_dispatch_output_func( + ::ET, + progress_bar, + ntraj, +) where {ET<:Union{EnsembleSplitThreads,EnsembleDistributed}} + if getVal(progress_bar) + progr = ProgressBar(ntraj, enable = getVal(progress_bar)) + progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1)) -function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which, normalize_states) - sol_i = sol[:, i] - dims = sol_i.prob.p.Hdims - !isempty(sol_i.prob.kwargs[:saveat]) ? states[i] = map(u -> _normalize_state!(u, dims, normalize_states), sol_i.u) : - nothing + f = (sol, i) -> _mcsolve_output_func_distributed(sol, i, progr_channel) + return (f, progr, progr_channel) + else + return (_mcsolve_output_func, nothing, nothing) + end +end - copyto!(view(expvals_all, i, :, :), sol_i.prob.p.expvals) - jump_times[i] = sol_i.prob.p.jump_times - return jump_which[i] = sol_i.prob.p.jump_which +function _normalize_state!(u, dims, normalize_states) + getVal(normalize_states) && normalize!(u) + return QuantumObject(u, type = Ket, dims = dims) end function _mcsolve_make_Heff_QobjEvo(H::QuantumObject, c_ops) @@ -145,7 +87,7 @@ end tlist::AbstractVector, c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), rng::AbstractRNG = default_rng(), jump_callback::TJC = ContinuousLindbladJumpCallback(), kwargs..., @@ -192,7 +134,7 @@ If the environmental measurements register a quantum jump, the wave function und - `tlist`: List of times at which to save either the state or the expectation values of the system. - `c_ops`: List of collapse operators ``\{\hat{C}_n\}_n``. It can be either a `Vector` or a `Tuple`. - `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`. -- `params`: `NamedTuple` of parameters to pass to the solver. +- `params`: Parameters to pass to the solver. This argument is usually expressed as a `NamedTuple` or `AbstractVector` of parameters. For more advanced usage, any custom struct can be used. - `rng`: Random number generator for reproducibility. - `jump_callback`: The Jump Callback type: Discrete or Continuous. The default is `ContinuousLindbladJumpCallback()`, which is more precise. - `kwargs`: The keyword arguments for the ODEProblem. @@ -206,7 +148,7 @@ If the environmental measurements register a quantum jump, the wave function und # Returns -- `prob::ODEProblem`: The ODEProblem for the Monte Carlo wave function time evolution. +- `prob`: The [`TimeEvolutionProblem`](@ref) containing the `ODEProblem` for the Monte Carlo wave function time evolution. """ function mcsolveProblem( H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple}, @@ -214,7 +156,7 @@ function mcsolveProblem( tlist::AbstractVector, c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), rng::AbstractRNG = default_rng(), jump_callback::TJC = ContinuousLindbladJumpCallback(), kwargs..., @@ -229,93 +171,17 @@ function mcsolveProblem( H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, c_ops) - if e_ops isa Nothing - expvals = Array{ComplexF64}(undef, 0, length(tlist)) - is_empty_e_ops_mc = true - e_ops_data = () - else - expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist)) - e_ops_data = get_data.(e_ops) - is_empty_e_ops_mc = isempty(e_ops) - end + T = Base.promote_eltype(H_eff_evo, ψ0) - saveat = is_empty_e_ops_mc ? tlist : [tlist[end]] + is_empty_e_ops = e_ops isa Nothing ? true : isempty(e_ops) + + saveat = is_empty_e_ops ? tlist : [tlist[end]] # We disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat, progress_bar = Val(false)) kwargs2 = merge(default_values, kwargs) + kwargs3 = _generate_mcsolve_kwargs(ψ0, T, e_ops, tlist, c_ops, jump_callback, rng, kwargs2) - cache_mc = similar(ψ0.data) - weights_mc = Array{Float64}(undef, length(c_ops)) - cumsum_weights_mc = similar(weights_mc) - - jump_times_which_init_size = 200 - jump_times = Vector{Float64}(undef, jump_times_which_init_size) - jump_which = Vector{Int16}(undef, jump_times_which_init_size) - - c_ops_data = get_data.(c_ops) - c_ops_herm_data = map(op -> op' * op, c_ops_data) - - params2 = ( - expvals = expvals, - e_ops_mc = e_ops_data, - is_empty_e_ops_mc = is_empty_e_ops_mc, - progr_mc = ProgressBar(length(tlist), enable = false), - traj_rng = rng, - c_ops = c_ops_data, - c_ops_herm = c_ops_herm_data, - cache_mc = cache_mc, - weights_mc = weights_mc, - cumsum_weights_mc = cumsum_weights_mc, - jump_times = jump_times, - jump_which = jump_which, - jump_times_which_init_size = jump_times_which_init_size, - jump_times_which_idx = Ref(1), - params..., - ) - - return mcsolveProblem(H_eff_evo, ψ0, tlist, params2, jump_callback; kwargs2...) -end - -function mcsolveProblem( - H_eff_evo::QuantumObjectEvolution{DT1,OperatorQuantumObject}, - ψ0::QuantumObject{DT2,KetQuantumObject}, - tlist::AbstractVector, - params::NamedTuple, - jump_callback::DiscreteLindbladJumpCallback; - kwargs..., -) where {DT1,DT2} - cb1 = DiscreteCallback(LindbladJumpDiscreteCondition, LindbladJumpAffect!, save_positions = (false, false)) - cb2 = PresetTimeCallback(tlist, _save_func_mcsolve, save_positions = (false, false)) - kwargs2 = (; kwargs...) - kwargs2 = - haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(cb1, cb2, kwargs2.callback),)) : - merge(kwargs2, (callback = CallbackSet(cb1, cb2),)) - - return sesolveProblem(H_eff_evo, ψ0, tlist; params = params, kwargs2...) -end - -function mcsolveProblem( - H_eff_evo::QuantumObjectEvolution{DT1,OperatorQuantumObject}, - ψ0::QuantumObject{DT2,KetQuantumObject}, - tlist::AbstractVector, - params::NamedTuple, - jump_callback::ContinuousLindbladJumpCallback; - kwargs..., -) where {DT1,DT2} - cb1 = ContinuousCallback( - LindbladJumpContinuousCondition, - LindbladJumpAffect!, - nothing, - interp_points = jump_callback.interp_points, - save_positions = (false, false), - ) - cb2 = PresetTimeCallback(tlist, _save_func_mcsolve, save_positions = (false, false)) - kwargs2 = (; kwargs...) - kwargs2 = - haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(cb1, cb2, kwargs2.callback),)) : - merge(kwargs2, (callback = CallbackSet(cb1, cb2),)) - - return sesolveProblem(H_eff_evo, ψ0, tlist; params = params, kwargs2...) + return sesolveProblem(H_eff_evo, ψ0, tlist; params = params, kwargs3...) end @doc raw""" @@ -325,14 +191,14 @@ end tlist::AbstractVector, c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), rng::AbstractRNG = default_rng(), ntraj::Int = 1, ensemble_method = EnsembleThreads(), jump_callback::TJC = ContinuousLindbladJumpCallback(), - prob_func::Function = _mcsolve_prob_func, - output_func::Function = _mcsolve_dispatch_output_func(ensemble_method), progress_bar::Union{Val,Bool} = Val(true), + prob_func::Union{Function, Nothing} = nothing, + output_func::Union{Tuple,Nothing} = nothing, kwargs..., ) @@ -377,14 +243,14 @@ If the environmental measurements register a quantum jump, the wave function und - `tlist`: List of times at which to save either the state or the expectation values of the system. - `c_ops`: List of collapse operators ``\{\hat{C}_n\}_n``. It can be either a `Vector` or a `Tuple`. - `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`. -- `params`: `NamedTuple` of parameters to pass to the solver. +- `params`: Parameters to pass to the solver. This argument is usually expressed as a `NamedTuple` or `AbstractVector` of parameters. For more advanced usage, any custom struct can be used. - `rng`: Random number generator for reproducibility. - `ntraj`: Number of trajectories to use. - `ensemble_method`: Ensemble method to use. Default to `EnsembleThreads()`. - `jump_callback`: The Jump Callback type: Discrete or Continuous. The default is `ContinuousLindbladJumpCallback()`, which is more precise. -- `prob_func`: Function to use for generating the ODEProblem. -- `output_func`: Function to use for generating the output of a single trajectory. - `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities. +- `prob_func`: Function to use for generating the ODEProblem. +- `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object. - `kwargs`: The keyword arguments for the ODEProblem. # Notes @@ -396,7 +262,7 @@ If the environmental measurements register a quantum jump, the wave function und # Returns -- `prob::EnsembleProblem with ODEProblem`: The Ensemble ODEProblem for the Monte Carlo wave function time evolution. +- `prob`: The [`TimeEvolutionProblem`](@ref) containing the Ensemble `ODEProblem` for the Monte Carlo wave function time evolution. """ function mcsolveEnsembleProblem( H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple}, @@ -404,51 +270,40 @@ function mcsolveEnsembleProblem( tlist::AbstractVector, c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), rng::AbstractRNG = default_rng(), ntraj::Int = 1, ensemble_method = EnsembleThreads(), jump_callback::TJC = ContinuousLindbladJumpCallback(), - prob_func::Function = _mcsolve_prob_func, - output_func::Function = _mcsolve_dispatch_output_func(ensemble_method), progress_bar::Union{Val,Bool} = Val(true), + prob_func::Union{Function,Nothing} = nothing, + output_func::Union{Tuple,Nothing} = nothing, kwargs..., ) where {DT1,DT2,TJC<:LindbladJumpCallbackType} - progr = ProgressBar(ntraj, enable = getVal(progress_bar)) - if ensemble_method isa EnsembleDistributed - progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1)) - @async while take!(progr_channel) - next!(progr) - end - params = merge(params, (progr_channel = progr_channel,)) - else - params = merge(params, (progr_trajectories = progr,)) - end + _prob_func = prob_func isa Nothing ? _mcsolve_dispatch_prob_func(rng, ntraj, tlist) : prob_func + _output_func = + output_func isa Nothing ? _mcsolve_dispatch_output_func(ensemble_method, progress_bar, ntraj) : output_func - # Stop the async task if an error occurs - try - seeds = map(i -> rand(rng, UInt64), 1:ntraj) - prob_mc = mcsolveProblem( - H, - ψ0, - tlist, - c_ops; - e_ops = e_ops, - params = merge(params, (global_rng = rng, seeds = seeds)), - rng = rng, - jump_callback = jump_callback, - kwargs..., - ) - - ensemble_prob = EnsembleProblem(prob_mc, prob_func = prob_func, output_func = output_func, safetycopy = false) - - return ensemble_prob - catch e - if ensemble_method isa EnsembleDistributed - put!(progr_channel, false) - end - rethrow() - end + prob_mc = mcsolveProblem( + H, + ψ0, + tlist, + c_ops; + e_ops = e_ops, + params = params, + rng = rng, + jump_callback = jump_callback, + kwargs..., + ) + + ensemble_prob = TimeEvolutionProblem( + EnsembleProblem(prob_mc.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false), + prob_mc.times, + prob_mc.dims, + (progr = _output_func[2], channel = _output_func[3]), + ) + + return ensemble_prob end @doc raw""" @@ -459,14 +314,14 @@ end c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; alg::OrdinaryDiffEqAlgorithm = Tsit5(), e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), rng::AbstractRNG = default_rng(), ntraj::Int = 1, ensemble_method = EnsembleThreads(), jump_callback::TJC = ContinuousLindbladJumpCallback(), - prob_func::Function = _mcsolve_prob_func, - output_func::Function = _mcsolve_dispatch_output_func(ensemble_method), progress_bar::Union{Val,Bool} = Val(true), + prob_func::Union{Function, Nothing} = nothing, + output_func::Union{Tuple,Nothing} = nothing, normalize_states::Union{Val,Bool} = Val(true), kwargs..., ) @@ -513,14 +368,14 @@ If the environmental measurements register a quantum jump, the wave function und - `c_ops`: List of collapse operators ``\{\hat{C}_n\}_n``. It can be either a `Vector` or a `Tuple`. - `alg`: The algorithm to use for the ODE solver. Default to `Tsit5()`. - `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`. -- `params`: `NamedTuple` of parameters to pass to the solver. +- `params`: Parameters to pass to the solver. This argument is usually expressed as a `NamedTuple` or `AbstractVector` of parameters. For more advanced usage, any custom struct can be used. - `rng`: Random number generator for reproducibility. - `ntraj`: Number of trajectories to use. - `ensemble_method`: Ensemble method to use. Default to `EnsembleThreads()`. - `jump_callback`: The Jump Callback type: Discrete or Continuous. The default is `ContinuousLindbladJumpCallback()`, which is more precise. -- `prob_func`: Function to use for generating the ODEProblem. -- `output_func`: Function to use for generating the output of a single trajectory. - `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities. +- `prob_func`: Function to use for generating the ODEProblem. +- `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object. - `normalize_states`: Whether to normalize the states. Default to `Val(true)`. - `kwargs`: The keyword arguments for the ODEProblem. @@ -544,14 +399,14 @@ function mcsolve( c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; alg::OrdinaryDiffEqAlgorithm = Tsit5(), e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), rng::AbstractRNG = default_rng(), ntraj::Int = 1, ensemble_method = EnsembleThreads(), jump_callback::TJC = ContinuousLindbladJumpCallback(), - prob_func::Function = _mcsolve_prob_func, - output_func::Function = _mcsolve_dispatch_output_func(ensemble_method), progress_bar::Union{Val,Bool} = Val(true), + prob_func::Union{Function,Nothing} = nothing, + output_func::Union{Tuple,Nothing} = nothing, normalize_states::Union{Val,Bool} = Val(true), kwargs..., ) where {DT1,DT2,TJC<:LindbladJumpCallbackType} @@ -567,67 +422,79 @@ function mcsolve( ntraj = ntraj, ensemble_method = ensemble_method, jump_callback = jump_callback, + progress_bar = progress_bar, prob_func = prob_func, output_func = output_func, - progress_bar = progress_bar, kwargs..., ) - return mcsolve( - ens_prob_mc; - alg = alg, - ntraj = ntraj, - ensemble_method = ensemble_method, - normalize_states = normalize_states, - ) + return mcsolve(ens_prob_mc, alg, ntraj, ensemble_method, normalize_states) +end + +function _mcsolve_solve_ens( + ens_prob_mc::TimeEvolutionProblem, + alg::OrdinaryDiffEqAlgorithm, + ensemble_method::ET, + ntraj::Int, +) where {ET<:Union{EnsembleSplitThreads,EnsembleDistributed}} + sol = nothing + + @sync begin + @async while take!(ens_prob_mc.kwargs.channel) + next!(ens_prob_mc.kwargs.progr) + end + + @async begin + sol = solve(ens_prob_mc.prob, alg, ensemble_method, trajectories = ntraj) + put!(ens_prob_mc.kwargs.channel, false) + end + end + + return sol +end + +function _mcsolve_solve_ens( + ens_prob_mc::TimeEvolutionProblem, + alg::OrdinaryDiffEqAlgorithm, + ensemble_method, + ntraj::Int, +) + sol = solve(ens_prob_mc.prob, alg, ensemble_method, trajectories = ntraj) + return sol end function mcsolve( - ens_prob_mc::EnsembleProblem; + ens_prob_mc::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5(), ntraj::Int = 1, ensemble_method = EnsembleThreads(), - normalize_states::Union{Val,Bool} = Val(true), + normalize_states = Val(true), ) - try - sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj) - - if ensemble_method isa EnsembleDistributed - put!(sol[:, 1].prob.p.progr_channel, false) - end - - _sol_1 = sol[:, 1] - - expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...) - states = - isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) : - Vector{Vector{QuantumObject}}(undef, length(sol)) - jump_times = Vector{Vector{Float64}}(undef, length(sol)) - jump_which = Vector{Vector{Int16}}(undef, length(sol)) - - foreach( - i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which, normalize_states), - eachindex(sol), - ) - expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol) - - return TimeEvolutionMCSol( - ntraj, - _sol_1.prob.p.times, - states, - expvals, - expvals_all, - jump_times, - jump_which, - sol.converged, - _sol_1.alg, - _sol_1.prob.kwargs[:abstol], - _sol_1.prob.kwargs[:reltol], - ) - catch e - if ensemble_method isa EnsembleDistributed - put!(ens_prob_mc.prob.p.progr_channel, false) - end - rethrow() - end + sol = _mcsolve_solve_ens(ens_prob_mc, alg, ensemble_method, ntraj) + + dims = ens_prob_mc.dims + _sol_1 = sol[:, 1] + _expvals_sol_1 = _mcsolve_get_expvals(_sol_1) + + _expvals_all = _expvals_sol_1 isa Nothing ? nothing : map(i -> _mcsolve_get_expvals(sol[:, i]), eachindex(sol)) + expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all) + states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)) + jump_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.jump_times, eachindex(sol)) + jump_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.jump_which, eachindex(sol)) + + expvals = _expvals_sol_1 isa Nothing ? nothing : dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol) + + return TimeEvolutionMCSol( + ntraj, + ens_prob_mc.times, + states, + expvals, + expvals_all, + jump_times, + jump_which, + sol.converged, + _sol_1.alg, + NamedTuple(_sol_1.prob.kwargs).abstol, + NamedTuple(_sol_1.prob.kwargs).reltol, + ) end diff --git a/src/time_evolution/mesolve.jl b/src/time_evolution/mesolve.jl index 76714c57..2d9f42ef 100644 --- a/src/time_evolution/mesolve.jl +++ b/src/time_evolution/mesolve.jl @@ -1,46 +1,5 @@ export mesolveProblem, mesolve -function _save_func_mesolve(integrator) - internal_params = integrator.p - progr = internal_params.progr - - if !internal_params.is_empty_e_ops - expvals = internal_params.expvals - e_ops = internal_params.e_ops - # This is equivalent to tr(op * ρ), when both are matrices. - # The advantage of using this convention is that I don't need - # to reshape u to make it a matrix, but I reshape the e_ops once. - - ρ = integrator.u - _expect = op -> dot(op, ρ) - @. expvals[:, progr.counter[]+1] = _expect(e_ops) - end - next!(progr) - return u_modified!(integrator, false) -end - -_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op))) - -function _generate_mesolve_kwargs_with_callback(tlist, kwargs) - cb1 = PresetTimeCallback(tlist, _save_func_mesolve, save_positions = (false, false)) - kwargs2 = - haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb1),)) : - merge(kwargs, (callback = cb1,)) - - return kwargs2 -end - -function _generate_mesolve_kwargs(e_ops, progress_bar::Val{true}, tlist, kwargs) - return _generate_mesolve_kwargs_with_callback(tlist, kwargs) -end - -function _generate_mesolve_kwargs(e_ops, progress_bar::Val{false}, tlist, kwargs) - if e_ops isa Nothing - return kwargs - end - return _generate_mesolve_kwargs_with_callback(tlist, kwargs) -end - _mesolve_make_L_QobjEvo(H::QuantumObject, c_ops) = QobjEvo(liouvillian(H, c_ops); type = SuperOperator) _mesolve_make_L_QobjEvo(H::Union{QuantumObjectEvolution,Tuple}, c_ops) = liouvillian(QobjEvo(H), c_ops) @@ -51,8 +10,9 @@ _mesolve_make_L_QobjEvo(H::Union{QuantumObjectEvolution,Tuple}, c_ops) = liouvil tlist, c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), + inplace::Union{Val,Bool} = Val(true), kwargs..., ) @@ -75,8 +35,9 @@ where - `tlist`: List of times at which to save either the state or the expectation values of the system. - `c_ops`: List of collapse operators ``\{\hat{C}_n\}_n``. It can be either a `Vector` or a `Tuple`. - `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`. -- `params`: `NamedTuple` of parameters to pass to the solver. +- `params`: Parameters to pass to the solver. This argument is usually expressed as a `NamedTuple` or `AbstractVector` of parameters. For more advanced usage, any custom struct can be used. - `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities. +- `inplace`: Whether to use the inplace version of the ODEProblem. The default is `Val(true)`. It is recommended to use `Val(true)` for better performance, but it is sometimes necessary to use `Val(false)`, for example when performing automatic differentiation using [Zygote.jl](https://github.com/FluxML/Zygote.jl). - `kwargs`: The keyword arguments for the ODEProblem. # Notes @@ -96,8 +57,9 @@ function mesolveProblem( tlist, c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), + inplace::Union{Val,Bool} = Val(true), kwargs..., ) where { DT1, @@ -113,38 +75,21 @@ function mesolveProblem( L_evo = _mesolve_make_L_QobjEvo(H, c_ops) check_dims(L_evo, ψ0) - ρ0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type + T = Base.promote_eltype(L_evo, ψ0) + ρ0 = sparse_to_dense(_CType(T), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type L = L_evo.data - progr = ProgressBar(length(tlist), enable = getVal(progress_bar)) - - if e_ops isa Nothing - expvals = Array{ComplexF64}(undef, 0, length(tlist)) - e_ops_data = () - is_empty_e_ops = true - else - expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist)) - e_ops_data = [_generate_mesolve_e_op(op) for op in e_ops] - is_empty_e_ops = isempty(e_ops) - end - - p = ( - e_ops = e_ops_data, - expvals = expvals, - progr = progr, - times = tlist, - Hdims = L_evo.dims, - is_empty_e_ops = is_empty_e_ops, - params..., - ) + is_empty_e_ops = (e_ops isa Nothing) ? true : isempty(e_ops) saveat = is_empty_e_ops ? tlist : [tlist[end]] default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat) kwargs2 = merge(default_values, kwargs) - kwargs3 = _generate_mesolve_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2) + kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncMESolve) tspan = (tlist[1], tlist[end]) - return ODEProblem{true,FullSpecialize}(L, ρ0, tspan, p; kwargs3...) + prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs3...) + + return TimeEvolutionProblem(prob, tlist, L_evo.dims) end @doc raw""" @@ -155,8 +100,9 @@ end c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; alg::OrdinaryDiffEqAlgorithm = Tsit5(), e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), + inplace::Union{Val,Bool} = Val(true), kwargs..., ) @@ -180,8 +126,9 @@ where - `c_ops`: List of collapse operators ``\{\hat{C}_n\}_n``. It can be either a `Vector` or a `Tuple`. - `alg`: The algorithm for the ODE solver. The default value is `Tsit5()`. - `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`. -- `params`: `NamedTuple` of parameters to pass to the solver. +- `params`: Parameters to pass to the solver. This argument is usually expressed as a `NamedTuple` or `AbstractVector` of parameters. For more advanced usage, any custom struct can be used. - `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities. +- `inplace`: Whether to use the inplace version of the ODEProblem. The default is `Val(true)`. It is recommended to use `Val(true)` for better performance, but it is sometimes necessary to use `Val(false)`, for example when performing automatic differentiation using [Zygote.jl](https://github.com/FluxML/Zygote.jl). - `kwargs`: The keyword arguments for the ODEProblem. # Notes @@ -203,8 +150,9 @@ function mesolve( c_ops::Union{Nothing,AbstractVector,Tuple} = nothing; alg::OrdinaryDiffEqAlgorithm = Tsit5(), e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), + inplace::Union{Val,Bool} = Val(true), kwargs..., ) where { DT1, @@ -221,24 +169,25 @@ function mesolve( e_ops = e_ops, params = params, progress_bar = progress_bar, + inplace = inplace, kwargs..., ) return mesolve(prob, alg) end -function mesolve(prob::ODEProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5()) - sol = solve(prob, alg) +function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5()) + sol = solve(prob.prob, alg) - ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = Operator, dims = sol.prob.p.Hdims), sol.u) + ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = Operator, dims = prob.dims), sol.u) return TimeEvolutionSol( - sol.prob.p.times, + prob.times, ρt, - sol.prob.p.expvals, + _se_me_sse_get_expvals(sol), sol.retcode, sol.alg, - sol.prob.kwargs[:abstol], - sol.prob.kwargs[:reltol], + NamedTuple(sol.prob.kwargs).abstol, + NamedTuple(sol.prob.kwargs).reltol, ) end diff --git a/src/time_evolution/sesolve.jl b/src/time_evolution/sesolve.jl index a7218d98..0c5a3305 100644 --- a/src/time_evolution/sesolve.jl +++ b/src/time_evolution/sesolve.jl @@ -1,41 +1,5 @@ export sesolveProblem, sesolve -function _save_func_sesolve(integrator) - internal_params = integrator.p - progr = internal_params.progr - - if !internal_params.is_empty_e_ops - e_ops = internal_params.e_ops - expvals = internal_params.expvals - - ψ = integrator.u - _expect = op -> dot(ψ, op, ψ) - @. expvals[:, progr.counter[]+1] = _expect(e_ops) - end - next!(progr) - return u_modified!(integrator, false) -end - -function _generate_sesolve_kwargs_with_callback(tlist, kwargs) - cb1 = PresetTimeCallback(tlist, _save_func_sesolve, save_positions = (false, false)) - kwargs2 = - haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb1),)) : - merge(kwargs, (callback = cb1,)) - - return kwargs2 -end - -function _generate_sesolve_kwargs(e_ops, progress_bar::Val{true}, tlist, kwargs) - return _generate_sesolve_kwargs_with_callback(tlist, kwargs) -end - -function _generate_sesolve_kwargs(e_ops, progress_bar::Val{false}, tlist, kwargs) - if e_ops isa Nothing - return kwargs - end - return _generate_sesolve_kwargs_with_callback(tlist, kwargs) -end - _sesolve_make_U_QobjEvo(H::QuantumObjectEvolution{<:MatrixOperator}) = QobjEvo(MatrixOperator(-1im * H.data.A), dims = H.dims, type = Operator) _sesolve_make_U_QobjEvo(H) = QobjEvo(H, -1im) @@ -46,8 +10,9 @@ _sesolve_make_U_QobjEvo(H) = QobjEvo(H, -1im) ψ0::QuantumObject{DT2,KetQuantumObject}, tlist::AbstractVector; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), + inplace::Union{Val,Bool} = Val(true), kwargs..., ) @@ -63,8 +28,9 @@ Generate the ODEProblem for the Schrödinger time evolution of a quantum system: - `ψ0`: Initial state of the system ``|\psi(0)\rangle``. - `tlist`: List of times at which to save either the state or the expectation values of the system. - `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`. -- `params`: `NamedTuple` of parameters to pass to the solver. +- `params`: Parameters to pass to the solver. This argument is usually expressed as a `NamedTuple` or `AbstractVector` of parameters. For more advanced usage, any custom struct can be used. - `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities. +- `inplace`: Whether to use the inplace version of the ODEProblem. The default is `Val(true)`. It is recommended to use `Val(true)` for better performance, but it is sometimes necessary to use `Val(false)`, for example when performing automatic differentiation using [Zygote.jl](https://github.com/FluxML/Zygote.jl). - `kwargs`: The keyword arguments for the ODEProblem. # Notes @@ -76,15 +42,16 @@ Generate the ODEProblem for the Schrödinger time evolution of a quantum system: # Returns -- `prob`: The `ODEProblem` for the Schrödinger time evolution of the system. +- `prob`: The [`TimeEvolutionProblem`](@ref) containing the `ODEProblem` for the Schrödinger time evolution of the system. """ function sesolveProblem( H::Union{AbstractQuantumObject{DT1,OperatorQuantumObject},Tuple}, ψ0::QuantumObject{DT2,KetQuantumObject}, tlist::AbstractVector; e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), + inplace::Union{Val,Bool} = Val(true), kwargs..., ) where {DT1,DT2} haskey(kwargs, :save_idxs) && @@ -96,38 +63,21 @@ function sesolveProblem( isoper(H_evo) || throw(ArgumentError("The Hamiltonian must be an Operator.")) check_dims(H_evo, ψ0) - ψ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type + T = Base.promote_eltype(H_evo, ψ0) + ψ0 = sparse_to_dense(_CType(T), get_data(ψ0)) # Convert it to dense vector with complex element type U = H_evo.data - progr = ProgressBar(length(tlist), enable = getVal(progress_bar)) - - if e_ops isa Nothing - expvals = Array{ComplexF64}(undef, 0, length(tlist)) - e_ops_data = () - is_empty_e_ops = true - else - expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist)) - e_ops_data = get_data.(e_ops) - is_empty_e_ops = isempty(e_ops) - end - - p = ( - e_ops = e_ops_data, - expvals = expvals, - progr = progr, - times = tlist, - Hdims = H_evo.dims, - is_empty_e_ops = is_empty_e_ops, - params..., - ) + is_empty_e_ops = (e_ops isa Nothing) ? true : isempty(e_ops) saveat = is_empty_e_ops ? tlist : [tlist[end]] default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat) kwargs2 = merge(default_values, kwargs) - kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2) + kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncSESolve) tspan = (tlist[1], tlist[end]) - return ODEProblem{true,FullSpecialize}(U, ψ0, tspan, p; kwargs3...) + prob = ODEProblem{getVal(inplace),FullSpecialize}(U, ψ0, tspan, params; kwargs3...) + + return TimeEvolutionProblem(prob, tlist, H_evo.dims) end @doc raw""" @@ -137,8 +87,9 @@ end tlist::AbstractVector; alg::OrdinaryDiffEqAlgorithm = Tsit5(), e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), + inplace::Union{Val,Bool} = Val(true), kwargs..., ) @@ -155,8 +106,9 @@ Time evolution of a closed quantum system using the Schrödinger equation: - `tlist`: List of times at which to save either the state or the expectation values of the system. - `alg`: The algorithm for the ODE solver. The default is `Tsit5()`. - `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`. -- `params`: `NamedTuple` of parameters to pass to the solver. +- `params`: Parameters to pass to the solver. This argument is usually expressed as a `NamedTuple` or `AbstractVector` of parameters. For more advanced usage, any custom struct can be used. - `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities. +- `inplace`: Whether to use the inplace version of the ODEProblem. The default is `Val(true)`. It is recommended to use `Val(true)` for better performance, but it is sometimes necessary to use `Val(false)`, for example when performing automatic differentiation using [Zygote.jl](https://github.com/FluxML/Zygote.jl). - `kwargs`: The keyword arguments for the ODEProblem. # Notes @@ -177,27 +129,37 @@ function sesolve( tlist::AbstractVector; alg::OrdinaryDiffEqAlgorithm = Tsit5(), e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, - params::NamedTuple = NamedTuple(), + params = NullParameters(), progress_bar::Union{Val,Bool} = Val(true), + inplace::Union{Val,Bool} = Val(true), kwargs..., ) where {DT1,DT2} - prob = sesolveProblem(H, ψ0, tlist; e_ops = e_ops, params = params, progress_bar = progress_bar, kwargs...) + prob = sesolveProblem( + H, + ψ0, + tlist; + e_ops = e_ops, + params = params, + progress_bar = progress_bar, + inplace = inplace, + kwargs..., + ) return sesolve(prob, alg) end -function sesolve(prob::ODEProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5()) - sol = solve(prob, alg) +function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5()) + sol = solve(prob.prob, alg) - ψt = map(ϕ -> QuantumObject(ϕ, type = Ket, dims = sol.prob.p.Hdims), sol.u) + ψt = map(ϕ -> QuantumObject(ϕ, type = Ket, dims = prob.dims), sol.u) return TimeEvolutionSol( - sol.prob.p.times, + prob.times, ψt, - sol.prob.p.expvals, + _se_me_sse_get_expvals(sol), sol.retcode, sol.alg, - sol.prob.kwargs[:abstol], - sol.prob.kwargs[:reltol], + NamedTuple(sol.prob.kwargs).abstol, + NamedTuple(sol.prob.kwargs).reltol, ) end diff --git a/src/time_evolution/ssesolve.jl b/src/time_evolution/ssesolve.jl index 3d126b9f..2dc7fac1 100644 --- a/src/time_evolution/ssesolve.jl +++ b/src/time_evolution/ssesolve.jl @@ -83,15 +83,6 @@ _ssesolve_dispatch_output_func() = _ssesolve_output_func _ssesolve_dispatch_output_func(::ET) where {ET<:Union{EnsembleSerial,EnsembleThreads}} = _ssesolve_output_func_progress _ssesolve_dispatch_output_func(::EnsembleDistributed) = _ssesolve_output_func_distributed -function _ssesolve_generate_statistics!(sol, i, states, expvals_all) - sol_i = sol[:, i] - !isempty(sol_i.prob.kwargs[:saveat]) ? - states[i] = [QuantumObject(sol_i.u[i], dims = sol_i.prob.p.Hdims) for i in 1:length(sol_i.u)] : nothing - - copyto!(view(expvals_all, i, :, :), sol_i.prob.p.expvals) - return nothing -end - _ScalarOperator_e(op, f = +) = ScalarOperator(one(eltype(op)), (a, u, p, t) -> f(_ssesolve_update_coeff(u, p, t, op))) _ScalarOperator_e2_2(op, f = +) = @@ -182,16 +173,6 @@ function ssesolveProblem( progr = ProgressBar(length(tlist), enable = getVal(progress_bar)) - if e_ops isa Nothing - expvals = Array{ComplexF64}(undef, 0, length(tlist)) - e_ops_data = () - is_empty_e_ops = true - else - expvals = Array{ComplexF64}(undef, length(e_ops), length(tlist)) - e_ops_data = get_data.(e_ops) - is_empty_e_ops = isempty(e_ops) - end - sc_ops_evo_data = Tuple(map(get_data ∘ QobjEvo, sc_ops)) # Here the coefficients depend on the state, so this is a non-linear operator, which should be implemented with FunctionOperator instead. However, the nonlinearity is only on the coefficients, and it should be safe. @@ -205,21 +186,14 @@ function ssesolveProblem( D_l = map(op -> op + _ScalarOperator_e(op, -) * IdentityOperator(prod(dims)), sc_ops_evo_data) D = DiffusionOperator(D_l) - p = ( - e_ops = e_ops_data, - expvals = expvals, - progr = progr, - times = tlist, - Hdims = dims, - is_empty_e_ops = is_empty_e_ops, - n_sc_ops = length(sc_ops), - params..., - ) + p = (progr = progr, times = tlist, Hdims = dims, n_sc_ops = length(sc_ops), params...) + + is_empty_e_ops = (e_ops isa Nothing) ? true : isempty(e_ops) saveat = is_empty_e_ops ? tlist : [tlist[end]] default_values = (DEFAULT_SDE_SOLVER_OPTIONS..., saveat = saveat) kwargs2 = merge(default_values, kwargs) - kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2) + kwargs3 = _generate_se_me_kwargs(e_ops, makeVal(progress_bar), tlist, kwargs2, SaveFuncSESolve) tspan = (tlist[1], tlist[end]) noise = @@ -469,14 +443,18 @@ function ssesolve( end _sol_1 = sol[:, 1] - - expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...) - states = - isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) : - Vector{Vector{QuantumObject}}(undef, length(sol)) - - foreach(i -> _ssesolve_generate_statistics!(sol, i, states, expvals_all), eachindex(sol)) - expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol) + _expvals_sol_1 = _se_me_sse_get_expvals(_sol_1) + + normalize_states = Val(false) + dims = _sol_1.prob.p.Hdims + _expvals_all = + _expvals_sol_1 isa Nothing ? nothing : map(i -> _se_me_sse_get_expvals(sol[:, i]), eachindex(sol)) + expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all) + states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)) + + expvals = + _se_me_sse_get_expvals(_sol_1) isa Nothing ? nothing : + dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol) return TimeEvolutionSSESol( ntraj, diff --git a/src/time_evolution/time_evolution.jl b/src/time_evolution/time_evolution.jl index 9297cbed..b0f701f3 100644 --- a/src/time_evolution/time_evolution.jl +++ b/src/time_evolution/time_evolution.jl @@ -4,6 +4,28 @@ export liouvillian_floquet, liouvillian_generalized const DEFAULT_ODE_SOLVER_OPTIONS = (abstol = 1e-8, reltol = 1e-6, save_everystep = false, save_end = true) const DEFAULT_SDE_SOLVER_OPTIONS = (abstol = 1e-2, reltol = 1e-2, save_everystep = false, save_end = true) +const JUMP_TIMES_WHICH_INIT_SIZE = 200 + +@doc raw""" + struct TimeEvolutionProblem + +A Julia constructor for handling the `ODEProblem` of the time evolution of quantum systems. + +# Fields (Attributes) + +- `prob::AbstractSciMLProblem`: The `ODEProblem` of the time evolution. +- `times::Abstractvector`: The time list of the evolution. +- `dims::Abstractvector`: The dimensions of the Hilbert space. +- `kwargs::KWT`: Generic keyword arguments. +""" +struct TimeEvolutionProblem{PT<:AbstractSciMLProblem,TT<:AbstractVector,DT<:AbstractVector,KWT} + prob::PT + times::TT + dims::DT + kwargs::KWT +end + +TimeEvolutionProblem(prob, times, dims) = TimeEvolutionProblem(prob, times, dims, nothing) @doc raw""" struct TimeEvolutionSol @@ -14,7 +36,7 @@ A structure storing the results and some information from solving time evolution - `times::AbstractVector`: The time list of the evolution. - `states::Vector{QuantumObject}`: The list of result states. -- `expect::Matrix`: The expectation values corresponding to each time point in `times`. +- `expect::Union{AbstractMatrix,Nothing}`: The expectation values corresponding to each time point in `times`. - `retcode`: The return code from the solver. - `alg`: The algorithm which is used during the solving process. - `abstol::Real`: The absolute tolerance which is used during the solving process. @@ -23,7 +45,7 @@ A structure storing the results and some information from solving time evolution struct TimeEvolutionSol{ TT<:AbstractVector{<:Real}, TS<:AbstractVector, - TE<:Matrix, + TE<:Union{AbstractMatrix,Nothing}, RETT<:Enum, AlgT<:OrdinaryDiffEqAlgorithm, AT<:Real, @@ -43,7 +65,11 @@ function Base.show(io::IO, sol::TimeEvolutionSol) print(io, "(return code: $(sol.retcode))\n") print(io, "--------------------------\n") print(io, "num_states = $(length(sol.states))\n") - print(io, "num_expect = $(size(sol.expect, 1))\n") + if sol.expect isa Nothing + print(io, "num_expect = 0\n") + else + print(io, "num_expect = $(size(sol.expect, 1))\n") + end print(io, "ODE alg.: $(sol.alg)\n") print(io, "abstol = $(sol.abstol)\n") print(io, "reltol = $(sol.reltol)\n") @@ -60,8 +86,8 @@ A structure storing the results and some information from solving quantum trajec - `ntraj::Int`: Number of trajectories - `times::AbstractVector`: The time list of the evolution. - `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory. -- `expect::Matrix`: The expectation values (averaging all trajectories) corresponding to each time point in `times`. -- `expect_all::Array`: The expectation values corresponding to each trajectory and each time point in `times` +- `expect::Union{AbstractMatrix,Nothing}`: The expectation values (averaging all trajectories) corresponding to each time point in `times`. +- `expect_all::Union{AbstractMatrix,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times` - `jump_times::Vector{Vector{Real}}`: The time records of every quantum jump occurred in each trajectory. - `jump_which::Vector{Vector{Int}}`: The indices of the jump operators in `c_ops` that describe the corresponding quantum jumps occurred in each trajectory. - `converged::Bool`: Whether the solution is converged or not. @@ -72,8 +98,8 @@ A structure storing the results and some information from solving quantum trajec struct TimeEvolutionMCSol{ TT<:AbstractVector{<:Real}, TS<:AbstractVector, - TE<:Matrix{ComplexF64}, - TEA<:Array{ComplexF64,3}, + TE<:Union{AbstractMatrix,Nothing}, + TEA<:Union{AbstractArray,Nothing}, TJT<:Vector{<:Vector{<:Real}}, TJW<:Vector{<:Vector{<:Integer}}, AlgT<:OrdinaryDiffEqAlgorithm, @@ -99,7 +125,11 @@ function Base.show(io::IO, sol::TimeEvolutionMCSol) print(io, "--------------------------------\n") print(io, "num_trajectories = $(sol.ntraj)\n") print(io, "num_states = $(length(sol.states[1]))\n") - print(io, "num_expect = $(size(sol.expect, 1))\n") + if sol.expect isa Nothing + print(io, "num_expect = 0\n") + else + print(io, "num_expect = $(size(sol.expect, 1))\n") + end print(io, "ODE alg.: $(sol.alg)\n") print(io, "abstol = $(sol.abstol)\n") print(io, "reltol = $(sol.reltol)\n") @@ -116,8 +146,8 @@ A structure storing the results and some information from solving trajectories o - `ntraj::Int`: Number of trajectories - `times::AbstractVector`: The time list of the evolution. - `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory. -- `expect::Matrix`: The expectation values (averaging all trajectories) corresponding to each time point in `times`. -- `expect_all::Array`: The expectation values corresponding to each trajectory and each time point in `times` +- `expect::Union{AbstractMatrix,Nothing}`: The expectation values (averaging all trajectories) corresponding to each time point in `times`. +- `expect_all::Union{AbstractArray,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times` - `converged::Bool`: Whether the solution is converged or not. - `alg`: The algorithm which is used during the solving process. - `abstol::Real`: The absolute tolerance which is used during the solving process. @@ -126,8 +156,8 @@ A structure storing the results and some information from solving trajectories o struct TimeEvolutionSSESol{ TT<:AbstractVector{<:Real}, TS<:AbstractVector, - TE<:Matrix{ComplexF64}, - TEA<:Array{ComplexF64,3}, + TE<:Union{AbstractMatrix,Nothing}, + TEA<:Union{AbstractArray,Nothing}, AlgT<:StochasticDiffEqAlgorithm, AT<:Real, RT<:Real, @@ -149,7 +179,11 @@ function Base.show(io::IO, sol::TimeEvolutionSSESol) print(io, "--------------------------------\n") print(io, "num_trajectories = $(sol.ntraj)\n") print(io, "num_states = $(length(sol.states[1]))\n") - print(io, "num_expect = $(size(sol.expect, 1))\n") + if sol.expect isa Nothing + print(io, "num_expect = 0\n") + else + print(io, "num_expect = $(size(sol.expect, 1))\n") + end print(io, "SDE alg.: $(sol.alg)\n") print(io, "abstol = $(sol.abstol)\n") print(io, "reltol = $(sol.reltol)\n") diff --git a/src/time_evolution/time_evolution_dynamical.jl b/src/time_evolution/time_evolution_dynamical.jl index 270168df..650f406e 100644 --- a/src/time_evolution/time_evolution_dynamical.jl +++ b/src/time_evolution/time_evolution_dynamical.jl @@ -131,7 +131,8 @@ function _DFDIncreaseReduceAffect!(integrator) copyto!(integrator.u, mat2vec(ρt)) # By doing this, we are assuming that the system is time-independent and f is a MatrixOperator integrator.f = ODEFunction{true,FullSpecialize}(MatrixOperator(L)) - integrator.p = merge(internal_params, (e_ops = e_ops2, dfd_ρt_cache = similar(integrator.u))) + integrator.p = merge(internal_params, (dfd_ρt_cache = similar(integrator.u),)) + _mesolve_callbacks_new_e_ops!(integrator, e_ops2) return nothing end @@ -232,7 +233,7 @@ function dfd_mesolve( kwargs..., ) - sol = solve(dfd_prob, alg) + sol = solve(dfd_prob.prob, alg) ρt = map( i -> QuantumObject( @@ -244,13 +245,13 @@ function dfd_mesolve( ) return TimeEvolutionSol( - sol.prob.p.times, + dfd_prob.times, ρt, - sol.prob.p.expvals, + _se_me_sse_get_expvals(sol), sol.retcode, sol.alg, - sol.prob.kwargs[:abstol], - sol.prob.kwargs[:reltol], + NamedTuple(sol.prob.kwargs).abstol, + NamedTuple(sol.prob.kwargs).reltol, ) end @@ -282,7 +283,6 @@ function _DSF_mesolve_Affect!(integrator) H = internal_params.H_fun c_ops = internal_params.c_ops_fun e_ops = internal_params.e_ops_fun - e_ops_vec = internal_params.e_ops dsf_cache = internal_params.dsf_cache dsf_params = internal_params.dsf_params expv_cache = internal_params.expv_cache @@ -333,8 +333,7 @@ function _DSF_mesolve_Affect!(integrator) op_l2 = op_list .+ αt_list e_ops2 = e_ops(op_l2, dsf_params) - _mat2vec_data = op -> mat2vec(get_data(op)') - @. e_ops_vec = _mat2vec_data(e_ops2) + _mesolve_callbacks_new_e_ops!(integrator, [_generate_mesolve_e_op(op) for op in e_ops2]) # By doing this, we are assuming that the system is time-independent and f is a MatrixOperator copyto!(integrator.f.f.A, liouvillian(H(op_l2, dsf_params), c_ops(op_l2, dsf_params), dsf_identity).data) return u_modified!(integrator, true) @@ -373,7 +372,6 @@ function dsf_mesolveProblem( dsf_displace_cache_full = dsf_displace_cache_left + dsf_displace_cache_left_dag + dsf_displace_cache_right + dsf_displace_cache_right_dag - params2 = params params2 = merge( params, ( @@ -489,9 +487,9 @@ end # Dynamical Shifted Fock mcsolve function _DSF_mcsolve_Condition(u, t, integrator) - internal_params = integrator.p - op_list = internal_params.op_list - δα_list = internal_params.δα_list + params = integrator.p + op_list = params.op_list + δα_list = params.δα_list ψt = u @@ -508,20 +506,24 @@ function _DSF_mcsolve_Condition(u, t, integrator) end function _DSF_mcsolve_Affect!(integrator) - internal_params = integrator.p - op_list = internal_params.op_list - αt_list = internal_params.αt_list - δα_list = internal_params.δα_list - H = internal_params.H_fun - c_ops = internal_params.c_ops_fun - e_ops = internal_params.e_ops_fun - e_ops0 = internal_params.e_ops_mc - c_ops0 = internal_params.c_ops - ψt = internal_params.dsf_cache1 - dsf_cache = internal_params.dsf_cache2 - expv_cache = internal_params.expv_cache - dsf_params = internal_params.dsf_params - dsf_displace_cache_full = internal_params.dsf_displace_cache_full + params = integrator.p + op_list = params.op_list + αt_list = params.αt_list + δα_list = params.δα_list + H = params.H_fun + c_ops = params.c_ops_fun + e_ops = params.e_ops_fun + ψt = params.dsf_cache1 + dsf_cache = params.dsf_cache2 + expv_cache = params.expv_cache + dsf_params = params.dsf_params + dsf_displace_cache_full = params.dsf_displace_cache_full + + # e_ops0 = params.e_ops + # c_ops0 = params.c_ops + + e_ops0 = _mcsolve_get_e_ops(integrator) + c_ops0, c_ops0_herm = _mcsolve_get_c_ops(integrator) copyto!(ψt, integrator.u) normalize!(ψt) @@ -561,42 +563,38 @@ function _DSF_mcsolve_Affect!(integrator) op_l2 = op_list .+ αt_list e_ops2 = e_ops(op_l2, dsf_params) c_ops2 = c_ops(op_l2, dsf_params) + + ## By copying the data, we are assuming that the variables are Vectors and not Tuple @. e_ops0 = get_data(e_ops2) @. c_ops0 = get_data(c_ops2) - H_nh = lmul!(convert(eltype(ψt), 0.5im), mapreduce(op -> op' * op, +, c_ops0)) + c_ops0_herm .= map(op -> op' * op, c_ops0) + + H_nh = convert(eltype(ψt), 0.5im) * sum(c_ops0_herm) # By doing this, we are assuming that the system is time-independent and f is a MatrixOperator copyto!(integrator.f.f.A, lmul!(-1im, H(op_l2, dsf_params).data - H_nh)) return u_modified!(integrator, true) end function _dsf_mcsolve_prob_func(prob, i, repeat) - internal_params = prob.p + params = prob.p prm = merge( - internal_params, + params, ( - e_ops_mc = deepcopy(internal_params.e_ops_mc), - c_ops = deepcopy(internal_params.c_ops), - expvals = similar(internal_params.expvals), - cache_mc = similar(internal_params.cache_mc), - weights_mc = similar(internal_params.weights_mc), - cumsum_weights_mc = similar(internal_params.weights_mc), - random_n = Ref(rand()), - progr_mc = ProgressBar(size(internal_params.expvals, 2), enable = false), - jump_times_which_idx = Ref(1), - jump_times = similar(internal_params.jump_times), - jump_which = similar(internal_params.jump_which), - αt_list = copy(internal_params.αt_list), - dsf_cache1 = similar(internal_params.dsf_cache1), - dsf_cache2 = similar(internal_params.dsf_cache2), - expv_cache = copy(internal_params.expv_cache), - dsf_displace_cache_full = deepcopy(internal_params.dsf_displace_cache_full), # This brutally copies also the MatrixOperators, and it is inefficient. + αt_list = copy(params.αt_list), + dsf_cache1 = similar(params.dsf_cache1), + dsf_cache2 = similar(params.dsf_cache2), + expv_cache = copy(params.expv_cache), + dsf_displace_cache_full = deepcopy(params.dsf_displace_cache_full), # This brutally copies also the MatrixOperators, and it is inefficient. ), ) f = deepcopy(prob.f.f) - return remake(prob, f = f, p = prm) + # We need to deepcopy the callbacks because they contain the c_ops and e_ops, which are modified in the affect function. They also contain all the cache variables needed for mcsolve. + cb = deepcopy(prob.kwargs[:callback]) + + return remake(prob, f = f, p = prm, callback = cb) end function dsf_mcsolveEnsembleProblem( @@ -731,5 +729,5 @@ function dsf_mcsolve( kwargs..., ) - return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method) + return mcsolve(ens_prob_mc, alg, ntraj, ensemble_method) end diff --git a/test/core-test/time_evolution.jl b/test/core-test/time_evolution.jl index c760b410..9695c20c 100644 --- a/test/core-test/time_evolution.jl +++ b/test/core-test/time_evolution.jl @@ -26,20 +26,21 @@ sol2 = sesolve(H, ψ0, tlist, progress_bar = Val(false)) sol3 = sesolve(H, ψ0, tlist, e_ops = e_ops, saveat = tlist, progress_bar = Val(false)) sol_string = sprint((t, s) -> show(t, "text/plain", s), sol) + sol_string2 = sprint((t, s) -> show(t, "text/plain", s), sol2) ## Analytical solution for the expectation value of a' * a Ω_rabi = sqrt(g^2 + ((ωc - ωq) / 2)^2) amp_rabi = g^2 / Ω_rabi^2 ## - @test prob.f.f isa MatrixOperator + @test prob.prob.f.f isa MatrixOperator @test sum(abs.(sol.expect[1, :] .- amp_rabi .* sin.(Ω_rabi * tlist) .^ 2)) / length(tlist) < 0.1 @test length(sol.times) == length(tlist) @test length(sol.states) == 1 @test size(sol.expect) == (length(e_ops), length(tlist)) @test length(sol2.times) == length(tlist) @test length(sol2.states) == length(tlist) - @test size(sol2.expect) == (0, length(tlist)) + @test sol2.expect === nothing @test length(sol3.times) == length(tlist) @test length(sol3.states) == length(tlist) @test size(sol3.expect) == (length(e_ops), length(tlist)) @@ -52,6 +53,15 @@ "ODE alg.: $(sol.alg)\n" * "abstol = $(sol.abstol)\n" * "reltol = $(sol.reltol)\n" + @test sol_string2 == + "Solution of time evolution\n" * + "(return code: $(sol2.retcode))\n" * + "--------------------------\n" * + "num_states = $(length(sol2.states))\n" * + "num_expect = 0\n" * + "ODE alg.: $(sol2.alg)\n" * + "abstol = $(sol2.abstol)\n" * + "reltol = $(sol2.reltol)\n" @testset "Memory Allocations" begin allocs_tot = @allocations sesolve(H, ψ0, tlist, e_ops = e_ops, progress_bar = Val(false)) # Warm-up @@ -116,9 +126,10 @@ sol_me_string = sprint((t, s) -> show(t, "text/plain", s), sol_me) sol_mc_string = sprint((t, s) -> show(t, "text/plain", s), sol_mc) + sol_mc_string_states = sprint((t, s) -> show(t, "text/plain", s), sol_mc_states) sol_sse_string = sprint((t, s) -> show(t, "text/plain", s), sol_sse) - @test prob_me.f.f isa MatrixOperator - @test prob_mc.f.f isa MatrixOperator + @test prob_me.prob.f.f isa MatrixOperator + @test prob_mc.prob.f.f isa MatrixOperator @test sum(abs.(sol_mc.expect .- sol_me.expect)) / length(tlist) < 0.1 @test sum(abs.(sol_mc2.expect .- sol_me.expect)) / length(tlist) < 0.1 @test sum(abs.(vec(expect_mc_states_mean) .- vec(sol_me.expect[1, :]))) / length(tlist) < 0.1 @@ -129,14 +140,14 @@ @test size(sol_me.expect) == (length(e_ops), length(tlist)) @test length(sol_me2.times) == length(tlist) @test length(sol_me2.states) == length(tlist) - @test size(sol_me2.expect) == (0, length(tlist)) + @test sol_me2.expect === nothing @test length(sol_me3.times) == length(tlist) @test length(sol_me3.states) == length(tlist) @test size(sol_me3.expect) == (length(e_ops), length(tlist)) @test length(sol_mc.times) == length(tlist) @test size(sol_mc.expect) == (length(e_ops), length(tlist)) @test length(sol_mc_states.times) == length(tlist) - @test size(sol_mc_states.expect) == (0, length(tlist)) + @test sol_mc_states.expect === nothing @test length(sol_sse.times) == length(tlist) @test size(sol_sse.expect) == (length(e_ops), length(tlist)) @test sol_me_string == @@ -158,6 +169,16 @@ "ODE alg.: $(sol_mc.alg)\n" * "abstol = $(sol_mc.abstol)\n" * "reltol = $(sol_mc.reltol)\n" + @test sol_mc_string_states == + "Solution of quantum trajectories\n" * + "(converged: $(sol_mc_states.converged))\n" * + "--------------------------------\n" * + "num_trajectories = $(sol_mc_states.ntraj)\n" * + "num_states = $(length(sol_mc_states.states[1]))\n" * + "num_expect = 0\n" * + "ODE alg.: $(sol_mc_states.alg)\n" * + "abstol = $(sol_mc_states.abstol)\n" * + "reltol = $(sol_mc_states.reltol)\n" @test sol_sse_string == "Solution of quantum trajectories\n" * "(converged: $(sol_sse.converged))\n" * @@ -462,7 +483,7 @@ @test sol_mc1.jump_times ≈ sol_mc2.jump_times atol = 1e-10 @test sol_mc1.jump_which ≈ sol_mc2.jump_which atol = 1e-10 - @test sol_mc1.expect_all ≈ sol_mc3.expect_all[1:500, :, :] atol = 1e-10 + @test sol_mc1.expect_all ≈ sol_mc3.expect_all[:, :, 1:500] atol = 1e-10 @test sol_sse1.expect ≈ sol_sse2.expect atol = 1e-10 @test sol_sse1.expect_all ≈ sol_sse2.expect_all atol = 1e-10