diff --git a/tests/test_cuaev.py b/tests/test_cuaev.py index 22c612431..d5a6e4f8b 100644 --- a/tests/test_cuaev.py +++ b/tests/test_cuaev.py @@ -94,6 +94,18 @@ def testNIST(self): _, cu_aev = self.cuaev_computer((species, coordinates)) self.assertEqual(cu_aev, aev) + def testVeryDenseMolecule(self): + for i in range(100): + datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i)) + with open(datafile, 'rb') as f: + coordinates, species, _, _, _, _, _, _ = pickle.load(f) + # change angstrom coordinates to 10 times smaller + coordinates = 0.1 * torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device) + species = torch.from_numpy(species).unsqueeze(0).to(self.device) + _, aev = self.aev_computer((species, coordinates)) + _, cu_aev = self.cuaev_computer((species, coordinates)) + self.assertEqual(cu_aev, aev, atol=5e-5, rtol=5e-5) + if __name__ == '__main__': unittest.main() diff --git a/torchani/cuaev/aev.cu b/torchani/cuaev/aev.cu index bffd7f326..7254b2345 100644 --- a/torchani/cuaev/aev.cu +++ b/torchani/cuaev/aev.cu @@ -209,7 +209,7 @@ __global__ void cuAngularAEVs( theta = acos(0.95 * (sdx[jj] * sdx[kk] + sdy[jj] * sdy[kk] + sdz[jj] * sdz[kk]) / (Rij * Rik)); } - for (int srcLane = 0; kk_start + srcLane < min(32, jnum); ++srcLane) { + for (int srcLane = 0; srcLane < 32 && (kk_start + srcLane) < jnum; ++srcLane) { int kk = kk_start + srcLane; DataT theta_ijk = __shfl_sync(0xFFFFFFFF, theta, srcLane);