diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 108a1915..e40fafb8 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -48,6 +48,16 @@ def forward(self, data): return data +def extract_representation_model(state_dict): + representation_model = {} + prefix = "model.representation_model." + for key, value in state_dict.items(): + if key.startswith(prefix): + new_key = key[len(prefix) :] + representation_model[new_key] = value + return representation_model + + class LNNP(LightningModule): """ Lightning wrapper for the Neural Network Potentials in TorchMD-Net. @@ -65,14 +75,31 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): hparams["charge"] = False if "spin" not in hparams: hparams["spin"] = False + if "overwrite_representation" not in hparams: + hparams["overwrite_representation"] = None + if "freeze_representation" not in hparams: + hparams["freeze_representation"] = False + if "reset_output_model" not in hparams: + hparams["reset_output_model"] = False self.save_hyperparameters(hparams) if self.hparams.load_model: self.model = load_model(self.hparams.load_model, args=self.hparams) + if self.hparams.reset_output_model: + self.model.output_model.reset_parameters() else: self.model = create_model(self.hparams, prior_model, mean, std) + if self.hparams.overwrite_representation is not None: + ckpt = torch.load(self.hparams.overwrite_representation, map_location="cpu") + state_dict = extract_representation_model(ckpt["state_dict"]) + self.model.representation_model.load_state_dict(state_dict) + + if self.hparams.freeze_representation: + for p in self.model.representation_model.parameters(): + p.requires_grad = False + # initialize exponential smoothing self.ema = None self._reset_ema_dict() diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 2e69212b..70d8d3ba 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -121,7 +121,9 @@ def get_argparse(): parser.add_argument('--wandb-project', default='training_', type=str, help='Define what wandb Project to log to') parser.add_argument('--wandb-resume-from-id', default=None, type=str, help='Resume a wandb run from a given run id. The id can be retrieved from the wandb dashboard') parser.add_argument('--tensorboard-use', default=False, type=bool, help='Defines if tensor board is used or not') - + parser.add_argument('--freeze-representation', help='Freeze the representation model parameters during training.', action='store_true') + parser.add_argument('--reset-output-model', help='Reset the parameters (randomize) of the output models before stating a training. This option is only used if the training is not starting from scratch, otherwise the parameters are always randomized.', action='store_true') + parser.add_argument('--overwrite-representation', type=str, help='After loading/creating the model, overwrite the weights of the representation model using the ones stored in the checkpoint provided in this argument.') # fmt: on return parser