Skip to content

Commit

Permalink
Remove excessive abstractions, implement AbstractVector
Browse files Browse the repository at this point in the history
  • Loading branch information
aryavorskiy committed Sep 18, 2023
1 parent 5db3d0c commit 56bf5f6
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/manybody.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
abstract type OccupationsIterator end

struct Occupations{T} <: OccupationsIterator
struct Occupations{T} <: AbstractVector{Vector{T}}
occupations::Vector{T}
function Occupations(occ::Vector{T}) where {T}
length(occ) > 0 && @assert all(==(length(first(occ))) length, occ)
Expand All @@ -13,21 +11,21 @@ struct Occupations{T} <: OccupationsIterator
Occupations(occ::Vector{T}, ::Val{:no_checks}) where {T} = new{T}(occ)
end
Base.:(==)(occ1::Occupations, occ2::Occupations) = occ1.occupations == occ2.occupations
Base.size(occ::Occupations) = (length(occ.occupations),)
Base.@propagate_inbounds function Base.getindex(occ::Occupations, i::Int)
@boundscheck !checkbounds(Bool, occ.occupations, i) && throw(BoundsError(occ, i))
return occ.occupations[i]
end
Base.iterate(occ::Occupations, s...) = iterate(occ.occupations, s...)
Base.length(occ::Occupations) = length(occ.occupations)
allocate_buffer(occ::Occupations) = similar(first(occ))
function state_index(occ::Occupations, state)
length(state) != length(first(occ)) && return nothing
ret = searchsortedfirst(occ.occupations, state, lt= >)
ret == length(occ) + 1 && return nothing
return occ.occupations[ret] == state ? ret : nothing
end
Base.union(occ1::Occupations{T}, occ2::Occupations{T}) where {T} =
Occupations(union(occ1.occupations, occ2.occupations))
state_index(occ::AbstractVector, state) = findfirst(==(state), occ)
Base.union(occ1::Occupations{T}, occs::Occupations{T}...) where {T} =
Occupations(union(occ1.occupations, (occ.occupations for occ in occs)...))

"""
ManyBodyBasis(b, occupations)
Expand All @@ -43,7 +41,7 @@ struct ManyBodyBasis{B,O,UT} <: Basis
onebodybasis::B
occupations::O
occupations_hash::UT
function ManyBodyBasis{B,O}(onebodybasis::B, occupations::O) where {B,O<:OccupationsIterator}
function ManyBodyBasis{B,O}(onebodybasis::B, occupations::O) where {B,O}
h = hash(hash.(occupations))
new{B,O,typeof(h)}(length(occupations), onebodybasis, occupations, h)
end
Expand Down

0 comments on commit 56bf5f6

Please sign in to comment.