diff --git a/torchani/aev.py b/torchani/aev.py index f3bdb7f7a..08e194c07 100644 --- a/torchani/aev.py +++ b/torchani/aev.py @@ -137,7 +137,7 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor, cutoff (float): the cutoff inside which atoms are considered pairs shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts """ - coordinates = coordinates.detach() + coordinates = coordinates.detach().masked_fill(padding_mask.unsqueeze(-1), math.nan) cell = cell.detach() num_atoms = padding_mask.shape[1] num_mols = padding_mask.shape[0] @@ -165,8 +165,6 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor, # step 5, compute distances, and find all pairs within cutoff selected_coordinates = coordinates.index_select(1, p12_all.view(-1)).view(num_mols, 2, -1, 3) distances = (selected_coordinates[:, 0, ...] - selected_coordinates[:, 1, ...] + shift_values).norm(2, -1) - padding_mask = padding_mask.index_select(1, p12_all.view(-1)).view(2, -1).any(0) - distances.masked_fill_(padding_mask, math.inf) in_cutoff = (distances <= cutoff).nonzero() molecule_index, pair_index = in_cutoff.unbind(1) molecule_index *= num_atoms @@ -188,7 +186,7 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: floa (molecules, atoms, 3) for atom coordinates. cutoff (float): the cutoff inside which atoms are considered pairs """ - coordinates = coordinates.detach() + coordinates = coordinates.detach().masked_fill(padding_mask.unsqueeze(-1), math.nan) current_device = coordinates.device num_atoms = padding_mask.shape[1] num_mols = padding_mask.shape[0] @@ -197,8 +195,6 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: floa pair_coordinates = coordinates.index_select(1, p12_all_flattened).view(num_mols, 2, -1, 3) distances = (pair_coordinates[:, 0, ...] - pair_coordinates[:, 1, ...]).norm(2, -1) - padding_mask = padding_mask.index_select(1, p12_all_flattened).view(num_mols, 2, -1).any(dim=1) - distances.masked_fill_(padding_mask, math.inf) in_cutoff = (distances <= cutoff).nonzero() molecule_index, pair_index = in_cutoff.unbind(1) molecule_index *= num_atoms