Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

indices kwarg in selectors #84

Merged
merged 11 commits into from
Sep 15, 2020
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The Quantica.jl package provides an expressive API to build arbitrary quantum sy

# Exported API
- `lattice`, `sublat`, `bravais`: build lattices
- `dims`, `sites`: inspect lattices
- `dims`, `sitepositions`, `siteindices`: inspect lattices
- `hopping`, `onsite`, `siteselector`, `hopselector`: build tightbinding models
- `hamiltonian`: build a Hamiltonian from tightbinding model and a lattice
- `parametric`, `@onsite!`, `@hopping!`, `parameters`: build a parametric Hamiltonian
Expand Down
3 changes: 2 additions & 1 deletion src/Quantica.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ using ExprTools

using SparseArrays: getcolptr, AbstractSparseMatrix

export sublat, bravais, lattice, dims, sites, supercell, unitcell,
export sublat, bravais, lattice, dims, supercell, unitcell,
hopping, onsite, @onsite!, @hopping!, parameters, siteselector, hopselector,
sitepositions, siteindices,
ket, randomkets,
hamiltonian, parametric, bloch, bloch!, optimize!, similarmatrix,
flatten, wrap, transform!, combine,
Expand Down
110 changes: 60 additions & 50 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ norbitals(h::Hamiltonian) = length.(h.orbitals)
function meandist(h::Hamiltonian)
distsum = 0.0
num = 0
ss = sites(h.lattice)
ss = allsitepositions(h.lattice)
br = h.lattice.bravais.matrix
for (dn, row, col) in nonzero_indices(h)
if row != col
Expand All @@ -349,6 +349,28 @@ end

# External API #

"""
sitepositions(lat::AbstractLattice; kw...)
sitepositions(h::Hamiltonian; kw...)

Build a generator of the positions of sites in the lattice unitcell. Only sites specified
by `siteselector(kw...)` are selected, see `siteselector` for details.

"""
sitepositions(lat::AbstractLattice; kw...) = sitepositions(lat, siteselector(;kw...))
sitepositions(h::Hamiltonian; kw...) = sitepositions(h.lattice, siteselector(;kw...))

"""
siteindices(lat::AbstractLattice; kw...)
siteindices(lat::Hamiltonian; kw...)

Build a generator of the unique indices of sites in the lattice unitcell. Only sites
specified by `siteselector(kw...)` are selected, see `siteselector` for details.

"""
siteindices(lat::AbstractLattice; kw...) = siteindices(lat, siteselector(;kw...))
siteindices(h::Hamiltonian; kw...) = siteindices(h.lattice, siteselector(;kw...))

"""
transform!(f::Function, h::Hamiltonian)

Expand All @@ -361,8 +383,6 @@ function transform!(f, h::Hamiltonian)
end

# Indexing #
"""
"""
Base.push!(h::Hamiltonian{<:Any,L}, dn::NTuple{L,Int}) where {L} = push!(h, SVector(dn...))
Base.push!(h::Hamiltonian{<:Any,L}, dn::Vararg{Int,L}) where {L} = push!(h, SVector(dn...))
function Base.push!(h::Hamiltonian{<:Any,L}, dn::SVector{L,Int}) where {L}
Expand Down Expand Up @@ -559,45 +579,40 @@ end
applyterms!(builder, terms...) = foreach(term -> applyterm!(builder, term), terms)

applyterm!(builder::IJVBuilder, term::Union{OnsiteTerm, HoppingTerm}) =
applyterm!(builder, term, sublats(term, builder.lat))
applyterm!(builder, term)

function applyterm!(builder::IJVBuilder{L}, term::OnsiteTerm, termsublats) where {L}
selector = term.selector
function applyterm!(builder::IJVBuilder{L}, term::OnsiteTerm) where {L}
lat = builder.lat
for s in termsublats
is = siterange(lat, s)
dn0 = zero(SVector{L,Int})
ijv = builder[dn0]
offset = lat.unitcell.offsets[s]
for i in is
isinregion(i, dn0, selector.region, lat) || continue
r = lat.unitcell.sites[i]
v = toeltype(term(r, r), eltype(builder), builder.orbs[s], builder.orbs[s])
push!(ijv, (i, i, v))
end
dn0 = zero(SVector{L,Int})
ijv = builder[dn0]
allpos = allsitepositions(lat)
rsel = resolve(term.selector, lat)
for s in sublats(rsel), i in siteindices(rsel, s)
r = allpos[i]
v = toeltype(term(r, r), eltype(builder), builder.orbs[s], builder.orbs[s])
push!(ijv, (i, i, v))
end
return nothing
end

function applyterm!(builder::IJVBuilder{L}, term::HoppingTerm, termsublats) where {L}
selector = term.selector
L > 0 && checkinfinite(selector)
function applyterm!(builder::IJVBuilder{L}, term::HoppingTerm) where {L}
lat = builder.lat
for (s2, s1) in termsublats # Each is a Pair s2 => s1
is, js = siterange(lat, s1), siterange(lat, s2)
dns = dniter(selector.dns, Val(L))
rsel = resolve(term.selector, lat)
L > 0 && checkinfinite(rsel)
allpos = allsitepositions(lat)
for (s2, s1) in sublats(rsel) # Each is a Pair s2 => s1
dns = dniter(rsel)
for dn in dns
foundlink = false
ijv = builder[dn]
for j in js
sitej = lat.unitcell.sites[j]
for j in source_candidates(rsel, s2)
sitej = allpos[j]
rsource = sitej - lat.bravais.matrix * dn
itargets = targets(builder, selector.range, rsource, s1)
for i in itargets
isselfhopping((i, j), (s1, s2), dn) && continue
isinregion((i, j), (dn, zero(dn)), selector.region, lat) || continue
is = targets(builder, rsel.selector.range, rsource, s1)
for i in is
((i, j), (dn, zero(dn))) in rsel || continue
foundlink = true
rtarget = lat.unitcell.sites[i]
rtarget = allsitepositions(lat)[i]
r, dr = _rdr(rsource, rtarget)
v = toeltype(term(r, dr), eltype(builder), builder.orbs[s1], builder.orbs[s2])
push!(ijv, (i, j, v))
Expand Down Expand Up @@ -629,14 +644,11 @@ toeltype(t::SVector{N}, ::Type{S}, t1::NTuple{N}) where {N,S<:SVector} = padtoty
toeltype(t::Array, x...) = throw(ArgumentError("Array input in model, please use StaticArrays instead (e.g. SA[1 0; 0 1] instead of [1 0; 0 1])"))
toeltype(t, x...) = throw(DimensionMismatch("Dimension mismatch between model and Hamiltonian. Does the `orbitals` kwarg in your `hamiltonian` match your model?"))

dniter(dns::Missing, ::Val{L}) where {L} = BoxIterator(zero(SVector{L,Int}))
dniter(dns, ::Val) = dns

function targets(builder, range::Real, rsource, s1)
!isfinite(range) && return targets(builder, missing, rsource, s1)
if !isassigned(builder.kdtrees, s1)
sites = view(builder.lat.unitcell.sites, siterange(builder.lat, s1))
(builder.kdtrees[s1] = KDTree(sites))
sitepos = sitepositions(builder.lat.unitcell, s1)
(builder.kdtrees[s1] = KDTree(sitepos))
end
targetlist = inrange(builder.kdtrees[s1], rsource, range)
targetlist .+= builder.lat.unitcell.offsets[s1]
Expand All @@ -645,12 +657,10 @@ end

targets(builder, range::Missing, rsource, s1) = siterange(builder.lat, s1)

checkinfinite(selector) =
selector.dns === missing && (selector.range === missing || !isfinite(selector.range)) &&
checkinfinite(rs) =
rs.selector.dns === missing && (rs.selector.range === missing || !isfinite(rs.selector.range)) &&
throw(ErrorException("Tried to implement an infinite-range hopping on an unbounded lattice"))

isselfhopping((i, j), (s1, s2), dn) = i == j && s1 == s2 && iszero(dn)

#######################################################################
# Matrix(::KetModel, ::Hamiltonian), and Vector
#######################################################################
Expand Down Expand Up @@ -691,7 +701,7 @@ end

function Base.Matrix(kms, h::Hamiltonian; orthogonal = false)
M = orbitaltype(h)
kmat = Matrix{M}(undef, size(h, 2), length(kms))
kmat = zeros(M, size(h, 2), length(kms))
for (j, km) in enumerate(kms)
kvec = view(kmat, :, j)
ket!(kvec, km, h)
Expand All @@ -703,15 +713,15 @@ end
function ket!(k, km::KetModel, h)
M = eltype(k)
fill!(k, zero(M))
hsites = sites(h.lattice)
hsites = allsitepositions(h.lattice)
for term in km.model.terms
region = term.selector.region
ss = sublats(term, h.lattice)
rs = resolve(term.selector, h.lattice)
ss = sublats(rs)
for s in ss
orbs = h.orbitals[s]
is = siterange(h.lattice, s)
for i in is
isinregion(i, region, h.lattice) || continue
i in rs || continue
r = hsites[i]
k[i] += generate_amplitude(km, term, r, M, orbs)
end
Expand Down Expand Up @@ -799,15 +809,15 @@ wrap_dn(olddn::SVector, newdn::SVector, supercell::SMatrix) = olddn - supercell
applymodifiers(val, lat, inds, dns) = val

function applymodifiers(val, lat, inds, dns, m::UniformModifier, ms...)
selected = m.selector(lat, inds, dns)
selected = (inds, dns) in m.selector
val´ = selected ? m.f(val) : val
return applymodifiers(val´, lat, inds, dns, ms...)
end

function applymodifiers(val, lat, (row, col), (dnrow, dncol), m::OnsiteModifier, ms...)
selected = m.selector(lat, (row, col), (dnrow, dncol))
selected = ((row, col), (dnrow, dncol)) in m.selector
if selected
r = sites(lat)[col] + bravais(lat) * dncol
r = allsitepositions(lat)[col] + bravais(lat) * dncol
val´ = selected ? m(val, r) : val
else
val´ = val
Expand All @@ -816,10 +826,10 @@ function applymodifiers(val, lat, (row, col), (dnrow, dncol), m::OnsiteModifier,
end

function applymodifiers(val, lat, (row, col), (dnrow, dncol), m::HoppingModifier, ms...)
selected = m.selector(lat, (row, col), (dnrow, dncol))
selected = ((row, col), (dnrow, dncol)) in m.selector
if selected
br = bravais(lat)
r, dr = _rdr(sites(lat)[col] + br * dncol, sites(lat)[row] + br * dnrow)
r, dr = _rdr(allsitepositions(lat)[col] + br * dncol, allsitepositions(lat)[row] + br * dnrow)
val´ = selected ? m(val, r, dr) : val
else
val´ = val
Expand Down Expand Up @@ -1393,7 +1403,7 @@ function flatten(unitcell::Unitcell, norbs::NTuple{S,Int}) where {S}
ns´ = last(offsets´)
sites´ = similar(unitcell.sites, ns´)
i = 1
for sl in 1:S, site in sites(unitcell, sl), rep in 1:norbs[sl]
for sl in 1:S, site in sitepositions(unitcell, sl), rep in 1:norbs[sl]
sites´[i] = site
i += 1
end
Expand Down
28 changes: 8 additions & 20 deletions src/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,12 @@ function uniquename(allnames, name, i)
return newname in allnames ? uniquename(allnames, name, i + 1) : newname
end

sites(u::Unitcell) = u.sites
sites(u::Unitcell, s::Int) = view(u.sites, siterange(u, s))

siteindex(u::Unitcell, sublat, idx) = idx + u.offsets[sublat]
sitepositions(u::Unitcell) = u.sites
sitepositions(u::Unitcell, s::Int) = view(u.sites, siterange(u, s))

siterange(u::Unitcell, sublat) = (1+u.offsets[sublat]):u.offsets[sublat+1]

enumeratesites(u::Unitcell, sublat) = ((i, sites(u)[i]) for i in siterange(u, sublat))
enumeratesites(u::Unitcell, sublat) = ((i, sitepositions(u)[i]) for i in siterange(u, sublat))

nsites(u::Unitcell) = length(u.sites)
nsites(u::Unitcell, sublat) = sublatsites(u)[sublat]
Expand Down Expand Up @@ -472,7 +470,7 @@ sublats(lat::AbstractLattice) = sublats(lat.unitcell)

siterange(lat::AbstractLattice, sublat) = siterange(lat.unitcell, sublat)

siteindex(lat::AbstractLattice, sublat, idx) = siteindex(lat.unitcell, sublat, idx)
allsitepositions(lat::AbstractLattice) = sitepositions(lat.unitcell)

offsets(lat::AbstractLattice) = offsets(lat.unitcell)

Expand Down Expand Up @@ -505,16 +503,6 @@ ismasked(lat::Superlattice) = ismasked(lat.supercell)
maskranges(lat::Superlattice) = (1:nsites(lat), lat.supercell.cells.indices...)
maskranges(lat::Lattice) = (1:nsites(lat),)

# External API #

"""
sites(lat[, sublat::Int])

Extract the positions of all sites in a lattice, or in a specific sublattice
"""
sites(lat::AbstractLattice) = sites(lat.unitcell)
sites(lat::AbstractLattice, s) = sites(lat.unitcell, s)

"""
transform!(f::Function, lat::Lattice)

Expand Down Expand Up @@ -694,7 +682,7 @@ function _supercell(lat::AbstractLattice{E,L}, scmatrix::SMatrix{L,L´,Int}, reg
continue
end
r0 = brmatrix * dnvec
for (i, site) in enumerate(lat.unitcell.sites)
for (i, site) in enumerate(allsitepositions(lat))
r = site + r0
mask[i, dntup...] = in_supercell && regionfunc(r)
end
Expand Down Expand Up @@ -737,7 +725,7 @@ function supercell_cells(lat::Lattice{E,L}, regionfunc, in_supercell_func, seed)
throw(ArgumentError("`region` seems unbounded (after $TOOMANYITERS iterations)"))
in_supercell = in_supercell_func(toSVector(Int, dn))
r0 = bravais * toSVector(Int, dn)
for site in lat.unitcell.sites
for site in allsitepositions(lat)
r = r0 + site
found = in_supercell && regionfunc(r)
if found || !foundfirst
Expand Down Expand Up @@ -873,8 +861,8 @@ function supercell_offsets(lat::Superlattice)
end

function supercell_sites(lat::Superlattice)
newsites = similar(lat.unitcell.sites, nsites(lat.supercell))
oldsites = lat.unitcell.sites
oldsites = allsitepositions(lat)
newsites = similar(oldsites, nsites(lat.supercell))
bravais = lat.bravais.matrix
foreach_supersite((s, oldi, dn, newi) -> newsites[newi] = bravais * dn + oldsites[oldi], lat)
return newsites
Expand Down
Loading