diff --git a/docs/userguide/configs.rst b/docs/userguide/configs.rst index c5b75846f..291700849 100644 --- a/docs/userguide/configs.rst +++ b/docs/userguide/configs.rst @@ -12,7 +12,7 @@ extensive information. We will explain the structure of the config at the example of training PaiNN on QM9 with the command:: - $ spktrain experiment=qm9_energy + $ spktrain experiment=qm9_atomwise Before going through the config step-by-step, we show the full config as printed by the command:: diff --git a/src/schnetpack/atomistic/atomwise.py b/src/schnetpack/atomistic/atomwise.py index c1ce4f6a3..502e4725b 100644 --- a/src/schnetpack/atomistic/atomwise.py +++ b/src/schnetpack/atomistic/atomwise.py @@ -117,20 +117,20 @@ def __init__( Args: n_in: input dimension of representation n_hidden: size of hidden layers. - If an integer, same number of node is used for all hidden layers resulting - in a rectangular network. - If None, the number of neurons is divided by two after each layer starting - n_in resulting in a pyramidal network. + If an integer, same number of node is used for all hidden layers + resulting in a rectangular network. + If None, the number of neurons is divided by two after each layer + starting n_in resulting in a pyramidal network. n_layers: number of layers. activation: activation function predict_magnitude: If true, calculate magnitude of dipole return_charges: If true, return latent partial charges dipole_key: the key under which the dipoles will be stored charges_key: the key under which partial charges will be stored - correct_charges: If true, forces the sum of partial charges to be the total charge, if provided, - and zero otherwise. - use_vector_representation: If true, use vector representation to predict local, - atomic dipoles. + correct_charges: If true, forces the sum of partial charges to be the total + charge, if provided, and zero otherwise. + use_vector_representation: If true, use vector representation to predict + local, atomic dipoles. """ super().__init__() diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index 62d183f2f..858ae0f63 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -56,7 +56,11 @@ def forward( class RemoveOffsets(Transform): """ - Remove offsets from property based on the mean of the training data and/or the single atom reference calculations. + Remove offsets from property based on the mean of the training data and/or the + single atom reference calculations. + + The `mean` and/or `atomref` are automatically obtained from the AtomsDataModule, + when it is used. Otherwise, they have to be provided in the init manually. """ is_preprocessor: bool = True @@ -69,7 +73,20 @@ def __init__( remove_atomrefs: bool = False, is_extensive: bool = True, zmax: int = 100, + atomrefs: torch.Tensor = None, + propery_mean: torch.Tensor = None, ): + """ + Args: + property: The property to add the offsets to. + remove_mean: If true, remove mean of the dataset from property. + remove_atomrefs: If true, remove single-atom references. + is_extensive: Set true if the property is extensive. + zmax: Set the maximum atomic number, to determine the size of the atomref + tensor. + atomrefs: Provide single-atom references directly. + propery_mean: Provide mean property value / n_atoms. + """ super().__init__() self._property = property self.remove_mean = remove_mean @@ -80,18 +97,32 @@ def __init__( remove_atomrefs or remove_mean ), "You should set at least one of `remove_mean` and `remove_atomrefs` to true!" + if atomrefs is not None: + self._atomrefs_initialized = True + else: + self._atomrefs_initialized = False + + if propery_mean is not None: + self._mean_initialized = True + else: + self._mean_initialized = False + if self.remove_atomrefs: - self.register_buffer("atomref", torch.zeros((zmax,))) + atomrefs = atomrefs or torch.zeros((zmax,)) + self.register_buffer("atomref", atomrefs) if self.remove_mean: - self.register_buffer("mean", torch.zeros((1,))) + propery_mean = propery_mean or torch.zeros((1,)) + self.register_buffer("mean", propery_mean) def datamodule(self, _datamodule): - - if self.remove_atomrefs: + """ + Sets mean and atomref automatically when using PyTorchLightning integration. + """ + if self.remove_atomrefs and not self._atomrefs_initialized: atrefs = _datamodule.train_dataset.atomrefs self.atomref = atrefs[self._property].detach() - if self.remove_mean: + if self.remove_mean and not self._mean_initialized: stats = _datamodule.get_stats( self._property, self.is_extensive, self.remove_atomrefs ) @@ -112,17 +143,15 @@ def forward( class ScaleProperty(Transform): """ - Scale the energy outputs of the network without influencing the gradient. - This is equivalent to scaling the labels for training and rescaling afterwards. + Scale an entry of the input or results dioctionary. + + The `scale` can be automatically obtained from the AtomsDataModule, + when it is used. Otherwise, it has to be provided in the init manually. - Hint: - If you want to add a bias to the prediction, use the ``AddOffsets`` - postprocessor and place it after casting to float64 for higher numerical - precision. """ - is_preprocessor: bool = False - is_postprocessor: bool = False + is_preprocessor: bool = True + is_postprocessor: bool = True def __init__( self, @@ -130,6 +159,7 @@ def __init__( target_key: str = None, output_key: str = None, scale_by_mean: bool = False, + scale: torch.Tensor = None, ): """ Args: @@ -139,6 +169,7 @@ def __init__( output_key: dict key for scaled output scale_by_mean: if true, use the mean of the target variable for scaling, otherwise use its standard deviation + scale: provide the scale of the property manually. """ super().__init__() self.input_key = input_key @@ -147,13 +178,19 @@ def __init__( self._scale_by_mean = scale_by_mean self.model_outputs = [self.output_key] - self.register_buffer("scale", torch.ones((1,))) + if scale is not None: + self._initialized = True + else: + self._initialized = False - def datamodule(self, _datamodule): + scale = scale or torch.ones((1,)) + self.register_buffer("scale", scale) - stats = _datamodule.get_stats(self._target_key, True, False) - scale = stats[0] if self._scale_by_mean else stats[1] - self.scale = torch.abs(scale).detach() + def datamodule(self, _datamodule): + if not self._initialized: + stats = _datamodule.get_stats(self._target_key, True, False) + scale = stats[0] if self._scale_by_mean else stats[1] + self.scale = torch.abs(scale).detach() def forward( self, @@ -168,6 +205,9 @@ class AddOffsets(Transform): Add offsets to property based on the mean of the training data and/or the single atom reference calculations. + The `mean` and/or `atomref` are automatically obtained from the AtomsDataModule, + when it is used. Otherwise, they have to be provided in the init manually. + Hint: Place this postprocessor after casting to float64 for higher numerical precision. @@ -184,7 +224,20 @@ def __init__( add_atomrefs: bool = False, is_extensive: bool = True, zmax: int = 100, + atomrefs: torch.Tensor = None, + propery_mean: torch.Tensor = None, ): + """ + Args: + property: The property to add the offsets to. + add_mean: If true, add mean of the dataset. + add_atomrefs: If true, add single-atom references. + is_extensive: Set true if the property is extensive. + zmax: Set the maximum atomic number, to determine the size of the atomref + tensor. + atomrefs: Provide single-atom references directly. + propery_mean: Provide mean property value / n_atoms. + """ super().__init__() self._property = property self.add_mean = add_mean @@ -196,15 +249,27 @@ def __init__( add_mean or add_atomrefs ), "You should set at least one of `add_mean` and `add_atomrefs` to true!" - self.register_buffer("atomref", torch.zeros((zmax,))) - self.register_buffer("mean", torch.zeros((1,))) + if atomrefs is not None: + self._atomrefs_initialized = True + else: + self._atomrefs_initialized = False + + if propery_mean is not None: + self._mean_initialized = True + else: + self._mean_initialized = False + + atomrefs = atomrefs or torch.zeros((zmax,)) + propery_mean = propery_mean or torch.zeros((1,)) + self.register_buffer("atomref", atomrefs) + self.register_buffer("mean", propery_mean) def datamodule(self, value): - if self.add_atomrefs: + if self.add_atomrefs and not self._atomrefs_initialized: atrefs = value.train_dataset.atomrefs self.atomref = atrefs[self._property].detach() - if self.add_mean: + if self.add_mean and not self._mean_initialized: stats = value.get_stats( self._property, self.is_extensive, self.add_atomrefs ) @@ -215,7 +280,12 @@ def forward( inputs: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: if self.add_mean: - inputs[self._property] += self.mean * inputs[structure.n_atoms] + mean = ( + self.mean * inputs[structure.n_atoms] + if self.is_extensive + else self.mean + ) + inputs[self._property] += mean if self.add_atomrefs: idx_m = inputs[structure.idx_m] diff --git a/src/schnetpack/transform/base.py b/src/schnetpack/transform/base.py index 6ce6d6335..77535c200 100644 --- a/src/schnetpack/transform/base.py +++ b/src/schnetpack/transform/base.py @@ -18,7 +18,8 @@ class TransformException(Exception): class Transform(nn.Module): """ Base class for all transforms. - The base class ensures that the reference to the data and datamodule attributes are initialized. + The base class ensures that the reference to the data and datamodule attributes are + initialized. Transforms can be used as pre- or post-processing layers. They can also be used for other parts of a model, that need to be initialized based on data. @@ -31,7 +32,9 @@ class Transform(nn.Module): def datamodule(self, value): """ - Extract all required information from data module. + Extract all required information from data module automatically when using + PyTorch Lightning integration. The transform should also implement a way to + set these things manually, to make it usable independent of PL. Do not store the datamodule, as this does not work with torchscript conversion! """