Skip to content

Commit

Permalink
Merge pull request #409 from atomistic-machine-learning/mg/nbl_md
Browse files Browse the repository at this point in the history
Added routine to NeighborListMD to properly filter out pairs due to the buffer region
  • Loading branch information
mgastegger authored May 13, 2022
2 parents c2d9816 + 1489415 commit b83f0a1
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions src/schnetpack/md/neighborlist_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_neighbors(self, inputs: Dict[str, torch.Tensor]):
Returns:
torch.tensor: indices of neighbors.
"""
# TODO: check if this is better or building Rij after the full indices have been generated
# TODO: check consistent wrapping
atom_types = inputs[properties.Z]
positions = inputs[properties.R]
n_atoms = inputs[properties.n_atoms]
Expand Down Expand Up @@ -153,29 +153,40 @@ def get_neighbors(self, inputs: Dict[str, torch.Tensor]):
# Move everything to correct device
neighbor_idx = {p: neighbor_idx[p].to(positions.device) for p in neighbor_idx}

# filter out all pairs in the buffer zone
neighbor_idx = self._filter_indices(positions, neighbor_idx)

return neighbor_idx

def _update_Rij(self, inputs: Dict[str, torch.tensor], mol_idx: torch.tensor):
def _filter_indices(
self, positions: torch.Tensor, neighbor_idx: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Update the interatomic distances.
Routine for filtering out pair indices and offets due to the buffer region, which would otherwise slow down
the calculators.
Args:
inputs (dict(str, torch.Tensor)): Input batch.
mol_idx (torch.Tensor): Molecule indices
"""
R = inputs[properties.R]
idx_i = self.molecular_indices[mol_idx][properties.idx_i]
idx_j = self.molecular_indices[mol_idx][properties.idx_j]
positions (torch.Tensor): Tensor of the Cartesian atom positions.
neighbor_idx (dict(str, torch.Tensor)): Dictionary containing pair indices and offets
new_Rij = R[idx_j] - R[idx_i]
Returns:
dict(str, torch.Tensor): Dictionary containing updated pair indices and offets
"""
offsets = neighbor_idx[properties.offsets]
idx_i = neighbor_idx[properties.idx_i]
idx_j = neighbor_idx[properties.idx_j]

cell = inputs[properties.cell]
Rij = positions[idx_j] - positions[idx_i] + offsets
d_ij = torch.linalg.norm(Rij, dim=1)
d_ij_filter = d_ij <= self.cutoff

if cell is not None:
offsets = self.molecular_indices[mol_idx][properties.offsets].to(cell.dtype)
new_Rij = new_Rij + offsets.mm(cell)
neighbor_idx[properties.idx_i] = neighbor_idx[properties.idx_i][d_ij_filter]
neighbor_idx[properties.idx_j] = neighbor_idx[properties.idx_j][d_ij_filter]
neighbor_idx[properties.offsets] = neighbor_idx[properties.offsets][
d_ij_filter, :
]

self.molecular_indices[mol_idx][properties.Rij] = new_Rij
return neighbor_idx

@staticmethod
def _split_batch(
Expand Down

0 comments on commit b83f0a1

Please sign in to comment.