Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Postprocessors and TorchScript #285

Merged
merged 3 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/schnetpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@
from schnetpack import nn
from schnetpack import units
from schnetpack import atomistic
from schnetpack import transforms
21 changes: 8 additions & 13 deletions src/schnetpack/atomistic/atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@ def __init__(
n_in: int,
n_out: int = 1,
aggregation_mode: str = "sum",
return_contributions: bool = False,
custom_outnet: Callable = None,
outnet_inputs: Union[str, Sequence[str]] = "scalar_representation",
outnet_input: Union[str, Sequence[str]] = "scalar_representation",
):
"""
Args:
n_in: input dimension of representation
n_out: output dimension of target property (default: 1)
aggregation_mode: one of {sum, avg} (default: sum)
return_contributions: If true, returns also atomwise contributions.
custom_outnet: Network used for atomistic outputs. Takes schnetpack input
dictionary as input. Output is not normalized. If set to None,
a pyramidal network is generated automatically.
Expand All @@ -39,24 +37,21 @@ def __init__(
self.outnet = custom_outnet or spk.nn.MLP(
n_in=n_in, n_out=n_out, n_layers=2, activation=spk.nn.shifted_softplus
)
self.outnet_inputs = (
[outnet_inputs] if isinstance(outnet_inputs, str) else outnet_inputs
)
self.outnet_input = outnet_input

self.aggregation_mode = aggregation_mode
self.return_contributions = return_contributions

def forward(self, inputs: Dict[str, torch.Tensor]):
# predict atomwise contributions
outins = [inputs[k] for k in self.outnet_inputs]
yi = self.outnet(*outins)
yi = self.outnet(inputs[self.outnet_input])

if self.aggregation_mode == "avg":
yi = yi / inputs[structure.n_atoms][:, None]

# aggregate
idx_m = inputs[structure.idx_m]
tmp = torch.zeros((idx_m[-1] + 1, self.n_out), dtype=yi.dtype, device=yi.device)
maxm = int(idx_m[-1]) + 1
tmp = torch.zeros((maxm, self.n_out), dtype=yi.dtype, device=yi.device)
y = tmp.index_add(0, idx_m, yi)
y = torch.squeeze(y, -1)

if self.return_contributions:
return y, yi
return y
17 changes: 12 additions & 5 deletions src/schnetpack/configs/experiment/md17.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,26 @@ n_rbf: 20
lr: 1e-4

name: md17_${data.molecule}
seed: 1234

data:
distance_unit: Ang
property_units:
energy: kcal/mol
forces: kcal/mol/Ang
transforms:
- _target_: schnetpack.transforms.SubtractCenterOfMass
- _target_: schnetpack.transforms.RemoveOffsets
- _target_: schnetpack.transform.SubtractCenterOfMass
- _target_: schnetpack.transform.RemoveOffsets
property: energy
remove_mean: True
- _target_: schnetpack.transforms.TorchNeighborList
- _target_: schnetpack.transform.TorchNeighborList
cutoff: ${cutoff}
- _target_: schnetpack.transforms.CastTo32
- _target_: schnetpack.transform.CastTo32


model:
postprocess:
- _target_: schnetpack.transform.CastTo64
- _target_: schnetpack.transform.AddOffsets
property: energy
add_mean: True

8 changes: 4 additions & 4 deletions src/schnetpack/configs/experiment/qm9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ data:
property_units:
energy_U0: eV
transforms:
- _target_: schnetpack.transforms.SubtractCenterOfMass
- _target_: schnetpack.transforms.RemoveOffsets
- _target_: schnetpack.transform.SubtractCenterOfMass
- _target_: schnetpack.transform.RemoveOffsets
property: ${property}
remove_atomrefs: True
remove_mean: True
- _target_: schnetpack.transforms.TorchNeighborList
- _target_: schnetpack.transform.TorchNeighborList
cutoff: ${cutoff}
- _target_: schnetpack.transforms.CastTo32
- _target_: schnetpack.transform.CastTo32

3 changes: 2 additions & 1 deletion src/schnetpack/configs/model/pes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ schedule:
factor: 0.5
patience: 50
min_lr: 1e-6
smoothing_factor: 0.9
smoothing_factor: 0.8

output:
module:
Expand Down Expand Up @@ -52,3 +52,4 @@ output:
mse:
_target_: pytorch_lightning.metrics.MeanSquaredError

postprocess: null
5 changes: 1 addition & 4 deletions src/schnetpack/data/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import schnetpack as spk
import schnetpack.structure as structure
from schnetpack.transforms import Transform
from schnetpack.transform import Transform

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,9 +86,6 @@ def transforms(self, value: Optional[List[Transform]]):

if value is not None:
for tf in value:
if tf.data is not None:
tf = copy.copy(tf)
tf.data = self
self._transforms.append(tf)
self._transform_module = torch.nn.Sequential(*self._transforms)

Expand Down
19 changes: 8 additions & 11 deletions src/schnetpack/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def __init__(
val_transforms=val_transforms or copy(transforms) or [],
test_transforms=test_transforms or copy(transforms) or [],
)
self._check_transforms(self.train_transforms)
self._check_transforms(self.val_transforms)
self._check_transforms(self.test_transforms)
self._init_transforms(self.train_transforms)
self._init_transforms(self.val_transforms)
self._init_transforms(self.test_transforms)

self.batch_size = batch_size
self.val_batch_size = val_batch_size or test_batch_size or batch_size
Expand All @@ -96,12 +96,9 @@ def __init__(
self.distance_unit = distance_unit
self._stats = {}

def _check_transforms(self, transforms):
def _init_transforms(self, transforms):
for t in transforms:
if not t.is_preprocessor:
raise AtomsDataModuleError(
f"Transform of type {t} is not a preprocessor (is_preprocessor=False)!"
)
t.preprocessor()

def setup(self, stage: Optional[str] = None):
self.load_data()
Expand Down Expand Up @@ -162,11 +159,11 @@ def partition(self):
def setup_transforms(self):
# setup transforms
for t in self.train_transforms:
t.datamodule = self
t.datamodule(self)
for t in self.val_transforms:
t.datamodule = self
t.datamodule(self)
for t in self.test_transforms:
t.datamodule = self
t.datamodule(self)
self._train_dataset.transforms = self.train_transforms
self._val_dataset.transforms = self.val_transforms
self._test_dataset.transforms = self.test_transforms
Expand Down
46 changes: 41 additions & 5 deletions src/schnetpack/model/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, Optional, Union

import hydra.utils
import torch
from omegaconf import DictConfig
from pytorch_lightning import LightningModule

Expand All @@ -22,20 +26,52 @@ def __init__(
output: DictConfig,
schedule: DictConfig,
optimizer: DictConfig,
postprocess: Optional[DictConfig] = None,
):
super().__init__()
self.save_hyperparameters("representation", "output", "optimizer", "schedule")
self.datamodule = datamodule
self.save_hyperparameters(
"representation", "output", "optimizer", "schedule", "postprocess"
)
self._representation_cfg = representation
self._output_cfg = output
self._schedule_cfg = schedule
self._optimizer_cfg = optimizer
self._postproc_cfg = postprocess or []
self.inference_mode = False

self.build_model()
self.build_postprocess(datamodule)

@abstractmethod
def build_model(
self,
):
def build_model(self):
"""Parser dict configs and instantiate the model"""
pass

def build_postprocess(self, datamodule: spk.data.AtomsDataModule):
self.postprocessors = torch.nn.ModuleList()
for pp in self._postproc_cfg:
pp = hydra.utils.instantiate(pp)
pp.postprocessor()
pp.datamodule(datamodule)
self.postprocessors.append(pp)

def postprocess(
self, inputs: Dict[str, torch.Tensor], results: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
if self.inference_mode:
for pp in self.postprocessors:
results = pp(results, inputs)
return results

def to_torchscript(
self,
file_path: Optional[Union[str, Path]] = None,
method: Optional[str] = "script",
example_inputs: Optional[Any] = None,
**kwargs,
) -> Union[torch.ScriptModule, Dict[str, torch.ScriptModule]]:
imode = self.inference_mode
self.inference_mode = True
script = super().to_torchscript(file_path, method, example_inputs, **kwargs)
self.inference_mode = imode
return script
30 changes: 20 additions & 10 deletions src/schnetpack/model/pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import hydra
import torch
from torch.autograd import grad
from typing import Dict, Optional, List

from schnetpack import structure
from schnetpack.model.base import AtomisticModel

import schnetpack as spk

log = logging.getLogger(__name__)


Expand All @@ -15,9 +18,7 @@ class PESModel(AtomisticModel):
AtomisticModel for potential energy surfaces
"""

def build_model(
self,
):
def build_model(self):
self.representation = hydra.utils.instantiate(self._representation_cfg)

self.props = {}
Expand Down Expand Up @@ -70,20 +71,26 @@ def _collect_metrics(self):
self.stress_metrics = torch.nn.ModuleDict(stress_metrics)
self.metrics["stress"] = stress_metrics

def forward(self, inputs):
def forward(
self,
inputs: Dict[str, torch.Tensor],
):
R = inputs[structure.R]
inputs[structure.Rij].requires_grad_()
inputs.update(self.representation(inputs))
Epred = self.output(inputs)
result = {"energy": Epred}
results = {"energy": Epred}

if self.predict_forces:
go: List[Optional[torch.Tensor]] = [torch.ones_like(Epred)]
dEdRij = grad(
Epred,
inputs[structure.Rij],
grad_outputs=torch.ones_like(Epred),
[Epred],
[inputs[structure.Rij]],
grad_outputs=go,
create_graph=self.training,
)[0]
if dEdRij is None:
dEdRij = torch.zeros_like(inputs[structure.Rij])

Fpred_i = torch.zeros_like(R)
Fpred_i = Fpred_i.index_add(
Expand All @@ -99,9 +106,11 @@ def forward(self, inputs):
dEdRij,
)
Fpred = Fpred_i - Fpred_j
result["forces"] = Fpred
results["forces"] = Fpred

results = self.postprocess(inputs, results)

return result
return results

def loss_fn(self, pred, batch):
loss = 0.0
Expand Down Expand Up @@ -143,6 +152,7 @@ def validation_step(self, batch, batch_idx):
on_epoch=True,
prog_bar=False,
)

return {"val_loss": loss}

def test_step(self, batch, batch_idx):
Expand Down
13 changes: 13 additions & 0 deletions src/schnetpack/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Transforms are applied before and/or after the model. They can be used, e.g., for calculating
neighbor lists, casting, unit conversion or data augmentation. Some can applied before batching,
i.e. to single systems, when loading the data. This is necessary for pre-processing and includes
neighbor lists, for example. On the other hand, transforms need to be able to handle batches
for post-processing. The flags `is_postprocessor` and `is_preprocessor` indicate how the tranforms
may be used. The attribute `mode` of a transform is set automatically to either "pre" or "post".q
"""

from .atomistic import *
from .casting import *
from .neighborlist import *
from .transform import *
Loading