Skip to content

Commit

Permalink
Implement gather and scatter operations
Browse files Browse the repository at this point in the history
  • Loading branch information
Joroks authored Jun 23, 2024
1 parent 7201aec commit 212af20
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 10 deletions.
151 changes: 141 additions & 10 deletions src/LAMMPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import MPI
include("api.jl")

export LMP, command, get_natoms, extract_atom, extract_compute, extract_global,
gather_atoms
gather, scatter!

using Preferences

Expand Down Expand Up @@ -361,19 +361,150 @@ function extract_variable(lmp::LMP, name::String, group=nothing)
end
end

function gather_atoms(lmp::LMP, name, T, count)
if T === Int32
dtype = 0
elseif T === Float64
dtype = 1
@deprecate gather_atoms(lmp::LMP, name, T, count) gather(lmp, name, T)


"""
gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, ids::Union{Nothing, Array{Int32}}=nothing)
Gather the named per-atom, per-atom fix, per-atom compute, or fix property/atom-based entities from all processes.
By default (when `ids=nothing`), this method collects data from all atoms in consecutive order according to their IDs.
The optional parameter `ids` determines for which subset of atoms the requested data will be gathered. The returned data will then be ordered according to `ids`
Compute entities have the prefix `c_`, fix entities use the prefix `f_`, and per-atom entites have no prefix.
The returned Array is decoupled from the internal state of the LAMMPS instance.
!!! warning "Type Verification"
Due to how the underlying C-API works, it's not possible to verify the element data-type of fix or compute style data.
Supplying the wrong data-type will not throw an error but will result in nonsensical output
!!! warning "ids"
The optional parameter `ids` only works, if there is a map defined. For example by doing:
`command(lmp, "atom_modify map yes")`
However, LAMMPS only issues a warning if that's the case, which unfortuately cannot be detected through the underlying API.
Starting form LAMMPS version `17 Apr 2024` this should no longer be an issue, as LAMMPS then throws an error instead of a warning.
"""
function gather(lmp::LMP, name::String, T::Union{Type{Int32}, Type{Float64}}, ids::Union{Nothing, Array{Int32}}=nothing)
name == "mass" && error("scattering/gathering mass is currently not supported! Use `extract_atom()` instead.")

count = _get_count(lmp, name)
_T = _get_T(lmp, name)

@assert ismissing(_T) || _T == T "Expected data type $_T got $T instead."

dtype = (T === Float64)
natoms = get_natoms(lmp)
ndata = isnothing(ids) ? natoms : length(ids)
data = Matrix{T}(undef, (count, ndata))

if isnothing(ids)
API.lammps_gather(lmp, name, dtype, count, data)
else
error("Only Int32 or Float64 allowed as T, got $T")
@assert all(1 <= id <= natoms for id in ids)
API.lammps_gather_subset(lmp, name, dtype, count, ndata, ids, data)
end
natoms = get_natoms(lmp)
data = Array{T, 2}(undef, (count, natoms))
API.lammps_gather_atoms(lmp, name, dtype, count, data)

check(lmp)
return data
end

"""
scatter!(lmp::LMP, name::String, data::VecOrMat{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64}
Scatter the named per-atom, per-atom fix, per-atom compute, or fix property/atom-based entity in data to all processes.
By default (when `ids=nothing`), this method scatters data to all atoms in consecutive order according to their IDs.
The optional parameter `ids` determines to which subset of atoms the data will be scattered.
Compute entities have the prefix `c_`, fix entities use the prefix `f_`, and per-atom entites have no prefix.
!!! warning "Type Verification"
Due to how the underlying C-API works, it's not possible to verify the element data-type of fix or compute style data.
Supplying the wrong data-type will not throw an error but will result in nonsensical date being supplied to the LAMMPS instance.
!!! warning "ids"
The optional parameter `ids` only works, if there is a map defined. For example by doing:
`command(lmp, "atom_modify map yes")`
However, LAMMPS only issues a warning if that's the case, which unfortuately cannot be detected through the underlying API.
Starting form LAMMPS version `17 Apr 2024` this should no longer be an issue, as LAMMPS then throws an error instead of a warning.
"""
function scatter!(lmp::LMP, name::String, data::VecOrMat{T}, ids::Union{Nothing, Array{Int32}}=nothing) where T<:Union{Int32, Float64}
name == "mass" && error("scattering/gathering mass is currently not supported! Use `extract_atom()` instead.")

count = _get_count(lmp, name)
_T = _get_T(lmp, name)

@assert ismissing(_T) || _T == T "Expected data type $_T got $T instead."

dtype = (T === Float64)
natoms = get_natoms(lmp)
ndata = isnothing(ids) ? natoms : length(ids)

if data isa Vector
@assert count == 1
@assert ndata == lenght(data)
else
@assert count == size(data,1)
@assert ndata == size(data,2)
end

if isnothing(ids)
API.lammps_scatter(lmp, name, dtype, count, data)
else
@assert all(1 <= id <= natoms for id in ids)
API.lammps_scatter_subset(lmp, name, dtype, count, ndata, ids, data)
end

check(lmp)
end

function _get_count(lmp::LMP, name::String)
# values taken from: https://docs.lammps.org/Classes_atom.html#_CPPv4N9LAMMPS_NS4Atom7extractEPKc

if startswith(name, r"[f,c]_")
if name[1] == 'c'
API.lammps_has_id(lmp, "compute", name[3:end]) != 1 && error("Unknown per atom compute $name")

count_ptr = API.lammps_extract_compute(lmp::LMP, name[3:end], API.LMP_STYLE_ATOM, API.LMP_SIZE_COLS)
else
API.lammps_has_id(lmp, "fix", name[3:end]) != 1 && error("Unknown per atom fix $name")

count_ptr = API.lammps_extract_fix(lmp::LMP, name[3:end], API.LMP_STYLE_ATOM, API.LMP_SIZE_COLS, 0, 0)
end
check(lmp)

count_ptr = reinterpret(Ptr{Cint}, count_ptr)
count = unsafe_load(count_ptr)

# a count of 0 indicates that the entity is a vector. In order to perserve type stability we just treat that as a 1xN Matrix.
return count == 0 ? 1 : count
elseif name in ("mass", "id", "type", "mask", "image", "molecule", "q", "radius", "rmass", "ellipsoid", "line", "tri", "body", "temperature", "heatflow")
return 1
elseif name in ("x", "v", "f", "mu", "omega", "angmom", "torque")
return 3
elseif name == "quat"
return 4
else
error("Unknown per atom property $name")
end
end

function _get_T(lmp::LMP, name::String)
if startswith(name, r"[f,c]_")
return missing # As far as I know, it's not possible to determine the datatype of computes or fixes at runtime
end

type = API.lammps_extract_atom_datatype(lmp, name)
check(lmp)

if type in (API.LAMMPS_INT, API.LAMMPS_INT_2D)
return Int32
elseif type in (API.LAMMPS_DOUBLE, API.LAMMPS_DOUBLE_2D)
return Float64
else
error("Unkown per atom property $name")
end

end

end # module
60 changes: 60 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,64 @@ end
end
end

@testset "gather/scatter" begin
LMP(["-screen", "none"]) do lmp
# setting up example data
command(lmp, "atom_modify map yes")
command(lmp, "region cell block 0 3 0 3 0 3")
command(lmp, "create_box 1 cell")
command(lmp, "lattice sc 1")
command(lmp, "create_atoms 1 region cell")
command(lmp, "mass 1 1")

command(lmp, "compute pos all property/atom x y z")
command(lmp, "fix pos all ave/atom 10 1 10 c_pos[1] c_pos[2] c_pos[3]")

command(lmp, "run 10")
data = zeros(Float64, 3, 27)
subset = Int32.([2,5,10, 5])
data_subset = ones(Float64, 3, 4)

subset_bad1 = Int32.([28])
subset_bad2 = Int32.([0])
subset_bad_data = ones(Float64, 3,1)

@test_throws AssertionError gather(lmp, "x", Int32)
@test_throws AssertionError gather(lmp, "id", Float64)

@test_throws ErrorException gather(lmp, "nonesense", Float64)
@test_throws ErrorException gather(lmp, "c_nonsense", Float64)
@test_throws ErrorException gather(lmp, "f_nonesense", Float64)

@test_throws AssertionError gather(lmp, "x", Float64, subset_bad1)
@test_throws AssertionError gather(lmp, "x", Float64, subset_bad2)

@test_throws ErrorException scatter!(lmp, "nonesense", data)
@test_throws ErrorException scatter!(lmp, "c_nonsense", data)
@test_throws ErrorException scatter!(lmp, "f_nonesense", data)

@test_throws AssertionError scatter!(lmp, "x", subset_bad_data, subset_bad1)
@test_throws AssertionError scatter!(lmp, "x", subset_bad_data, subset_bad2)

@test gather(lmp, "x", Float64) == gather(lmp, "c_pos", Float64) == gather(lmp, "f_pos", Float64)

@test gather(lmp, "x", Float64)[:,subset] == gather(lmp, "x", Float64, subset)
@test gather(lmp, "c_pos", Float64)[:,subset] == gather(lmp, "c_pos", Float64, subset)
@test gather(lmp, "f_pos", Float64)[:,subset] == gather(lmp, "f_pos", Float64, subset)

scatter!(lmp, "x", data)
scatter!(lmp, "f_pos", data)
scatter!(lmp, "c_pos", data)

@test gather(lmp, "x", Float64) == gather(lmp, "c_pos", Float64) == gather(lmp, "f_pos", Float64) == data

scatter!(lmp, "x", data_subset, subset)
scatter!(lmp, "c_pos", data_subset, subset)
scatter!(lmp, "f_pos", data_subset, subset)

@test gather(lmp, "x", Float64, subset) == gather(lmp, "c_pos", Float64, subset) == gather(lmp, "f_pos", Float64, subset) == data_subset

end
end

@test success(pipeline(`$(MPI.mpiexec()) -n 2 $(Base.julia_cmd()) mpitest.jl`, stderr=stderr, stdout=stdout))

0 comments on commit 212af20

Please sign in to comment.