Skip to content

Commit

Permalink
Merge pull request #448 from atomistic-machine-learning/kts/so3repres…
Browse files Browse the repository at this point in the history
…entation

SO3 representation
  • Loading branch information
mgastegger authored Oct 25, 2022
2 parents f1c3122 + 75c48a1 commit dac7ee9
Show file tree
Hide file tree
Showing 13 changed files with 189 additions and 13 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def read(fname):
"rich",
"fasteners",
"dirsync",
"torch-ema",
"matscipy @ git+https://github.com/libAtoms/matscipy.git",
],
include_package_data=True,
Expand Down
3 changes: 3 additions & 0 deletions src/schnetpack/configs/callbacks/ema.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ema:
_target_: schnetpack.train.ExponentialMovingAverage
decay: 0.995
15 changes: 9 additions & 6 deletions src/schnetpack/configs/experiment/md17.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ defaults:
- override /model: nnp
- override /data: md17

run.path: runs/md17_${data.molecule}
run:
experiment: md17_${data.molecule}

globals:
cutoff: 5.
lr: 5e-4
lr: 1e-3
energy_key: energy
forces_key: forces

Expand Down Expand Up @@ -50,16 +51,18 @@ task:
metrics:
mae:
_target_: torchmetrics.regression.MeanAbsoluteError
mse:
rmse:
_target_: torchmetrics.regression.MeanSquaredError
loss_weight: 0.005
squared: False
loss_weight: 0.01
- _target_: schnetpack.task.ModelOutput
name: ${globals.forces_key}
loss_fn:
_target_: torch.nn.MSELoss
metrics:
mae:
_target_: torchmetrics.regression.MeanAbsoluteError
mse:
rmse:
_target_: torchmetrics.regression.MeanSquaredError
loss_weight: 0.995
squared: False
loss_weight: 0.99
6 changes: 4 additions & 2 deletions src/schnetpack/configs/experiment/qm9_energy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ defaults:
- override /model: nnp
- override /data: qm9

run.path: runs/qm9_${globals.property}
run:
experiment: qm9_${globals.property}

globals:
cutoff: 5.
Expand Down Expand Up @@ -44,6 +45,7 @@ task:
metrics:
mae:
_target_: torchmetrics.regression.MeanAbsoluteError
mse:
rmse:
_target_: torchmetrics.regression.MeanSquaredError
squared: False
loss_weight: 1.
4 changes: 2 additions & 2 deletions src/schnetpack/configs/logger/aim.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TensorBoard
# Aim Logger

aim:
_target_: aim.pytorch_lightning.AimLogger
repo: ${hydra:runtime.cwd}/${run.path}
experiment: ${run.id}
experiment: ${run.experiment}
13 changes: 13 additions & 0 deletions src/schnetpack/configs/model/representation/so3net.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_target_: schnetpack.representation.SO3net
_recursive_: False
n_atom_basis: 64
n_interactions: 3
lmax: 2
shared_interactions: False
radial_basis:
_target_: schnetpack.nn.radial.GaussianRBF
n_rbf: 20
cutoff: ${globals.cutoff}
cutoff_fn:
_target_: schnetpack.nn.cutoff.CosineCutoff
cutoff: ${globals.cutoff}
1 change: 1 addition & 0 deletions src/schnetpack/configs/run/default_run.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
work_dir: ${hydra:runtime.cwd}
data_dir: ${run.work_dir}/data
path: runs
experiment: default
id: ${uuid:1}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ scheduler_monitor: val_loss
scheduler_args:
mode: min
factor: 0.8
patience: 80
patience: 50
threshold: 1e-4
threshold_mode: rel
cooldown: 10
Expand Down
1 change: 1 addition & 0 deletions src/schnetpack/configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ defaults:
- checkpoint
- earlystopping
- lrmonitor
- ema
- task: default_task
- model: null
- data: custom
Expand Down
2 changes: 1 addition & 1 deletion src/schnetpack/data/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def split(self, dataset, *split_sizes) -> List[torch.tensor]:
class SubsamplePartitions(SplittingStrategy):
"""
Strategy that splits the atoms dataset into predefined partitions as defined in the
metadata. If the split size is smaller then the predefined partition, a given
metadata. If the split size is smaller than the predefined partition, a given
strategy will be used to subsample the partition (default: random).
An metadata in the atoms dataset might look like this:
Expand Down
1 change: 1 addition & 0 deletions src/schnetpack/representation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .schnet import *
from .painn import *
from .field_schnet import *
from .so3net import *
110 changes: 110 additions & 0 deletions src/schnetpack/representation/so3net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Callable, Dict, Optional

import hydra
import torch
import torch.nn as nn

import schnetpack.nn as snn
import schnetpack.nn.so3 as so3
import schnetpack.properties as properties

__all__ = ["SO3net"]


class SO3net(nn.Module):
"""
A simple SO3-equivariant representation using spherical harmonics and
Clebsch-Gordon tensor products.
"""

def __init__(
self,
n_atom_basis: int,
n_interactions: int,
lmax: int,
radial_basis: nn.Module,
cutoff_fn: Optional[Callable] = None,
shared_interactions: bool = False,
max_z: int = 100,
):
"""
Args:
n_atom_basis: number of features to describe atomic environments.
This determines the size of each embedding vector; i.e. embeddings_dim.
n_interactions: number of interaction blocks.
lmax: maximum angular momentum of spherical harmonics basis
radial_basis: layer for expanding interatomic distances in a basis set
cutoff_fn: cutoff function
shared_interactions:
max_z:
conv_layer:
"""
super(SO3net, self).__init__()

self.n_atom_basis = n_atom_basis
self.n_interactions = n_interactions
self.lmax = lmax
self.cutoff_fn = hydra.utils.instantiate(cutoff_fn)
self.cutoff = cutoff_fn.cutoff
self.radial_basis = hydra.utils.instantiate(radial_basis)

self.embedding = nn.Embedding(max_z, n_atom_basis, padding_idx=0)
self.sphharm = so3.RealSphericalHarmonics(lmax=lmax)

self.so3convs = snn.replicate_module(
lambda: so3.SO3Convolution(lmax, n_atom_basis, self.radial_basis.n_rbf),
self.n_interactions,
shared_interactions,
)
self.mixings = snn.replicate_module(
lambda: nn.Linear(n_atom_basis, n_atom_basis, bias=False),
self.n_interactions,
shared_interactions,
)
self.gatings = snn.replicate_module(
lambda: so3.SO3ParametricGatedNonlinearity(n_atom_basis, lmax),
self.n_interactions,
shared_interactions,
)
self.so3product = so3.SO3TensorProduct(lmax)

def forward(self, inputs: Dict[str, torch.Tensor]):
"""
Compute atomic representations/embeddings.
Args:
inputs (dict of torch.Tensor): SchNetPack dictionary of input tensors.
Returns:
torch.Tensor: atom-wise representation.
list of torch.Tensor: intermediate atom-wise representations, if
return_intermediate=True was used.
"""
# get tensors from input dictionary
atomic_numbers = inputs[properties.Z]
r_ij = inputs[properties.Rij]
idx_i = inputs[properties.idx_i]
idx_j = inputs[properties.idx_j]

# compute atom and pair features
d_ij = torch.norm(r_ij, dim=1, keepdim=True)
dir_ij = r_ij / d_ij

Yij = self.sphharm(dir_ij)
radial_ij = self.radial_basis(d_ij)
cutoff_ij = self.cutoff_fn(d_ij)[..., None]

x0 = self.embedding(atomic_numbers)[:, None]
x = so3.scalar2rsh(x0, self.lmax)

for i in range(self.n_interactions):
dx = self.so3convs[i](x, radial_ij, Yij, cutoff_ij, idx_i, idx_j)
ddx = self.mixings[i](dx)
dx = self.so3product(dx, ddx)
dx = self.gatings[i](dx)
x = x + dx

inputs["scalar_representation"] = x[:, 0]
inputs["multipole_representation"] = x
return inputs
43 changes: 42 additions & 1 deletion src/schnetpack/train/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from copy import copy
from typing import Dict

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import ModelCheckpoint as BaseModelCheckpoint

from torch_ema import ExponentialMovingAverage as EMA

import torch
import os
from pytorch_lightning.callbacks import BasePredictionWriter
from typing import List, Any
from schnetpack.task import AtomisticTask

__all__ = ["ModelCheckpoint", "PredictionWriter"]
__all__ = ["ModelCheckpoint", "PredictionWriter", "ExponentialMovingAverage"]


class PredictionWriter(BasePredictionWriter):
Expand Down Expand Up @@ -81,3 +84,41 @@ def _update_best_and_save(
if self.trainer.strategy.local_rank == 0:
# remove references to trainer and data loaders to avoid pickle error in ddp
self.task.save_model(self.model_path, do_postprocessing=True)


class ExponentialMovingAverage(Callback):
def __init__(self, decay, *args, **kwargs):
self.decay = decay
self.ema = None
self._to_load = None

def on_fit_start(self, trainer, pl_module: AtomisticTask):
if self.ema is None:
self.ema = EMA(pl_module.model.parameters(), decay=self.decay)
if self._to_load is not None:
self.ema.load_state_dict(self._to_load)
self._to_load = None

def on_train_batch_end(self, trainer, pl_module: AtomisticTask, *args, **kwargs):
self.ema.update()

def on_validation_start(
self, trainer: "pl.Trainer", pl_module: AtomisticTask, *args, **kwargs
):
self.ema.store()
self.ema.copy_to()

def on_validation_end(
self, trainer: "pl.Trainer", pl_module: AtomisticTask, *args, **kwargs
):
self.ema.restore()

def load_state_dict(self, state_dict):
if "ema" in state_dict:
if self.ema is None:
self._to_load = state_dict["ema"]
else:
self.ema.load_state_dict(state_dict["ema"])

def state_dict(self):
return {"ema": self.ema.state_dict()}

0 comments on commit dac7ee9

Please sign in to comment.