Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refinement of density and forces #881

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ include("response/chi0.jl")
include("response/hessian.jl")
export compute_current
include("postprocess/current.jl")
include("postprocess/refine.jl")

# Workarounds
include("workarounds/dummy_inplace_fft.jl")
Expand Down
129 changes: 129 additions & 0 deletions src/postprocess/refine.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Refinement of some quantities of interest (density, forces) following the
# strategy described in [CDKL2021].
#
# [CDKL2021]:
# E. Cancès, G. Dusson, G. Kemlin, and A. Levitt
# *Practical error bounds for properties in plane-wave electronic structure
# calculations* Preprint, 2021. [arXiv](https://arxiv.org/abs/2111.01470)

@kwdef struct PreRefinementOutputs
basis_ref
ψr
ρr
occupation
schur_residual
δρ
end

function refine_scfres(scfres, basis_ref::PlaneWaveBasis{T}; ΩpK_tol,

Check warning on line 18 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L18

Added line #L18 was not covered by tests
occ_threshold=default_occupation_threshold(T), kwargs...) where {T}
basis = scfres.basis

Check warning on line 20 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L20

Added line #L20 was not covered by tests

@assert basis.model.lattice == basis_ref.model.lattice
@assert length(basis.kpoints) == length(basis_ref.kpoints)
@assert all(basis.kpoints[ik].coordinate == basis_ref.kpoints[ik].coordinate

Check warning on line 24 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L22-L24

Added lines #L22 - L24 were not covered by tests
for ik in 1:length(basis.kpoints))

haskey(scfres, :pre_refinement) && error() # TODO decide how to handle this...

Check warning on line 27 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L27

Added line #L27 was not covered by tests

ψ, occ = select_occupied_orbitals(basis, scfres.ψ, scfres.occupation; threshold=occ_threshold)
ψr = transfer_blochwave(ψ, basis, basis_ref)
ρr = transfer_density(scfres.ρ, basis, basis_ref)
_, ham = energy_hamiltonian(basis, ψ, occ; ρ=scfres.ρ)
_, hamr = energy_hamiltonian(basis_ref, ψr, occ; ρ=ρr )

Check warning on line 33 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L29-L33

Added lines #L29 - L33 were not covered by tests

# Compute the residual R(P) and remove the virtual orbitals, as required
# in src/scf/newton.jl

# TODO fix compute_projected_gradient and replace
res = [proj_tangent_kpt(hamr.blocks[ik] * ψk, ψk) for (ik, ψk) in enumerate(ψr)]

Check warning on line 39 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L39

Added line #L39 was not covered by tests

# Compute M^{-1} R(P), with M^{-1} defined in [CDKL2021]
P = [PreconditionerTPA(basis_ref, kpt) for kpt in basis_ref.kpoints]
map(zip(P, ψr)) do (Pk, ψk)
precondprep!(Pk, ψk)

Check warning on line 44 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L42-L44

Added lines #L42 - L44 were not covered by tests
end

function apply_M(φk, Pk, δφnk, n)
proj_tangent_kpt!(δφnk, φk)
δφnk = sqrt.(Pk.mean_kin[n] .+ Pk.kin) .* δφnk
proj_tangent_kpt!(δφnk, φk)
δφnk = sqrt.(Pk.mean_kin[n] .+ Pk.kin) .* δφnk
proj_tangent_kpt!(δφnk, φk)

Check warning on line 52 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L47-L52

Added lines #L47 - L52 were not covered by tests
end

function apply_inv_M(φk, Pk, δφnk, n)
proj_tangent_kpt!(δφnk, φk)
op(x) = apply_M(φk, Pk, x, n)
function f_ldiv!(x, y)
x .= proj_tangent_kpt(y, φk)
x ./= (Pk.mean_kin[n] .+ Pk.kin)
proj_tangent_kpt!(x, φk)

Check warning on line 61 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L55-L61

Added lines #L55 - L61 were not covered by tests
end
J = LinearMap{eltype(φk)}(op, size(δφnk, 1))
δφnk = IterativeSolvers.cg(J, δφnk, Pl=FunctionPreconditioner(f_ldiv!),

Check warning on line 64 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
verbose=false, reltol=0, abstol=1e-15)
proj_tangent_kpt!(δφnk, φk)

Check warning on line 66 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L66

Added line #L66 was not covered by tests
end

function apply_metric(φ, P, δφ, A::Function)
map(enumerate(δφ)) do (ik, δφk)
Aδφk = similar(δφk)
φk = φ[ik]
for n = 1:size(δφk,2)
Aδφk[:,n] = A(φk, P[ik], δφk[:,n], n)

Check warning on line 74 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L69-L74

Added lines #L69 - L74 were not covered by tests
end
Aδφk

Check warning on line 76 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L76

Added line #L76 was not covered by tests
end
end

# Compute the projection of the residual onto the high and low frequencies
resLF = transfer_blochwave(res, basis_ref, basis)
resHF = res - transfer_blochwave(resLF, basis, basis_ref)

Check warning on line 82 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L81-L82

Added lines #L81 - L82 were not covered by tests

# - Compute M^{-1}_22 R_2(P)
e2 = apply_metric(ψr, P, resHF, apply_inv_M)

Check warning on line 85 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L85

Added line #L85 was not covered by tests

# Apply Ω+K to M^{-1}_22 R_2(P)
Λ = map(enumerate(ψr)) do (ik, ψk)
Hk = hamr.blocks[ik]
Hψk = Hk * ψk
ψk'Hψk

Check warning on line 91 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L88-L91

Added lines #L88 - L91 were not covered by tests
end # Rayleigh coefficients
ΩpKe2 = apply_Ω(e2, ψr, hamr, Λ) .+ apply_K(basis_ref, e2, ψr, ρr, occ)
ΩpKe2 = transfer_blochwave(ΩpKe2, basis_ref, basis)

Check warning on line 94 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L93-L94

Added lines #L93 - L94 were not covered by tests

rhs = resLF - ΩpKe2

Check warning on line 96 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L96

Added line #L96 was not covered by tests

# Invert Ω+K on the small space: for now, only solve_ΩplusK_split is MPI-compatible:
#e1 = solve_ΩplusK(basis, ψ, rhs, occ; tol=ΩpK_tol).δψ
e1 = solve_ΩplusK_split(ham, scfres.ρ, ψ, occ, scfres.εF, scfres.eigenvalues, rhs;

Check warning on line 100 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L100

Added line #L100 was not covered by tests
tol=ΩpK_tol, occupation_threshold=zero(eltype(first(occ))),
kwargs...).δψ

e1 = transfer_blochwave(e1, basis, basis_ref)
schur_residual = e1 + e2

Check warning on line 105 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L104-L105

Added lines #L104 - L105 were not covered by tests

# Use the Schur residual to compute (minus) the first-order correction to
# the density.
δρ = compute_δρ(basis_ref, ψr, schur_residual, occ)

Check warning on line 109 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L109

Added line #L109 was not covered by tests

merge(scfres, (pre_refinement = PreRefinementOutputs(; basis_ref, ψr, ρr, occupation=occ,

Check warning on line 111 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L111

Added line #L111 was not covered by tests
schur_residual, δρ),))
end

function refine_density(scfres)
haskey(scfres, :pre_refinement) || error() # TODO decide...
scfres.pre_refinement.ρr - scfres.pre_refinement.δρ

Check warning on line 117 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L115-L117

Added lines #L115 - L117 were not covered by tests
end

function refine_forces(scfres; forces=nothing)
haskey(scfres, :pre_refinement) || error() # TODO decide...
isnothing(forces) && (forces = compute_forces(scfres)) # TODO use DiffResults?
pre_ref = scfres.pre_refinement
dF = ForwardDiff.derivative(ε -> compute_forces(pre_ref.basis_ref,

Check warning on line 124 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L120-L124

Added lines #L120 - L124 were not covered by tests
pre_ref.ψr.+ε.*pre_ref.schur_residual,
pre_ref.occupation;
ρ=pre_ref.ρr+ε.*pre_ref.δρ), 0)
forces - dF

Check warning on line 128 in src/postprocess/refine.jl

View check run for this annotation

Codecov / codecov/patch

src/postprocess/refine.jl#L128

Added line #L128 was not covered by tests
end
Loading