Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embedding spin multiplicity and charge based on SpookyNet implementation #608

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fb188e1
initial commit
epens94 Feb 7, 2024
384f778
commit includdes all files for
epens94 Feb 7, 2024
09fb039
electron embed for painn
epens94 Feb 12, 2024
bb27705
add several evaluation scripts but later refactor and remove if neces…
epens94 Feb 13, 2024
c1cfd98
cleanup
epens94 Feb 13, 2024
6c524c9
clean up and comments added
epens94 Feb 19, 2024
4266da0
add docstring to electron configuration py
epens94 Mar 6, 2024
b433573
clean up gitignore
epens94 Mar 6, 2024
4d4cc5c
fixing docstring in electronic embedding
epens94 Mar 6, 2024
f8494fd
adding further description to electron configuration
epens94 Mar 6, 2024
2d23890
add docstring to electronic embedding fix unclear naming
epens94 Mar 6, 2024
cd06b83
revert Z back to 100
epens94 Mar 6, 2024
800c3b0
fix docstring nuclear embedding
epens94 Mar 6, 2024
aff25bf
fix naming in nuclear embedding
epens94 Mar 6, 2024
f4ca4ee
move ssp to activations module and add docstring
epens94 Mar 6, 2024
c465ce4
change order to be equal in args in nn embedding
epens94 Mar 6, 2024
2156399
clear naming of vars and remove redundant code
epens94 Mar 6, 2024
c86b404
move all embedding classes into one module and delete not needed modules
epens94 Mar 6, 2024
1ebad7a
fix of init
epens94 Mar 6, 2024
f99a432
activation ssp trainable implement,pass nuclear embedding directly
epens94 Mar 6, 2024
3a399fa
bugfix nuclear embedding
epens94 Mar 6, 2024
64b5d2e
missed one replace string activation function
epens94 Mar 7, 2024
9517bd2
missed one replace string activation functionin elec embedding
epens94 Mar 7, 2024
c503f6b
fix docstring, problem with NaN in activation fn, write docstring mor…
epens94 Mar 7, 2024
68dcf26
add electronic embedding to so3 net and bugfix painn and schnet rep
epens94 Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,9 @@ interfaces/lammps/examples/*/deployed_model

# batchwise optimizer examples
examples/howtos/howto_batchwise_relaxations_outputs/*
tests/electronic_embedding/pain_carbene.json
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
.vscode/launch.json
.vscode/settings.json
tests/ase_carbene_2200.db
tests/electronic_embedding/carbene.json
tests/electronic_embedding/carbene.json
1 change: 1 addition & 0 deletions src/schnetpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from schnetpack import transform
from schnetpack import properties
from schnetpack import electron_configurations
from schnetpack import data
from schnetpack import datasets
from schnetpack import atomistic
Expand Down
99 changes: 99 additions & 0 deletions src/schnetpack/electron_configurations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python3
import numpy as np

# fmt: off
# up until Z = 86 is enough for QMML project since higher Z elements are not present in the dataset
electron_config = np.array([
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
# Z 1s 2s 2p 3s 3p 4s 3d 4p 5s 4d 5p 6s 4f 5d 6p vs vp vd vf
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # n
[ 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # H
[ 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # He
[ 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Li
[ 4, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Be
[ 5, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0], # B
[ 6, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0], # C
[ 7, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0], # N
[ 8, 2, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0], # O
[ 9, 2, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, 0], # F
[ 10, 2, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0], # Ne
[ 11, 2, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Na
[ 12, 2, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Mg
[ 13, 2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0], # Al
[ 14, 2, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0], # Si
[ 15, 2, 2, 6, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0], # P
[ 16, 2, 2, 6, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0], # S
[ 17, 2, 2, 6, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, 0], # Cl
[ 18, 2, 2, 6, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0], # Ar
[ 19, 2, 2, 6, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # K
[ 20, 2, 2, 6, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Ca
[ 21, 2, 2, 6, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0], # Sc
[ 22, 2, 2, 6, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0], # Ti
[ 23, 2, 2, 6, 2, 6, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 0], # V
[ 24, 2, 2, 6, 2, 6, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 5, 0], # Cr
[ 25, 2, 2, 6, 2, 6, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 5, 0], # Mn
[ 26, 2, 2, 6, 2, 6, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 6, 0], # Fe
[ 27, 2, 2, 6, 2, 6, 2, 7, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 7, 0], # Co
[ 28, 2, 2, 6, 2, 6, 2, 8, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 8, 0], # Ni
[ 29, 2, 2, 6, 2, 6, 1, 10, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 10, 0], # Cu
[ 30, 2, 2, 6, 2, 6, 2, 10, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 10, 0], # Zn
[ 31, 2, 2, 6, 2, 6, 2, 10, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 10, 0], # Ga
[ 32, 2, 2, 6, 2, 6, 2, 10, 2, 0, 0, 0, 0, 0, 0, 0, 2, 2, 10, 0], # Ge
[ 33, 2, 2, 6, 2, 6, 2, 10, 3, 0, 0, 0, 0, 0, 0, 0, 2, 3, 10, 0], # As
[ 34, 2, 2, 6, 2, 6, 2, 10, 4, 0, 0, 0, 0, 0, 0, 0, 2, 4, 10, 0], # Se
[ 35, 2, 2, 6, 2, 6, 2, 10, 5, 0, 0, 0, 0, 0, 0, 0, 2, 5, 10, 0], # Br
[ 36, 2, 2, 6, 2, 6, 2, 10, 6, 0, 0, 0, 0, 0, 0, 0, 2, 6, 10, 0], # Kr
[ 37, 2, 2, 6, 2, 6, 2, 10, 6, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Rb
[ 38, 2, 2, 6, 2, 6, 2, 10, 6, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Sr
[ 39, 2, 2, 6, 2, 6, 2, 10, 6, 2, 1, 0, 0, 0, 0, 0, 2, 0, 1, 0], # Y
[ 40, 2, 2, 6, 2, 6, 2, 10, 6, 2, 2, 0, 0, 0, 0, 0, 2, 0, 2, 0], # Zr
[ 41, 2, 2, 6, 2, 6, 2, 10, 6, 1, 4, 0, 0, 0, 0, 0, 1, 0, 4, 0], # Nb
[ 42, 2, 2, 6, 2, 6, 2, 10, 6, 1, 5, 0, 0, 0, 0, 0, 1, 0, 5, 0], # Mo
[ 43, 2, 2, 6, 2, 6, 2, 10, 6, 2, 5, 0, 0, 0, 0, 0, 2, 0, 5, 0], # Tc
[ 44, 2, 2, 6, 2, 6, 2, 10, 6, 1, 7, 0, 0, 0, 0, 0, 1, 0, 7, 0], # Ru
[ 45, 2, 2, 6, 2, 6, 2, 10, 6, 1, 8, 0, 0, 0, 0, 0, 1, 0, 8, 0], # Rh
[ 46, 2, 2, 6, 2, 6, 2, 10, 6, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0], # Pd
[ 47, 2, 2, 6, 2, 6, 2, 10, 6, 1, 10, 0, 0, 0, 0, 0, 1, 0, 10, 0], # Ag
[ 48, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 0, 0, 0, 0, 0, 2, 0, 10, 0], # Cd
[ 49, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 1, 0, 0, 0, 0, 2, 1, 10, 0], # In
[ 50, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 2, 0, 0, 0, 0, 2, 2, 10, 0], # Sn
[ 51, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 3, 0, 0, 0, 0, 2, 3, 10, 0], # Sb
[ 52, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 4, 0, 0, 0, 0, 2, 4, 10, 0], # Te
[ 53, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 5, 0, 0, 0, 0, 2, 5, 10, 0], # I
[ 54, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 0, 0, 0, 0, 2, 6, 10, 0], # Xe
[ 55, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 0, 0, 0, 1, 0, 0, 0], # Cs
[ 56, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 0, 0, 0, 2, 0, 0, 0], # Ba
[ 57, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 0, 1, 0, 2, 0, 1, 0], # La
[ 58, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 1, 1, 0, 2, 0, 1, 1], # Ce
[ 59, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 3, 0, 0, 2, 0, 0, 3], # Pr
[ 60, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 4, 0, 0, 2, 0, 0, 4], # Nd
[ 61, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 5, 0, 0, 2, 0, 0, 5], # Pm
[ 62, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 6, 0, 0, 2, 0, 0, 6], # Sm
[ 63, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 7, 0, 0, 2, 0, 0, 7], # Eu
[ 64, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 7, 1, 0, 2, 0, 1, 7], # Gd
[ 65, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 9, 0, 0, 2, 0, 0, 9], # Tb
[ 66, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 10, 0, 0, 2, 0, 0, 10], # Dy
[ 67, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 11, 0, 0, 2, 0, 0, 11], # Ho
[ 68, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 12, 0, 0, 2, 0, 0, 12], # Er
[ 69, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 13, 0, 0, 2, 0, 0, 13], # Tm
[ 70, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 0, 0, 2, 0, 0, 14], # Yb
[ 71, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 1, 0, 2, 0, 1, 14], # Lu
[ 72, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 2, 0, 2, 0, 2, 14], # Hf
[ 73, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 3, 0, 2, 0, 3, 14], # Ta
[ 74, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 4, 0, 2, 0, 4, 14], # W
[ 75, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 5, 0, 2, 0, 5, 14], # Re
[ 76, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 6, 0, 2, 0, 6, 14], # Os
[ 77, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 7, 0, 2, 0, 7, 14], # Ir
[ 78, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 14, 9, 0, 1, 0, 9, 14], # Pt
[ 79, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 14, 10, 0, 1, 0, 10, 14], # Au
[ 80, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 0, 2, 0, 10, 14], # Hg
[ 81, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 1, 2, 1, 10, 14], # Tl
[ 82, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 2, 2, 2, 10, 14], # Pb
[ 83, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 3, 2, 3, 10, 14], # Bi
[ 84, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 4, 2, 4, 10, 14], # Po
[ 85, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 5, 2, 5, 10, 14], # At
[ 86, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 6, 10, 14] # Rn
], dtype=np.float32)
# fmt: on
# normalize entries (between 0.0 and 1.0)
# normalization just for numerical reasons
electron_config = electron_config / np.max(electron_config, axis=0)
3 changes: 3 additions & 0 deletions src/schnetpack/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@
from schnetpack.nn.scatter import *
from schnetpack.nn.radial import *
from schnetpack.nn.utils import *
from schnetpack.nn.electronic_embeeding import *
from schnetpack.nn.residual_blocks import *
from schnetpack.nn.nuclear_embedding import *
123 changes: 123 additions & 0 deletions src/schnetpack/nn/electronic_embeeding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional
from schnetpack.nn.residual_blocks import ResidualMLP

class ElectronicEmbedding(nn.Module):
"""
Single Head self attention block for updating atomic features through nonlocal interactions with the
electrons.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add 2 more sentences what happens in that class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 2d23890


Arguments:
num_features (int):
Dimensions of feature space.
num_basis_functions (int):
Number of radial basis functions.
num_residual_pre_i (int):
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
Number of residual blocks applied to atomic features in i branch
(central atoms) before computing the interaction.
num_residual_pre_j (int):
Number of residual blocks applied to atomic features in j branch
(neighbouring atoms) before computing the interaction.
num_residual_post (int):
Number of residual blocks applied to interaction features.
activation (str):
Kind of activation function. Possible values:
'swish': Swish activation function.
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
'ssp': Shifted softplus activation function.
"""

def __init__(
self,
num_features: int,
num_residual: int,
activation: str = "swish",
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
is_charge: bool = False,
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
""" Initializes the ElectronicEmbedding class. """
super(ElectronicEmbedding, self).__init__()
self.is_charge = is_charge
self.linear_q = nn.Linear(num_features, num_features)
if is_charge: # charges are duplicated to use separate weights for +/-
self.linear_k = nn.Linear(2, num_features, bias=False)
self.linear_v = nn.Linear(2, num_features, bias=False)
else:
self.linear_k = nn.Linear(1, num_features, bias=False)
self.linear_v = nn.Linear(1, num_features, bias=False)
self.resblock = ResidualMLP(
num_features,
num_residual,
activation=activation,
zero_init=True,
bias=False,
)
self.reset_parameters()

def reset_parameters(self) -> None:
""" Initialize parameters. """
nn.init.orthogonal_(self.linear_k.weight)
nn.init.orthogonal_(self.linear_v.weight)
nn.init.orthogonal_(self.linear_q.weight)
nn.init.zeros_(self.linear_q.bias)

def forward(
self,
x: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make argument names more expressive?
x --> atomic_features
E --> electronic_features ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 2d23890

E: torch.Tensor,
num_batch: int,
batch_seg: torch.Tensor,
eps: float = 1e-8,
) -> torch.Tensor:
"""
Evaluate interaction block.

x (FloatTensor [N, num_features]):
Atomic feature vectors.
E (FloatTensor [N]):
either charges or spin values per molecular graph
num_batch (int):
number of molecular graphs in the batch
batch_seq (LongTensor [N]):
segment ids (aka _idx_m) are used to separate different molecules in a batch
eps (float):
small number to avoid division by zero
"""

# queries (Batchsize x N_atoms, n_atom_basis)
q = self.linear_q(x)

# to account for negative and positive charge
if self.is_charge:
e = F.relu(torch.stack([E, -E], dim=-1))
# +/- spin is the same => abs
else:
e = torch.abs(E).unsqueeze(-1)
enorm = torch.maximum(e, torch.ones_like(e))

# keys (Batchsize x N_atoms, n_atom_basis), the batch_seg ensures that the key is the same for all atoms belonging to the same graph
k = self.linear_k(e / enorm)[batch_seg]

# values (Batchsize x N_atoms, n_atom_basis) the batch_seg ensures that the value is the same for all atoms belonging to the same graph
v = self.linear_v(e)[batch_seg]

# unnormalized, scaled attention weights, obtained by dot product of queries and keys (are logits)
# scaling by square root of attention dimension
weights = torch.sum(k * q, dim=-1) / k.shape[-1] ** 0.5

# probability distribution of scaled unnormalized attention weights, by applying softmax function
a = nn.functional.softplus(weights)

# normalization factor for every molecular graph, by adding up attention weights of every atom in the graph
anorm = a.new_zeros(num_batch).index_add_(0, batch_seg, a)

# make tensor filled with anorm value at the position of the corresponding molecular graph,
# indexing faster on CPU, gather faster on GPU
if a.device.type == "cpu":
anorm = anorm[batch_seg]
else:
anorm = torch.gather(anorm, 0, batch_seg)

# return probability distribution of scaled normalized attention weights, eps is added for numerical stability (sum / batchsize equals 1)
return self.resblock((a / (anorm + eps)).unsqueeze(-1) * v)
87 changes: 87 additions & 0 deletions src/schnetpack/nn/nuclear_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from schnetpack.electron_configurations import electron_config


class NuclearEmbedding(nn.Module):
"""
Embedding which maps scalar nuclear charges Z to vectors in a
(num_features)-dimensional feature space. The embedding consists of a freely
learnable parameter matrix [Zmax, num_features] and a learned linear mapping
from the electron configuration to a (num_features)-dimensional vector. The
latter part encourages alchemically meaningful representations without
restricting the expressivity of learned embeddings.
Using complexe nuclear embedding can have negative impact on the model performance, when spin charge embedding is activated

Arguments:
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
num_features (int):
Dimensions of feature space.
Zmax (int):
Maximum nuclear charge +1 of atoms. The default is 87, so all
elements up to Rn (Z=86) are supported. Can be kept at the default
value (has minimal memory impact).
"""

def __init__(
self, num_features: int, Zmax: int = 87, zero_init: bool = True
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
""" Initializes the NuclearEmbedding class. """
super(NuclearEmbedding, self).__init__()
self.num_features = num_features
self.register_buffer("electron_config", torch.tensor(electron_config))
self.register_parameter(
"element_embedding", nn.Parameter(torch.Tensor(Zmax, self.num_features))
)
self.register_buffer(
"embedding", torch.Tensor(Zmax, self.num_features), persistent=False
)
self.config_linear = nn.Linear(
self.electron_config.size(1), self.num_features, bias=False
)
self.reset_parameters(zero_init)

def reset_parameters(self, zero_init: bool = True) -> None:
""" Initialize parameters. """
if zero_init:
nn.init.zeros_(self.element_embedding)
nn.init.zeros_(self.config_linear.weight)
else:
nn.init.uniform_(self.element_embedding, -math.sqrt(3), math.sqrt(3))
nn.init.orthogonal_(self.config_linear.weight)

def train(self, mode: bool = True) -> None:
""" Switch between training and evaluation mode. """
super(NuclearEmbedding, self).train(mode=mode)
if not self.training:
with torch.no_grad():
self.embedding = self.element_embedding + self.config_linear(
self.electron_config
)

def forward(self, Z: torch.Tensor) -> torch.Tensor:
"""
Assign corresponding embeddings to nuclear charges.
N: Number of atoms.
num_features: Dimensions of feature space.

Arguments:
Z (LongTensor [N]):
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
Nuclear charges (atomic numbers) of atoms.

Returns:
x (FloatTensor [N, num_features]):
Embeddings of all atoms.
"""
if self.training: # during training, the embedding needs to be recomputed
self.embedding = self.element_embedding + self.config_linear(
self.electron_config
)
if self.embedding.device.type == "cpu": # indexing is faster on CPUs
return self.embedding[Z]
else: # gathering is faster on GPUs
return torch.gather(
self.embedding, 0, Z.view(-1, 1).expand(-1, self.num_features)
)
Loading