From 863d7ecd4c4d5009136c1095d5577471e1a21c79 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Mon, 7 Oct 2024 22:07:23 +0200 Subject: [PATCH] Dispatch progress bar method in EnsembleProblems --- src/QuantumToolbox.jl | 2 + src/time_evolution/mcsolve.jl | 196 +++++++++++------- src/time_evolution/ssesolve.jl | 159 +++++++++----- .../time_evolution_dynamical.jl | 57 +++-- 4 files changed, 254 insertions(+), 160 deletions(-) diff --git a/src/QuantumToolbox.jl b/src/QuantumToolbox.jl index 2326a084..87e6a215 100644 --- a/src/QuantumToolbox.jl +++ b/src/QuantumToolbox.jl @@ -24,7 +24,9 @@ import SciMLBase: ODEProblem, SDEProblem, EnsembleProblem, + EnsembleSerial, EnsembleThreads, + EnsembleDistributed, FullSpecialize, CallbackSet, ContinuousCallback, diff --git a/src/time_evolution/mcsolve.jl b/src/time_evolution/mcsolve.jl index 6b63570c..4eaf30ad 100644 --- a/src/time_evolution/mcsolve.jl +++ b/src/time_evolution/mcsolve.jl @@ -80,13 +80,29 @@ function _mcsolve_prob_func(prob, i, repeat) return remake(prob, p = prm) end +# Standard output function function _mcsolve_output_func(sol, i) resize!(sol.prob.p.jump_times, sol.prob.p.jump_times_which_idx[] - 1) resize!(sol.prob.p.jump_which, sol.prob.p.jump_times_which_idx[] - 1) - put!(sol.prob.p.progr_channel, true) return (sol, false) end +# Output function with progress bar update +function _mcsolve_output_func_progress(sol, i) + next!(sol.prob.p.progr_trajectories) + return _mcsolve_output_func(sol, i) +end + +# Output function with distributed channel update for progress bar +function _mcsolve_output_func_distributed(sol, i) + put!(sol.prob.p.progr_channel, true) + return _mcsolve_output_func(sol, i) +end + +_mcsolve_dispatch_output_func() = _mcsolve_output_func +_mcsolve_dispatch_output_func(::ET) where {ET<:Union{EnsembleSerial,EnsembleThreads}} = _mcsolve_output_func_progress +_mcsolve_dispatch_output_func(::EnsembleDistributed) = _mcsolve_output_func_distributed + function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which) sol_i = sol[:, i] !isempty(sol_i.prob.kwargs[:saveat]) ? @@ -293,9 +309,12 @@ end e_ops::Union{Nothing,AbstractVector,Tuple}=nothing, H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing, params::NamedTuple=NamedTuple(), + ntraj::Int=1, + ensemble_method=EnsembleThreads(), jump_callback::TJC=ContinuousLindbladJumpCallback(), prob_func::Function=_mcsolve_prob_func, output_func::Function=_mcsolve_output_func, + progress_bar::Union{Val,Bool}=Val(true), kwargs...) Generates the `EnsembleProblem` of `ODEProblem`s for the ensemble of trajectories of the Monte Carlo wave function time evolution of an open quantum system. @@ -343,9 +362,12 @@ If the environmental measurements register a quantum jump, the wave function und - `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian. - `params::NamedTuple`: Dictionary of parameters to pass to the solver. - `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided. +- `ntraj::Int`: Number of trajectories to use. +- `ensemble_method`: Ensemble method to use. - `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous. - `prob_func::Function`: Function to use for generating the ODEProblem. - `output_func::Function`: Function to use for generating the output of a single trajectory. +- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities. - `kwargs...`: Additional keyword arguments to pass to the solver. # Notes @@ -369,29 +391,51 @@ function mcsolveEnsembleProblem( e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing, params::NamedTuple = NamedTuple(), + ntraj::Int = 1, + ensemble_method = EnsembleThreads(), jump_callback::TJC = ContinuousLindbladJumpCallback(), seeds::Union{Nothing,Vector{Int}} = nothing, prob_func::Function = _mcsolve_prob_func, - output_func::Function = _mcsolve_output_func, + output_func::Function = _mcsolve_dispatch_output_func(ensemble_method), + progress_bar::Union{Val,Bool} = Val(true), kwargs..., ) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType} - prob_mc = mcsolveProblem( - H, - ψ0, - tlist, - c_ops; - alg = alg, - e_ops = e_ops, - H_t = H_t, - params = params, - seeds = seeds, - jump_callback = jump_callback, - kwargs..., - ) + progr = ProgressBar(ntraj, enable = getVal(progress_bar)) + if ensemble_method isa EnsembleDistributed + progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1)) + @async while take!(progr_channel) + next!(progr) + end + params = merge(params, (progr_channel = progr_channel,)) + else + params = merge(params, (progr_trajectories = progr,)) + end + + # Stop the async task if an error occurs + try + prob_mc = mcsolveProblem( + H, + ψ0, + tlist, + c_ops; + alg = alg, + e_ops = e_ops, + H_t = H_t, + params = params, + seeds = seeds, + jump_callback = jump_callback, + kwargs..., + ) - ensemble_prob = EnsembleProblem(prob_mc, prob_func = prob_func, output_func = output_func, safetycopy = false) + ensemble_prob = EnsembleProblem(prob_mc, prob_func = prob_func, output_func = output_func, safetycopy = false) - return ensemble_prob + return ensemble_prob + catch e + if ensemble_method isa EnsembleDistributed + put!(progr_channel, false) + end + rethrow() + end end @doc raw""" @@ -408,7 +452,7 @@ end ensemble_method = EnsembleThreads(), jump_callback::TJC = ContinuousLindbladJumpCallback(), prob_func::Function = _mcsolve_prob_func, - output_func::Function = _mcsolve_output_func, + output_func::Function = _mcsolve_dispatch_output_func(ensemble_method), progress_bar::Union{Val,Bool} = Val(true), kwargs..., ) @@ -493,7 +537,7 @@ function mcsolve( ensemble_method = EnsembleThreads(), jump_callback::TJC = ContinuousLindbladJumpCallback(), prob_func::Function = _mcsolve_prob_func, - output_func::Function = _mcsolve_output_func, + output_func::Function = _mcsolve_dispatch_output_func(ensemble_method), progress_bar::Union{Val,Bool} = Val(true), kwargs..., ) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType} @@ -501,35 +545,26 @@ function mcsolve( throw(ArgumentError("Length of seeds must match ntraj ($ntraj), but got $(length(seeds))")) end - progr = ProgressBar(ntraj, enable = getVal(progress_bar)) - progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1)) - @async while take!(progr_channel) - next!(progr) - end - - # Stop the async task if an error occurs - try - ens_prob_mc = mcsolveEnsembleProblem( - H, - ψ0, - tlist, - c_ops; - alg = alg, - e_ops = e_ops, - H_t = H_t, - params = merge(params, (progr_channel = progr_channel,)), - seeds = seeds, - jump_callback = jump_callback, - prob_func = prob_func, - output_func = output_func, - kwargs..., - ) + ens_prob_mc = mcsolveEnsembleProblem( + H, + ψ0, + tlist, + c_ops; + alg = alg, + e_ops = e_ops, + H_t = H_t, + params = params, + seeds = seeds, + ntraj = ntraj, + ensemble_method = ensemble_method, + jump_callback = jump_callback, + prob_func = prob_func, + output_func = output_func, + progress_bar = progress_bar, + kwargs..., + ) - return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method) - catch e - put!(progr_channel, false) - rethrow() - end + return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method) end function mcsolve( @@ -538,33 +573,42 @@ function mcsolve( ntraj::Int = 1, ensemble_method = EnsembleThreads(), ) - sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj) - - put!(sol[:, 1].prob.p.progr_channel, false) - - _sol_1 = sol[:, 1] - - expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...) - states = - isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) : - Vector{Vector{QuantumObject}}(undef, length(sol)) - jump_times = Vector{Vector{Float64}}(undef, length(sol)) - jump_which = Vector{Vector{Int16}}(undef, length(sol)) - - foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol)) - expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol) - - return TimeEvolutionMCSol( - ntraj, - _sol_1.prob.p.times, - states, - expvals, - expvals_all, - jump_times, - jump_which, - sol.converged, - _sol_1.alg, - _sol_1.prob.kwargs[:abstol], - _sol_1.prob.kwargs[:reltol], - ) + try + sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj) + + if ensemble_method isa EnsembleDistributed + put!(sol[:, 1].prob.p.progr_channel, false) + end + + _sol_1 = sol[:, 1] + + expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...) + states = + isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) : + Vector{Vector{QuantumObject}}(undef, length(sol)) + jump_times = Vector{Vector{Float64}}(undef, length(sol)) + jump_which = Vector{Vector{Int16}}(undef, length(sol)) + + foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol)) + expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol) + + return TimeEvolutionMCSol( + ntraj, + _sol_1.prob.p.times, + states, + expvals, + expvals_all, + jump_times, + jump_which, + sol.converged, + _sol_1.alg, + _sol_1.prob.kwargs[:abstol], + _sol_1.prob.kwargs[:reltol], + ) + catch e + if ensemble_method isa EnsembleDistributed + put!(ens_prob_mc.prob.p.progr_channel, false) + end + rethrow() + end end diff --git a/src/time_evolution/ssesolve.jl b/src/time_evolution/ssesolve.jl index bdd92ba2..fae96117 100644 --- a/src/time_evolution/ssesolve.jl +++ b/src/time_evolution/ssesolve.jl @@ -52,11 +52,25 @@ function _ssesolve_prob_func(prob, i, repeat) return remake(prob, p = prm, noise = noise, noise_rate_prototype = noise_rate_prototype) end -function _ssesolve_output_func(sol, i) +# Standard output function +_ssesolve_output_func(sol, i) = (sol, false) + +# Output function with progress bar update +function _ssesolve_output_func_progress(sol, i) + next!(sol.prob.p.progr) + return _ssesolve_output_func(sol, i) +end + +# Output function with distributed channel update for progress bar +function _ssesolve_output_func_distributed(sol, i) put!(sol.prob.p.progr_channel, true) - return (sol, false) + return _ssesolve_output_func(sol, i) end +_ssesolve_dispatch_output_func() = _ssesolve_output_func +_ssesolve_dispatch_output_func(::ET) where {ET<:Union{EnsembleSerial,EnsembleThreads}} = _ssesolve_output_func_progress +_ssesolve_dispatch_output_func(::EnsembleDistributed) = _ssesolve_output_func_distributed + function _ssesolve_generate_statistics!(sol, i, states, expvals_all) sol_i = sol[:, i] !isempty(sol_i.prob.kwargs[:saveat]) ? @@ -209,8 +223,11 @@ end e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing, params::NamedTuple=NamedTuple(), + ntraj::Int=1, + ensemble_method=EnsembleThreads(), prob_func::Function=_mcsolve_prob_func, - output_func::Function=_mcsolve_output_func, + output_func::Function=_ssesolve_dispatch_output_func(ensemble_method), + progress_bar::Union{Val,Bool}=Val(true), kwargs...) Generates the SDE EnsembleProblem for the Stochastic Schrödinger time evolution of a quantum system. This is defined by the following stochastic differential equation: @@ -244,8 +261,11 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i - `e_ops::Union{Nothing,AbstractVector,Tuple}=nothing`: The list of operators to be evaluated during the evolution. - `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: The time-dependent Hamiltonian of the system. If `nothing`, the Hamiltonian is time-independent. - `params::NamedTuple`: The parameters of the system. +- `ntraj::Int`: Number of trajectories to use. +- `ensemble_method`: Ensemble method to use. - `prob_func::Function`: Function to use for generating the SDEProblem. - `output_func::Function`: Function to use for generating the output of a single trajectory. +- `progress_bar::Union{Val,Bool}`: Whether to show a progress bar. - `kwargs...`: The keyword arguments passed to the `SDEProblem` constructor. # Notes @@ -269,15 +289,38 @@ function ssesolveEnsembleProblem( e_ops::Union{Nothing,AbstractVector,Tuple} = nothing, H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing, params::NamedTuple = NamedTuple(), + ntraj::Int = 1, + ensemble_method = EnsembleThreads(), prob_func::Function = _ssesolve_prob_func, - output_func::Function = _ssesolve_output_func, + output_func::Function = _ssesolve_dispatch_output_func(ensemble_method), + progress_bar::Union{Val,Bool} = Val(true), kwargs..., ) where {MT1<:AbstractMatrix,T2} - prob_sse = ssesolveProblem(H, ψ0, tlist, sc_ops; alg = alg, e_ops = e_ops, H_t = H_t, params = params, kwargs...) + progr = ProgressBar(ntraj, enable = getVal(progress_bar)) + if ensemble_method isa EnsembleDistributed + progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1)) + @async while take!(progr_channel) + next!(progr) + end + params = merge(params, (progr_channel = progr_channel,)) + else + params = merge(params, (progr_trajectories = progr,)) + end - ensemble_prob = EnsembleProblem(prob_sse, prob_func = prob_func, output_func = output_func, safetycopy = false) + # Stop the async task if an error occurs + try + prob_sse = + ssesolveProblem(H, ψ0, tlist, sc_ops; alg = alg, e_ops = e_ops, H_t = H_t, params = params, kwargs...) - return ensemble_prob + ensemble_prob = EnsembleProblem(prob_sse, prob_func = prob_func, output_func = output_func, safetycopy = false) + + return ensemble_prob + catch e + if ensemble_method isa EnsembleDistributed + put!(progr_channel, false) + end + rethrow(e) + end end @doc raw""" @@ -291,8 +334,8 @@ end params::NamedTuple=NamedTuple(), ntraj::Int=1, ensemble_method=EnsembleThreads(), - prob_func::Function=_mcsolve_prob_func, - output_func::Function=_mcsolve_output_func, + prob_func::Function=_ssesolve_prob_func, + output_func::Function=_ssesolve_dispatch_output_func(ensemble_method), progress_bar::Union{Val,Bool} = Val(true), kwargs...) @@ -363,7 +406,7 @@ function ssesolve( ntraj::Int = 1, ensemble_method = EnsembleThreads(), prob_func::Function = _ssesolve_prob_func, - output_func::Function = _ssesolve_output_func, + output_func::Function = _ssesolve_dispatch_output_func(ensemble_method), progress_bar::Union{Val,Bool} = Val(true), kwargs..., ) where {MT1<:AbstractMatrix,T2} @@ -373,26 +416,24 @@ function ssesolve( next!(progr) end - try - ens_prob = ssesolveEnsembleProblem( - H, - ψ0, - tlist, - sc_ops; - alg = alg, - e_ops = e_ops, - H_t = H_t, - params = merge(params, (progr_channel = progr_channel,)), - prob_func = prob_func, - output_func = output_func, - kwargs..., - ) + ens_prob = ssesolveEnsembleProblem( + H, + ψ0, + tlist, + sc_ops; + alg = alg, + e_ops = e_ops, + H_t = H_t, + params = params, + ntraj = ntraj, + ensemble_method = ensemble_method, + prob_func = prob_func, + output_func = output_func, + progress_bar = progress_bar, + kwargs..., + ) - return ssesolve(ens_prob; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method) - catch e - put!(progr_channel, false) - rethrow() - end + return ssesolve(ens_prob; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method) end function ssesolve( @@ -401,29 +442,39 @@ function ssesolve( ntraj::Int = 1, ensemble_method = EnsembleThreads(), ) - sol = solve(ens_prob, alg, ensemble_method, trajectories = ntraj) - - put!(sol[:, 1].prob.p.progr_channel, false) - - _sol_1 = sol[:, 1] - - expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...) - states = - isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) : - Vector{Vector{QuantumObject}}(undef, length(sol)) - - foreach(i -> _ssesolve_generate_statistics!(sol, i, states, expvals_all), eachindex(sol)) - expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol) - - return TimeEvolutionSSESol( - ntraj, - _sol_1.prob.p.times, - states, - expvals, - expvals_all, - sol.converged, - _sol_1.alg, - _sol_1.prob.kwargs[:abstol], - _sol_1.prob.kwargs[:reltol], - ) + # Stop the async task if an error occurs + try + sol = solve(ens_prob, alg, ensemble_method, trajectories = ntraj) + + if ensemble_method isa EnsembleDistributed + put!(sol[:, 1].prob.p.progr_channel, false) + end + + _sol_1 = sol[:, 1] + + expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...) + states = + isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) : + Vector{Vector{QuantumObject}}(undef, length(sol)) + + foreach(i -> _ssesolve_generate_statistics!(sol, i, states, expvals_all), eachindex(sol)) + expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol) + + return TimeEvolutionSSESol( + ntraj, + _sol_1.prob.p.times, + states, + expvals, + expvals_all, + sol.converged, + _sol_1.alg, + _sol_1.prob.kwargs[:abstol], + _sol_1.prob.kwargs[:reltol], + ) + catch e + if ensemble_method isa EnsembleDistributed + put!(ens_prob.prob.p.progr_channel, false) + end + rethrow(e) + end end diff --git a/src/time_evolution/time_evolution_dynamical.jl b/src/time_evolution/time_evolution_dynamical.jl index 496f2513..4ffe3187 100644 --- a/src/time_evolution/time_evolution_dynamical.jl +++ b/src/time_evolution/time_evolution_dynamical.jl @@ -620,9 +620,12 @@ function dsf_mcsolveEnsembleProblem( e_ops::Function = (op_list, p) -> Vector{TOl}([]), H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing, params::NamedTuple = NamedTuple(), + ntraj::Int = 1, + ensemble_method = EnsembleThreads(), δα_list::Vector{<:Real} = fill(0.2, length(op_list)), jump_callback::TJC = ContinuousLindbladJumpCallback(), krylov_dim::Int = min(5, cld(length(ψ0.data), 3)), + progress_bar::Union{Bool,Val} = Val(true), kwargs..., ) where {T,TOl,TJC<:LindbladJumpCallbackType} op_l = op_list @@ -669,8 +672,11 @@ function dsf_mcsolveEnsembleProblem( alg = alg, H_t = H_t, params = params2, + ntraj = ntraj, + ensemble_method = ensemble_method, jump_callback = jump_callback, prob_func = _dsf_mcsolve_prob_func, + progress_bar = progress_bar, kwargs2..., ) end @@ -720,35 +726,26 @@ function dsf_mcsolve( progress_bar::Union{Bool,Val} = Val(true), kwargs..., ) where {T,TOl,TJC<:LindbladJumpCallbackType} - progr = ProgressBar(ntraj, enable = getVal(progress_bar)) - progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1)) - @async while take!(progr_channel) - next!(progr) - end + ens_prob_mc = dsf_mcsolveEnsembleProblem( + H, + ψ0, + t_l, + c_ops, + op_list, + α0_l, + dsf_params; + alg = alg, + e_ops = e_ops, + H_t = H_t, + params = params, + ntraj = ntraj, + ensemble_method = ensemble_method, + δα_list = δα_list, + jump_callback = jump_callback, + krylov_dim = krylov_dim, + progress_bar = progress_bar, + kwargs..., + ) - # Stop the async task if an error occurs - try - ens_prob_mc = dsf_mcsolveEnsembleProblem( - H, - ψ0, - t_l, - c_ops, - op_list, - α0_l, - dsf_params; - alg = alg, - e_ops = e_ops, - H_t = H_t, - params = merge(params, (progr_channel = progr_channel,)), - δα_list = δα_list, - jump_callback = jump_callback, - krylov_dim = krylov_dim, - kwargs..., - ) - - return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method) - catch e - put!(progr_channel, false) - rethrow() - end + return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method) end