Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restructure embeddings to avoid issues with torch jit #634

Merged
merged 5 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 49 additions & 38 deletions src/schnetpack/nn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import schnetpack.properties as properties
from schnetpack.nn.activations import shifted_softplus
from schnetpack.nn.blocks import ResidualMLP
from typing import Callable, Union


__all__ = ["NuclearEmbedding", "ElectronicEmbedding"]


'''
The usage of the electron configuration is to provide a shorthand descriptor. This descriptor encode
information about the groundstate information of an atom, the nuclear charge and the number of electrons in the
Expand Down Expand Up @@ -157,8 +163,10 @@ class NuclearEmbedding(nn.Module):
from the electron configuration to a (num_features)-dimensional vector. The
latter part encourages alchemically meaningful representations without
restricting the expressivity of learned embeddings.
Using complexe nuclear embedding can have negative impact on the model performance, when spin charge embedding is activated
Negative performance in regard of the duration until the model converges. The model will converge to a lower value, but the duration is longer.
Using complexe nuclear embedding can have negative impact on the model
performance, when spin charge embedding is activated
Negative performance in regard of the duration until the model converges.
The model will converge to a lower value, but the duration is longer.
"""

def __init__(
Expand Down Expand Up @@ -201,27 +209,26 @@ def train(self, mode: bool = True) -> None:
self.electron_config
)

def forward(self, atom_numbers: torch.Tensor) -> torch.Tensor:
def forward(self, atomic_numbers: torch.Tensor) -> torch.Tensor:
"""
Assign corresponding embeddings to nuclear charges.
N: Number of atoms.
num_features: Dimensions of feature space.

Args:
atom_numbers (LongTensor [N]): Nuclear charges (atomic numbers) of atoms.
atomic_numbers: nuclear charges

Returns:
x (FloatTensor [N, num_features]):Embeddings of all atoms.
Embeddings of all atoms.

"""
if self.training: # during training, the embedding needs to be recomputed
self.embedding = self.element_embedding + self.config_linear(
self.electron_config
)
if self.embedding.device.type == "cpu": # indexing is faster on CPUs
return self.embedding[atom_numbers]
return self.embedding[atomic_numbers]
else: # gathering is faster on GPUs
return torch.gather(
self.embedding, 0, atom_numbers.view(-1, 1).expand(-1, self.num_features)
self.embedding, 0, atomic_numbers.view(-1, 1).expand(-1, self.num_features)
)


Expand All @@ -236,24 +243,28 @@ class ElectronicEmbedding(nn.Module):

def __init__(
self,
property_key: str,
num_features: int,
num_residual: int,
activation: Union[Callable, nn.Module],
is_charged: bool = False):
is_charged: bool,
num_residual: int = 1,
activation: Union[Callable, nn.Module] = shifted_softplus,
epsilon: float = 1e-8,
):
"""
Args:
num_features:
Dimensions of feature space aka the number of features to describe atomic environments.
property_key: key of electronic property in the spk 'inputs' dictionary
num_features: Dimensions of feature space aka the number of features to describe atomic environments.
This determines the size of each embedding vector
num_residual:
Number of residual blocks applied to atomic features
num_residual: Number of residual blocks applied to atomic features
activation: activation function.
is_charged: True corresponds to building embedding for molecular charge and
separate weights are used for positive and negative charges.
False corresponds to building embedding for spin values,
no seperate weights are used
epsilon: numerical stability parameter
"""
super(ElectronicEmbedding, self).__init__()
self.property_key = property_key
self.is_charged = is_charged
self.linear_q = nn.Linear(num_features, num_features)
if is_charged: # charges are duplicated to use separate weights for +/-
Expand All @@ -269,6 +280,7 @@ def __init__(
zero_init=True,
bias=False,
)
self.epsilon = epsilon
self.reset_parameters()

def reset_parameters(self) -> None:
Expand All @@ -280,25 +292,24 @@ def reset_parameters(self) -> None:

def forward(
self,
atomic_features: torch.Tensor,
electronic_feature: torch.Tensor,
num_batch: int,
batch_seg: torch.Tensor,
eps: float = 1e-8,
embedding,
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
inputs,
) -> torch.Tensor:
"""
Evaluate interaction block.

Args:
atomic_features: Atomic feature vector of dimension [N, num_features]
electronic_feature: either charges or spin values per N molecular graph
either charges or spin values per molecular graph
num_batch: number of molecular graphs in the batch
batch_seq: segment ids (aka _idx_m) are used to separate different molecules in a batch
eps: small number to avoid division by zero
embedding: embedding of nuclear charges (and other electronic embeddings)
inputs: spk style input dictionary

"""


num_batch = len(inputs[properties.idx])
idx_m = inputs[properties.idx_m]
electronic_feature = inputs[self.property_key]

# queries (Batchsize x N_atoms, n_atom_basis)
q = self.linear_q(atomic_features)
q = self.linear_q(embedding)

# to account for negative and positive charge
if self.is_charged:
Expand All @@ -308,25 +319,25 @@ def forward(
e = torch.abs(electronic_feature).unsqueeze(-1)
enorm = torch.maximum(e, torch.ones_like(e))

# keys (Batchsize x N_atoms, n_atom_basis), the batch_seg ensures that the key is the same for all atoms belonging to the same graph
k = self.linear_k(e / enorm)[batch_seg]
# keys (Batchsize x N_atoms, n_atom_basis), the idx_m ensures that the key is the same for all atoms belonging to the same graph
k = self.linear_k(e / enorm)[idx_m]

# values (Batchsize x N_atoms, n_atom_basis) the batch_seg ensures that the value is the same for all atoms belonging to the same graph
v = self.linear_v(e)[batch_seg]
# values (Batchsize x N_atoms, n_atom_basis) the idx_m ensures that the value is the same for all atoms belonging to the same graph
v = self.linear_v(e)[idx_m]

# unnormalized, scaled attention weights, obtained by dot product of queries and keys (are logits)
# scaling by square root of attention dimension
weights = torch.sum(k * q, dim=-1) / k.shape[-1] ** 0.5

# probability distribution of scaled unnormalized attention weights, by applying softmax function
a = nn.functional.softmax(weights,dim=0) # nn.functional.softplus(weights) seems to function to but softmax might be more stable
a = nn.functional.softmax(weights, dim=0) # nn.functional.softplus(weights) seems to function to but softmax might be more stable
# normalization factor for every molecular graph, by adding up attention weights of every atom in the graph
anorm = a.new_zeros(num_batch).index_add_(0, batch_seg, a)
anorm = a.new_zeros(num_batch).index_add_(0, idx_m, a)
# make tensor filled with anorm value at the position of the corresponding molecular graph,
# indexing faster on CPU, gather faster on GPU
if a.device.type == "cpu":
anorm = anorm[batch_seg]
anorm = anorm[idx_m]
else:
anorm = torch.gather(anorm, 0, batch_seg)
anorm = torch.gather(anorm, 0, idx_m)
# return probability distribution of scaled normalized attention weights, eps is added for numerical stability (sum / batchsize equals 1)
return self.resblock((a / (anorm + eps)).unsqueeze(-1) * v)
return self.resblock((a / (anorm + self.epsilon)).unsqueeze(-1) * v)
72 changes: 21 additions & 51 deletions src/schnetpack/representation/painn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Callable, Dict, Optional, Union
from typing import Callable, Dict, Optional, Union, List

import torch
import torch.nn as nn
import torch.nn.functional as F

import schnetpack.properties as properties
import schnetpack.nn as snn
from schnetpack.nn.embedding import NuclearEmbedding, ElectronicEmbedding


__all__ = ["PaiNN", "PaiNNInteraction", "PaiNNMixing"]

Expand Down Expand Up @@ -135,12 +135,11 @@ def __init__(
radial_basis: nn.Module,
cutoff_fn: Optional[Callable] = None,
activation: Optional[Callable] = F.silu,
max_z: int = 101,
shared_interactions: bool = False,
shared_filters: bool = False,
epsilon: float = 1e-8,
activate_charge_spin_embedding: bool = False,
embedding: Union[Callable, nn.Module] = None,
nuclear_embedding: Optional[nn.Module] = None,
electronic_embeddings: Optional[List] = None,
):
"""
Args:
Expand All @@ -149,16 +148,15 @@ def __init__(
n_interactions: number of interaction blocks.
radial_basis: layer for expanding interatomic distances in a basis set
cutoff_fn: cutoff function
max_z: maximal nuclear charge
activation: activation function
shared_interactions: if True, share the weights across
interaction blocks.
shared_interactions: if True, share the weights across
filter-generating networks.
epsilon: stability constant added in norm to prevent numerical instabilities
activate_charge_spin_embedding: if True, charge and spin embeddings are added
to nuclear embeddings taken from SpookyNet Implementation
embedding: custom nuclear embedding
epsilon: numerical stability parameter
nuclear_embedding: custom nuclear embedding (e.g. spk.nn.embeddings.NuclearEmbedding)
electronic_embeddings: list of electronic embeddings. E.g. for spin and
charge (see spk.nn.embeddings.ElectronicEmbedding)
"""
super(PaiNN, self).__init__()

Expand All @@ -167,25 +165,15 @@ def __init__(
self.cutoff_fn = cutoff_fn
self.cutoff = cutoff_fn.cutoff
self.radial_basis = radial_basis
self.activate_charge_spin_embedding = activate_charge_spin_embedding

# initialize nuclear embedding
self.embedding = embedding
if self.embedding is None:
self.embedding = nn.Embedding(max_z, self.n_atom_basis, padding_idx=0)

# initialize spin and charge embeddings
if self.activate_charge_spin_embedding:
self.charge_embedding = ElectronicEmbedding(
self.n_atom_basis,
num_residual=1,
activation=activation,
is_charged=True)
self.spin_embedding = ElectronicEmbedding(
self.n_atom_basis,
num_residual=1,
activation=activation,
is_charged=False)

# initialize embeddings
if nuclear_embedding is None:
nuclear_embedding = nn.Embedding(100, n_atom_basis)
self.embedding = nuclear_embedding
if electronic_embeddings is None:
electronic_embeddings = []
electronic_embeddings = nn.ModuleList(electronic_embeddings)
self.electronic_embeddings = electronic_embeddings

# initialize filter layers
self.share_filters = shared_filters
Expand Down Expand Up @@ -248,28 +236,10 @@ def forward(self, inputs: Dict[str, torch.Tensor]):
filter_list = torch.split(filters, 3 * self.n_atom_basis, dim=-1)

# compute initial embeddings
q = self.embedding(atomic_numbers)[:, None]

# add spin and charge embeddings
if hasattr(self, "activate_charge_spin_embedding") and self.activate_charge_spin_embedding:
# get tensors from input dictionary
total_charge = inputs[properties.total_charge]
spin = inputs[properties.spin_multiplicity]
num_batch = len(inputs[properties.idx])
idx_m = inputs[properties.idx_m]

charge_embedding = self.charge_embedding(
q.squeeze(),
total_charge,
num_batch,
idx_m
)[:, None]
spin_embedding = self.spin_embedding(
q.squeeze(), spin, num_batch, idx_m
)[:, None]

# additive combining of nuclear, charge and spin embedding
q = (q + charge_embedding + spin_embedding)
q = self.embedding(atomic_numbers)
for embedding in self.electronic_embeddings:
q = q + embedding(q, inputs)
q = q.unsqueeze(1)

# compute interaction blocks and update atomic embeddings
qs = q.shape
Expand Down
Loading