Skip to content

Commit

Permalink
fix type piracies
Browse files Browse the repository at this point in the history
  • Loading branch information
apkille committed Jul 27, 2024
1 parent 08ac108 commit ecb593c
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,10 @@ for f ∈ [:find_basis,:find_dType]
@eval ($f)(::Any, rest) = ($f)(rest)
end

find_basis(a::StateVector, rest) = a.basis
find_dType(a::StateVector, rest) = eltype(a)
@inline Base.getindex(st::StateVector, idx) = getindex(st.data, idx)
Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::StateVector, i) = x.data[i]
find_basis(x::T, rest) where {T<:Union{Ket, Bra}} = x.basis
find_dType(x::T, rest) where {T<:Union{Ket, Bra}} = eltype(x)
@inline Base.getindex(x::T, idx) where {T<:Union{Ket, Bra}} = getindex(x.data, idx)
Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::T, i) where {T<:Union{Ket, Bra}} = x.data[i]

# In-place broadcasting for Kets
@inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:KetStyle{B},Axes,F,Args}
Expand Down Expand Up @@ -243,19 +243,18 @@ end
@inline Base.copyto!(dest::Bra{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:BraStyle{B2},Axes,F,Args} =
throw(IncompatibleBases())

@inline Base.copyto!(A::T,B::T) where T<:Union{Ket, Bra} = (copyto!(A.data,B.data); A) # Can not use T<:QuantumInterface.StateVector, because StateVector does not imply the existence of a data property
@inline Base.copyto!(dest::T,src::T) where {T<:Union{Ket, Bra}} = (copyto!(dest.data,src.data); dest) # Can not use T<:QuantumInterface.StateVector, because StateVector does not imply the existence of a data property

# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl
Base.eltype(::Type{Ket{B,A}}) where {B,N,A<:AbstractVector{N}} = N # ODE init
Base.eltype(::Type{Bra{B,A}}) where {B,N,A<:AbstractVector{N}} = N
Base.zero(k::StateVector) = typeof(k)(k.basis, zero(k.data)) # ODE init
Base.any(f::Function, x::StateVector; kwargs...) = any(f, x.data; kwargs...) # ODE nan checks
Base.all(f::Function, x::StateVector; kwargs...) = all(f, x.data; kwargs...)
Base.fill!(k::StateVector, a) = typeof(k)(k.basis, fill!(k.data, a))
Broadcast.similar(k::StateVector, t) = typeof(k)(k.basis, similar(k.data))
Base.any(f::Function, x::T; kwargs...) where {T<:Union{Ket, Bra}} = any(f, x.data; kwargs...) # ODE nan checks
Base.all(f::Function, x::T; kwargs...) where {T<:Union{Ket, Bra}} = all(f, x.data; kwargs...)
Base.fill!(x::T, a) where {T<:Union{Ket, Bra}} = typeof(x)(x.basis, fill!(x.data, a))
Broadcast.similar(x::T, t) where {T<:Union{Ket, Bra}} = typeof(x)(x.basis, similar(x.data))
using RecursiveArrayTools
RecursiveArrayTools.recursivecopy!(dst::Ket{B,A},src::Ket{B,A}) where {B,A} = copy!(dst.data,src.data) # ODE in-place equations
RecursiveArrayTools.recursivecopy!(dst::Bra{B,A},src::Bra{B,A}) where {B,A} = copy!(dst.data,src.data)
RecursiveArrayTools.recursivecopy(x::StateVector) = copy(x)
RecursiveArrayTools.recursivecopy(x::AbstractArray{T}) where {T<:StateVector} = copy(x)
RecursiveArrayTools.recursivefill!(x::StateVector, a) = fill!(x, a)
RecursiveArrayTools.recursivecopy!(dest::Ket{B,A},src::Ket{B,A}) where {B,A} = copyto!(dest, src) # ODE in-place equations
RecursiveArrayTools.recursivecopy!(dest::Bra{B,A},src::Bra{B,A}) where {B,A} = copyto!(dest, src)
RecursiveArrayTools.recursivecopy(x::T) where {T<:Union{Ket, Bra}} = copy(x)
RecursiveArrayTools.recursivecopy(x::AbstractArray{T}) where {T<:Union{Ket, Bra}} = copy(x)
RecursiveArrayTools.recursivefill!(x::T, a) where {T<:Union{Ket, Bra}} = fill!(x, a)

0 comments on commit ecb593c

Please sign in to comment.