Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type-stability of sesolve #191

Merged
merged 7 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/timeevolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function benchmark_timeevolution!(SUITE)
tlist = range(0, 2π * 10 / g, 1000)

SUITE["Time Evolution"]["time-independent"]["sesolve"] =
@benchmarkable sesolve($H, $ψ0, $tlist, e_ops = $e_ops, progress_bar = false)
@benchmarkable sesolve($H, $ψ0, $tlist, e_ops = $e_ops, progress_bar = Val(false))

## mesolve ##

Expand All @@ -49,7 +49,7 @@ function benchmark_timeevolution!(SUITE)
$c_ops,
n_traj = 100,
e_ops = $e_ops,
progress_bar = false,
progress_bar = Val(false),
ensemble_method = EnsembleSerial(),
)
SUITE["Time Evolution"]["time-independent"]["mcsolve"]["Multithreaded"] = @benchmarkable mcsolve(
Expand All @@ -59,7 +59,7 @@ function benchmark_timeevolution!(SUITE)
$c_ops,
n_traj = 100,
e_ops = $e_ops,
progress_bar = false,
progress_bar = Val(false),
ensemble_method = EnsembleThreads(),
)

Expand Down
111 changes: 56 additions & 55 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,35 @@ function sesolve_td_dudt!(du, u, p, t)
return mul!(du, H_t, u, -1im, 1)
end

function _generate_sesolve_kwargs_with_callback(t_l, kwargs)
cb1 = PresetTimeCallback(t_l, _save_func_sesolve, save_positions = (false, false))
kwargs2 =
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb1),)) :
merge(kwargs, (callback = cb1,))

return kwargs2
end

function _generate_sesolve_kwargs(e_ops, progress_bar::Val{true}, t_l, kwargs)
return _generate_sesolve_kwargs_with_callback(t_l, kwargs)
end

function _generate_sesolve_kwargs(e_ops, progress_bar::Val{false}, t_l, kwargs)
if e_ops isa Nothing
return kwargs
end
return _generate_sesolve_kwargs_with_callback(t_l, kwargs)
end

@doc raw"""
sesolveProblem(H::QuantumObject,
ψ0::QuantumObject,
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5()
e_ops::AbstractVector=[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
progress_bar::Bool=true,
progress_bar::Union{Val,Bool}=Val(true),
kwargs...)

Generates the ODEProblem for the Schrödinger time evolution of a quantum system:
Expand All @@ -46,10 +66,10 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system
- `ψ0::QuantumObject`: The initial state of the system ``|\psi(0)\rangle``.
- `tlist::AbstractVector`: The time list of the evolution.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: The algorithm used for the time evolution.
- `e_ops::AbstractVector`: The list of operators to be evaluated during the evolution.
- `e_ops::Union{Nothing,AbstractVector}`: The list of operators to be evaluated during the evolution.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: The time-dependent Hamiltonian of the system. If `nothing`, the Hamiltonian is time-independent.
- `params::NamedTuple`: The parameters of the system.
- `progress_bar::Bool`: Whether to show the progress bar.
- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
- `kwargs...`: The keyword arguments passed to the `ODEProblem` constructor.

# Notes
Expand All @@ -65,31 +85,39 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system
"""
function sesolveProblem(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
ψ0::QuantumObject{<:AbstractVector{T2},KetQuantumObject},
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{MT2,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
progress_bar::Bool = true,
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {MT1<:AbstractMatrix,T2,MT2<:AbstractMatrix}
) where {MT1<:AbstractMatrix,T2}
H.dims != ψ0.dims && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension."))

haskey(kwargs, :save_idxs) &&
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

is_time_dependent = !(H_t === nothing)
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = collect(tlist)

ϕ0 = get_data(ψ0)
U = -1im * get_data(H)

progr = ProgressBar(length(t_l), enable = progress_bar)
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
e_ops2 = get_data.(e_ops)
is_empty_e_ops = isempty(e_ops)
U = -1im * get_data(H)
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))

if e_ops isa Nothing
expvals = Array{ComplexF64}(undef, 0, length(t_l))
e_ops2 = MT1[]
is_empty_e_ops = true
else
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
e_ops2 = get_data.(e_ops)
is_empty_e_ops = isempty(e_ops)
end

p = (
U = U,
Expand All @@ -102,53 +130,26 @@ function sesolveProblem(
params...,
)

saveat = is_empty_e_ops ? t_l : [t_l[end]]
saveat = e_ops isa Nothing ? t_l : [t_l[end]]
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
kwargs2 = merge(default_values, kwargs)
if !isempty(e_ops) || progress_bar
cb1 = PresetTimeCallback(t_l, _save_func_sesolve, save_positions = (false, false))
kwargs2 =
haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(kwargs2.callback, cb1),)) :
merge(kwargs2, (callback = cb1,))
end
kwargs3 = _generate_sesolve_kwargs(e_ops, progress_bar_val, t_l, kwargs2)

tspan = (t_l[1], t_l[end])
return _sesolveProblem(U, ϕ0, tspan, alg, Val(is_time_dependent), p; kwargs2...)
end
dudt! = is_time_dependent ? sesolve_td_dudt! : sesolve_ti_dudt!

function _sesolveProblem(
U::AbstractMatrix{<:T1},
ϕ0::AbstractVector{<:T2},
tspan::Tuple,
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm,
is_time_dependent::Val{false},
p;
kwargs...,
) where {T1,T2}
return ODEProblem{true,SciMLBase.FullSpecialize}(sesolve_ti_dudt!, ϕ0, tspan, p; kwargs...)
end

function _sesolveProblem(
U::AbstractMatrix{<:T1},
ϕ0::AbstractVector{<:T2},
tspan::Tuple,
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm,
is_time_dependent::Val{true},
p;
kwargs...,
) where {T1,T2}
return ODEProblem{true,SciMLBase.FullSpecialize}(sesolve_td_dudt!, ϕ0, tspan, p; kwargs...)
tspan = (t_l[1], t_l[end])
return ODEProblem{true,SciMLBase.FullSpecialize}(dudt!, ϕ0, tspan, p; kwargs3...)
end

@doc raw"""
sesolve(H::QuantumObject,
ψ0::QuantumObject,
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::AbstractVector=[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
progress_bar::Bool=true,
progress_bar::Union{Val,Bool}=Val(true),
kwargs...)

Time evolution of a closed quantum system using the Schrödinger equation:
Expand All @@ -163,10 +164,10 @@ Time evolution of a closed quantum system using the Schrödinger equation:
- `ψ0::QuantumObject`: The initial state of the system ``|\psi(0)\rangle``.
- `tlist::AbstractVector`: List of times at which to save the state of the system.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: Algorithm to use for the time evolution.
- `e_ops::AbstractVector`: List of operators for which to calculate expectation values.
- `e_ops::Union{Nothing,AbstractVector}`: List of operators for which to calculate expectation values.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
- `progress_bar::Bool`: Whether to show the progress bar.
- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
- `kwargs...`: Additional keyword arguments to pass to the solver.

# Notes
Expand All @@ -182,15 +183,15 @@ Time evolution of a closed quantum system using the Schrödinger equation:
"""
function sesolve(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
ψ0::QuantumObject{<:AbstractVector{T2},KetQuantumObject},
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{MT2,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
progress_bar::Bool = true,
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {MT1<:AbstractMatrix,T2,MT2<:AbstractMatrix}
) where {MT1<:AbstractMatrix,T2}
prob = sesolveProblem(
H,
ψ0,
Expand All @@ -199,7 +200,7 @@ function sesolve(
e_ops = e_ops,
H_t = H_t,
params = params,
progress_bar = progress_bar,
progress_bar = makeVal(progress_bar),
kwargs...,
)

Expand Down
5 changes: 5 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ _get_dense_similar(A::AbstractArray, args...) = similar(A, args...)
_get_dense_similar(A::AbstractSparseMatrix, args...) = similar(nonzeros(A), args...)

_Ginibre_ensemble(n::Int, rank::Int = n) = randn(ComplexF64, n, rank) / sqrt(n)

makeVal(x::Val{T}) where {T} = x
makeVal(x) = Val(x)

getVal(x::Val{T}) where {T} = T
19 changes: 10 additions & 9 deletions test/time_evolution_and_partial_trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
psi0 = kron(fock(N, 0), fock(2, 0))
t_l = LinRange(0, 1000, 1000)
e_ops = [a_d * a]
sol = sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = false)
sol2 = sesolve(H, psi0, t_l, progress_bar = false)
sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = false)
sol = sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false))
sol2 = sesolve(H, psi0, t_l, progress_bar = Val(false))
sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
sol_string = sprint((t, s) -> show(t, "text/plain", s), sol)
@test sum(abs.(sol.expect[1, :] .- sin.(η * t_l) .^ 2)) / length(t_l) < 0.1
@test ptrace(sol.states[end], 1) ≈ ptrace(ket2dm(sol.states[end]), 1)
Expand All @@ -36,9 +36,10 @@

@testset "Type Inference sesolve" begin
if VERSION >= v"1.10"
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = false)
@inferred sesolve(H, psi0, t_l, progress_bar = false)
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = false)
@inferred sesolveProblem(H, psi0, t_l)
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false))
@inferred sesolve(H, psi0, t_l, progress_bar = Val(false))
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
end
end
end
Expand All @@ -52,10 +53,10 @@
e_ops = [a_d * a]
psi0 = basis(N, 3)
t_l = LinRange(0, 100, 1000)
sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, alg = Vern7(), progress_bar = false)
sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = false)
sol_me2 = mesolve(H, psi0, t_l, c_ops, progress_bar = false)
sol_me3 = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = false)
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = false)
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
sol_me_string = sprint((t, s) -> show(t, "text/plain", s), sol_me)
sol_mc_string = sprint((t, s) -> show(t, "text/plain", s), sol_mc)
@test sum(abs.(sol_mc.expect .- sol_me.expect)) / length(t_l) < 0.1
Expand Down Expand Up @@ -121,7 +122,7 @@
psi0 = kron(psi0_1, psi0_2)
t_l = LinRange(0, 20 / γ1, 1000)
sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = false)
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = false)
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = Val(false))
@test sum(abs.(sol_mc.expect[1:2, :] .- sol_me.expect[1:2, :])) / length(t_l) < 0.1
@test expect(sp1 * sm1, sol_me.states[end]) ≈ expect(sigmap() * sigmam(), ptrace(sol_me.states[end], 1))
end
Expand Down