diff --git a/torchani/nn.py b/torchani/nn.py index 43e028826..a05568ab1 100644 --- a/torchani/nn.py +++ b/torchani/nn.py @@ -125,7 +125,7 @@ class SpeciesConverter(torch.nn.Module): def __init__(self, species): super().__init__() - rev_idx = {s: k for k, s in enumerate(utils.PERIODIC_TABLE, 1)} + rev_idx = {s: k for k, s in enumerate(utils.PERIODIC_TABLE)} maxidx = max(rev_idx.values()) self.register_buffer('conv_tensor', torch.full((maxidx + 2,), -1, dtype=torch.long)) for i, s in enumerate(species): diff --git a/torchani/utils.py b/torchani/utils.py index 4189bd1fb..56dd49a55 100644 --- a/torchani/utils.py +++ b/torchani/utils.py @@ -376,7 +376,10 @@ def get_atomic_masses(species): return masses -PERIODIC_TABLE = """ +# This constant, when indexed with the corresponding atomic number, gives the +# element associated with it. Note that there is no element with atomic number +# 0, so 'Dummy' returned in this case. +PERIODIC_TABLE = ['Dummy'] + """ H He Li Be B C N O F Ne Na Mg Al Si P S Cl Ar