-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #448 from atomistic-machine-learning/kts/so3repres…
…entation SO3 representation
- Loading branch information
Showing
13 changed files
with
189 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
ema: | ||
_target_: schnetpack.train.ExponentialMovingAverage | ||
decay: 0.995 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
# TensorBoard | ||
# Aim Logger | ||
|
||
aim: | ||
_target_: aim.pytorch_lightning.AimLogger | ||
repo: ${hydra:runtime.cwd}/${run.path} | ||
experiment: ${run.id} | ||
experiment: ${run.experiment} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
_target_: schnetpack.representation.SO3net | ||
_recursive_: False | ||
n_atom_basis: 64 | ||
n_interactions: 3 | ||
lmax: 2 | ||
shared_interactions: False | ||
radial_basis: | ||
_target_: schnetpack.nn.radial.GaussianRBF | ||
n_rbf: 20 | ||
cutoff: ${globals.cutoff} | ||
cutoff_fn: | ||
_target_: schnetpack.nn.cutoff.CosineCutoff | ||
cutoff: ${globals.cutoff} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
work_dir: ${hydra:runtime.cwd} | ||
data_dir: ${run.work_dir}/data | ||
path: runs | ||
experiment: default | ||
id: ${uuid:1} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .schnet import * | ||
from .painn import * | ||
from .field_schnet import * | ||
from .so3net import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from typing import Callable, Dict, Optional | ||
|
||
import hydra | ||
import torch | ||
import torch.nn as nn | ||
|
||
import schnetpack.nn as snn | ||
import schnetpack.nn.so3 as so3 | ||
import schnetpack.properties as properties | ||
|
||
__all__ = ["SO3net"] | ||
|
||
|
||
class SO3net(nn.Module): | ||
""" | ||
A simple SO3-equivariant representation using spherical harmonics and | ||
Clebsch-Gordon tensor products. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_atom_basis: int, | ||
n_interactions: int, | ||
lmax: int, | ||
radial_basis: nn.Module, | ||
cutoff_fn: Optional[Callable] = None, | ||
shared_interactions: bool = False, | ||
max_z: int = 100, | ||
): | ||
""" | ||
Args: | ||
n_atom_basis: number of features to describe atomic environments. | ||
This determines the size of each embedding vector; i.e. embeddings_dim. | ||
n_interactions: number of interaction blocks. | ||
lmax: maximum angular momentum of spherical harmonics basis | ||
radial_basis: layer for expanding interatomic distances in a basis set | ||
cutoff_fn: cutoff function | ||
shared_interactions: | ||
max_z: | ||
conv_layer: | ||
""" | ||
super(SO3net, self).__init__() | ||
|
||
self.n_atom_basis = n_atom_basis | ||
self.n_interactions = n_interactions | ||
self.lmax = lmax | ||
self.cutoff_fn = hydra.utils.instantiate(cutoff_fn) | ||
self.cutoff = cutoff_fn.cutoff | ||
self.radial_basis = hydra.utils.instantiate(radial_basis) | ||
|
||
self.embedding = nn.Embedding(max_z, n_atom_basis, padding_idx=0) | ||
self.sphharm = so3.RealSphericalHarmonics(lmax=lmax) | ||
|
||
self.so3convs = snn.replicate_module( | ||
lambda: so3.SO3Convolution(lmax, n_atom_basis, self.radial_basis.n_rbf), | ||
self.n_interactions, | ||
shared_interactions, | ||
) | ||
self.mixings = snn.replicate_module( | ||
lambda: nn.Linear(n_atom_basis, n_atom_basis, bias=False), | ||
self.n_interactions, | ||
shared_interactions, | ||
) | ||
self.gatings = snn.replicate_module( | ||
lambda: so3.SO3ParametricGatedNonlinearity(n_atom_basis, lmax), | ||
self.n_interactions, | ||
shared_interactions, | ||
) | ||
self.so3product = so3.SO3TensorProduct(lmax) | ||
|
||
def forward(self, inputs: Dict[str, torch.Tensor]): | ||
""" | ||
Compute atomic representations/embeddings. | ||
Args: | ||
inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors. | ||
Returns: | ||
torch.Tensor: atom-wise representation. | ||
list of torch.Tensor: intermediate atom-wise representations, if | ||
return_intermediate=True was used. | ||
""" | ||
# get tensors from input dictionary | ||
atomic_numbers = inputs[properties.Z] | ||
r_ij = inputs[properties.Rij] | ||
idx_i = inputs[properties.idx_i] | ||
idx_j = inputs[properties.idx_j] | ||
|
||
# compute atom and pair features | ||
d_ij = torch.norm(r_ij, dim=1, keepdim=True) | ||
dir_ij = r_ij / d_ij | ||
|
||
Yij = self.sphharm(dir_ij) | ||
radial_ij = self.radial_basis(d_ij) | ||
cutoff_ij = self.cutoff_fn(d_ij)[..., None] | ||
|
||
x0 = self.embedding(atomic_numbers)[:, None] | ||
x = so3.scalar2rsh(x0, self.lmax) | ||
|
||
for i in range(self.n_interactions): | ||
dx = self.so3convs[i](x, radial_ij, Yij, cutoff_ij, idx_i, idx_j) | ||
ddx = self.mixings[i](dx) | ||
dx = self.so3product(dx, ddx) | ||
dx = self.gatings[i](dx) | ||
x = x + dx | ||
|
||
inputs["scalar_representation"] = x[:, 0] | ||
inputs["multipole_representation"] = x | ||
return inputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters