Skip to content

Commit

Permalink
fix incorrect times in time evolution solutions (#244)
Browse files Browse the repository at this point in the history
  • Loading branch information
ytdHuang authored Sep 29, 2024
1 parent 4df91fb commit 905e658
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 38 deletions.
24 changes: 13 additions & 11 deletions src/time_evolution/lr_mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,19 @@ function _periodicsave_func(integrator)
return u_modified!(integrator, false)
end

_save_control_lr_mesolve(u, t, integrator) = t in integrator.p.t_l
_save_control_lr_mesolve(u, t, integrator) = t in integrator.p.times

function _save_affect_lr_mesolve!(integrator)
ip = integrator.p
N, M = ip.N, ip.M
idx = select(integrator.t, ip.t_l)
idx = select(integrator.t, ip.times)

@views z = reshape(integrator.u[1:N*M], N, M)
@views B = reshape(integrator.u[N*M+1:end], M, M)
_calculate_expectation!(ip, z, B, idx)

if integrator.p.opt.progress
print("\rProgress: $(round(Int, 100*idx/length(ip.t_l)))%")
print("\rProgress: $(round(Int, 100*idx/length(ip.times)))%")
flush(stdout)
end
return u_modified!(integrator, false)
Expand Down Expand Up @@ -365,7 +365,7 @@ end
#=======================================================#

@doc raw"""
lr_mesolveProblem(H, z, B, t_l, c_ops; e_ops=(), f_ops=(), opt=LRMesolveOptions(), kwargs...) where T
lr_mesolveProblem(H, z, B, tlist, c_ops; e_ops=(), f_ops=(), opt=LRMesolveOptions(), kwargs...) where T
Formulates the ODEproblem for the low-rank time evolution of the system. The function is called by lr_mesolve.
Parameters
Expand All @@ -376,7 +376,7 @@ end
The initial z matrix.
B : AbstractMatrix{T}
The initial B matrix.
t_l : AbstractVector{T}
tlist : AbstractVector{T}
The time steps at which the expectation values and function values are calculated.
c_ops : AbstractVector{QuantumObject}
The jump operators of the system.
Expand All @@ -393,7 +393,7 @@ function lr_mesolveProblem(
H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
z::AbstractArray{T2,2},
B::AbstractArray{T2,2},
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::AbstractVector = [];
e_ops::Tuple = (),
f_ops::Tuple = (),
Expand All @@ -407,6 +407,8 @@ function lr_mesolveProblem(
c_ops = get_data.(c_ops)
e_ops = get_data.(e_ops)

t_l = convert(Vector{_FType(H)}, tlist)

# Initialization of Arrays
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
funvals = Array{ComplexF64}(undef, length(f_ops), length(t_l))
Expand All @@ -421,7 +423,7 @@ function lr_mesolveProblem(
e_ops = e_ops,
f_ops = f_ops,
opt = opt,
t_l = t_l,
times = t_l,
expvals = expvals,
funvals = funvals,
Ml = Ml,
Expand Down Expand Up @@ -489,14 +491,14 @@ function lr_mesolve(
H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
z::AbstractArray{T2,2},
B::AbstractArray{T2,2},
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::AbstractVector = [];
e_ops::Tuple = (),
f_ops::Tuple = (),
opt::LRMesolveOptions{AlgType} = LRMesolveOptions(),
kwargs...,
) where {T1,T2,AlgType<:OrdinaryDiffEqAlgorithm}
prob = lr_mesolveProblem(H, z, B, t_l, c_ops; e_ops = e_ops, f_ops = f_ops, opt = opt, kwargs...)
prob = lr_mesolveProblem(H, z, B, tlist, c_ops; e_ops = e_ops, f_ops = f_ops, opt = opt, kwargs...)
return lr_mesolve(prob; kwargs...)
end

Expand All @@ -520,7 +522,7 @@ get_B(u::AbstractArray{T}, N::Integer, M::Integer) where {T} = reshape(view(u, (
Additional keyword arguments for the ODEProblem.
"""
function lr_mesolve(prob::ODEProblem; kwargs...)
sol = solve(prob, prob.p.opt.alg, tstops = prob.p.t_l)
sol = solve(prob, prob.p.opt.alg, tstops = prob.p.times)
prob.p.opt.progress && print("\n")

N = prob.p.N
Expand All @@ -535,5 +537,5 @@ function lr_mesolve(prob::ODEProblem; kwargs...)
zt = get_z(sol.u, N, Ml)
end

return LRTimeEvolutionSol(sol.t, zt, Bt, prob.p.expvals, prob.p.funvals, prob.p.Ml)
return LRTimeEvolutionSol(sol.prob.p.times, zt, Bt, prob.p.expvals, prob.p.funvals, prob.p.Ml)
end
11 changes: 3 additions & 8 deletions src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,12 @@ function _mcsolve_output_func(sol, i)
return (sol, false)
end

function _mcsolve_generate_statistics(sol, i, times, states, expvals_all, jump_times, jump_which)
function _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which)
sol_i = sol[:, i]
!isempty(sol_i.prob.kwargs[:saveat]) ?
states[i] = [QuantumObject(normalize!(sol_i.u[i]), dims = sol_i.prob.p.Hdims) for i in 1:length(sol_i.u)] : nothing

copyto!(view(expvals_all, i, :, :), sol_i.prob.p.expvals)
times[i] = sol_i.t
jump_times[i] = sol_i.prob.p.jump_times
return jump_which[i] = sol_i.prob.p.jump_which
end
Expand Down Expand Up @@ -522,22 +521,18 @@ function mcsolve(
_sol_1 = sol[:, 1]

expvals_all = Array{ComplexF64}(undef, length(sol), size(_sol_1.prob.p.expvals)...)
times = Vector{Vector{Float64}}(undef, length(sol))
states =
isempty(_sol_1.prob.kwargs[:saveat]) ? fill(QuantumObject[], length(sol)) :
Vector{Vector{QuantumObject}}(undef, length(sol))
jump_times = Vector{Vector{Float64}}(undef, length(sol))
jump_which = Vector{Vector{Int16}}(undef, length(sol))

foreach(
i -> _mcsolve_generate_statistics(sol, i, times, states, expvals_all, jump_times, jump_which),
eachindex(sol),
)
foreach(i -> _mcsolve_generate_statistics(sol, i, states, expvals_all, jump_times, jump_which), eachindex(sol))
expvals = dropdims(sum(expvals_all, dims = 1), dims = 1) ./ length(sol)

return TimeEvolutionMCSol(
ntraj,
times,
_sol_1.prob.p.times,
states,
expvals,
expvals_all,
Expand Down
3 changes: 2 additions & 1 deletion src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ function mesolveProblem(
e_ops = e_ops2,
expvals = expvals,
H_t = H_t,
times = t_l,
is_empty_e_ops = is_empty_e_ops,
params...,
)
Expand Down Expand Up @@ -253,7 +254,7 @@ function mesolve(prob::ODEProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
ρt = map-> QuantumObject(vec2mat(ϕ), type = Operator, dims = sol.prob.p.Hdims), sol.u)

return TimeEvolutionSol(
sol.t,
sol.prob.p.times,
ρt,
sol.prob.p.expvals,
sol.retcode,
Expand Down
3 changes: 2 additions & 1 deletion src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ function sesolveProblem(
progr = progr,
Hdims = H.dims,
H_t = H_t,
times = t_l,
is_empty_e_ops = is_empty_e_ops,
params...,
)
Expand Down Expand Up @@ -215,7 +216,7 @@ function sesolve(prob::ODEProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
ψt = map-> QuantumObject(ϕ, type = Ket, dims = sol.prob.p.Hdims), sol.u)

return TimeEvolutionSol(
sol.t,
sol.prob.p.times,
ψt,
sol.prob.p.expvals,
sol.retcode,
Expand Down
3 changes: 2 additions & 1 deletion src/time_evolution/ssesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ function ssesolveProblem(
progr = progr,
Hdims = H.dims,
H_t = H_t,
times = t_l,
is_empty_e_ops = is_empty_e_ops,
params...,
)
Expand Down Expand Up @@ -404,7 +405,7 @@ function ssesolve(

return TimeEvolutionSSESol(
ntraj,
_sol_1.t,
_sol_1.prob.p.times,
states,
expvals,
expvals_all,
Expand Down
33 changes: 18 additions & 15 deletions src/time_evolution/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ A structure storing the results and some information from solving quantum trajec
# Fields (Attributes)
- `ntraj::Int`: Number of trajectories
- `times::AbstractVector`: The time list of the evolution in each trajectory.
- `times::AbstractVector`: The time list of the evolution.
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
- `expect::Matrix`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
- `expect_all::Array`: The expectation values corresponding to each trajectory and each time point in `times`
Expand All @@ -63,7 +63,7 @@ A structure storing the results and some information from solving quantum trajec
- `reltol::Real`: The relative tolerance which is used during the solving process.
"""
struct TimeEvolutionMCSol{
TT<:Vector{<:Vector{<:Real}},
TT<:Vector{<:Real},
TS<:AbstractVector,
TE<:Matrix{ComplexF64},
TEA<:Array{ComplexF64,3},
Expand Down Expand Up @@ -97,19 +97,22 @@ function Base.show(io::IO, sol::TimeEvolutionMCSol)
end

@doc raw"""
struct TimeEvolutionSSESol
A structure storing the results and some information from solving trajectories of the Stochastic Shrodinger equation time evolution.
# Fields (Attributes)
- `ntraj::Int`: Number of trajectories
- `times::AbstractVector`: The time list of the evolution in each trajectory.
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
- `expect::Matrix`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
- `expect_all::Array`: The expectation values corresponding to each trajectory and each time point in `times`
- `converged::Bool`: Whether the solution is converged or not.
- `alg`: The algorithm which is used during the solving process.
- `abstol::Real`: The absolute tolerance which is used during the solving process.
- `reltol::Real`: The relative tolerance which is used during the solving process.
"""
struct TimeEvolutionSSESol
A structure storing the results and some information from solving trajectories of the Stochastic Shrodinger equation time evolution.
# Fields (Attributes)
- `ntraj::Int`: Number of trajectories
- `times::AbstractVector`: The time list of the evolution.
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory.
- `expect::Matrix`: The expectation values (averaging all trajectories) corresponding to each time point in `times`.
- `expect_all::Array`: The expectation values corresponding to each trajectory and each time point in `times`
- `converged::Bool`: Whether the solution is converged or not.
- `alg`: The algorithm which is used during the solving process.
- `abstol::Real`: The absolute tolerance which is used during the solving process.
- `reltol::Real`: The relative tolerance which is used during the solving process.
"""
struct TimeEvolutionSSESol{
TT<:Vector{<:Real},
TS<:AbstractVector,
Expand Down
2 changes: 1 addition & 1 deletion src/time_evolution/time_evolution_dynamical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ function dfd_mesolve(
)

return TimeEvolutionSol(
sol.t,
sol.prob.p.times,
ρt,
sol.prob.p.expvals,
sol.retcode,
Expand Down
12 changes: 12 additions & 0 deletions test/core-test/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
sol_string = sprint((t, s) -> show(t, "text/plain", s), sol)
@test sum(abs.(sol.expect[1, :] .- sin.(η * t_l) .^ 2)) / length(t_l) < 0.1
@test length(sol.times) == length(t_l)
@test length(sol.states) == 1
@test size(sol.expect) == (length(e_ops), length(t_l))
@test length(sol2.times) == length(t_l)
@test length(sol2.states) == length(t_l)
@test size(sol2.expect) == (0, length(t_l))
@test length(sol3.times) == length(t_l)
@test length(sol3.states) == length(t_l)
@test size(sol3.expect) == (length(e_ops), length(t_l))
@test sol_string ==
Expand Down Expand Up @@ -68,12 +71,21 @@
@test sum(abs.(sol_mc.expect .- sol_me.expect)) / length(t_l) < 0.1
@test sum(abs.(vec(expect_mc_states_mean) .- vec(sol_me.expect))) / length(t_l) < 0.1
@test sum(abs.(sol_sse.expect .- sol_me.expect)) / length(t_l) < 0.1
@test length(sol_me.times) == length(t_l)
@test length(sol_me.states) == 1
@test size(sol_me.expect) == (length(e_ops), length(t_l))
@test length(sol_me2.times) == length(t_l)
@test length(sol_me2.states) == length(t_l)
@test size(sol_me2.expect) == (0, length(t_l))
@test length(sol_me3.times) == length(t_l)
@test length(sol_me3.states) == length(t_l)
@test size(sol_me3.expect) == (length(e_ops), length(t_l))
@test length(sol_mc.times) == length(t_l)
@test size(sol_mc.expect) == (length(e_ops), length(t_l))
@test length(sol_mc_states.times) == length(t_l)
@test size(sol_mc_states.expect) == (0, length(t_l))
@test length(sol_sse.times) == length(t_l)
@test size(sol_sse.expect) == (length(e_ops), length(t_l))
@test sol_me_string ==
"Solution of time evolution\n" *
"(return code: $(sol_me.retcode))\n" *
Expand Down

0 comments on commit 905e658

Please sign in to comment.