Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting for semiclassical objects #404

Merged
merged 13 commits into from
Aug 11, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ RecursiveArrayTools = "2, 3"
Reexport = "0.2, 1.0"
StochasticDiffEq = "6"
WignerSymbols = "1, 2"
julia = "1.3"
julia = "1.10"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
1 change: 1 addition & 0 deletions src/QuantumOptics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module QuantumOptics
using Reexport
@reexport using QuantumOpticsBase
using SparseArrays, LinearAlgebra
import RecursiveArrayTools

export
ylm,
Expand Down
112 changes: 97 additions & 15 deletions src/semiclassical.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module semiclassical

using QuantumOpticsBase
import Base: ==
import Base: ==, isapprox, +, -, *, /
import ..timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback,
JumpRNGState, threshold, roll!, as_vector, QO_CHECKS
import LinearAlgebra: normalize, normalize!
import RecursiveArrayTools

using Random, LinearAlgebra
import OrdinaryDiffEq
Expand All @@ -31,26 +32,107 @@
new{B,T,C}(quantum, classical)
end
end

Base.length(state::State) = length(state.quantum) + length(state.classical)
Base.copy(state::State) = State(copy(state.quantum), copy(state.classical))
Base.eltype(state::State) = promote_type(eltype(state.quantum),eltype(state.classical))
normalize!(state::State) = (normalize!(state.quantum); state)
normalize(state::State) = State(normalize(state.quantum),copy(state.classical))

function ==(a::State, b::State)
QuantumOpticsBase.samebases(a.quantum, b.quantum) &&
length(a.classical)==length(b.classical) &&
(a.classical==b.classical) &&
(a.quantum==b.quantum)
end
State{B}(q::T, c::C) where {B,T<:QuantumState{B},C} = State(q,c)

# Standard interfaces
Base.zero(x::State) = State(zero(x.quantum), zero(x.classical))
Base.length(x::State) = length(x.quantum) + length(x.classical)
Base.axes(x::State) = (Base.OneTo(length(x)),)
Base.size(x::State) = size(x.quantum)
Base.ndims(x::State) = ndims(x.quantum)

Check warning on line 42 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
Base.ndims(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = ndims(T)
Base.copy(x::State) = State(copy(x.quantum), copy(x.classical))
Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x)
Base.fill!(x::State, a) = (fill!(x.quantum, a), fill!(x.classical, a))

Check warning on line 46 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L46

Added line #L46 was not covered by tests
Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical))
Base.eltype(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = promote_type(eltype(T), eltype(C))
Base.similar(x::State, ::Type{T} = eltype(x)) where {T} = State(similar(x.quantum, T), similar(x.classical, T))

Check warning on line 49 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L49

Added line #L49 was not covered by tests
Base.getindex(x::State, idx) = idx <= length(x.quantum) ? getindex(x.quantum, idx) : getindex(x.classical, idx-length(x.quantum))
Base.setindex!(x::State, v, idx) = idx <= length(x.quantum) ? setindex(x.quantum, v, idx) : setindex(x.classical, v, idx-length(x.quantum))

Check warning on line 51 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L51

Added line #L51 was not covered by tests

normalize!(x::State) = (normalize!(x.quantum); x)

Check warning on line 53 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L53

Added line #L53 was not covered by tests
normalize(x::State) = State(normalize(x.quantum),copy(x.classical))
LinearAlgebra.norm(x::State) = LinearAlgebra.norm(x.quantum)
LinearAlgebra.norm(x::State, p::Int64) = LinearAlgebra.norm(x.quantum, p)

Check warning on line 56 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L56

Added line #L56 was not covered by tests

==(x::State{B}, y::State{B}) where {B} = (x.classical==y.classical) && (x.quantum==y.quantum)
==(x::State, y::State) = false

Check warning on line 59 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L59

Added line #L59 was not covered by tests

isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum) && isapprox(x.classical,y.classical)
apkille marked this conversation as resolved.
Show resolved Hide resolved
isapprox(x::State, y::State; kwargs...) = false

Check warning on line 62 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L62

Added line #L62 was not covered by tests

QuantumOpticsBase.expect(op, state::State) = expect(op, state.quantum)
QuantumOpticsBase.variance(op, state::State) = variance(op, state.quantum)
QuantumOpticsBase.ptrace(state::State, indices) = State(ptrace(state.quantum, indices), state.classical)

QuantumOpticsBase.dm(x::State) = State(dm(x.quantum), x.classical)

Base.broadcastable(x::State) = x

# Custom broadcasting style
struct StateStyle{B} <: Broadcast.BroadcastStyle end

# Style precedence rules
Broadcast.BroadcastStyle(::Type{<:State{B}}) where {B} = StateStyle{B}()
Broadcast.BroadcastStyle(::StateStyle{B1}, ::StateStyle{B2}) where {B1,B2} = throw(IncompatibleBases())

Check warning on line 76 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L76

Added line #L76 was not covered by tests
Broadcast.BroadcastStyle(::StateStyle{B}, ::Broadcast.DefaultArrayStyle{0}) where {B} = StateStyle{B}()
Broadcast.BroadcastStyle(::Broadcast.DefaultArrayStyle{0}, ::StateStyle{B}) where {B} = StateStyle{B}()

# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
# extract quantum object from broadcast container
qobj = find_quantum(bcf)
data_q = zeros(eltype(qobj), size(qobj)...)
Nq = length(qobj)
# allocate quantum data from broadcast container
@inbounds @simd for I in 1:Nq
data_q[I] = bcf[I]
end

Check warning on line 90 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L90

Added line #L90 was not covered by tests
# extract classical object from broadcast container
cobj = find_classical(bcf)
data_c = zeros(eltype(cobj), length(cobj))
Nc = length(cobj)
# allocate classical data from broadcast container
@inbounds @simd for I in 1:Nc
data_c[I] = bcf[I+Nq]
end

Check warning on line 98 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L98

Added line #L98 was not covered by tests
type = eval(nameof(typeof(qobj)))
return State{B}(type(basis(qobj), data_q), data_c)
end

for f ∈ [:find_quantum, :find_classical]
@eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args)
@eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args))
@eval ($f)(x) = x
@eval ($f)(::Any, rest) = ($f)(rest)

Check warning on line 107 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L107

Added line #L107 was not covered by tests
end
find_quantum(x::State, rest) = x.quantum
find_classical(x::State, rest) = x.classical

# In-place broadcasting
@inline function Base.copyto!(dest::State{B}, bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple}
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
bc′ = Base.Broadcast.preprocess(dest, bc)
# write broadcasted quantum data to dest
qobj = dest.quantum
@inbounds @simd for I in 1:length(qobj)
qobj.data[I] = bc′[I]
end

Check warning on line 120 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L120

Added line #L120 was not covered by tests
# write broadcasted classical data to dest
cobj = dest.classical
@inbounds @simd for I in 1:length(cobj)
cobj[I] = bc′[I+length(qobj)]
end

Check warning on line 125 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L125

Added line #L125 was not covered by tests
return dest
end
@inline Base.copyto!(dest::State{B1}, bc::Broadcast.Broadcasted{<:StateStyle{B2},Axes,F,Args}) where {B1,B2,Axes,F,Args<:Tuple} =

Check warning on line 128 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L128

Added line #L128 was not covered by tests
throw(IncompatibleBases())

Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = Base.getindex(x, i)
RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x)
RecursiveArrayTools.recursivecopy!(dest::State, src::State) = copyto!(dest, src)
RecursiveArrayTools.recursivecopy(x::State) = copy(x)
RecursiveArrayTools.recursivefill!(x::State, a) = fill!(x, a)

Check warning on line 135 in src/semiclassical.jl

View check run for this annotation

Codecov / codecov/patch

src/semiclassical.jl#L135

Added line #L135 was not covered by tests

"""
semiclassical.schroedinger_dynamic(tspan, state0, fquantum, fclassical[; fout, ...])
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ names = [

"test_timeevolution_abstractdata.jl",

"test_sciml_broadcast_interfaces.jl",
"test_ForwardDiff.jl"
]

Expand Down
25 changes: 25 additions & 0 deletions test/test_sciml_broadcast_interfaces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using Test
using QuantumOptics
using OrdinaryDiffEq

@testset "sciml interface" begin

# semiclassical ODE problem
b = SpinBasis(1//2)
psi0 = spindown(b)
u0 = ComplexF64[0.5, 0.75]
sc = semiclassical.State(psi0, u0)
t₀, t₁ = (0.0, pi)
σx = sigmax(b)

fquantum(t, q, u) = σx + cos(u[1])*identityoperator(σx)
fclassical!(du, u, q, t) = (du[1] = sin(u[2]); du[2] = 2*u[1])
f!(dstate, state, p, t) = semiclassical.dschroedinger_dynamic!(dstate, fquantum, fclassical!, state, t)
prob = ODEProblem(f!, sc, (t₀, t₁))

sol = solve(prob, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false)
tout, ψt = semiclassical.schroedinger_dynamic([t₀, t₁], sc, fquantum, fclassical!; reltol = 1.0e-8, abstol = 1.0e-10)

@test sol[end] ≈ ψt[end]

end
14 changes: 14 additions & 0 deletions test/test_semiclassical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,18 @@ after_jump = findlast(t-> !(t∈T), tout4)
@test ψt4[before_jump].quantum == ψ0.quantum
@test ψt4[after_jump].quantum == spindown(ba)⊗fockstate(bf,0)

# Test broadcasting interface
b = FockBasis(10)
u0 = ComplexF64[0.7, 0.2]
psi = fockstate(b, 2)
rho = dm(psi)

sc_ket = semiclassical.State(psi, u0)
sc_dm = semiclassical.State(rho, u0)

@test sc_ket .* 1.0 == sc_ket
@test sc_dm .* 1.0 == sc_dm
@test sc_ket .+ 2.0 == semiclassical.State(psi .+ 2.0, u0 .+ 2.0)
@test sc_dm .+ 2.0 == semiclassical.State(rho .+ 2.0, u0 .+ 2.0)

end # testsets