diff --git a/Project.toml b/Project.toml index c6080e37..a20c6559 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ RecursiveArrayTools = "2, 3" Reexport = "0.2, 1.0" StochasticDiffEq = "6" WignerSymbols = "1, 2" -julia = "1.3" +julia = "1.10" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/QuantumOptics.jl b/src/QuantumOptics.jl index ed9e5b37..30e7d23e 100644 --- a/src/QuantumOptics.jl +++ b/src/QuantumOptics.jl @@ -3,6 +3,7 @@ module QuantumOptics using Reexport @reexport using QuantumOpticsBase using SparseArrays, LinearAlgebra +import RecursiveArrayTools export ylm, diff --git a/src/semiclassical.jl b/src/semiclassical.jl index 55505df2..d70dcbad 100644 --- a/src/semiclassical.jl +++ b/src/semiclassical.jl @@ -1,10 +1,12 @@ module semiclassical using QuantumOpticsBase -import Base: == +import QuantumOpticsBase: IncompatibleBases +import Base: ==, isapprox, +, -, *, / import ..timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback, JumpRNGState, threshold, roll!, as_vector, QO_CHECKS import LinearAlgebra: normalize, normalize! +import RecursiveArrayTools using Random, LinearAlgebra import OrdinaryDiffEq @@ -31,26 +33,104 @@ mutable struct State{B,T,C} new{B,T,C}(quantum, classical) end end - -Base.length(state::State) = length(state.quantum) + length(state.classical) -Base.copy(state::State) = State(copy(state.quantum), copy(state.classical)) -Base.eltype(state::State) = promote_type(eltype(state.quantum),eltype(state.classical)) -normalize!(state::State) = (normalize!(state.quantum); state) -normalize(state::State) = State(normalize(state.quantum),copy(state.classical)) - -function ==(a::State, b::State) - QuantumOpticsBase.samebases(a.quantum, b.quantum) && - length(a.classical)==length(b.classical) && - (a.classical==b.classical) && - (a.quantum==b.quantum) -end +State{B}(q::T, c::C) where {B,T<:QuantumState{B},C} = State(q,c) + +# Standard interfaces +Base.zero(x::State) = State(zero(x.quantum), zero(x.classical)) +Base.length(x::State) = length(x.quantum) + length(x.classical) +Base.axes(x::State) = (Base.OneTo(length(x)),) +Base.size(x::State) = size(x.quantum) +Base.ndims(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = ndims(T) +Base.copy(x::State) = State(copy(x.quantum), copy(x.classical)) +Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x) +Base.fill!(x::State, a) = (fill!(x.quantum, a), fill!(x.classical, a)) +Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical)) +Base.eltype(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = promote_type(eltype(T), eltype(C)) +Base.similar(x::State, ::Type{T} = eltype(x)) where {T} = State(similar(x.quantum, T), similar(x.classical, T)) +Base.getindex(x::State, idx) = idx <= length(x.quantum) ? getindex(x.quantum, idx) : getindex(x.classical, idx-length(x.quantum)) + +normalize!(x::State) = (normalize!(x.quantum); x) +normalize(x::State) = State(normalize(x.quantum),copy(x.classical)) +LinearAlgebra.norm(x::State) = LinearAlgebra.norm(x.quantum) + +==(x::State{B}, y::State{B}) where {B} = (x.classical==y.classical) && (x.quantum==y.quantum) +==(x::State, y::State) = false + +isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum; kwargs...) && isapprox(x.classical,y.classical; kwargs...) +isapprox(x::State, y::State; kwargs...) = false QuantumOpticsBase.expect(op, state::State) = expect(op, state.quantum) QuantumOpticsBase.variance(op, state::State) = variance(op, state.quantum) QuantumOpticsBase.ptrace(state::State, indices) = State(ptrace(state.quantum, indices), state.classical) - QuantumOpticsBase.dm(x::State) = State(dm(x.quantum), x.classical) +Base.broadcastable(x::State) = x + +# Custom broadcasting style +struct StateStyle{B} <: Broadcast.BroadcastStyle end + +# Style precedence rules +Broadcast.BroadcastStyle(::Type{<:State{B}}) where {B} = StateStyle{B}() +Broadcast.BroadcastStyle(::StateStyle{B1}, ::StateStyle{B2}) where {B1,B2} = throw(IncompatibleBases()) +Broadcast.BroadcastStyle(::StateStyle{B}, ::Broadcast.DefaultArrayStyle{0}) where {B} = StateStyle{B}() +Broadcast.BroadcastStyle(::Broadcast.DefaultArrayStyle{0}, ::StateStyle{B}) where {B} = StateStyle{B}() + +# Out-of-place broadcasting +@inline function Base.copy(bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple} + bcf = Broadcast.flatten(bc) + # extract quantum object from broadcast container + qobj = find_quantum(bcf) + data_q = zeros(eltype(qobj), size(qobj)...) + Nq = length(qobj) + # allocate quantum data from broadcast container + @inbounds @simd for I in 1:Nq + data_q[I] = bcf[I] + end + # extract classical object from broadcast container + cobj = find_classical(bcf) + data_c = zeros(eltype(cobj), length(cobj)) + Nc = length(cobj) + # allocate classical data from broadcast container + @inbounds @simd for I in 1:Nc + data_c[I] = bcf[I+Nq] + end + type = eval(nameof(typeof(qobj))) + return State{B}(type(basis(qobj), data_q), data_c) +end + +for f ∈ [:find_quantum, :find_classical] + @eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args) + @eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args)) + @eval ($f)(x) = x + @eval ($f)(::Any, rest) = ($f)(rest) +end +find_quantum(x::State, rest) = x.quantum +find_classical(x::State, rest) = x.classical + +# In-place broadcasting +@inline function Base.copyto!(dest::State{B}, bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple} + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) + bc′ = Base.Broadcast.preprocess(dest, bc) + # write broadcasted quantum data to dest + qobj = dest.quantum + @inbounds @simd for I in 1:length(qobj) + qobj.data[I] = bc′[I] + end + # write broadcasted classical data to dest + cobj = dest.classical + @inbounds @simd for I in 1:length(cobj) + cobj[I] = bc′[I+length(qobj)] + end + return dest +end +@inline Base.copyto!(dest::State{B1}, bc::Broadcast.Broadcasted{<:StateStyle{B2},Axes,F,Args}) where {B1,B2,Axes,F,Args<:Tuple} = + throw(IncompatibleBases()) + +Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = Base.getindex(x, i) +RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x) +RecursiveArrayTools.recursivecopy!(dest::State, src::State) = copyto!(dest, src) +RecursiveArrayTools.recursivecopy(x::State) = copy(x) +RecursiveArrayTools.recursivefill!(x::State, a) = fill!(x, a) """ semiclassical.schroedinger_dynamic(tspan, state0, fquantum, fclassical[; fout, ...]) diff --git a/test/runtests.jl b/test/runtests.jl index a5f92631..487c03fd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,7 @@ names = [ "test_timeevolution_abstractdata.jl", + "test_sciml_broadcast_interfaces.jl", "test_ForwardDiff.jl" ] diff --git a/test/test_sciml_broadcast_interfaces.jl b/test/test_sciml_broadcast_interfaces.jl new file mode 100644 index 00000000..6f454161 --- /dev/null +++ b/test/test_sciml_broadcast_interfaces.jl @@ -0,0 +1,25 @@ +using Test +using QuantumOptics +using OrdinaryDiffEq + +@testset "sciml interface" begin + +# semiclassical ODE problem +b = SpinBasis(1//2) +psi0 = spindown(b) +u0 = ComplexF64[0.5, 0.75] +sc = semiclassical.State(psi0, u0) +t₀, t₁ = (0.0, pi) +σx = sigmax(b) + +fquantum(t, q, u) = σx + cos(u[1])*identityoperator(σx) +fclassical!(du, u, q, t) = (du[1] = sin(u[2]); du[2] = 2*u[1]) +f!(dstate, state, p, t) = semiclassical.dschroedinger_dynamic!(dstate, fquantum, fclassical!, state, t) +prob = ODEProblem(f!, sc, (t₀, t₁)) + +sol = solve(prob, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false) +tout, ψt = semiclassical.schroedinger_dynamic([t₀, t₁], sc, fquantum, fclassical!; reltol = 1.0e-8, abstol = 1.0e-10) + +@test sol[end] ≈ ψt[end] + +end \ No newline at end of file diff --git a/test/test_semiclassical.jl b/test/test_semiclassical.jl index 72198612..fdc2c3c1 100644 --- a/test/test_semiclassical.jl +++ b/test/test_semiclassical.jl @@ -1,6 +1,7 @@ using Test using QuantumOptics using LinearAlgebra +using QuantumOpticsBase: IncompatibleBases @testset "semiclassical" begin @@ -175,4 +176,28 @@ after_jump = findlast(t-> !(t∈T), tout4) @test ψt4[before_jump].quantum == ψ0.quantum @test ψt4[after_jump].quantum == spindown(ba)⊗fockstate(bf,0) +# Test broadcasting interface +b = FockBasis(10) +bn = FockBasis(20) +u0 = ComplexF64[0.7, 0.2] +psi = fockstate(b, 2) +psin = fockstate(bn, 2) +rho = dm(psi) + +sc_ket = semiclassical.State(psi, u0) +sc_ketn = semiclassical.State(psin, u0) +sc_dm = semiclassical.State(rho, u0) + +@test Base.size(sc_dm) == Base.size(rho) +@test (copy_sc = copy(sc_ket); Base.fill!(copy_sc, 0.0); copy_sc) == semiclassical.State(fill!(copy(psi), 0.0), fill!(copy(u0), 0.0)) +@test Base.similar(sc_ket, Int) isa semiclassical.State +@test normalize!(copy(sc_ket)) == semiclassical.State(normalize!(copy(psi)), u0) +@test !(sc_ket == sc_ketn) +@test !(isapprox(sc_ket, sc_ketn)) +@test sc_ket .* 1.0 == sc_ket +@test sc_dm .* 1.0 == sc_dm +@test sc_ket .+ 2.0 == semiclassical.State(psi .+ 2.0, u0 .+ 2.0) +@test sc_dm .+ 2.0 == semiclassical.State(rho .+ 2.0, u0 .+ 2.0) +@test_throws IncompatibleBases sc_ket .+ semiclassical.State(spinup(SpinBasis(10)), u0) + end # testsets