Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting for semiclassical objects #404

Merged
merged 13 commits into from
Aug 11, 2024
112 changes: 97 additions & 15 deletions src/semiclassical.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module semiclassical

using QuantumOpticsBase
import Base: ==
import Base: ==, isapprox, +, -, *, /
import ..timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback,
JumpRNGState, threshold, roll!, as_vector, QO_CHECKS
import LinearAlgebra: normalize, normalize!
Expand Down Expand Up @@ -31,26 +31,108 @@ mutable struct State{B,T,C}
new{B,T,C}(quantum, classical)
end
end

Base.length(state::State) = length(state.quantum) + length(state.classical)
Base.copy(state::State) = State(copy(state.quantum), copy(state.classical))
Base.eltype(state::State) = promote_type(eltype(state.quantum),eltype(state.classical))
normalize!(state::State) = (normalize!(state.quantum); state)
normalize(state::State) = State(normalize(state.quantum),copy(state.classical))

function ==(a::State, b::State)
QuantumOpticsBase.samebases(a.quantum, b.quantum) &&
length(a.classical)==length(b.classical) &&
(a.classical==b.classical) &&
(a.quantum==b.quantum)
end
State{B}(q::T, c::C) where {B,T<:QuantumState{B},C} = State(q,c)

# Standard interfaces
Base.zero(x::State) = State(zero(x.quantum), zero(x.classical))
Base.length(x::State) = length(x.quantum) + length(x.classical)
Base.axes(x::State) = (Base.OneTo(length(x)),)
Base.size(x::State) = size(x.quantum)
Base.ndims(x::State) = ndims(x.quantum)
Base.ndims(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = ndims(T)
Base.copy(x::State) = State(copy(x.quantum), copy(x.classical))
Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x)
Base.fill!(x::State, a) = (fill!(x.quantum, a), fill!(x.classical, a))
Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical))
Base.eltype(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = promote_type(eltype(T), eltype(C))
Base.similar(x::State, ::Type{T} = eltype(x)) where {T} = State(similar(x.quantum, T), similar(x.classical, T))
Base.getindex(x::State, idx) = idx <= length(x.quantum) ? getindex(x.quantum, idx) : getindex(x.classical, idx-length(x.quantum))
Base.setindex!(x::State, v, idx) = idx <= length(x.quantum) ? setindex(x.quantum, v, idx) : setindex(x.classical, v, idx-length(x.quantum))

normalize!(x::State) = (normalize!(x.quantum); x)
normalize(x::State) = State(normalize(x.quantum),copy(x.classical))
LinearAlgebra.norm(x::State) = LinearAlgebra.norm(x.quantum)
LinearAlgebra.norm(x::State, p::Int64) = LinearAlgebra.norm(x.quantum, p)

==(x::State{B}, y::State{B}) where {B} = (x.classical==y.classical) && (x.quantum==y.quantum)
==(x::State, y::State) = false

isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum) && isapprox(x.classical,y.classical)
apkille marked this conversation as resolved.
Show resolved Hide resolved
isapprox(x::State, y::State; kwargs...) = false

QuantumOpticsBase.expect(op, state::State) = expect(op, state.quantum)
QuantumOpticsBase.variance(op, state::State) = variance(op, state.quantum)
QuantumOpticsBase.ptrace(state::State, indices) = State(ptrace(state.quantum, indices), state.classical)

QuantumOpticsBase.dm(x::State) = State(dm(x.quantum), x.classical)

Base.broadcastable(x::State) = x

# Custom broadcasting style
struct StateStyle{B} <: Broadcast.BroadcastStyle end

# Style precedence rules
Broadcast.BroadcastStyle(::Type{<:State{B}}) where {B} = StateStyle{B}()
Broadcast.BroadcastStyle(::StateStyle{B1}, ::StateStyle{B2}) where {B1,B2} = throw(IncompatibleBases())
Broadcast.BroadcastStyle(::StateStyle{B}, ::Broadcast.DefaultArrayStyle{0}) where {B} = StateStyle{B}()
Broadcast.BroadcastStyle(::Broadcast.DefaultArrayStyle{0}, ::StateStyle{B}) where {B} = StateStyle{B}()

# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
# extract quantum object from broadcast container
qobj = find_quantum(bcf)
data_q = zeros(eltype(qobj), size(qobj)...)
Nq = length(qobj)
# allocate quantum data from broadcast container
@inbounds @simd for I in 1:Nq
data_q[I] = bcf[I]
end
# extract classical object from broadcast container
cobj = find_classical(bcf)
data_c = zeros(eltype(cobj), length(cobj))
Nc = length(cobj)
# allocate classical data from broadcast container
@inbounds @simd for I in 1:Nc
data_c[I] = bcf[I+Nq]
end
type = eval(nameof(typeof(qobj)))
return State{B}(type(basis(qobj), data_q), data_c)
end

for f ∈ [:find_quantum, :find_classical]
@eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args)
@eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args))
@eval ($f)(x) = x
@eval ($f)(::Any, rest) = ($f)(rest)
end
find_quantum(x::State, rest) = x.quantum
find_classical(x::State, rest) = x.classical

# In-place broadcasting
@inline function Base.copyto!(dest::State{B}, bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple}
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
bc′ = Base.Broadcast.preprocess(dest, bc)
# write broadcasted quantum data to dest
qobj = dest.quantum
@inbounds @simd for I in 1:length(qobj)
qobj.data[I] = bc′[I]
end
# write broadcasted classical data to dest
cobj = dest.classical
@inbounds @simd for I in 1:length(cobj)
cobj[I] = bc′[I+length(qobj)]
end
return dest
end
@inline Base.copyto!(dest::State{B1}, bc::Broadcast.Broadcasted{<:StateStyle{B2},Axes,F,Args}) where {B1,B2,Axes,F,Args<:Tuple} =
throw(IncompatibleBases())

Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = Base.getindex(x, i)
using RecursiveArrayTools
apkille marked this conversation as resolved.
Show resolved Hide resolved
RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x)
RecursiveArrayTools.recursivecopy!(dest::State, src::State) = copyto!(dest, src)
RecursiveArrayTools.recursivecopy(x::State) = copy(x)
RecursiveArrayTools.recursivefill!(x::State, a) = fill!(x, a)

"""
semiclassical.schroedinger_dynamic(tspan, state0, fquantum, fclassical[; fout, ...])
Expand Down
14 changes: 14 additions & 0 deletions test/test_semiclassical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,18 @@ after_jump = findlast(t-> !(t∈T), tout4)
@test ψt4[before_jump].quantum == ψ0.quantum
@test ψt4[after_jump].quantum == spindown(ba)⊗fockstate(bf,0)

# Test broadcasting interface
b = FockBasis(10)
u0 = ComplexF64[0.7, 0.2]
psi = fockstate(b, 2)
rho = dm(psi)

sc_ket = semiclassical.State(psi, u0)
sc_dm = semiclassical.State(rho, u0)

@test sc_ket .* 1.0 == sc_ket
@test sc_dm .* 1.0 == sc_dm
@test sc_ket .+ 2.0 == semiclassical.State(psi .+ 2.0, u0 .+ 2.0)
@test sc_dm .+ 2.0 == semiclassical.State(rho .+ 2.0, u0 .+ 2.0)

end # testsets
Loading