Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embedding spin multiplicity and charge based on SpookyNet implementation #608

Merged
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fb188e1
initial commit
epens94 Feb 7, 2024
384f778
commit includdes all files for
epens94 Feb 7, 2024
09fb039
electron embed for painn
epens94 Feb 12, 2024
bb27705
add several evaluation scripts but later refactor and remove if neces…
epens94 Feb 13, 2024
c1cfd98
cleanup
epens94 Feb 13, 2024
6c524c9
clean up and comments added
epens94 Feb 19, 2024
4266da0
add docstring to electron configuration py
epens94 Mar 6, 2024
b433573
clean up gitignore
epens94 Mar 6, 2024
4d4cc5c
fixing docstring in electronic embedding
epens94 Mar 6, 2024
f8494fd
adding further description to electron configuration
epens94 Mar 6, 2024
2d23890
add docstring to electronic embedding fix unclear naming
epens94 Mar 6, 2024
cd06b83
revert Z back to 100
epens94 Mar 6, 2024
800c3b0
fix docstring nuclear embedding
epens94 Mar 6, 2024
aff25bf
fix naming in nuclear embedding
epens94 Mar 6, 2024
f4ca4ee
move ssp to activations module and add docstring
epens94 Mar 6, 2024
c465ce4
change order to be equal in args in nn embedding
epens94 Mar 6, 2024
2156399
clear naming of vars and remove redundant code
epens94 Mar 6, 2024
c86b404
move all embedding classes into one module and delete not needed modules
epens94 Mar 6, 2024
1ebad7a
fix of init
epens94 Mar 6, 2024
f99a432
activation ssp trainable implement,pass nuclear embedding directly
epens94 Mar 6, 2024
3a399fa
bugfix nuclear embedding
epens94 Mar 6, 2024
64b5d2e
missed one replace string activation function
epens94 Mar 7, 2024
9517bd2
missed one replace string activation functionin elec embedding
epens94 Mar 7, 2024
c503f6b
fix docstring, problem with NaN in activation fn, write docstring mor…
epens94 Mar 7, 2024
68dcf26
add electronic embedding to so3 net and bugfix painn and schnet rep
epens94 Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ interfaces/lammps/examples/*/*.dat
interfaces/lammps/examples/*/deployed_model

# batchwise optimizer examples
examples/howtos/howto_batchwise_relaxations_outputs/*
examples/howtos/howto_batchwise_relaxations_outputs/*
1 change: 1 addition & 0 deletions src/schnetpack/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from schnetpack.nn.scatter import *
from schnetpack.nn.radial import *
from schnetpack.nn.utils import *
from schnetpack.nn.embedding import *
62 changes: 61 additions & 1 deletion src/schnetpack/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch.nn import functional

__all__ = ["shifted_softplus", "softplus_inverse"]
__all__ = ["shifted_softplus", "softplus_inverse", "ShiftedSoftplus"]


def shifted_softplus(x: torch.Tensor):
Expand Down Expand Up @@ -33,3 +33,63 @@ def softplus_inverse(x: torch.Tensor):
torch.Tensor: softplus inverse of input.
"""
return x + (torch.log(-torch.expm1(-x)))


class ShiftedSoftplus(torch.nn.Module):
"""
Shifted softplus activation function with learnable feature-wise parameters:
f(x) = alpha/beta * (softplus(beta*x) - log(2))
softplus(x) = log(exp(x) + 1)
For beta -> 0 : f(x) -> 0.5*alpha*x
For beta -> inf: f(x) -> max(0, alpha*x)

With learnable parameters alpha and beta, the shifted softplus function can
become equivalent to ReLU (if alpha is equal 1 and beta approaches infinity) or to
the identity function (if alpha is equal 2 and beta is equal 0).

Arguments:
num_features (int):
Dimensions of feature space.
initial_alpha (float):
Initial "scale" alpha of the softplus function.
initial_beta (float):
Initial "temperature" beta of the softplus function.
"""

def __init__(
self,
num_features: int,
initial_alpha: float = 1.0,
initial_beta: float = 1.0,
trainable: bool = False) -> None:

""" Initializes the ShiftedSoftplus class. """
super(ShiftedSoftplus, self).__init__()
initial_alpha = torch.tensor(initial_alpha)
initial_beta = torch.tensor(initial_beta)

if trainable:
self.alpha = torch.nn.Parameter(torch.Tensor(num_features))
stefaanhessmann marked this conversation as resolved.
Show resolved Hide resolved
self.beta = torch.nn.Parameter(torch.Tensor(num_features))
else:
self.register_buffer("alpha", initial_alpha)
self.register_buffer("beta", initial_beta)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Evaluate activation function given the input features x.
num_features: Dimensions of feature space.

Arguments:
x (FloatTensor [:, num_features]):
Input features.

Returns:
y (FloatTensor [:, num_features]):
Activated features.
"""
return self.alpha * torch.where(
self.beta != 0,
(torch.nn.functional.softplus(self.beta * x) - math.log(2)) / self.beta,
0.5 * x,
)
146 changes: 146 additions & 0 deletions src/schnetpack/nn/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F
import schnetpack.nn as snn
from schnetpack.nn.activations import shifted_softplus

__all__ = ["build_mlp", "build_gated_equivariant_mlp"]

Expand Down Expand Up @@ -153,3 +154,148 @@ def build_gated_equivariant_mlp(
# put all layers together to make the network
out_net = nn.Sequential(*layers)
return out_net


class Residual(nn.Module):
"""
Pre-activation residual block inspired by He, Kaiming, et al. "Identity
mappings in deep residual networks.".

Arguments:
num_features (int):
Dimensions of feature space.
activation (str):
Kind of activation function. Possible value:
'ssp': Shifted softplus activation function.
"""

def __init__(
self,
num_features: int,
activation: Union[Callable, nn.Module] = None,
bias: bool = True,
zero_init: bool = True,
) -> None:
""" Initializes the Residual class. """
super(Residual, self).__init__()
# initialize attributes

self.activation1 = activation#(num_features)
self.linear1 = nn.Linear(num_features, num_features, bias=bias)
self.activation2 = activation#(num_features)
self.linear2 = nn.Linear(num_features, num_features, bias=bias)
self.reset_parameters(bias, zero_init)

def reset_parameters(self, bias: bool = True, zero_init: bool = True) -> None:
""" Initialize parameters to compute an identity mapping. """
nn.init.orthogonal_(self.linear1.weight)
if zero_init:
nn.init.zeros_(self.linear2.weight)
else:
nn.init.orthogonal_(self.linear2.weight)
if bias:
nn.init.zeros_(self.linear1.bias)
nn.init.zeros_(self.linear2.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply residual block to input atomic features.
N: Number of atoms.
num_features: Dimensions of feature space.

Arguments:
x (FloatTensor [N, num_features]):
Input feature representations of atoms.

Returns:
y (FloatTensor [N, num_features]):
Output feature representations of atoms.
"""
y = self.activation1(x)
y = self.linear1(y)
y = self.activation2(y)
y = self.linear2(y)
return x + y


class ResidualStack(nn.Module):
"""
Stack of num_blocks pre-activation residual blocks evaluated in sequence.

Arguments:
num_blocks (int):
Number of residual blocks to be stacked in sequence.
num_features (int):
Dimensions of feature space.
activation (str):
Kind of activation function. Possible values:
'ssp': Shifted softplus activation function.
"""

def __init__(
self,
num_features: int,
num_residual: int,
activation: Union[Callable, nn.Module],
bias: bool = True,
zero_init: bool = True,
) -> None:
""" Initializes the ResidualStack class. """
super(ResidualStack, self).__init__()
self.stack = nn.ModuleList(
[
Residual(num_features, activation, bias, zero_init)
for i in range(num_residual)
]
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies all residual blocks to input features in sequence.
N: Number of inputs.
num_features: Dimensions of feature space.

Arguments:
x (FloatTensor [N, num_features]):
Input feature representations.

Returns:
y (FloatTensor [N, num_features]):
Output feature representations.
"""
for residual in self.stack:
x = residual(x)
return x


class ResidualMLP(nn.Module):

# if used with learnable shifted softplus activation function, callable needs to be initiated with num features

def __init__(
self,
num_features: int,
num_residual: int,
activation: Union[Callable, nn.Module],
bias: bool = True,
zero_init: bool = False,
) -> None:
super(ResidualMLP, self).__init__()
self.residual = ResidualStack(
num_features, num_residual, activation=activation, bias=bias, zero_init=True
)

self.linear = nn.Linear(num_features, num_features, bias=bias)
self.activation = activation
self.reset_parameters(bias, zero_init)

def reset_parameters(self, bias: bool = True, zero_init: bool = False) -> None:
if zero_init:
nn.init.zeros_(self.linear.weight)
else:
nn.init.orthogonal_(self.linear.weight)
if bias:
nn.init.zeros_(self.linear.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.activation(self.residual(x)))
124 changes: 124 additions & 0 deletions src/schnetpack/nn/electronic_embeeding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional
from schnetpack.nn.blocks import ResidualMLP

class ElectronicEmbedding(nn.Module):
"""
Single Head self attention like block for updating atomic features through nonlocal interactions with the
electrons.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add 2 more sentences what happens in that class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 2d23890

The embeddings are used to map the total molecular charge or molecular spin to a feature vector.
Since those properties are not localized on a specific atom they have to be delocalized over the whole molecule.
The delocalization is achieved by using a self attention like mechanism.


Arguments:
num_features (int):
Dimensions of feature space aka the number of features to describe atomic environments.
This determines the size of each embedding vector
num_residual (int):
Number of residual blocks applied to atomic features
activation (str):
Kind of activation function. Possible value:
'ssp': Shifted softplus activation function.
is_charged (bool):
is_charged True corresponds to building embedding for molecular charge and
separate weights are used for positive and negative charges.
i_charged False corresponds to building embedding for spin values,
no seperate weights are used
"""

def __init__(
self,
num_features: int,
num_residual: int,
activation: str = "ssp",
is_charged: bool = False,
) -> None:
""" Initializes the ElectronicEmbedding class. """
super(ElectronicEmbedding, self).__init__()
self.is_charged = is_charged
self.linear_q = nn.Linear(num_features, num_features)
if is_charged: # charges are duplicated to use separate weights for +/-
self.linear_k = nn.Linear(2, num_features, bias=False)
self.linear_v = nn.Linear(2, num_features, bias=False)
else:
self.linear_k = nn.Linear(1, num_features, bias=False)
self.linear_v = nn.Linear(1, num_features, bias=False)
self.resblock = ResidualMLP(
num_features,
num_residual,
activation=activation,
zero_init=True,
bias=False,
)
self.reset_parameters()

def reset_parameters(self) -> None:
""" Initialize parameters. """
nn.init.orthogonal_(self.linear_k.weight)
nn.init.orthogonal_(self.linear_v.weight)
nn.init.orthogonal_(self.linear_q.weight)
nn.init.zeros_(self.linear_q.bias)

def forward(
self,
atomic_features: torch.Tensor,
electronic_feature: torch.Tensor,
num_batch: int,
batch_seg: torch.Tensor,
eps: float = 1e-8,
) -> torch.Tensor:
"""
Evaluate interaction block.

atomic_features (FloatTensor [N, num_features]):
Atomic feature vectors.
electronic_feature (FloatTensor [N]):
either charges or spin values per molecular graph
num_batch (int):
number of molecular graphs in the batch
batch_seq (LongTensor [N]):
segment ids (aka _idx_m) are used to separate different molecules in a batch
eps (float):
small number to avoid division by zero
"""

# queries (Batchsize x N_atoms, n_atom_basis)
q = self.linear_q(atomic_features)

# to account for negative and positive charge
if self.is_charged:
e = F.relu(torch.stack([electronic_feature, -electronic_feature], dim=-1))
# +/- spin is the same => abs
else:
e = torch.abs(electronic_feature).unsqueeze(-1)
enorm = torch.maximum(e, torch.ones_like(e))

# keys (Batchsize x N_atoms, n_atom_basis), the batch_seg ensures that the key is the same for all atoms belonging to the same graph
k = self.linear_k(e / enorm)[batch_seg]

# values (Batchsize x N_atoms, n_atom_basis) the batch_seg ensures that the value is the same for all atoms belonging to the same graph
v = self.linear_v(e)[batch_seg]

# unnormalized, scaled attention weights, obtained by dot product of queries and keys (are logits)
# scaling by square root of attention dimension
weights = torch.sum(k * q, dim=-1) / k.shape[-1] ** 0.5

# probability distribution of scaled unnormalized attention weights, by applying softmax function
a = nn.functional.softplus(weights)

# normalization factor for every molecular graph, by adding up attention weights of every atom in the graph
anorm = a.new_zeros(num_batch).index_add_(0, batch_seg, a)

# make tensor filled with anorm value at the position of the corresponding molecular graph,
# indexing faster on CPU, gather faster on GPU
if a.device.type == "cpu":
anorm = anorm[batch_seg]
else:
anorm = torch.gather(anorm, 0, batch_seg)

# return probability distribution of scaled normalized attention weights, eps is added for numerical stability (sum / batchsize equals 1)
return self.resblock((a / (anorm + eps)).unsqueeze(-1) * v)
Loading