diff --git a/torchani/aev.py b/torchani/aev.py index e15d36eac..f3bdb7f7a 100644 --- a/torchani/aev.py +++ b/torchani/aev.py @@ -66,10 +66,9 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor, vectors12 = vectors12.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) distances12 = vectors12.norm(2, dim=-5) - # 0.95 is multiplied to the cos values to prevent acos from - # returning NaN. - cos_angles = 0.95 * torch.nn.functional.cosine_similarity(vectors12[0], vectors12[1], dim=-5) - angles = torch.acos(cos_angles) + cos_angles = vectors12.prod(0).sum(1) / distances12.prod(0) + # 0.95 is multiplied to the cos values to prevent acos from returning NaN. + angles = torch.acos(0.95 * cos_angles) fcj12 = cutoff_cosine(distances12, Rca) factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta