Skip to content

Commit

Permalink
Performance improvements (#62)
Browse files Browse the repository at this point in the history
* simplify and add threading

* add combination call for forces&virial
  • Loading branch information
tjjarvinen authored Nov 30, 2023
1 parent a3302c7 commit 9127bcf
Showing 1 changed file with 68 additions and 27 deletions.
95 changes: 68 additions & 27 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ for ace_method in [ :ace_energy, :ace_forces, :ace_virial, :ace_atom_energies ]
cutoff_unit=calc.cutoff_unit,
kwargs...
)
tmp = asyncmap( calc ) do V
$ace_method(V, at;
tmp = map( calc ) do V
Threads.@spawn $ace_method(V, at;
domain=domain,
executor=executor,
ntasks=ntasks,
Expand All @@ -89,7 +89,7 @@ for ace_method in [ :ace_energy, :ace_forces, :ace_virial, :ace_atom_energies ]
kwargs...
)
end
return sum( tmp )
return sum(fetch, tmp)
end
end
end
Expand Down Expand Up @@ -130,29 +130,18 @@ function ace_forces(
kwargs...
)
nlist = neighborlist(at, get_cutoff(V; cutoff_unit=cutoff_unit) )
F = Folds.sum( collect(chunks(domain, ntasks)), executor ) do (d, _)
ace_forces(V, at, nlist; domain=d)
end
return F * (energy_unit / length_unit)
end


function ace_forces(
V, at, nlist;
domain=1:length(at),
kwargs...
)
f = zeros(SVector{3, Float64}, length(at))
for i in domain
j, R, Z = neigsz(nlist, at, i)
_, tmp = ace_evaluate_d(V, R, Z, _atomic_number(at,i))

for k in eachindex(j)
f[j[k]] -= tmp.dV[k]
F = Folds.sum( collect(chunks(domain, ntasks)), executor ) do (sub_domain, _)
f = zeros(SVector{3, Float64}, length(at))
for i in sub_domain
j, R, Z = neigsz(nlist, at, i)
_, tmp = ace_evaluate_d(V, R, Z, _atomic_number(at,i))

f[j] -= tmp.dV
f[i] += sum(tmp.dV)
end
f[i] += sum(tmp.dV)
f
end
return f
return F * (energy_unit / length_unit)
end


Expand Down Expand Up @@ -232,9 +221,7 @@ end

function ace_energy_forces_virial(pot, data; kwargs...)
E = ace_energy(pot, data; kwargs...)
F = ace_forces(pot, data; kwargs...)
V = ace_virial(pot, data; kwargs...)
#return Dict("energy"=>E, "forces"=>F, "virial"=>V)
F, V = ace_forces_virial(pot, data; kwargs...)
return (; :energy=>E, :forces=>F, :virial=>V)
end

Expand All @@ -244,6 +231,60 @@ function ace_forces_virial(pot, data; kwargs...)
return (; :forces=>F, :virial=>V)
end

function ace_forces_virial(pot::ACEpotential, data; kwargs...)
tmp = map( pot.potentials ) do pot
Threads.@spawn _ace_forces_virial( pot, data, kwargs...)
end
F_V = sum( tmp ) do t
f, v = fetch(t)
[f, v]
end
return (; :forces=>F_V[1], :virial=>F_V[2] )
end


function _ace_forces_virial(::ACE1.OneBody, as::AbstractSystem; energy_unit=default_energy, length_unit=default_length, kwargs...)
T = eltype( ustrip.( position(as, 1) ) )
F = [ SVector{3}( zeros(T, 3) ) * (energy_unit / length_unit) for _ in 1:length(as) ]
V = SMatrix{3,3}(zeros(T, 3,3)) * energy_unit
return (; :forces => F, :virial => V)
end

function _ace_forces_virial(
V,
at;
domain=1:length(at),
executor=ThreadedEx(),
ntasks=Threads.nthreads(),
energy_unit=default_energy,
length_unit=default_length,
cutoff_unit=default_length,
kwargs...
)
nlist = neighborlist(at, get_cutoff(V; cutoff_unit=cutoff_unit) )
F_V = Folds.sum( collect(chunks(domain, ntasks)), executor ) do (sub_domain, _)
f = zeros(SVector{3, Float64}, length(at))
virial_sum = zeros(SMatrix{3,3,Float64})
for i in sub_domain
j, R, Z = neigsz(nlist, at, i)
_, tmp = ace_evaluate_d(V, R, Z, _atomic_number(at,i))

f[j] -= tmp.dV
f[i] += sum(tmp.dV)

virial_sum -= sum( zip(R, tmp.dV) ) do (Rⱼ, dVⱼ)
dVⱼ * Rⱼ'
end
end
[f, virial_sum]
end
return (;
:forces => F_V[1] * energy_unit / length_unit,
:virial => F_V[2] * energy_unit
)
end



## Individual atom energies

Expand Down

0 comments on commit 9127bcf

Please sign in to comment.