Skip to content

Commit

Permalink
Make transforms independently usable from datamodule (#464)
Browse files Browse the repository at this point in the history
* Make transforms independently usable from datamodule

* Only initialize transform w/ datamodul if not manually set
  • Loading branch information
ktschuett authored Dec 2, 2022
1 parent 44847a4 commit 1ca8678
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 35 deletions.
2 changes: 1 addition & 1 deletion docs/userguide/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extensive information.
We will explain the structure of the config at the example of training PaiNN on QM9
with the command::

$ spktrain experiment=qm9_energy
$ spktrain experiment=qm9_atomwise

Before going through the config step-by-step, we show the full config as printed by
the command::
Expand Down
16 changes: 8 additions & 8 deletions src/schnetpack/atomistic/atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,20 @@ def __init__(
Args:
n_in: input dimension of representation
n_hidden: size of hidden layers.
If an integer, same number of node is used for all hidden layers resulting
in a rectangular network.
If None, the number of neurons is divided by two after each layer starting
n_in resulting in a pyramidal network.
If an integer, same number of node is used for all hidden layers
resulting in a rectangular network.
If None, the number of neurons is divided by two after each layer
starting n_in resulting in a pyramidal network.
n_layers: number of layers.
activation: activation function
predict_magnitude: If true, calculate magnitude of dipole
return_charges: If true, return latent partial charges
dipole_key: the key under which the dipoles will be stored
charges_key: the key under which partial charges will be stored
correct_charges: If true, forces the sum of partial charges to be the total charge, if provided,
and zero otherwise.
use_vector_representation: If true, use vector representation to predict local,
atomic dipoles.
correct_charges: If true, forces the sum of partial charges to be the total
charge, if provided, and zero otherwise.
use_vector_representation: If true, use vector representation to predict
local, atomic dipoles.
"""
super().__init__()

Expand Down
118 changes: 94 additions & 24 deletions src/schnetpack/transform/atomistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def forward(

class RemoveOffsets(Transform):
"""
Remove offsets from property based on the mean of the training data and/or the single atom reference calculations.
Remove offsets from property based on the mean of the training data and/or the
single atom reference calculations.
The `mean` and/or `atomref` are automatically obtained from the AtomsDataModule,
when it is used. Otherwise, they have to be provided in the init manually.
"""

is_preprocessor: bool = True
Expand All @@ -69,7 +73,20 @@ def __init__(
remove_atomrefs: bool = False,
is_extensive: bool = True,
zmax: int = 100,
atomrefs: torch.Tensor = None,
propery_mean: torch.Tensor = None,
):
"""
Args:
property: The property to add the offsets to.
remove_mean: If true, remove mean of the dataset from property.
remove_atomrefs: If true, remove single-atom references.
is_extensive: Set true if the property is extensive.
zmax: Set the maximum atomic number, to determine the size of the atomref
tensor.
atomrefs: Provide single-atom references directly.
propery_mean: Provide mean property value / n_atoms.
"""
super().__init__()
self._property = property
self.remove_mean = remove_mean
Expand All @@ -80,18 +97,32 @@ def __init__(
remove_atomrefs or remove_mean
), "You should set at least one of `remove_mean` and `remove_atomrefs` to true!"

if atomrefs is not None:
self._atomrefs_initialized = True
else:
self._atomrefs_initialized = False

if propery_mean is not None:
self._mean_initialized = True
else:
self._mean_initialized = False

if self.remove_atomrefs:
self.register_buffer("atomref", torch.zeros((zmax,)))
atomrefs = atomrefs or torch.zeros((zmax,))
self.register_buffer("atomref", atomrefs)
if self.remove_mean:
self.register_buffer("mean", torch.zeros((1,)))
propery_mean = propery_mean or torch.zeros((1,))
self.register_buffer("mean", propery_mean)

def datamodule(self, _datamodule):

if self.remove_atomrefs:
"""
Sets mean and atomref automatically when using PyTorchLightning integration.
"""
if self.remove_atomrefs and not self._atomrefs_initialized:
atrefs = _datamodule.train_dataset.atomrefs
self.atomref = atrefs[self._property].detach()

if self.remove_mean:
if self.remove_mean and not self._mean_initialized:
stats = _datamodule.get_stats(
self._property, self.is_extensive, self.remove_atomrefs
)
Expand All @@ -112,24 +143,23 @@ def forward(

class ScaleProperty(Transform):
"""
Scale the energy outputs of the network without influencing the gradient.
This is equivalent to scaling the labels for training and rescaling afterwards.
Scale an entry of the input or results dioctionary.
The `scale` can be automatically obtained from the AtomsDataModule,
when it is used. Otherwise, it has to be provided in the init manually.
Hint:
If you want to add a bias to the prediction, use the ``AddOffsets``
postprocessor and place it after casting to float64 for higher numerical
precision.
"""

is_preprocessor: bool = False
is_postprocessor: bool = False
is_preprocessor: bool = True
is_postprocessor: bool = True

def __init__(
self,
input_key: str,
target_key: str = None,
output_key: str = None,
scale_by_mean: bool = False,
scale: torch.Tensor = None,
):
"""
Args:
Expand All @@ -139,6 +169,7 @@ def __init__(
output_key: dict key for scaled output
scale_by_mean: if true, use the mean of the target variable for scaling,
otherwise use its standard deviation
scale: provide the scale of the property manually.
"""
super().__init__()
self.input_key = input_key
Expand All @@ -147,13 +178,19 @@ def __init__(
self._scale_by_mean = scale_by_mean
self.model_outputs = [self.output_key]

self.register_buffer("scale", torch.ones((1,)))
if scale is not None:
self._initialized = True
else:
self._initialized = False

def datamodule(self, _datamodule):
scale = scale or torch.ones((1,))
self.register_buffer("scale", scale)

stats = _datamodule.get_stats(self._target_key, True, False)
scale = stats[0] if self._scale_by_mean else stats[1]
self.scale = torch.abs(scale).detach()
def datamodule(self, _datamodule):
if not self._initialized:
stats = _datamodule.get_stats(self._target_key, True, False)
scale = stats[0] if self._scale_by_mean else stats[1]
self.scale = torch.abs(scale).detach()

def forward(
self,
Expand All @@ -168,6 +205,9 @@ class AddOffsets(Transform):
Add offsets to property based on the mean of the training data and/or the single
atom reference calculations.
The `mean` and/or `atomref` are automatically obtained from the AtomsDataModule,
when it is used. Otherwise, they have to be provided in the init manually.
Hint:
Place this postprocessor after casting to float64 for higher numerical
precision.
Expand All @@ -184,7 +224,20 @@ def __init__(
add_atomrefs: bool = False,
is_extensive: bool = True,
zmax: int = 100,
atomrefs: torch.Tensor = None,
propery_mean: torch.Tensor = None,
):
"""
Args:
property: The property to add the offsets to.
add_mean: If true, add mean of the dataset.
add_atomrefs: If true, add single-atom references.
is_extensive: Set true if the property is extensive.
zmax: Set the maximum atomic number, to determine the size of the atomref
tensor.
atomrefs: Provide single-atom references directly.
propery_mean: Provide mean property value / n_atoms.
"""
super().__init__()
self._property = property
self.add_mean = add_mean
Expand All @@ -196,15 +249,27 @@ def __init__(
add_mean or add_atomrefs
), "You should set at least one of `add_mean` and `add_atomrefs` to true!"

self.register_buffer("atomref", torch.zeros((zmax,)))
self.register_buffer("mean", torch.zeros((1,)))
if atomrefs is not None:
self._atomrefs_initialized = True
else:
self._atomrefs_initialized = False

if propery_mean is not None:
self._mean_initialized = True
else:
self._mean_initialized = False

atomrefs = atomrefs or torch.zeros((zmax,))
propery_mean = propery_mean or torch.zeros((1,))
self.register_buffer("atomref", atomrefs)
self.register_buffer("mean", propery_mean)

def datamodule(self, value):
if self.add_atomrefs:
if self.add_atomrefs and not self._atomrefs_initialized:
atrefs = value.train_dataset.atomrefs
self.atomref = atrefs[self._property].detach()

if self.add_mean:
if self.add_mean and not self._mean_initialized:
stats = value.get_stats(
self._property, self.is_extensive, self.add_atomrefs
)
Expand All @@ -215,7 +280,12 @@ def forward(
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
if self.add_mean:
inputs[self._property] += self.mean * inputs[structure.n_atoms]
mean = (
self.mean * inputs[structure.n_atoms]
if self.is_extensive
else self.mean
)
inputs[self._property] += mean

if self.add_atomrefs:
idx_m = inputs[structure.idx_m]
Expand Down
7 changes: 5 additions & 2 deletions src/schnetpack/transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class TransformException(Exception):
class Transform(nn.Module):
"""
Base class for all transforms.
The base class ensures that the reference to the data and datamodule attributes are initialized.
The base class ensures that the reference to the data and datamodule attributes are
initialized.
Transforms can be used as pre- or post-processing layers.
They can also be used for other parts of a model, that need to be
initialized based on data.
Expand All @@ -31,7 +32,9 @@ class Transform(nn.Module):

def datamodule(self, value):
"""
Extract all required information from data module.
Extract all required information from data module automatically when using
PyTorch Lightning integration. The transform should also implement a way to
set these things manually, to make it usable independent of PL.
Do not store the datamodule, as this does not work with torchscript conversion!
"""
Expand Down

0 comments on commit 1ca8678

Please sign in to comment.