Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add progress_bar in mcsolve, ssesolve and dsf_mcsolve #254

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
Expand Down Expand Up @@ -37,23 +38,24 @@ CUDA = "5"
DiffEqBase = "6"
DiffEqCallbacks = "2 - 3.1, 3.8, 4"
DiffEqNoiseProcess = "5"
Distributed = "1"
FFTW = "1.5"
Graphs = "1.7"
IncompleteLU = "0.2"
LinearAlgebra = "<0.0.1, 1"
LinearAlgebra = "1"
LinearSolve = "2"
OrdinaryDiffEqCore = "1"
OrdinaryDiffEqTsit5 = "1"
Pkg = "<0.0.1, 1"
Random = "<0.0.1, 1"
Pkg = "1"
Random = "1"
Reexport = "1"
SciMLBase = "2"
SciMLOperators = "0.3"
SparseArrays = "<0.0.1, 1"
SparseArrays = "1"
SpecialFunctions = "2"
StaticArraysCore = "1"
StochasticDiffEq = "6"
Test = "<0.0.1, 1"
Test = "1"
julia = "1.10"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/QuantumToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import DiffEqNoiseProcess: RealWienerProcess

# other dependencies (in alphabetical order)
import ArrayInterface: allowed_getindex, allowed_setindex!
import Distributed: RemoteChannel
import FFTW: fft, fftshift
import Graphs: connected_components, DiGraph
import IncompleteLU: ilu
Expand Down
2 changes: 1 addition & 1 deletion src/qobj/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ function tunneling(N::Int, m::Int = 1; sparse::Union{Bool,Val} = Val(false))
(m < 1) && throw(ArgumentError("The number of excitations (m) cannot be less than 1"))

data = ones(ComplexF64, N - m)
if getVal(makeVal(sparse))
if getVal(sparse)
return QuantumObject(spdiagm(m => data, -m => data); type = Operator, dims = N)
else
return QuantumObject(diagm(m => data, -m => data); type = Operator, dims = N)
Expand Down
4 changes: 2 additions & 2 deletions src/qobj/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ It is also possible to specify the list of dimensions `dims` if different subsys
If you want to keep type stability, it is recommended to use `fock(N, j, dims=dims, sparse=Val(sparse))` instead of `fock(N, j, dims=dims, sparse=sparse)`. Consider also to use `dims` as a `Tuple` or `SVector` instead of `Vector`. See [this link](https://docs.julialang.org/en/v1/manual/performance-tips/#man-performance-value-type) and the [related Section](@ref doc:Type-Stability) about type stability for more details.
"""
function fock(N::Int, j::Int = 0; dims::Union{Int,AbstractVector{Int},Tuple} = N, sparse::Union{Bool,Val} = Val(false))
if getVal(makeVal(sparse))
if getVal(sparse)
array = sparsevec([j + 1], [1.0 + 0im], N)
else
array = zeros(ComplexF64, N)
Expand Down Expand Up @@ -130,7 +130,7 @@ function thermal_dm(N::Int, n::Real; sparse::Union{Bool,Val} = Val(false))
β = log(1.0 / n + 1.0)
N_list = Array{Float64}(0:N-1)
data = exp.(-β .* N_list)
if getVal(makeVal(sparse))
if getVal(sparse)
return QuantumObject(spdiagm(0 => data ./ sum(data)), Operator, N)
else
return QuantumObject(diagm(0 => data ./ sum(data)), Operator, N)
Expand Down
76 changes: 50 additions & 26 deletions src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ end
function _mcsolve_output_func(sol, i)
resize!(sol.prob.p.jump_times, sol.prob.p.jump_times_which_idx[] - 1)
resize!(sol.prob.p.jump_which, sol.prob.p.jump_times_which_idx[] - 1)
put!(sol.prob.p.progr_channel, true)
return (sol, false)
end

Expand Down Expand Up @@ -204,7 +205,8 @@ function mcsolveProblem(
end

saveat = e_ops isa Nothing ? t_l : [t_l[end]]
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
# We disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat, progress_bar = Val(false))
kwargs2 = merge(default_values, kwargs)

cache_mc = similar(ψ0.data)
Expand Down Expand Up @@ -396,15 +398,20 @@ end
mcsolve(H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
tlist::AbstractVector,
c_ops::Union{Nothing,AbstractVector,Tuple}=nothing;
alg::OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::Union{Nothing,AbstractVector,Tuple}=nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
ntraj::Int=1,
ensemble_method=EnsembleThreads(),
jump_callback::TJC=ContinuousLindbladJumpCallback(),
kwargs...)
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
seeds::Union{Nothing,Vector{Int}} = nothing,
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_output_func,
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
)

Time evolution of an open quantum system using quantum trajectories.

Expand Down Expand Up @@ -457,6 +464,7 @@ If the environmental measurements register a quantum jump, the wave function und
- `prob_func::Function`: Function to use for generating the ODEProblem.
- `output_func::Function`: Function to use for generating the output of a single trajectory.
- `kwargs...`: Additional keyword arguments to pass to the solver.
- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.

# Notes

Expand Down Expand Up @@ -486,29 +494,42 @@ function mcsolve(
jump_callback::TJC = ContinuousLindbladJumpCallback(),
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_output_func,
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
if !isnothing(seeds) && length(seeds) != ntraj
throw(ArgumentError("Length of seeds must match ntraj ($ntraj), but got $(length(seeds))"))
end

ens_prob_mc = mcsolveEnsembleProblem(
H,
ψ0,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = params,
seeds = seeds,
jump_callback = jump_callback,
prob_func = prob_func,
output_func = output_func,
kwargs...,
)
progr = ProgressBar(ntraj, enable = getVal(progress_bar))
progr_channel::RemoteChannel{Channel{Bool}} = RemoteChannel(() -> Channel{Bool}(1))
@async while take!(progr_channel)
next!(progr)
end

return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
# Stop the async task if an error occurs
try
ens_prob_mc = mcsolveEnsembleProblem(
H,
ψ0,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = merge(params, (progr_channel = progr_channel,)),
seeds = seeds,
jump_callback = jump_callback,
prob_func = prob_func,
output_func = output_func,
kwargs...,
)

return mcsolve(ens_prob_mc; alg = alg, ntraj = ntraj, ensemble_method = ensemble_method)
catch e
put!(progr_channel, false)
rethrow()
end
end

function mcsolve(
Expand All @@ -518,6 +539,9 @@ function mcsolve(
ensemble_method = EnsembleThreads(),
)
sol = solve(ens_prob_mc, alg, ensemble_method, trajectories = ntraj)

put!(sol[:, 1].prob.p.progr_channel, false)

_sol_1 = sol[:, 1]

expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
Expand Down
7 changes: 3 additions & 4 deletions src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,13 @@ function mesolveProblem(
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

ρ0 = sparse_to_dense(_CType(ψ0), mat2vec(ket2dm(ψ0).data)) # Convert it to dense vector with complex element type

t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

L = liouvillian(H, c_ops).data
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
progr = ProgressBar(length(t_l), enable = getVal(progress_bar))

if e_ops isa Nothing
expvals = Array{ComplexF64}(undef, 0, length(t_l))
Expand Down Expand Up @@ -158,7 +157,7 @@ function mesolveProblem(
saveat = e_ops isa Nothing ? t_l : [t_l[end]]
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
kwargs2 = merge(default_values, kwargs)
kwargs3 = _generate_mesolve_kwargs(e_ops, progress_bar_val, t_l, kwargs2)
kwargs3 = _generate_mesolve_kwargs(e_ops, makeVal(progress_bar), t_l, kwargs2)

dudt! = is_time_dependent ? mesolve_td_dudt! : mesolve_ti_dudt!

Expand Down Expand Up @@ -241,7 +240,7 @@ function mesolve(
e_ops = e_ops,
H_t = H_t,
params = params,
progress_bar = makeVal(progress_bar),
progress_bar = progress_bar,
kwargs...,
)

Expand Down
7 changes: 3 additions & 4 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,13 @@ function sesolveProblem(
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

ϕ0 = sparse_to_dense(_CType(ψ0), get_data(ψ0)) # Convert it to dense vector with complex element type

t_l = convert(Vector{_FType(ψ0)}, tlist) # Convert it to support GPUs and avoid type instabilities for OrdinaryDiffEq.jl

U = -1im * get_data(H)
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))
progr = ProgressBar(length(t_l), enable = getVal(progress_bar))

if e_ops isa Nothing
expvals = Array{ComplexF64}(undef, 0, length(t_l))
Expand All @@ -135,7 +134,7 @@ function sesolveProblem(
saveat = e_ops isa Nothing ? t_l : [t_l[end]]
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
kwargs2 = merge(default_values, kwargs)
kwargs3 = _generate_sesolve_kwargs(e_ops, progress_bar_val, t_l, kwargs2)
kwargs3 = _generate_sesolve_kwargs(e_ops, makeVal(progress_bar), t_l, kwargs2)

dudt! = is_time_dependent ? sesolve_td_dudt! : sesolve_ti_dudt!

Expand Down Expand Up @@ -203,7 +202,7 @@ function sesolve(
e_ops = e_ops,
H_t = H_t,
params = params,
progress_bar = makeVal(progress_bar),
progress_bar = progress_bar,
kwargs...,
)

Expand Down
Loading
Loading