Skip to content

Commit

Permalink
s'more namedtuples
Browse files Browse the repository at this point in the history
  • Loading branch information
epolack committed Dec 7, 2023
1 parent 44782be commit 5969b9f
Show file tree
Hide file tree
Showing 27 changed files with 69 additions and 65 deletions.
6 changes: 3 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ DFTKREPO = DFTKGH * ".git"
# Setup julia dependencies for docs generation if not yet done
Pkg.activate(@__DIR__)
if !isfile(joinpath(@__DIR__, "Manifest.toml"))
Pkg.develop(Pkg.PackageSpec(path=ROOTPATH))
Pkg.develop(Pkg.PackageSpec(; path=ROOTPATH))
Pkg.instantiate()
end

Expand Down Expand Up @@ -152,9 +152,9 @@ end
# The examples go to docs/literate_build/examples, the .jl files stay where they are
literate_files = map(filter!(endswith(".jl"), extract_paths(PAGES))) do file
if startswith(file, "examples/")
(src=joinpath(ROOTPATH, file), dest=joinpath(SRCPATH, "examples"), example=true)
(; src=joinpath(ROOTPATH, file), dest=joinpath(SRCPATH, "examples"), example=true)
else
(src=joinpath(SRCPATH, file), dest=joinpath(SRCPATH, dirname(file)), example=false)
(; src=joinpath(SRCPATH, file), dest=joinpath(SRCPATH, dirname(file)), example=false)
end
end

Expand Down
4 changes: 2 additions & 2 deletions docs/src/assets/0_pregenerate.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using MKL
using DFTK
using LinearAlgebra
setup_threading(n_blas=2)
setup_threading(; n_blas=2)

let
include("../../../../examples/convergence_study.jl")
Expand Down Expand Up @@ -61,7 +61,7 @@ let
conv_hgh = converge_Ecut(Ecuts, psp_hgh, tol)
println("HGH: $(conv_hgh.Ecut_conv)")

plt = plot(yaxis=:log10, xlabel="Ecut [Eh]", ylabel="Error [Eh]")
plt = plot(; yaxis=:log10, xlabel="Ecut [Eh]", ylabel="Error [Eh]")
plot!(plt, conv_hgh.Ecuts, conv_hgh.errors, label="HGH",
markers=true, linewidth=3)
plot!(plt, conv_upf.Ecuts, conv_upf.errors, label="PseudoDojo NC SR LDA UPF",
Expand Down
2 changes: 1 addition & 1 deletion examples/geometry_optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end;

x0 = vcat(lattice \ [0., 0., 0.], lattice \ [1.4, 0., 0.])
xres = optimize(Optim.only_fg!(fg!), x0, LBFGS(),
Optim.Options(show_trace=true, f_tol=tol))
Optim.Options(; show_trace=true, f_tol=tol))
xmin = Optim.minimizer(xres)
dmin = norm(lattice*xmin[1:3] - lattice*xmin[4:6])
@printf "\nOptimal bond length for Ecut=%.2f: %.3f Bohr\n" Ecut dmin
Expand Down
2 changes: 1 addition & 1 deletion examples/scf_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ basis = PlaneWaveBasis(model; Ecut=5, kgrid=[3, 3, 3]);
# has finished. For this we first define the empty plot canvas
# and an empty container for all the density differences:
using Plots
p = plot(yaxis=:log)
p = plot(; yaxis=:log)
density_differences = Float64[];

# The callback function itself gets passed a named tuple
Expand Down
2 changes: 1 addition & 1 deletion src/common/threading.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import FFTW
using LinearAlgebra

function setup_threading(;n_fft=1, n_blas=Threads.nthreads())
function setup_threading(; n_fft=1, n_blas=Threads.nthreads())
n_julia = Threads.nthreads()
FFTW.set_num_threads(n_fft)
BLAS.set_num_threads(n_blas)
Expand Down
18 changes: 9 additions & 9 deletions src/elements.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,17 @@ or an element name (e.g. `"silicon"`)
function ElementCohenBergstresser(key; lattice_constant=nothing)
# Form factors from Cohen-Bergstresser paper Table 2
# Lattice constants from Table 1
data = Dict(:Si => (form_factors=Dict( 3 => -0.21u"Ry",
8 => 0.04u"Ry",
11 => 0.08u"Ry"),
data = Dict(:Si => (; form_factors=Dict( 3 => -0.21u"Ry",
8 => 0.04u"Ry",
11 => 0.08u"Ry"),
lattice_constant=5.43u"Å"),
:Ge => (form_factors=Dict( 3 => -0.23u"Ry",
8 => 0.01u"Ry",
11 => 0.06u"Ry"),
:Ge => (; form_factors=Dict( 3 => -0.23u"Ry",
8 => 0.01u"Ry",
11 => 0.06u"Ry"),
lattice_constant=5.66u"Å"),
:Sn => (form_factors=Dict( 3 => -0.20u"Ry",
8 => 0.00u"Ry",
11 => 0.04u"Ry"),
:Sn => (; form_factors=Dict( 3 => -0.20u"Ry",
8 => 0.00u"Ry",
11 => 0.04u"Ry"),
lattice_constant=6.49u"Å"),
)

Expand Down
2 changes: 1 addition & 1 deletion src/external/wannier90.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function read_w90_nnkp(fileprefix::String)
# 1st: Index of k-point
# 2nd: Index of periodic image of k+b k-point
# 3rd: Shift vector to get k-point of ikpb to the actual k+b point required
(ik=splitted[1], ikpb=splitted[2], G_shift=splitted[3:5])
(; ik=splitted[1], ikpb=splitted[2], G_shift=splitted[3:5])
end
(; nntot, nnkpts)
end
Expand Down
6 changes: 3 additions & 3 deletions src/plotting.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# This is needed to flag that the plots-dependent code has been loaded
const PLOTS_LOADED = true

function ScfPlotTrace(plt=Plots.plot(yaxis=:log); kwargs...)
function ScfPlotTrace(plt=Plots.plot(; yaxis=:log); kwargs...)
energies = nothing
function callback(info)
if info.stage == :finalize
minenergy = minimum(energies[max(1, end-5):end])
error = abs.(energies .- minenergy)
error[error .== 0] .= NaN
extra = ifelse(:mark in keys(kwargs), (), (mark=:x, ))
extra = ifelse(:mark in keys(kwargs), (), (; mark=:x))
Plots.plot!(plt, error; extra..., kwargs...)
display(plt)
elseif info.n_iter == 1
Expand Down Expand Up @@ -42,7 +42,7 @@ function plot_band_data(kpath::KPathInterpolant, band_data;
to_unit = ustrip(auconvert(unit, 1.0))

# Plot all bands, spins and errors
p = Plots.plot(xlabel="wave vector")
p = Plots.plot(; xlabel="wave vector")
margs = length(kpath) < 70 ? (; markersize=2, markershape=:circle) : (; )
for σ in 1:data.n_spin, iband = 1:data.n_bands, branch in data.kbranches
yerror = nothing
Expand Down
2 changes: 1 addition & 1 deletion src/postprocess/band_structure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function kdistances_and_ticks(kcoords, klabels::Dict, kbranches)
end
end
end
ticks = (distances=tick_distances, labels=tick_labels)
ticks = (; distances=tick_distances, labels=tick_labels)
(; kdistances, ticks)
end

Expand Down
12 changes: 6 additions & 6 deletions src/pseudo/list_psp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ to restrict the displayed files.
# Examples
```julia-repl
julia> list_psp(family="hgh")
julia> list_psp(; family="hgh")
```
will list all HGH-type pseudopotentials and
```julia-repl
julia> list_psp(family="hgh", functional="lda")
julia> list_psp(; family="hgh", functional="lda")
```
will only list those for LDA (also known as Pade in this context)
and
Expand Down Expand Up @@ -47,7 +47,7 @@ function list_psp(element=nothing; family=nothing, functional=nothing, core=noth

f_identifier = joinpath(root, file)
Sys.iswindows() && (f_identifier = replace(f_identifier, "\\" => "/"))
push!(psplist, (identifier=f_identifier, family=f_family,
push!(psplist, (; identifier=f_identifier, family=f_family,
functional=f_functional, element=f_element,
n_elec_valence=parse(Int, f_nvalence[2:end]),
path=joinpath(datadir_psp(), root, file)))
Expand All @@ -59,10 +59,10 @@ function list_psp(element=nothing; family=nothing, functional=nothing, core=noth
psp_per_element = map(per_elem) do elgroup
@assert length(elgroup) > 0
if length(elgroup) == 1
cores = [(core=:fullcore, )]
cores = [(; core=:fullcore)]
else
cores = append!(fill((core=:other, ), length(elgroup) - 2),
[(core=:fullcore, ), (core=:semicore, )])
cores = append!(fill((; core=:other), length(elgroup) - 2),
[(; core=:fullcore), (; core=:semicore)])
end
merge.(sort(elgroup, by=psp -> psp.n_elec_valence), cores)
end
Expand Down
2 changes: 1 addition & 1 deletion src/scf/mixing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ Important `kwargs` passed on to [`χ0Mixing`](@ref)
- `verbose`: Run the GMRES in verbose mode.
- `reltol`: Relative tolerance for GMRES
"""
function HybridMixing(;εr=1.0, kTF=0.8, localization=identity,
function HybridMixing(; εr=1.0, kTF=0.8, localization=identity,
adjust_temperature=IncreaseMixingTemperature(), kwargs...)
χ0terms = [DielectricModel(; εr, kTF, localization),
LdosModel(;adjust_temperature)]
Expand Down
4 changes: 2 additions & 2 deletions src/scf/scf_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
# Cheat support for multidimensional arrays
if length(size(x0)) != 1
x, conv= CROP(x -> vec(f(reshape(x, size(x0)...))), vec(x0), m, max_iter, tol, warming)
return (fixpoint=reshape(x, size(x0)...), converged=conv)
return (; fixpoint=reshape(x, size(x0)...), converged=conv)
end
N = size(x0,1)
T = eltype(x0)
Expand Down Expand Up @@ -110,6 +110,6 @@ function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
fs[:,1] = ftnp1
# fs[:,1] = f(xs[:,1])
end
(fixpoint=xs[:, 1], converged=err < tol)
(; fixpoint=xs[:, 1], converged=err < tol)
end
scf_CROP_solver(m=10) = (f, x0, max_iter; tol=1e-6) -> CROP(x -> f(x) - x, x0, m, max_iter, tol)
2 changes: 1 addition & 1 deletion src/scf/self_consistent_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Overview of parameters:
# Diagonalize `ham` to get the new state
nextstate = next_density(ham, nbandsalg, fermialg; eigensolver, ψ, eigenvalues,
occupation, miniter=1, tol=determine_diagtol(info))
ψ, eigenvalues, occupation, εF, ρout = nextstate
(; ψ, eigenvalues, occupation, εF, ρout) = nextstate

# Update info with results gathered so far
info = (; ham, basis, converged, stage=:iterate, algorithm="SCF",
Expand Down
12 changes: 7 additions & 5 deletions src/terms/Hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ Base.:*(H::Hamiltonian, ψ) = mul!(deepcopy(ψ), H, ψ)
ifft!(ψ_real, H.basis, H.kpoint, ψ[:, iband])
for op in H.optimized_operators
@timing "$(nameof(typeof(op)))" begin
apply!((fourier=Hψ_fourier, real=Hψ_real),
apply!((; fourier=Hψ_fourier, real=Hψ_real),
op,
(fourier=ψ[:, iband], real=ψ_real))
(; fourier=ψ[:, iband], real=ψ_real))
end
end
Hψ[:, iband] .= Hψ_fourier
Expand Down Expand Up @@ -144,9 +144,9 @@ end

if have_divAgrad
@timeit to "divAgrad" begin
apply!((fourier=Hψ[:, iband], real=nothing),
apply!((; fourier=Hψ[:, iband], real=nothing),
H.divAgrad_op,
(fourier=ψ[:, iband], real=nothing),
(; fourier=ψ[:, iband], real=nothing),
ψ_real) # ψ_real used as scratch
end
end
Expand All @@ -162,7 +162,9 @@ end
# Apply the nonlocal operator
if !isnothing(H.nonlocal_op)
@timing "nonlocal" begin
apply!((fourier=Hψ, real=nothing), H.nonlocal_op, (fourier=ψ, real=nothing))
apply!((; fourier=Hψ, real=nothing),
H.nonlocal_op,
(; fourier=ψ, real=nothing))
end
end

Expand Down
4 changes: 3 additions & 1 deletion src/terms/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ function LinearAlgebra.mul!(Hψ::AbstractVector, op::RealFourierOperator, ψ::Ab
Hψ_real = similar(ψ_real)
Hψ_fourier .= 0
Hψ_real .= 0
apply!((real=Hψ_real, fourier=Hψ_fourier), op, (real=ψ_real, fourier=ψ))
apply!((; real=Hψ_real, fourier=Hψ_fourier),
op,
(; real=ψ_real, fourier=ψ))
Hψ .= Hψ_fourier .+ fft(op.basis, op.kpoint, Hψ_real)
end
Expand Down
2 changes: 1 addition & 1 deletion src/terms/pairwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct TermPairwisePotential{TV, Tparams, T} <:Term
end

function ene_ops(term::TermPairwisePotential, basis::PlaneWaveBasis, ψ, occupation; kwargs...)
(E=term.energy, ops=[NoopOperator(basis, kpt) for kpt in basis.kpoints])
(; E=term.energy, ops=[NoopOperator(basis, kpt) for kpt in basis.kpoints])
end
compute_forces(term::TermPairwisePotential, ::PlaneWaveBasis, ψ, occ; kwargs...) = term.forces

Expand Down
2 changes: 1 addition & 1 deletion src/terms/psp_correction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function TermPspCorrection(basis::PlaneWaveBasis)
end

function ene_ops(term::TermPspCorrection, basis::PlaneWaveBasis, ψ, occupation; kwargs...)
(E=term.energy, ops=[NoopOperator(basis, kpt) for kpt in basis.kpoints])
(; E=term.energy, ops=[NoopOperator(basis, kpt) for kpt in basis.kpoints])
end

"""
Expand Down
6 changes: 3 additions & 3 deletions src/terms/terms.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
include("operators.jl")

### Terms
# - A Term is something that, given a state, returns a named tuple (E, hams) with an energy
# - A Term is something that, given a state, returns a named tuple (; E, hams) with an energy
# and a list of RealFourierOperator (for each kpoint).
# - Each term must overload
# `ene_ops(term, basis, ψ, occupation; kwargs...)`
# -> (E::Real, ops::Vector{RealFourierOperator}).
# -> (; E::Real, ops::Vector{RealFourierOperator}).
# - Note that terms are allowed to hold on to references to ψ (eg Fock term),
# so ψ should not mutated after ene_ops

Expand All @@ -27,7 +27,7 @@ A term with a constant zero energy.
"""
struct TermNoop <: Term end
function ene_ops(term::TermNoop, basis::PlaneWaveBasis{T}, ψ, occupation; kwargs...) where {T}
(E=zero(eltype(T)), ops=[NoopOperator(basis, kpt) for kpt in basis.kpoints])
(; E=zero(eltype(T)), ops=[NoopOperator(basis, kpt) for kpt in basis.kpoints])
end

include("Hamiltonian.jl")
Expand Down
4 changes: 2 additions & 2 deletions test/PspUpf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ upf_pseudos = Dict(
:Cr => load_psp(artifact"pd_nc_sr_pbe_standard_0.4.1_upf", "Cu.upf"; rcut=12.0)
)
hgh_pseudos = [
(hgh=load_psp("hgh/pbe/si-q4.hgh"), upf=upf_pseudos[:Si]),
(hgh=load_psp("hgh/pbe/tl-q13.hgh"), upf=upf_pseudos[:Tl])
(; hgh=load_psp("hgh/pbe/si-q4.hgh"), upf=upf_pseudos[:Si]),
(; hgh=load_psp("hgh/pbe/tl-q13.hgh"), upf=upf_pseudos[:Tl])
]
end

Expand Down
4 changes: 2 additions & 2 deletions test/aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
Aqua.test_all(DFTK;
ambiguities=false,
piracies=false,
deps_compat=(check_extras=false, ),
stale_deps=(ignore=[:Primes, ], ))
deps_compat=(; check_extras=false),
stale_deps=(; ignore=[:Primes, ]))
end
10 changes: 5 additions & 5 deletions test/bzmesh_symmetry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
using LinearAlgebra
testcase = TestCases.silicon

args = ((kgrid=[2, 2, 2], kshift=[1/2, 0, 0]),
(kgrid=[2, 2, 2], kshift=[1/2, 1/2, 0]),
(kgrid=[2, 2, 2], kshift=[0, 0, 0]),
(kgrid=[3, 2, 3], kshift=[0, 0, 0]),
(kgrid=[3, 2, 3], kshift=[0, 1/2, 1/2]))
args = ((; kgrid=[2, 2, 2], kshift=[1/2, 0, 0]),
(; kgrid=[2, 2, 2], kshift=[1/2, 1/2, 0]),
(; kgrid=[2, 2, 2], kshift=[0, 0, 0]),
(; kgrid=[3, 2, 3], kshift=[0, 0, 0]),
(; kgrid=[3, 2, 3], kshift=[0, 1/2, 1/2]))
for case in args
model_nosym = model_LDA(testcase.lattice, testcase.atoms, testcase.positions;
symmetries=false)
Expand Down
2 changes: 1 addition & 1 deletion test/chi0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function test_chi0(testcase; symmetries=false, temperature=0, spin_polarization=
ham0 = energy_hamiltonian(basis, nothing, nothing; ρ=ρ0).ham
nbandsalg = is_εF_fixed ? FixedBands(; n_bands_converge=6) : AdaptiveBands(model)
res = DFTK.next_density(ham0, nbandsalg; tol, eigensolver)
scfres = (ham=ham0, res...)
scfres = (; ham=ham0, res...)

# create external small perturbation εδV
n_spin = model.n_spin_components
Expand Down
2 changes: 1 addition & 1 deletion test/compute_density.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
kwargs = ()
n_bands = div(testcase.n_electrons, 2, RoundUp)
if testcase.temperature !== nothing
kwargs = (temperature=testcase.temperature, smearing=DFTK.Smearing.FermiDirac())
kwargs = (; testcase.temperature, smearing=DFTK.Smearing.FermiDirac())
n_bands = div(testcase.n_electrons, 2, RoundUp) + 4
end

Expand Down
8 changes: 4 additions & 4 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
end
basis = PlaneWaveBasis(model; Ecut=5, kgrid=[2, 2, 2], kshift=[0, 0, 0])

response = ResponseOptions(verbose=true)
response = ResponseOptions(; verbose=true)
is_converged = DFTK.ScfConvergenceForce(tol)
scfres = self_consistent_field(basis; is_converged, response)
compute_forces_cart(scfres)
Expand Down Expand Up @@ -78,7 +78,7 @@ end
is_converged = DFTK.ScfConvergenceDensity(1e-10)
scfres = self_consistent_field(basis; is_converged, mixing=KerkerMixing(),
nbandsalg=FixedBands(; n_bands_converge=10),
damping=0.6, response=ResponseOptions(verbose=true))
damping=0.6, response=ResponseOptions(; verbose=true))

ComponentArray(
eigenvalues=hcat([ev[1:10] for ev in scfres.eigenvalues]...),
Expand Down Expand Up @@ -116,7 +116,7 @@ end

is_converged = DFTK.ScfConvergenceDensity(1e-10)
scfres = self_consistent_field(basis; is_converged,
response=ResponseOptions(verbose=true))
response=ResponseOptions(; verbose=true))
compute_forces_cart(scfres)
end

Expand Down Expand Up @@ -161,7 +161,7 @@ end
ρ = zeros(Float64, basis.fft_size..., 1)
is_converged = DFTK.ScfConvergenceDensity(1e-10)
scfres = self_consistent_field(basis; ρ, is_converged,
response=ResponseOptions(verbose=true))
response=ResponseOptions(; verbose=true))
compute_forces_cart(scfres)
end
derivative_ε = let ε = 1e-5
Expand Down
2 changes: 1 addition & 1 deletion test/helium_all_electron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
scfres.energies.total, DFTK.compute_forces(scfres)
end

E, forces = energy_forces(Ecut=5, tol=1e-10)
E, forces = energy_forces(; Ecut=5, tol=1e-10)
@test E -1.5869009433016852 atol=1e-12
@test norm(forces) < 1e-7
end
Loading

0 comments on commit 5969b9f

Please sign in to comment.