Skip to content

Commit

Permalink
Merge pull request #99 from pablosanjose/ketmatrices
Browse files Browse the repository at this point in the history
Kets : support `SMatrix` eltype
  • Loading branch information
pablosanjose authored Sep 26, 2020
2 parents 334250f + a4dbc4d commit 743b928
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 92 deletions.
123 changes: 72 additions & 51 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ Base.size(h::HamiltonianHarmonic) = size(h.h)
flatsize(h::Hamiltonian, n) = first(flatsize(h)) # h is always square

function flatsize(h::Hamiltonian)
n = sum(sublatsites(h.lattice) .* length.(h.orbitals))
n = sum(sublatlengths(h.lattice) .* length.(h.orbitals))
return (n, n)
end

Expand Down Expand Up @@ -383,7 +383,7 @@ end
# Indexing #
Base.push!(h::Hamiltonian{<:Any,L}, dn::NTuple{L,Int}) where {L} = push!(h, SVector(dn...))
Base.push!(h::Hamiltonian{<:Any,L}, dn::Vararg{Int,L}) where {L} = push!(h, SVector(dn...))
function Base.push!(h::Hamiltonian{<:Any,L}, dn::SVector{L,Int}) where {L}
function Base.push!(h::Hamiltonian{<:Any,L}, dn::SVector{L,Int}) where {L}
get_or_push!(h.harmonics, dn, size(h))
return h
end
Expand Down Expand Up @@ -639,6 +639,7 @@ toeltype(u::UniformScaling, ::Type{S}, t1::NTuple{N1}, t2::NTuple{N2}) where {N1
toeltype(t::Number, ::Type{T}, t1::NTuple{1}) where {T<:Number} = T(t)
toeltype(t::Number, ::Type{S}, t1::NTuple{1}) where {S<:SVector} = padtotype(t, S)
toeltype(t::SVector{N}, ::Type{S}, t1::NTuple{N}) where {N,S<:SVector} = padtotype(t, S)
toeltype(t::SMatrix{N}, ::Type{S}, t1::NTuple{N}) where {N,S<:SMatrix} = padtotype(t, S)

# Fallback to catch mismatched or undesired block types
toeltype(t::Array, x...) = throw(ArgumentError("Array input in model, please use StaticArrays instead (e.g. SA[1 0; 0 1] instead of [1 0; 0 1])"))
Expand Down Expand Up @@ -675,77 +676,97 @@ Construct a `Vector` representation of `km` applied to Hamiltonian `h`.
Base.Vector(km::KetModel, h::Hamiltonian) = vec(Matrix(km, h))

"""
Matrix(km::KetModel, h::Hamiltonian; orthogonal = false)
Matrix(kms::NTuple{N,KetModel}, h::Hamiltonian, orthogonal = false)
Matrix(kms::AbstractMatrix, h::Hamiltonian; orthogonal = false)
Matrix(kms::StochasticTraceKets, h::Hamiltonian; orthogonal = false)
Matrix(km::KetModel, h::Hamiltonian)
Matrix(kms::NTuple{N,KetModel}, h::Hamiltonian)
Matrix(kms::AbstractMatrix, h::Hamiltonian)
Matrix(kms::StochasticTraceKets, h::Hamiltonian)
Construct an `M×N` `Matrix` representation of the `N` kets `kms` applied to `M×M`
Hamiltonian `h`. If `orthogonal = true`, the columns are made orthogonal through a
Gram-Schmidt process. If `kms::StochasticTraceKets` for `n` random kets (constructed with
Hamiltonian `h`. If `kms::StochasticTraceKets` for `n` random kets (constructed with
`randomkets(n)`), a normalization `1/√n` required for stochastic traces is included.
"""
Base.Matrix(km::KetModel, h::Hamiltonian) = Matrix((km,), h)

function Base.Matrix(km::AbstractMatrix, h::Hamiltonian; orthogonal = false)
eltype(km) == orbitaltype(h) && size(km, 1) == size(h, 2) || throw(ArgumentError("ket vector or matrix is incompatible with Hamiltonian"))
function Base.Matrix(km::AbstractMatrix, h::Hamiltonian)
check_compatible_kets(km, h)
kmat = Matrix(km)
orthogonal && make_orthogonal!(kmat, kms)
return kmat
end

function Base.Matrix(rk::StochasticTraceKets, h::Hamiltonian)
ketmodels = Base.Iterators.repeated(rk.ketmodel, rk.repetitions)
kmat = Matrix(ketmodels, h; orthogonal = rk.orthogonal)
normk = sqrt(1/size(kmat,2))
kmat .*= normk # normalized for stochastic traces
# kmodels should be a Union{NTuple{N,KetModel},StochasticTraceKets}
function Base.Matrix(kmodels, h::Hamiltonian)
kmodels´ = resolve_tuple(kmodels, h.lattice)
allpos = allsitepositions(h.lattice)
T = guess_eltype(kmodels´, allpos, h)
orbs = h.orbitals
kmat = [generate_amplitude(km, i, allpos[i], T, orbs[s]) for (i, s) in sitesublats(h.lattice), km in kmodels´]
check_compatible_kets(kmat, h)
maybe_normalize!(kmat, kmodels)
return kmat
end

function Base.Matrix(kms, h::Hamiltonian; orthogonal = false)
M = orbitaltype(h)
kmat = zeros(M, size(h, 2), length(kms))
for (j, km) in enumerate(kms)
kvec = view(kmat, :, j)
ket!(kvec, km, h)
end
orthogonal && make_orthogonal!(kmat, kms)
resolve_tuple(ks::NTuple{N,KetModel}, lat) where {N} = resolve.(ks, Ref(lat))
resolve_tuple(ks::StochasticTraceKets, lat) = resolve(ks, lat)

function guess_eltype(kms, allpos, h)
km = first(kms)
term = first(km.model.terms)
rsel = term.selector
s = first(sublats(rsel))
i = first(siteindices(rsel, s))
r = allpos[i]
t = term(r, r)
z = zero(orbitaltype(h))
T = _guess_eltype(km.maporbitals, t, z)
return T
end

_guess_eltype(::Val{true}, t::Number, z) = typeof(t * z)
_guess_eltype(::Val{false}, t::Number, z) = typeof(t * z)
_guess_eltype(::Val{false}, t::SVector{<:Any,T}, z::SVector) where {T} = typeof(zero(T) * z)
_guess_eltype(::Val{false}, t::SMatrix{M,N,T}, z::SVector{M2,T2}) where {M,N,M2,T,T2} = typeof(SMatrix{M2,N}(zero(T) * zero(T2) * I))

function maybe_normalize!(kmat, kms::StochasticTraceKets)
kms.ketmodel.normalized && normalize_columns!(kmat)
kmat .*= sqrt(1/size(kmat,2))
return kmat
end

function ket!(k, km::KetModel, h)
M = eltype(k)
fill!(k, zero(M))
hsites = allsitepositions(h.lattice)
for term in km.model.terms
rs = resolve(term.selector, h.lattice)
ss = sublats(rs)
for s in ss
orbs = h.orbitals[s]
is = siterange(h.lattice, s)
for i in is
i in rs || continue
r = hsites[i]
k[i] += generate_amplitude(km, term, r, M, orbs)
end
end
function maybe_normalize!(kmat, kms::Tuple{KetModel})
for (i, km) in enumerate(kms)
km.normalized && normalize_columns!(kmat, i)
end
km.normalized && normalize!(k)
return k
return kmat
end

function make_orthogonal!(kmat::AbstractMatrix{<:Number}, kms)
q, r = qr!(kmat)
kmat .= Matrix(q)
for (j, km) in enumerate(kms)
if !km.normalized
kmat[:,j] .*= r[j, j]
end
check_compatible_kets(kmat::AbstractMatrix, h::Hamiltonian) =
comp_eltypes(h, kmat) && size(kmat, 1) == size(h, 2) ||
throw(ArgumentError("ket vector or matrix is incompatible with Hamiltonian"))

comp_eltypes(h::Hamiltonian, k::AbstractMatrix) = comp_eltypes(blocktype(h), eltype(k))
comp_eltypes(::Type{<:Number}, ::Type{<:Number}) = true
comp_eltypes(::Type{<:Number}, ::Type{<:SMatrix{1}}) = true
comp_eltypes(::Type{<:SMatrix{N,M}}, ::Type{<:SVector{M}}) where {N,M} = true
comp_eltypes(::Type{<:SMatrix{N,M}}, ::Type{<:SMatrix{M}}) where {N,M} = true
comp_eltypes(t1, t2) = false

### generate_amplitude (asssumes resolved selectors) ###

function generate_amplitude(ketmodel::KetModel, i, r, T, orbs)
amplitude = sum(ketmodel.model.terms) do term
i in term.selector ? maybe_maporbitals(ketmodel.maporbitals, T, orbs, term, r) : zero(T)
end
return kmat
return amplitude
end

make_orthogonal!(kmat, kms) = throw(ArgumentError("The orthogonalize option is only available for kets of scalar eltype, not for $(eltype(kmat))."))
function maybe_maporbitals(::Val{false}, T, orbs, term, r)
return toeltype(term(r, r), T, orbs)
end

function maybe_maporbitals(::Val{true}, T, orbs::NTuple{N}, term, r) where {N}
x = SVector{N}(ntuple(_ -> Number(term(r, r)), Val(N)))
return toeltype(x, T, orbs)
end

#######################################################################
# unitcell/supercell for Hamiltonians
Expand Down
21 changes: 21 additions & 0 deletions src/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,24 @@ function Base.iterate(s::SparseMatrixReader, state = (1, 1))
end

enumerate_sparse(s::SparseMatrixCSC) = SparseMatrixReader(s)

#######################################################################
# SiteSublats
#######################################################################

struct SiteSublats{G,U}
g::G
u::U
end

sitesublats(u) = SiteSublats(((i, s) for s in sublats(u) for i in siterange(u, s)), u)

Base.iterate(s::SiteSublats, x...) = iterate(s.g, x...)

Base.IteratorSize(::SiteSublats) = Base.HasLength()

Base.IteratorEltype(::SiteSublats) = Base.HasEltype()

Base.eltype(::SiteSublats) = Tuple{Int,Int}

Base.length(s::SiteSublats) = nsites(s.u)
14 changes: 8 additions & 6 deletions src/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct Unitcell{E,T,N}
sites::Vector{SVector{E,T}}
names::NTuple{N,NameType}
offsets::Vector{Int} # Linear site number offsets for each sublat
end # so that diff(offset) == sublatsites
end # so that diff(offset) == sublatlengths

Unitcell(sublats::Sublat...; kw...) = Unitcell(promote(sublats...); kw...)

Expand Down Expand Up @@ -113,7 +113,7 @@ siterange(u::Unitcell, sublat) = (1+u.offsets[sublat]):u.offsets[sublat+1]
enumeratesites(u::Unitcell, sublat) = ((i, sitepositions(u)[i]) for i in siterange(u, sublat))

nsites(u::Unitcell) = length(u.sites)
nsites(u::Unitcell, sublat) = sublatsites(u)[sublat]
nsites(u::Unitcell, sublat) = sublatlengths(u)[sublat]

offsets(u::Unitcell) = u.offsets

Expand All @@ -125,7 +125,7 @@ function sublat(u::Unitcell, siteidx)
return l
end

sublatsites(u::Unitcell) = diff(u.offsets)
sublatlengths(u::Unitcell) = diff(u.offsets)

nsublats(u::Unitcell) = length(u.names)

Expand Down Expand Up @@ -205,7 +205,7 @@ function Base.show(io::IO, lat::Lattice)
"$i Bravais vectors : $(displayvectors(bravais(lat); digits = 6))
$i Sublattices : $(nsublats(lat))
$i Names : $(displaynames(lat))
$i Sites : $(display_as_tuple(sublatsites(lat))) --> $(nsites(lat)) total per unit cell")
$i Sites : $(display_as_tuple(sublatlengths(lat))) --> $(nsites(lat)) total per unit cell")
end

Base.summary(::Lattice{E,L,T}) where {E,L,T} =
Expand Down Expand Up @@ -376,7 +376,7 @@ function Base.show(io::IO, lat::Superlattice)
"$i Bravais vectors : $(displayvectors(bravais(lat); digits = 6))
$i Sublattices : $(nsublats(lat))
$i Names : $(displaynames(lat))
$i Sites : $(display_as_tuple(sublatsites(lat))) --> $(nsites(lat)) total per unit cell\n")
$i Sites : $(display_as_tuple(sublatlengths(lat))) --> $(nsites(lat)) total per unit cell\n")
print(ioindent, lat.supercell)
end

Expand Down Expand Up @@ -445,9 +445,11 @@ allsitepositions(lat::AbstractLattice) = sitepositions(lat.unitcell)
siteposition(i, lat::AbstractLattice) = allsitepositions(lat)[i]
siteposition(i, dn::SVector, lat::AbstractLattice) = siteposition(i, lat) + bravais(lat) * dn

sitesublats(lat::AbstractLattice) = sitesublats(lat.unitcell)

offsets(lat::AbstractLattice) = offsets(lat.unitcell)

sublatsites(lat::AbstractLattice) = sublatsites(lat.unitcell)
sublatlengths(lat::AbstractLattice) = sublatlengths(lat.unitcell)

enumeratesites(lat::AbstractLattice, sublat) = enumeratesites(lat.unitcell, sublat)

Expand Down
Loading

0 comments on commit 743b928

Please sign in to comment.