Skip to content

Commit

Permalink
Simplify pair computation on AEV (#519)
Browse files Browse the repository at this point in the history
* Simplify pair computation on AEV

* save

Co-authored-by: Farhad Ramezanghorbani <farhadrgh@users.noreply.github.com>
  • Loading branch information
zasdfgbnm and farhadrgh authored Nov 2, 2020
1 parent d0ab8a8 commit 29606ef
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
cutoff (float): the cutoff inside which atoms are considered pairs
shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
"""
coordinates = coordinates.detach()
coordinates = coordinates.detach().masked_fill(padding_mask.unsqueeze(-1), math.nan)
cell = cell.detach()
num_atoms = padding_mask.shape[1]
num_mols = padding_mask.shape[0]
Expand Down Expand Up @@ -165,8 +165,6 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
# step 5, compute distances, and find all pairs within cutoff
selected_coordinates = coordinates.index_select(1, p12_all.view(-1)).view(num_mols, 2, -1, 3)
distances = (selected_coordinates[:, 0, ...] - selected_coordinates[:, 1, ...] + shift_values).norm(2, -1)
padding_mask = padding_mask.index_select(1, p12_all.view(-1)).view(2, -1).any(0)
distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms
Expand All @@ -188,7 +186,7 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: floa
(molecules, atoms, 3) for atom coordinates.
cutoff (float): the cutoff inside which atoms are considered pairs
"""
coordinates = coordinates.detach()
coordinates = coordinates.detach().masked_fill(padding_mask.unsqueeze(-1), math.nan)
current_device = coordinates.device
num_atoms = padding_mask.shape[1]
num_mols = padding_mask.shape[0]
Expand All @@ -197,8 +195,6 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: floa

pair_coordinates = coordinates.index_select(1, p12_all_flattened).view(num_mols, 2, -1, 3)
distances = (pair_coordinates[:, 0, ...] - pair_coordinates[:, 1, ...]).norm(2, -1)
padding_mask = padding_mask.index_select(1, p12_all_flattened).view(num_mols, 2, -1).any(dim=1)
distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms
Expand Down

0 comments on commit 29606ef

Please sign in to comment.