diff --git a/src/schnetpack/configs/callbacks/earlystopping.yaml b/src/schnetpack/configs/callbacks/earlystopping.yaml index fcb6a1064..05a8900c7 100644 --- a/src/schnetpack/configs/callbacks/earlystopping.yaml +++ b/src/schnetpack/configs/callbacks/earlystopping.yaml @@ -1,6 +1,6 @@ early_stopping: _target_: pytorch_lightning.callbacks.EarlyStopping monitor: "val_loss" # name of the logged metric which determines when model is improving - patience: 1000 # how many epochs of not improving until training stops + patience: 100 # how many epochs of not improving until training stops mode: "min" # can be "max" or "min" - min_delta: 0.0 # minimum change in the monitored metric needed to qualify as an improvement \ No newline at end of file + min_delta: 1e-5 # minimum change in the monitored metric needed to qualify as an improvement \ No newline at end of file diff --git a/src/schnetpack/representation/so3net.py b/src/schnetpack/representation/so3net.py index 35302146b..7a10ba0f7 100644 --- a/src/schnetpack/representation/so3net.py +++ b/src/schnetpack/representation/so3net.py @@ -57,7 +57,17 @@ def __init__( self.n_interactions, shared_interactions, ) - self.mixings = snn.replicate_module( + self.mixings1 = snn.replicate_module( + lambda: nn.Linear(n_atom_basis, n_atom_basis, bias=False), + self.n_interactions, + shared_interactions, + ) + self.mixings2 = snn.replicate_module( + lambda: nn.Linear(n_atom_basis, n_atom_basis, bias=False), + self.n_interactions, + shared_interactions, + ) + self.mixings3 = snn.replicate_module( lambda: nn.Linear(n_atom_basis, n_atom_basis, bias=False), self.n_interactions, shared_interactions, @@ -100,9 +110,11 @@ def forward(self, inputs: Dict[str, torch.Tensor]): for i in range(self.n_interactions): dx = self.so3convs[i](x, radial_ij, Yij, cutoff_ij, idx_i, idx_j) - ddx = self.mixings[i](dx) + ddx = self.mixings1[i](dx) dx = self.so3product(dx, ddx) + dx = self.mixings2[i](dx) dx = self.gatings[i](dx) + dx = self.mixings3[i](dx) x = x + dx inputs["scalar_representation"] = x[:, 0]