From a4aac798ac433e000db80f890168d322a08b4cab Mon Sep 17 00:00:00 2001 From: apkille Date: Fri, 26 Jul 2024 19:37:21 -0400 Subject: [PATCH 01/12] broadcast interface --- src/semiclassical.jl | 91 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 77 insertions(+), 14 deletions(-) diff --git a/src/semiclassical.jl b/src/semiclassical.jl index 55505df2..847fcb7c 100644 --- a/src/semiclassical.jl +++ b/src/semiclassical.jl @@ -1,7 +1,7 @@ module semiclassical using QuantumOpticsBase -import Base: == +import Base: ==, isapprox, +, -, *, / import ..timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback, JumpRNGState, threshold, roll!, as_vector, QO_CHECKS import LinearAlgebra: normalize, normalize! @@ -31,19 +31,36 @@ mutable struct State{B,T,C} new{B,T,C}(quantum, classical) end end - -Base.length(state::State) = length(state.quantum) + length(state.classical) -Base.copy(state::State) = State(copy(state.quantum), copy(state.classical)) -Base.eltype(state::State) = promote_type(eltype(state.quantum),eltype(state.classical)) -normalize!(state::State) = (normalize!(state.quantum); state) -normalize(state::State) = State(normalize(state.quantum),copy(state.classical)) - -function ==(a::State, b::State) - QuantumOpticsBase.samebases(a.quantum, b.quantum) && - length(a.classical)==length(b.classical) && - (a.classical==b.classical) && - (a.quantum==b.quantum) -end +State{B}(q::T, c::C) where {B,T<:QuantumState{B},C} = State(q,c) + +Base.zero(x::State) = State(zero(x.quantum), zero(x.classical)) +Base.real(x::State) = State(real.(x.quantum), real(x.classical)) +Base.oneunit(x::State) = State(one.(x.quantum), one.(x.classical)) +Base.length(x::State) = length(x.quantum) + length(x.classical) +Base.size(x::State) = size(x.quantum) +Base.ndims(x::State) = ndims(x.quantum) +Base.ndims(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = ndims(T) +Base.copy(x::State) = State(copy(x.quantum), copy(x.classical)) +Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x) +Base.fill!(x::State, a) = (fill!(x.quantum, a), fill!(x.classical, a)) +Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical)) +normalize!(x::State) = (normalize!(x.quantum); x) +normalize(x::State) = State(normalize(x.quantum),copy(x.classical)) +LinearAlgebra.norm(x::State) = LinearAlgebra.norm(x.quantum) +LinearAlgebra.norm(x::State, p::Int64) = LinearAlgebra.norm(x.quantum, p) + +==(x::State{B}, y::State{B}) where {B} = (x.classical==y.classical) && (x.quantum==y.quantum) +==(x::State, y::State) = false + ++(x::State, y::State) = State(x.quantum+y.quantum, x.classical+y.classical) +-(x::State, y::State) = State(x.quantum-y.quantum, x.classical-y.classical) +*(x::Number, y::State) = State(x*y.quantum, x*y.classical) +*(x::State, y::Number) = y*x +/(x::State, y::State) = State(x.quantum ./ y.quantum, x.classical ./ y.classical) +/(x::State, y::Number) = State(x.quantum/y, x.classical/y) + +isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum) && isapprox(x.classical,y.classical) +isapprox(x::State, y::State; kwargs...) = false QuantumOpticsBase.expect(op, state::State) = expect(op, state.quantum) QuantumOpticsBase.variance(op, state::State) = variance(op, state.quantum) @@ -51,6 +68,52 @@ QuantumOpticsBase.ptrace(state::State, indices) = State(ptrace(state.quantum, in QuantumOpticsBase.dm(x::State) = State(dm(x.quantum), x.classical) +Base.broadcastable(x::State) = Ref(x) + +# Custom broadcasting style +struct StateStyle{B} <: Broadcast.BroadcastStyle end + +# Style precedence rules +Broadcast.BroadcastStyle(::Type{<:State{B}}) where {B} = StateStyle{B}() +Broadcast.BroadcastStyle(::StateStyle{B1}, ::StateStyle{B2}) where {B1,B2} = throw(IncompatibleBases()) + +# Broadcast with scalars +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B,T<:StateStyle{B}} = T() + +# Out-of-place broadcasting +@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:StateStyle{B},Axes,F,Args<:Tuple} + bcf = Broadcast.flatten(bc) + q, c = find_quantum(bcf), find_classical(bcf) + return State{B}(copy(q), copy(c)) +end + +for f ∈ [:find_quantum, :find_classical] + @eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args) + @eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args)) + @eval ($f)(x) = x + @eval ($f)(::Any, rest) = ($f)(rest) +end +find_basis(x::State, rest) = QuantumOpticsBase.find_basis(x.quantum) +find_quantum(x::State, rest) = x.quantum +find_classical(x::State, rest) = x.classical +@inline Base.getindex(x::State, idx) = getindex([vec(x.quantum); x.classical], idx) +Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = [vec(x.quantum); x.classical][i] + +# In-place broadcasting +@inline function Base.copyto!(dest::State{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:StateStyle{B},Axes,F,Args} + bc′ = Base.Broadcast.preprocess(dest, bc) + q, c = find_quantum(bc), find_classical(bc) + copyto!(dest.quantum, q) + copyto!(dest.classical, c) + return dest +end +@inline Base.copyto!(dest::State{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:StateStyle{B2},Axes,F,Args} = + throw(IncompatibleBases()) +@inline Base.copyto!(dest::State, bc::Broadcast.Broadcasted) = print(bc) + +Broadcast.similar(x::State, t) = State(similar(x.quantum), similar(x.classical)) +using RecursiveArrayTools +RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x) """ semiclassical.schroedinger_dynamic(tspan, state0, fquantum, fclassical[; fout, ...]) From a745b16282edcd6994ae94b2e7434bc88cf262b9 Mon Sep 17 00:00:00 2001 From: apkille Date: Sun, 28 Jul 2024 08:23:38 -0400 Subject: [PATCH 02/12] more interface --- src/semiclassical.jl | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/semiclassical.jl b/src/semiclassical.jl index 847fcb7c..4167c91c 100644 --- a/src/semiclassical.jl +++ b/src/semiclassical.jl @@ -34,16 +34,21 @@ end State{B}(q::T, c::C) where {B,T<:QuantumState{B},C} = State(q,c) Base.zero(x::State) = State(zero(x.quantum), zero(x.classical)) -Base.real(x::State) = State(real.(x.quantum), real(x.classical)) Base.oneunit(x::State) = State(one.(x.quantum), one.(x.classical)) Base.length(x::State) = length(x.quantum) + length(x.classical) Base.size(x::State) = size(x.quantum) Base.ndims(x::State) = ndims(x.quantum) +Base.axes(x::State) = axes(x.quantum) Base.ndims(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = ndims(T) Base.copy(x::State) = State(copy(x.quantum), copy(x.classical)) Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x) Base.fill!(x::State, a) = (fill!(x.quantum, a), fill!(x.classical, a)) Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical)) +Base.similar(x::State) = State(similar(x.quantum), similar(x.classical)) +Base.getindex(x::State, idx) = idx <= length(x.quantum) ? getindex(x.quantum, idx) : getindex(x.classical, idx-length(x.quantum)) +Base.setindex!(x::State, v, idx) = idx <= length(x.quantum) ? setindex(x.quantum, v, idx) : setindex(x.classical, v, idx-length(x.quantum)) +Base.firstindex(x::State) = 1 +Base.lastindex(x::State) = length(x) normalize!(x::State) = (normalize!(x.quantum); x) normalize(x::State) = State(normalize(x.quantum),copy(x.classical)) LinearAlgebra.norm(x::State) = LinearAlgebra.norm(x.quantum) @@ -56,8 +61,8 @@ LinearAlgebra.norm(x::State, p::Int64) = LinearAlgebra.norm(x.quantum, p) -(x::State, y::State) = State(x.quantum-y.quantum, x.classical-y.classical) *(x::Number, y::State) = State(x*y.quantum, x*y.classical) *(x::State, y::Number) = y*x -/(x::State, y::State) = State(x.quantum ./ y.quantum, x.classical ./ y.classical) /(x::State, y::Number) = State(x.quantum/y, x.classical/y) +/(x::State, y::State) = State(x.quantum ./ y.quantum, x.classical ./ y.classical) isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum) && isapprox(x.classical,y.classical) isapprox(x::State, y::State; kwargs...) = false @@ -83,7 +88,8 @@ Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B,T<:Stat # Out-of-place broadcasting @inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:StateStyle{B},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) - q, c = find_quantum(bcf), find_classical(bcf) + q = find_quantum(bcf) + c = find_classical(bcf) return State{B}(copy(q), copy(c)) end @@ -96,22 +102,22 @@ end find_basis(x::State, rest) = QuantumOpticsBase.find_basis(x.quantum) find_quantum(x::State, rest) = x.quantum find_classical(x::State, rest) = x.classical -@inline Base.getindex(x::State, idx) = getindex([vec(x.quantum); x.classical], idx) -Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = [vec(x.quantum); x.classical][i] # In-place broadcasting @inline function Base.copyto!(dest::State{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:StateStyle{B},Axes,F,Args} bc′ = Base.Broadcast.preprocess(dest, bc) - q, c = find_quantum(bc), find_classical(bc) + q, c = find_quantum(bc′), find_classical(bc′) copyto!(dest.quantum, q) copyto!(dest.classical, c) return dest end @inline Base.copyto!(dest::State{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:StateStyle{B2},Axes,F,Args} = throw(IncompatibleBases()) + @inline Base.copyto!(dest::State, bc::Broadcast.Broadcasted) = print(bc) -Broadcast.similar(x::State, t) = State(similar(x.quantum), similar(x.classical)) +Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = Base.getindex(x, i) +Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(bc::Broadcast.Broadcasted{Style,Axes,F,Args}, i) where {B,Style<:StateStyle{B},Axes,F,Args} = Base.getindex(bc) using RecursiveArrayTools RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x) From 00fde439f8384877dbca4dfe71e93c5ed82117d2 Mon Sep 17 00:00:00 2001 From: apkille Date: Mon, 29 Jul 2024 20:32:26 -0400 Subject: [PATCH 03/12] broadcasting with tests --- src/semiclassical.jl | 75 ++++++++++++++++++++++---------------- test/test_semiclassical.jl | 14 +++++++ 2 files changed, 58 insertions(+), 31 deletions(-) diff --git a/src/semiclassical.jl b/src/semiclassical.jl index 4167c91c..87fcfea0 100644 --- a/src/semiclassical.jl +++ b/src/semiclassical.jl @@ -33,22 +33,22 @@ mutable struct State{B,T,C} end State{B}(q::T, c::C) where {B,T<:QuantumState{B},C} = State(q,c) +# Standard interfaces Base.zero(x::State) = State(zero(x.quantum), zero(x.classical)) -Base.oneunit(x::State) = State(one.(x.quantum), one.(x.classical)) Base.length(x::State) = length(x.quantum) + length(x.classical) +Base.axes(x::State) = (Base.OneTo(length(x)),) Base.size(x::State) = size(x.quantum) Base.ndims(x::State) = ndims(x.quantum) -Base.axes(x::State) = axes(x.quantum) Base.ndims(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = ndims(T) Base.copy(x::State) = State(copy(x.quantum), copy(x.classical)) Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x) Base.fill!(x::State, a) = (fill!(x.quantum, a), fill!(x.classical, a)) Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical)) -Base.similar(x::State) = State(similar(x.quantum), similar(x.classical)) +Base.eltype(x::State{B,T,C}) where {B,T<:QuantumState{B},C} = promote_type(eltype(T), eltype(C)) +Base.similar(x::State, ::Type{T} = eltype(x)) where {T} = State(similar(x.quantum, T), similar(x.classical, T)) Base.getindex(x::State, idx) = idx <= length(x.quantum) ? getindex(x.quantum, idx) : getindex(x.classical, idx-length(x.quantum)) Base.setindex!(x::State, v, idx) = idx <= length(x.quantum) ? setindex(x.quantum, v, idx) : setindex(x.classical, v, idx-length(x.quantum)) -Base.firstindex(x::State) = 1 -Base.lastindex(x::State) = length(x) + normalize!(x::State) = (normalize!(x.quantum); x) normalize(x::State) = State(normalize(x.quantum),copy(x.classical)) LinearAlgebra.norm(x::State) = LinearAlgebra.norm(x.quantum) @@ -57,23 +57,15 @@ LinearAlgebra.norm(x::State, p::Int64) = LinearAlgebra.norm(x.quantum, p) ==(x::State{B}, y::State{B}) where {B} = (x.classical==y.classical) && (x.quantum==y.quantum) ==(x::State, y::State) = false -+(x::State, y::State) = State(x.quantum+y.quantum, x.classical+y.classical) --(x::State, y::State) = State(x.quantum-y.quantum, x.classical-y.classical) -*(x::Number, y::State) = State(x*y.quantum, x*y.classical) -*(x::State, y::Number) = y*x -/(x::State, y::Number) = State(x.quantum/y, x.classical/y) -/(x::State, y::State) = State(x.quantum ./ y.quantum, x.classical ./ y.classical) - isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum) && isapprox(x.classical,y.classical) isapprox(x::State, y::State; kwargs...) = false QuantumOpticsBase.expect(op, state::State) = expect(op, state.quantum) QuantumOpticsBase.variance(op, state::State) = variance(op, state.quantum) QuantumOpticsBase.ptrace(state::State, indices) = State(ptrace(state.quantum, indices), state.classical) - QuantumOpticsBase.dm(x::State) = State(dm(x.quantum), x.classical) -Base.broadcastable(x::State) = Ref(x) +Base.broadcastable(x::State) = x # Custom broadcasting style struct StateStyle{B} <: Broadcast.BroadcastStyle end @@ -81,16 +73,30 @@ struct StateStyle{B} <: Broadcast.BroadcastStyle end # Style precedence rules Broadcast.BroadcastStyle(::Type{<:State{B}}) where {B} = StateStyle{B}() Broadcast.BroadcastStyle(::StateStyle{B1}, ::StateStyle{B2}) where {B1,B2} = throw(IncompatibleBases()) - -# Broadcast with scalars -Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B,T<:StateStyle{B}} = T() +Broadcast.BroadcastStyle(::StateStyle{B}, ::Broadcast.DefaultArrayStyle{0}) where {B} = StateStyle{B}() +Broadcast.BroadcastStyle(::Broadcast.DefaultArrayStyle{0}, ::StateStyle{B}) where {B} = StateStyle{B}() # Out-of-place broadcasting -@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:StateStyle{B},Axes,F,Args<:Tuple} +@inline function Base.copy(bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) - q = find_quantum(bcf) - c = find_classical(bcf) - return State{B}(copy(q), copy(c)) + # extract quantum object from broadcast container + qobj = find_quantum(bcf) + data_q = zeros(eltype(qobj), size(qobj)...) + Nq = length(qobj) + # allocate quantum data from broadcast container + @inbounds @simd for I in 1:Nq + data_q[I] = bcf[I] + end + # extract classical object from broadcast container + cobj = find_classical(bcf) + data_c = zeros(eltype(cobj), length(cobj)) + Nc = length(cobj) + # allocate classical data from broadcast container + @inbounds @simd for I in 1:Nc + data_c[I] = bcf[I+Nq] + end + type = eval(nameof(typeof(qobj))) + return State{B}(type(basis(qobj), data_q), data_c) end for f ∈ [:find_quantum, :find_classical] @@ -99,27 +105,34 @@ for f ∈ [:find_quantum, :find_classical] @eval ($f)(x) = x @eval ($f)(::Any, rest) = ($f)(rest) end -find_basis(x::State, rest) = QuantumOpticsBase.find_basis(x.quantum) find_quantum(x::State, rest) = x.quantum find_classical(x::State, rest) = x.classical # In-place broadcasting -@inline function Base.copyto!(dest::State{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:StateStyle{B},Axes,F,Args} +@inline function Base.copyto!(dest::State{B}, bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple} + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) bc′ = Base.Broadcast.preprocess(dest, bc) - q, c = find_quantum(bc′), find_classical(bc′) - copyto!(dest.quantum, q) - copyto!(dest.classical, c) + # write broadcasted quantum data to dest + qobj = dest.quantum + @inbounds @simd for I in 1:length(qobj) + qobj.data[I] = bc′[I] + end + # write broadcasted classical data to dest + cobj = dest.classical + @inbounds @simd for I in 1:length(c) + cobj[I] = bc′[I+length(q)] + end return dest end -@inline Base.copyto!(dest::State{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:StateStyle{B2},Axes,F,Args} = +@inline Base.copyto!(dest::State{B1}, bc::Broadcast.Broadcasted{<:StateStyle{B2},Axes,F,Args}) where {B1,B2,Axes,F,Args<:Tuple} = throw(IncompatibleBases()) -@inline Base.copyto!(dest::State, bc::Broadcast.Broadcasted) = print(bc) - Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = Base.getindex(x, i) -Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(bc::Broadcast.Broadcasted{Style,Axes,F,Args}, i) where {B,Style<:StateStyle{B},Axes,F,Args} = Base.getindex(bc) using RecursiveArrayTools -RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x) +RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x) +RecursiveArrayTools.recursivecopy!(dest::State, src::State) = copyto!(dest, src) +RecursiveArrayTools.recursivecopy(x::State) = copy(x) +RecursiveArrayTools.recursivefill!(x::State, a) = fill!(x, a) """ semiclassical.schroedinger_dynamic(tspan, state0, fquantum, fclassical[; fout, ...]) diff --git a/test/test_semiclassical.jl b/test/test_semiclassical.jl index 72198612..0fcc07ad 100644 --- a/test/test_semiclassical.jl +++ b/test/test_semiclassical.jl @@ -175,4 +175,18 @@ after_jump = findlast(t-> !(t∈T), tout4) @test ψt4[before_jump].quantum == ψ0.quantum @test ψt4[after_jump].quantum == spindown(ba)⊗fockstate(bf,0) +# Test broadcasting interface +b = FockBasis(10) +u0 = ComplexF64[0.7, 0.2] +psi = fockstate(b, 2) +rho = dm(psi) + +sc_ket = semiclassical.State(psi, u0) +sc_dm = semiclassical.State(rho, u0) + +@test sc_ket .* 1.0 == sc_ket +@test sc_dm .* 1.0 == sc_dm +@test sc_ket .+ 2.0 == semiclassical.State(psi .+ 2.0, u0 .+ 2.0) +@test sc_dm .+ 2.0 == semiclassical.State(rho .+ 2.0, u0 .+ 2.0) + end # testsets From a5ec4ae606ead01d09ced1b329abead0d90e6982 Mon Sep 17 00:00:00 2001 From: apkille Date: Mon, 29 Jul 2024 20:40:36 -0400 Subject: [PATCH 04/12] typo fix --- src/semiclassical.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/semiclassical.jl b/src/semiclassical.jl index 87fcfea0..e3f92290 100644 --- a/src/semiclassical.jl +++ b/src/semiclassical.jl @@ -119,8 +119,8 @@ find_classical(x::State, rest) = x.classical end # write broadcasted classical data to dest cobj = dest.classical - @inbounds @simd for I in 1:length(c) - cobj[I] = bc′[I+length(q)] + @inbounds @simd for I in 1:length(cobj) + cobj[I] = bc′[I+length(qobj)] end return dest end From d4a8d971add96279f5e5fcb2eab340fffd8f846d Mon Sep 17 00:00:00 2001 From: apkille Date: Mon, 29 Jul 2024 21:02:43 -0400 Subject: [PATCH 05/12] eltype switch --- src/semiclassical.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/semiclassical.jl b/src/semiclassical.jl index e3f92290..6e246b7b 100644 --- a/src/semiclassical.jl +++ b/src/semiclassical.jl @@ -44,7 +44,7 @@ Base.copy(x::State) = State(copy(x.quantum), copy(x.classical)) Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x) Base.fill!(x::State, a) = (fill!(x.quantum, a), fill!(x.classical, a)) Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical)) -Base.eltype(x::State{B,T,C}) where {B,T<:QuantumState{B},C} = promote_type(eltype(T), eltype(C)) +Base.eltype(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = promote_type(eltype(T), eltype(C)) Base.similar(x::State, ::Type{T} = eltype(x)) where {T} = State(similar(x.quantum, T), similar(x.classical, T)) Base.getindex(x::State, idx) = idx <= length(x.quantum) ? getindex(x.quantum, idx) : getindex(x.classical, idx-length(x.quantum)) Base.setindex!(x::State, v, idx) = idx <= length(x.quantum) ? setindex(x.quantum, v, idx) : setindex(x.classical, v, idx-length(x.quantum)) From c85e235bc454f25f391b32e07cd1e6906f787d35 Mon Sep 17 00:00:00 2001 From: apkille Date: Wed, 31 Jul 2024 11:23:35 -0400 Subject: [PATCH 06/12] add test and compat changes --- .github/workflows/ci.yml | 2 +- Project.toml | 2 +- src/QuantumOptics.jl | 1 + src/semiclassical.jl | 2 +- test/runtests.jl | 1 + test/test_sciml_broadcast_interfaces.jl | 25 +++++++++++++++++++++++++ 6 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 test/test_sciml_broadcast_interfaces.jl diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a1e491d5..6acaeda6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: fail-fast: false matrix: version: - - '1' + - '1.10' os: - ubuntu-latest - windows-latest diff --git a/Project.toml b/Project.toml index c6080e37..a20c6559 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ RecursiveArrayTools = "2, 3" Reexport = "0.2, 1.0" StochasticDiffEq = "6" WignerSymbols = "1, 2" -julia = "1.3" +julia = "1.10" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/QuantumOptics.jl b/src/QuantumOptics.jl index ed9e5b37..30e7d23e 100644 --- a/src/QuantumOptics.jl +++ b/src/QuantumOptics.jl @@ -3,6 +3,7 @@ module QuantumOptics using Reexport @reexport using QuantumOpticsBase using SparseArrays, LinearAlgebra +import RecursiveArrayTools export ylm, diff --git a/src/semiclassical.jl b/src/semiclassical.jl index 6e246b7b..b608a55c 100644 --- a/src/semiclassical.jl +++ b/src/semiclassical.jl @@ -5,6 +5,7 @@ import Base: ==, isapprox, +, -, *, / import ..timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback, JumpRNGState, threshold, roll!, as_vector, QO_CHECKS import LinearAlgebra: normalize, normalize! +import RecursiveArrayTools using Random, LinearAlgebra import OrdinaryDiffEq @@ -128,7 +129,6 @@ end throw(IncompatibleBases()) Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = Base.getindex(x, i) -using RecursiveArrayTools RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x) RecursiveArrayTools.recursivecopy!(dest::State, src::State) = copyto!(dest, src) RecursiveArrayTools.recursivecopy(x::State) = copy(x) diff --git a/test/runtests.jl b/test/runtests.jl index a5f92631..487c03fd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,7 @@ names = [ "test_timeevolution_abstractdata.jl", + "test_sciml_broadcast_interfaces.jl", "test_ForwardDiff.jl" ] diff --git a/test/test_sciml_broadcast_interfaces.jl b/test/test_sciml_broadcast_interfaces.jl new file mode 100644 index 00000000..6f454161 --- /dev/null +++ b/test/test_sciml_broadcast_interfaces.jl @@ -0,0 +1,25 @@ +using Test +using QuantumOptics +using OrdinaryDiffEq + +@testset "sciml interface" begin + +# semiclassical ODE problem +b = SpinBasis(1//2) +psi0 = spindown(b) +u0 = ComplexF64[0.5, 0.75] +sc = semiclassical.State(psi0, u0) +t₀, t₁ = (0.0, pi) +σx = sigmax(b) + +fquantum(t, q, u) = σx + cos(u[1])*identityoperator(σx) +fclassical!(du, u, q, t) = (du[1] = sin(u[2]); du[2] = 2*u[1]) +f!(dstate, state, p, t) = semiclassical.dschroedinger_dynamic!(dstate, fquantum, fclassical!, state, t) +prob = ODEProblem(f!, sc, (t₀, t₁)) + +sol = solve(prob, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false) +tout, ψt = semiclassical.schroedinger_dynamic([t₀, t₁], sc, fquantum, fclassical!; reltol = 1.0e-8, abstol = 1.0e-10) + +@test sol[end] ≈ ψt[end] + +end \ No newline at end of file From 5cd9e45d1f418f2026a65930d5540178a8a04d41 Mon Sep 17 00:00:00 2001 From: Stefan Krastanov Date: Thu, 8 Aug 2024 16:59:40 -0400 Subject: [PATCH 07/12] ci fix --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6acaeda6..a1e491d5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: fail-fast: false matrix: version: - - '1.10' + - '1' os: - ubuntu-latest - windows-latest From 50dc8760df78372aae31f196bac165a5d9dc5937 Mon Sep 17 00:00:00 2001 From: Stefan Krastanov Date: Sat, 10 Aug 2024 16:14:15 -0400 Subject: [PATCH 08/12] trigger ci From 7129f378750cd21e11081e0111d27e1f310a0d09 Mon Sep 17 00:00:00 2001 From: apkille Date: Sat, 10 Aug 2024 22:51:37 -0400 Subject: [PATCH 09/12] more interface tests --- src/semiclassical.jl | 6 ++---- test/test_semiclassical.jl | 11 +++++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/semiclassical.jl b/src/semiclassical.jl index b608a55c..73983dad 100644 --- a/src/semiclassical.jl +++ b/src/semiclassical.jl @@ -1,6 +1,7 @@ module semiclassical using QuantumOpticsBase +import QuantumOpticsBases: IncompatibleBases import Base: ==, isapprox, +, -, *, / import ..timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback, JumpRNGState, threshold, roll!, as_vector, QO_CHECKS @@ -39,7 +40,6 @@ Base.zero(x::State) = State(zero(x.quantum), zero(x.classical)) Base.length(x::State) = length(x.quantum) + length(x.classical) Base.axes(x::State) = (Base.OneTo(length(x)),) Base.size(x::State) = size(x.quantum) -Base.ndims(x::State) = ndims(x.quantum) Base.ndims(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = ndims(T) Base.copy(x::State) = State(copy(x.quantum), copy(x.classical)) Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x) @@ -48,17 +48,15 @@ Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical)) Base.eltype(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = promote_type(eltype(T), eltype(C)) Base.similar(x::State, ::Type{T} = eltype(x)) where {T} = State(similar(x.quantum, T), similar(x.classical, T)) Base.getindex(x::State, idx) = idx <= length(x.quantum) ? getindex(x.quantum, idx) : getindex(x.classical, idx-length(x.quantum)) -Base.setindex!(x::State, v, idx) = idx <= length(x.quantum) ? setindex(x.quantum, v, idx) : setindex(x.classical, v, idx-length(x.quantum)) normalize!(x::State) = (normalize!(x.quantum); x) normalize(x::State) = State(normalize(x.quantum),copy(x.classical)) LinearAlgebra.norm(x::State) = LinearAlgebra.norm(x.quantum) -LinearAlgebra.norm(x::State, p::Int64) = LinearAlgebra.norm(x.quantum, p) ==(x::State{B}, y::State{B}) where {B} = (x.classical==y.classical) && (x.quantum==y.quantum) ==(x::State, y::State) = false -isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum) && isapprox(x.classical,y.classical) +isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum; kwargs...) && isapprox(x.classical,y.classical; kwargs...) isapprox(x::State, y::State; kwargs...) = false QuantumOpticsBase.expect(op, state::State) = expect(op, state.quantum) diff --git a/test/test_semiclassical.jl b/test/test_semiclassical.jl index 0fcc07ad..fdc2c3c1 100644 --- a/test/test_semiclassical.jl +++ b/test/test_semiclassical.jl @@ -1,6 +1,7 @@ using Test using QuantumOptics using LinearAlgebra +using QuantumOpticsBase: IncompatibleBases @testset "semiclassical" begin @@ -177,16 +178,26 @@ after_jump = findlast(t-> !(t∈T), tout4) # Test broadcasting interface b = FockBasis(10) +bn = FockBasis(20) u0 = ComplexF64[0.7, 0.2] psi = fockstate(b, 2) +psin = fockstate(bn, 2) rho = dm(psi) sc_ket = semiclassical.State(psi, u0) +sc_ketn = semiclassical.State(psin, u0) sc_dm = semiclassical.State(rho, u0) +@test Base.size(sc_dm) == Base.size(rho) +@test (copy_sc = copy(sc_ket); Base.fill!(copy_sc, 0.0); copy_sc) == semiclassical.State(fill!(copy(psi), 0.0), fill!(copy(u0), 0.0)) +@test Base.similar(sc_ket, Int) isa semiclassical.State +@test normalize!(copy(sc_ket)) == semiclassical.State(normalize!(copy(psi)), u0) +@test !(sc_ket == sc_ketn) +@test !(isapprox(sc_ket, sc_ketn)) @test sc_ket .* 1.0 == sc_ket @test sc_dm .* 1.0 == sc_dm @test sc_ket .+ 2.0 == semiclassical.State(psi .+ 2.0, u0 .+ 2.0) @test sc_dm .+ 2.0 == semiclassical.State(rho .+ 2.0, u0 .+ 2.0) +@test_throws IncompatibleBases sc_ket .+ semiclassical.State(spinup(SpinBasis(10)), u0) end # testsets From 51271bb108796130f7ea6f5a867b2281e851cc3e Mon Sep 17 00:00:00 2001 From: apkille Date: Sat, 10 Aug 2024 22:54:22 -0400 Subject: [PATCH 10/12] version fix --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a1e491d5..6acaeda6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: fail-fast: false matrix: version: - - '1' + - '1.10' os: - ubuntu-latest - windows-latest From 88d2a6a03af3345a0536d3ad58ba92fd0561f649 Mon Sep 17 00:00:00 2001 From: apkille Date: Sat, 10 Aug 2024 23:14:37 -0400 Subject: [PATCH 11/12] ci fix again --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6acaeda6..a1e491d5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: fail-fast: false matrix: version: - - '1.10' + - '1' os: - ubuntu-latest - windows-latest From bb2d0f55c1b674f558bef6027309f1da2e07b50d Mon Sep 17 00:00:00 2001 From: apkille Date: Sun, 11 Aug 2024 02:05:59 -0400 Subject: [PATCH 12/12] add import QuantumOpticsBase --- src/semiclassical.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/semiclassical.jl b/src/semiclassical.jl index 73983dad..d70dcbad 100644 --- a/src/semiclassical.jl +++ b/src/semiclassical.jl @@ -1,7 +1,7 @@ module semiclassical using QuantumOpticsBase -import QuantumOpticsBases: IncompatibleBases +import QuantumOpticsBase: IncompatibleBases import Base: ==, isapprox, +, -, *, / import ..timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback, JumpRNGState, threshold, roll!, as_vector, QO_CHECKS